From 9b375831b55f6e3ba62af7931095a5c8774b4e1c Mon Sep 17 00:00:00 2001 From: David Stevens <394663+drstevens@users.noreply.github.com> Date: Mon, 13 Apr 2026 14:29:38 -0700 Subject: [PATCH] DSPT-4967 provide alternate method for encoding and passing ext JSON --- README.md | 13 ++++-- java/build.gradle.kts | 8 ++++ .../BidRequestEvaluatorOnRuleBasedModel.java | 9 ++-- .../evaluation/evaluator/Response.java | 14 +++++- .../evaluation/evaluator/Slot.java | 14 +++++- .../util/ResponseUtil.java | 35 ++++++++++++++- .../evaluator/evaluator_response.proto | 21 +++++++++ ...dRequestEvaluatorOnRuleBasedModelTest.java | 45 +++++++++++++------ ...dRequestEvaluatorOnRuleBasedModelTest.java | 29 ++++++++++-- .../util/ResponseUtilTest.java | 45 +++++++++++++++---- 10 files changed, 197 insertions(+), 36 deletions(-) create mode 100644 java/src/main/proto/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/evaluator_response.proto diff --git a/README.md b/README.md index 99d77df..385fac3 100644 --- a/README.md +++ b/README.md @@ -360,7 +360,7 @@ This section describes structure of the experiment configuration json. The exper | 2.1.7.2 | idStart | These fields specifies a range of allocationIds. Requests associated with allocation Ids, that fall within this range are associated with this treatment. **It has to be between 0 and 4095 (inclusive).** *We use first 3 chars of the hexdecimal of the request id which is 16^3, as the allocation id.* | | 2.1.7.3 | idEnd | See above | | 2.1.7.4 | weight | [By Default] It is used when you use **TreatmentAllocatorOnRandom** in the **provideTreatmentAllocator**. Specifies the probability that one request is allocated to one group. | -| 2.1.7.5 | learning | An integer flag to determine if the given treatment (traffic group) is used for learning purposes. Traffic allocated to a learning value of 1 should not be subject to filtering, while traffic allocated to a learning value of 0 is subject to filtering. We still expect Sellers to add extension fields in the bid-requests based on the model filtering evaluation. *See Section 2.2.5* | +| 2.1.7.5 | learning | An integer flag to determine if the given treatment (traffic group) is used for learning purposes. Traffic allocated to a learning value of 1 should not be subject to filtering, while traffic allocated to a learning value of 0 is subject to filtering. We still expect Sellers to add extension fields in the bid-requests based on the model filtering evaluation or send encoded via header. *See Section 2.2.5* | | 3 | modelToExperiment | A map of models to experiments, to specify which experiment to use when making filtering decisions for a given model. | | 3.1 | [model-identifier] | The key in this map is the model identifier defined in the model configuration. *See Section 2.1.1.1* | @@ -647,7 +647,9 @@ requestOutput := requestEvaluator.Evaluate(&evaluation.BidRequestEvaluatorInput{ Once the Seller receives filtering recommendations from DTE evaluator library, Sellers are responsible for enforcing the decision on behalf of the Buyer to either filter or forward the bid request based on the value of `Response.slots[*].filterDecision`. If the `filterDecision` value is 0.0, DTE recommends the Seller filter the request, and if the value is 1.0, DTE recommends the Seller to forward the request. -The Seller will also need to send the following custom extension fields: +The Seller will also need to send the following custom extension fields either in JSON or via HTTP header: + +If adding ext to bid request JSON 1. `slots[*].ext.amazontest.decision`: Recommended filter decision for the slot based on Buyer's signals. 1. 0.0 = filter slot (low-value request) @@ -656,10 +658,15 @@ The Seller will also need to send the following custom extension fields: 1. 0 if request is in treatment, Seller evaluates request and filter/forward based on filter decision; 2. 1 if request is in control, Seller evaluates request, but ALWAYS forward the request regardless of filter decision. -These extensions are returned in the DTE Response object under the `Response.ext` and `Response.slots[*].ext` fields and can be appended as-is to the OpenRTB request forwarded to the Buyer. +These extensions are returned in the DTE Response object under the `Response.getExt` and `Response.slots[*].getExt` fields and can be appended as-is to the OpenRTB request forwarded to the Buyer. Note that if the request is in control (learning=1), the `Response.slots[*].filterDecision` value will always be 1.0, regardless of the model result. If the request is in treatment (learning=0), the `Response.slots[*].filterDecision` value can be either 0.0 or 1.0, based on the model result. +If sending via HTTP header, currently only available in Java implementation +1. retrieve protobuf generated message from Response, `Response.getExtProto` +2. Use this to generate a URI safe string using `ResponseUtil.encodedResponseMetadata` +3. Send this string as value for HTTP Header named `XAmazonTest: ` + ## 2.4. DTE Failure Handling When the current hour's model output is not available to be fetched, the DTE library uses the latest successfully loaded model results to evaluate requests for up to 24 hours (the default local cache TTL). After the 24 hour window, all entries in the cache will expire, and no requests will be evaluated as low-value/to be filtered. diff --git a/java/build.gradle.kts b/java/build.gradle.kts index b07015c..cb9c406 100644 --- a/java/build.gradle.kts +++ b/java/build.gradle.kts @@ -14,6 +14,7 @@ plugins { // Apply the java-library plugin for API and implementation separation. `java-library` + id("com.google.protobuf") version "0.9.4" id("io.freefair.lombok") version "8.6" id("com.github.johnrengelman.shadow") version "7.1.2" jacoco @@ -31,6 +32,7 @@ dependencies { // This dependency is exported to consumers, that is to say found on their compile classpath. api(libs.commons.math3) api("com.google.guava:guava:33.1.0-jre") + api("com.google.protobuf:protobuf-java:3.25.3") // AWS S3 Start: https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/setup-project-gradle.html api(platform("software.amazon.awssdk:bom:2.25.41")) @@ -69,6 +71,12 @@ dependencies { spotbugsSlf4j("org.slf4j:slf4j-simple:2.0.12") } +protobuf { + protoc { + artifact = "com.google.protobuf:protoc:3.25.3" + } +} + jacoco { toolVersion = "0.8.9" // Use the latest version available } diff --git a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModel.java b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModel.java index 1665115..6922c06 100644 --- a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModel.java +++ b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModel.java @@ -19,11 +19,11 @@ import java.util.List; import java.util.Map; import java.util.UUID; + import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; import static com.amazon.demanddriventrafficevaluator.util.ResponseUtil.EXTENSION_KEYWORD_DECISION; -import static com.amazon.demanddriventrafficevaluator.util.ResponseUtil.EXTENSION_KEYWORD_LEARNING; /** * This class implements the BidRequestEvaluator interface to evaluate bid requests @@ -45,9 +45,9 @@ public class BidRequestEvaluatorOnRuleBasedModel implements BidRequestEvaluator static final Response DEFAULT_RESPONSE = Response.builder() .slots(List.of(Slot.builder() .filterDecision(DEFAULT_FILTER_RECOMMENDATION) - .ext(ResponseUtil.buildExtension(Map.of(EXTENSION_KEYWORD_DECISION, DEFAULT_FILTER_RECOMMENDATION))) + .decision(DEFAULT_FILTER_RECOMMENDATION) .build())) - .ext(ResponseUtil.buildExtension(Map.of(EXTENSION_KEYWORD_LEARNING, DEFAULT_LEARNING))) + .learning(DEFAULT_LEARNING) .build(); private final String sspIdentifier; @@ -187,10 +187,9 @@ private List getModelDefinitions(EvaluationContext context) { } private Response buildResponse(EvaluationContext context) { - AggregatedModelEvaluationResult aggregatedModelEvaluationResult = context.getAggregatedModelEvaluationResult(); return Response.builder() .slots(ResponseUtil.buildSlots(context)) - .ext(ResponseUtil.buildExtension(Map.of(EXTENSION_KEYWORD_LEARNING, aggregatedModelEvaluationResult.getTreatmentCodeInInt()))) + .learning(context.getAggregatedModelEvaluationResult().getTreatmentCodeInInt()) .build(); } } diff --git a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Response.java b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Response.java index 0802ea1..2618ebc 100644 --- a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Response.java +++ b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Response.java @@ -3,14 +3,26 @@ package com.amazon.demanddriventrafficevaluator.evaluation.evaluator; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.ResponseMetadata; +import com.amazon.demanddriventrafficevaluator.util.ResponseUtil; import lombok.Builder; import lombok.Data; import java.util.List; +import java.util.Map; @Builder @Data public class Response { private final List slots; - private final String ext; + private final int learning; + + public String toExt() { + return ResponseUtil.buildExtension(Map.of(ResponseUtil.EXTENSION_KEYWORD_LEARNING, learning)); + } + + public ResponseMetadata toExtProto() { + return ResponseMetadata.newBuilder().setLearning(learning).addAllSlots(slots.stream().map(s -> s.toExtProto()).toList()).build(); + } + } diff --git a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Slot.java b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Slot.java index 714ee16..0977b39 100644 --- a/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Slot.java +++ b/java/src/main/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/Slot.java @@ -3,12 +3,24 @@ package com.amazon.demanddriventrafficevaluator.evaluation.evaluator; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.SlotMetadata; +import com.amazon.demanddriventrafficevaluator.util.ResponseUtil; import lombok.Builder; import lombok.Data; +import java.util.Map; + @Builder @Data public class Slot { private final double filterDecision; - private final String ext; + private final double decision; + + public String toExt() { + return ResponseUtil.buildExtension(Map.of(ResponseUtil.EXTENSION_KEYWORD_DECISION, decision)); + } + + public SlotMetadata toExtProto() { + return SlotMetadata.newBuilder().setDecision(decision).build(); + } } diff --git a/java/src/main/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtil.java b/java/src/main/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtil.java index fa65aa5..f1f08ab 100644 --- a/java/src/main/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtil.java +++ b/java/src/main/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtil.java @@ -9,13 +9,17 @@ import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.ModelEvaluatorOutput; import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Signal; import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Slot; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.ResponseMetadata; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.Map; + +import com.google.protobuf.InvalidProtocolBufferException; import lombok.extern.log4j.Log4j2; /** @@ -35,6 +39,9 @@ public class ResponseUtil { private static final ObjectMapper mapper = new ObjectMapper(); + private final static Base64.Encoder b64Encoder = Base64.getUrlEncoder().withoutPadding(); + private final static Base64.Decoder b64Decoder = Base64.getUrlDecoder(); + /** * Builds a list of Slot objects based on the evaluation context. * @@ -45,7 +52,7 @@ public static List buildSlots(EvaluationContext context) { AggregatedModelEvaluationResult aggregatedModelEvaluationResult = context.getAggregatedModelEvaluationResult(); return List.of(Slot.builder() .filterDecision(aggregatedModelEvaluationResult.getScoreWithTreatment()) - .ext(buildExtension(Map.of(EXTENSION_KEYWORD_DECISION, aggregatedModelEvaluationResult.getScore()))) + .decision(aggregatedModelEvaluationResult.getScore()) .build()); } @@ -69,6 +76,8 @@ public static String buildExtension(Map extensionMapping) { } } + + /** * Builds a list of Signal objects from model evaluator outputs. * @@ -109,4 +118,28 @@ public static String getDebugInfo(ModelEvaluationContext context) { return requestLevelDebugInfo + modelLevelDebugInfo; } + + /** + * Encodes Response as uri safe base64 string to provide alternate method of passing amazonTest data + * @param response protobuf representation of evaluation response + * @return uri safe base64 encoded string + */ + public static String encodedResponseMetadata(ResponseMetadata response) { + return b64Encoder.encodeToString(response.toByteArray()); + } + + /** + * Decode base64 encoded string to protobuf representation of evaluation response + * @throws IllegalArgumentException when invalid string is encountered + * @param response string to parse + * @return decoded and parsed response + */ + public static ResponseMetadata decodeResponseMetadata(String response){ + try { + return ResponseMetadata.parseFrom(b64Decoder.decode(response)); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException("Encoded bytes are not a valid base64 representation of a ResponseMetadata", e); + } + } + } diff --git a/java/src/main/proto/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/evaluator_response.proto b/java/src/main/proto/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/evaluator_response.proto new file mode 100644 index 0000000..da7caff --- /dev/null +++ b/java/src/main/proto/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/evaluator_response.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.amazon.demanddriventrafficevaluator.evaluation.evaluator; + +option java_multiple_files = true; +option java_package = "com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf"; +option java_outer_classname = "ResponseProto"; + + +message SlotMetadata { + double decision = 1; +} + +// Structured equivalent of Java Response: zero or more slots (List) plus response-level ext. +message ResponseMetadata { + int32 learning = 1; + repeated SlotMetadata slots = 2; +} diff --git a/java/src/test/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModelTest.java b/java/src/test/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModelTest.java index 6b55318..a3de687 100644 --- a/java/src/test/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModelTest.java +++ b/java/src/test/java/com/amazon/demanddriventrafficevaluator/evaluation/evaluator/BidRequestEvaluatorOnRuleBasedModelTest.java @@ -4,12 +4,15 @@ package com.amazon.demanddriventrafficevaluator.evaluation.evaluator; import com.amazon.demanddriventrafficevaluator.BaseTestCase; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.SlotMetadata; import com.amazon.demanddriventrafficevaluator.evaluation.experiment.ExperimentContext; import com.amazon.demanddriventrafficevaluator.evaluation.experiment.ExperimentManager; import com.amazon.demanddriventrafficevaluator.repository.entity.ExperimentConfiguration; import com.amazon.demanddriventrafficevaluator.repository.entity.ModelConfiguration; import com.amazon.demanddriventrafficevaluator.repository.provider.configuration.ConfigurationProvider; -import java.util.Collections; + +import java.util.*; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -19,10 +22,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import software.amazon.awssdk.utils.ImmutableMap; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - import static com.amazon.demanddriventrafficevaluator.evaluation.evaluator.BidRequestEvaluatorOnRuleBasedModel.DEFAULT_RESPONSE; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -119,8 +118,13 @@ void testEvaluateSuccess() { assertEquals(1, output.getResponse().getSlots().size()); Slot slot = output.getResponse().getSlots().get(0); assertEquals(1.0, slot.getFilterDecision()); - assertTrue(slot.getExt().contains("{\"decision\":0.0}")); - assertTrue(output.getResponse().getExt().contains("{\"amazontest\":{\"learning\":1}}")); + assertTrue(slot.toExt().contains("{\"decision\":0.0}")); + assertTrue(output.getResponse().toExt().contains("{\"amazontest\":{\"learning\":1}}")); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(0.0).build()), + responseProto.getSlotsList()); verify(experimentManager).setupExperimentContext(any(EvaluationContext.class)); verify(modelConfigurationProvider).provide(); @@ -173,8 +177,13 @@ void testEvaluateSuccessWithMap() { assertEquals(1, output.getResponse().getSlots().size()); Slot slot = output.getResponse().getSlots().get(0); assertEquals(1.0, slot.getFilterDecision()); - assertTrue(slot.getExt().contains("{\"decision\":0.0}")); - assertTrue(output.getResponse().getExt().contains("{\"amazontest\":{\"learning\":1}}")); + assertTrue(slot.toExt().contains("{\"decision\":0.0}")); + assertTrue(output.getResponse().toExt().contains("{\"amazontest\":{\"learning\":1}}")); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(0.0).build()), + responseProto.getSlotsList()); verify(experimentManager).setupExperimentContext(any(EvaluationContext.class)); verify(modelConfigurationProvider).provide(); @@ -236,8 +245,13 @@ void testEvaluateSuccessWithTwoModels() { assertEquals(1, output.getResponse().getSlots().size()); Slot slot = output.getResponse().getSlots().get(0); assertEquals(1.0, slot.getFilterDecision()); - assertTrue(slot.getExt().contains("{\"decision\":0.0}")); - assertTrue(output.getResponse().getExt().contains("{\"amazontest\":{\"learning\":1}}")); + assertTrue(slot.toExt().contains("{\"decision\":0.0}")); + assertTrue(output.getResponse().toExt().contains("{\"amazontest\":{\"learning\":1}}")); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(0.0).build()), + responseProto.getSlotsList()); verify(experimentManager).setupExperimentContext(any(EvaluationContext.class)); verify(modelConfigurationProvider).provide(); @@ -300,8 +314,13 @@ void testEvaluateSuccessWithMapTwoModels() { assertEquals(1, output.getResponse().getSlots().size()); Slot slot = output.getResponse().getSlots().get(0); assertEquals(1.0, slot.getFilterDecision()); - assertTrue(slot.getExt().contains("{\"decision\":0.0}")); - assertTrue(output.getResponse().getExt().contains("{\"amazontest\":{\"learning\":1}}")); + assertTrue(slot.toExt().contains("{\"decision\":0.0}")); + assertTrue(output.getResponse().toExt().contains("{\"amazontest\":{\"learning\":1}}")); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(0.0).build()), + responseProto.getSlotsList()); verify(experimentManager).setupExperimentContext(any(EvaluationContext.class)); verify(modelConfigurationProvider).provide(); diff --git a/java/src/test/java/com/amazon/demanddriventrafficevaluator/functional/evaluation/BidRequestEvaluatorOnRuleBasedModelTest.java b/java/src/test/java/com/amazon/demanddriventrafficevaluator/functional/evaluation/BidRequestEvaluatorOnRuleBasedModelTest.java index a6726c4..0d6ab40 100644 --- a/java/src/test/java/com/amazon/demanddriventrafficevaluator/functional/evaluation/BidRequestEvaluatorOnRuleBasedModelTest.java +++ b/java/src/test/java/com/amazon/demanddriventrafficevaluator/functional/evaluation/BidRequestEvaluatorOnRuleBasedModelTest.java @@ -12,6 +12,7 @@ import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Response; import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.RuleBasedModelEvaluator; import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Slot; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.SlotMetadata; import com.amazon.demanddriventrafficevaluator.factory.DefaultLocalCacheRegistryFactory; import com.amazon.demanddriventrafficevaluator.factory.ExperimentManagerFactory; import com.amazon.demanddriventrafficevaluator.factory.ExtractorRegistryFactory; @@ -59,6 +60,7 @@ import java.time.ZoneId; import java.time.ZonedDateTime; import java.time.format.DateTimeFormatter; +import java.util.Arrays; import java.util.List; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -140,10 +142,20 @@ public void testEvaluateMultiModel_returnExpectedResponse() { Response expectedResponse = Response.builder() .slots(List.of(Slot.builder() .filterDecision(1.0) - .ext("{\"amazontest\":{\"decision\":1.0}}") + .decision(1.0) .build())) - .ext("{\"amazontest\":{\"learning\":1}}") + .learning(1) .build(); + + String expectedResponseJson = "{\"amazontest\":{\"learning\":1}}"; + List expectedSlotJson = Arrays.asList("{\"amazontest\":{\"decision\":1.0}}"); + assertEquals(expectedResponseJson, output.getResponse().toExt()); + assertEquals(expectedSlotJson, output.getResponse().getSlots().stream().map(s -> s.toExt()).toList()); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(1.0).build()), + responseProto.getSlotsList()); assertEquals(expectedResponse, output.getResponse()); } @@ -161,10 +173,19 @@ public void testEvaluate_returnExpectedResponse() { Response expectedResponse = Response.builder() .slots(List.of(Slot.builder() .filterDecision(1.0) - .ext("{\"amazontest\":{\"decision\":0.0}}") + .decision(0.0) .build())) - .ext("{\"amazontest\":{\"learning\":1}}") + .learning(1) .build(); + String expectedResponseJson = "{\"amazontest\":{\"learning\":1}}"; + List expectedSlotJson = Arrays.asList("{\"amazontest\":{\"decision\":0.0}}"); + assertEquals(expectedResponseJson, output.getResponse().toExt()); + assertEquals(expectedSlotJson, output.getResponse().getSlots().stream().map(s -> s.toExt()).toList()); + var responseProto = output.getResponse().toExtProto(); + assertEquals(1, responseProto.getLearning()); + assertEquals( + Arrays.asList(SlotMetadata.newBuilder().setDecision(0.0).build()), + responseProto.getSlotsList()); assertEquals(expectedResponse, output.getResponse()); } diff --git a/java/src/test/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtilTest.java b/java/src/test/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtilTest.java index c064e2c..1a7b804 100644 --- a/java/src/test/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtilTest.java +++ b/java/src/test/java/com/amazon/demanddriventrafficevaluator/util/ResponseUtilTest.java @@ -3,13 +3,9 @@ package com.amazon.demanddriventrafficevaluator.util; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.AggregatedModelEvaluationResult; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.EvaluationContext; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.ModelEvaluationContext; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.ModelEvaluationStatus; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.ModelEvaluatorOutput; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Signal; -import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.Slot; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.*; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.ResponseMetadata; +import com.amazon.demanddriventrafficevaluator.evaluation.evaluator.protobuf.SlotMetadata; import com.amazon.demanddriventrafficevaluator.repository.entity.ModelDefinition; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -26,6 +22,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -59,7 +56,7 @@ void testBuildSlots() { assertEquals(1, slots.size()); Slot slot = slots.get(0); assertEquals(0.8, slot.getFilterDecision()); - assertTrue(slot.getExt().contains("\"decision\":0.7")); + assertTrue(slot.toExt().contains("\"decision\":0.7")); } @Test @@ -151,4 +148,36 @@ void testGetDebugInfo() { assertTrue(debugInfo.contains("Model Debug 1")); assertTrue(debugInfo.contains("Model Debug 2")); } + + @Test + public void testEncodedDecodeProto() { + ResponseMetadata expectedResponse = + ResponseMetadata + .newBuilder() + .addSlots(SlotMetadata.newBuilder().setDecision(99)) + .setLearning(111) + .build(); + ResponseMetadata response; + var encoded = ResponseUtil.encodedResponseMetadata(expectedResponse); + response = ResponseUtil.decodeResponseMetadata(encoded); + assertEquals(expectedResponse, response); + } + + @Test + public void testDecodeFailureOnInvalidBase64() { + var exception = assertThrows(IllegalArgumentException.class, () -> ResponseUtil.decodeResponseMetadata("*&*&")); + assertTrue(exception.getMessage().contains("Illegal base64 character")); + } + + @Test + public void testDecodeFailureOnInvalidBytes() { + var exception = assertThrows(IllegalArgumentException.class, () -> ResponseUtil.decodeResponseMetadata("08983490832")); + assertTrue(exception.getMessage().contains("Encoded bytes are not a valid base64 representation of a ResponseMetadata")); + } + + @Test + public void testDecodeEmptyArray() { + ResponseUtil.decodeResponseMetadata(""); + assertEquals(ResponseMetadata.getDefaultInstance(), ResponseUtil.decodeResponseMetadata("")); + } }