From 2203ac83b0fb19ad56fe7b9bfb4a963034886bd3 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 14 Apr 2026 18:59:59 -0700 Subject: [PATCH] Support parsed-only evaluation for lists extensions, remove check for heterogeneous numeric comparisons for sorting PiperOrigin-RevId: 899880149 --- .../cel/extensions/CelListsExtensions.java | 66 +++----- .../extensions/CelListsExtensionsTest.java | 146 +++++++++--------- 2 files changed, 99 insertions(+), 113 deletions(-) diff --git a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java index a91edd822..79539b008 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java @@ -128,7 +128,8 @@ public enum Function { "list_sort", "Sorts a list with comparable elements.", ListType.create(TypeParamType.create("T")), - ListType.create(TypeParamType.create("T"))))), + ListType.create(TypeParamType.create("T")))), + CelFunctionBinding.from("list_sort", Collection.class, CelListsExtensions::sort)), SORT_BY( CelFunctionDecl.newFunctionDeclaration( "lists.@sortByAssociatedKeys", @@ -136,7 +137,11 @@ public enum Function { "list_sortByAssociatedKeys", "Sorts a list by a key value. Used by the 'sortBy' macro", ListType.create(TypeParamType.create("T")), - ListType.create(TypeParamType.create("T"))))); + ListType.create(TypeParamType.create("T")))), + CelFunctionBinding.from( + "list_sortByAssociatedKeys", + Collection.class, + CelListsExtensions::sortByAssociatedKeys)); private final CelFunctionDecl functionDecl; private final ImmutableSet functionBindings; @@ -147,7 +152,10 @@ String getFunction() { Function(CelFunctionDecl functionDecl, CelFunctionBinding... functionBindings) { this.functionDecl = functionDecl; - this.functionBindings = ImmutableSet.copyOf(functionBindings); + this.functionBindings = + functionBindings.length > 0 + ? CelFunctionBinding.fromOverloads(functionDecl.name(), functionBindings) + : ImmutableSet.of(); } } @@ -240,32 +248,13 @@ public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { @Override public void setRuntimeOptions( CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) { - for (Function function : functions) { - runtimeBuilder.addFunctionBindings(function.functionBindings); - for (CelOverloadDecl overload : function.functionDecl.overloads()) { - switch (overload.overloadId()) { - case "list_distinct": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_distinct", Collection.class, (list) -> distinct(list, runtimeEquality))); - break; - case "list_sort": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_sort", Collection.class, (list) -> sort(list, celOptions))); - break; - case "list_sortByAssociatedKeys": - runtimeBuilder.addFunctionBindings( - CelFunctionBinding.from( - "list_sortByAssociatedKeys", - Collection.class, - (list) -> sortByAssociatedKeys(list, celOptions))); - break; - default: - // Nothing to add - } - } - } + functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings)); + + runtimeBuilder.addFunctionBindings( + CelFunctionBinding.fromOverloads( + "distinct", + CelFunctionBinding.from( + "list_distinct", Collection.class, (list) -> distinct(list, runtimeEquality)))); } private static ImmutableList slice(Collection list, long from, long to) { @@ -369,22 +358,18 @@ private static List reverse(Collection list) { } } - private static ImmutableList sort(Collection objects, CelOptions options) { - return ImmutableList.sortedCopyOf( - new CelObjectComparator(options.enableHeterogeneousNumericComparisons()), objects); + private static ImmutableList sort(Collection objects) { + return ImmutableList.sortedCopyOf(new CelObjectComparator(), objects); } private static class CelObjectComparator implements Comparator { - private final boolean enableHeterogeneousNumericComparisons; - CelObjectComparator(boolean enableHeterogeneousNumericComparisons) { - this.enableHeterogeneousNumericComparisons = enableHeterogeneousNumericComparisons; - } + CelObjectComparator() {} @SuppressWarnings({"unchecked"}) @Override public int compare(Object o1, Object o2) { - if (o1 instanceof Number && o2 instanceof Number && enableHeterogeneousNumericComparisons) { + if (o1 instanceof Number && o2 instanceof Number) { return ComparisonFunctions.numericCompare((Number) o1, (Number) o2); } @@ -444,12 +429,9 @@ private static Optional sortByMacro( @SuppressWarnings({"unchecked", "rawtypes"}) private static ImmutableList sortByAssociatedKeys( - Collection> keyValuePairs, CelOptions options) { + Collection> keyValuePairs) { List[] array = keyValuePairs.toArray(new List[0]); - Arrays.sort( - array, - new CelObjectByKeyComparator( - new CelObjectComparator(options.enableHeterogeneousNumericComparisons()))); + Arrays.sort(array, new CelObjectByKeyComparator(new CelObjectComparator())); ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(array.length); for (List pair : array) { builder.add(pair.get(1)); diff --git a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java index c4739b18b..2083ccc42 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelListsExtensionsTest.java @@ -19,41 +19,38 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSortedMultiset; import com.google.common.collect.ImmutableSortedSet; +import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; +import dev.cel.bundle.CelBuilder; +import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; -import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.expr.conformance.test.SimpleTest; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelEvaluationException; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public class CelListsExtensionsTest { - private static final Cel CEL = - CelFactory.standardCelBuilder() - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.lists()) - .setContainer(CelContainer.ofName("cel.expr.conformance.test")) - .addMessageTypes(SimpleTest.getDescriptor()) - .addVar("non_list", SimpleType.DYN) - .build(); - - private static final Cel CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS = - CelFactory.standardCelBuilder() - .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addCompilerLibraries(CelExtensions.lists()) - .addRuntimeLibraries(CelExtensions.lists()) - .setContainer(CelContainer.ofName("cel.expr.conformance.test")) - .addMessageTypes(SimpleTest.getDescriptor()) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = setupEnv(runtimeFlavor.builder()); + } @Test public void functionList_byVersion() { @@ -89,10 +86,9 @@ public void macroList_byVersion() { @TestParameters("{expression: 'non_list.slice(1, 3)', expected: '[2, 3]'}") public void slice_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval(ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); + eval(cel, expression, ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -107,10 +103,7 @@ public void slice_success(String expression, String expected) throws Exception { "{expression: '[1,2,3,4].slice(-5, -3)', " + "expectedError: 'Negative indexes not supported'}") public void slice_throws(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); @@ -127,7 +120,7 @@ public void slice_throws(String expression, String expectedError) throws Excepti @TestParameters("{expression: 'dyn([{1: 2}]).flatten() == [{1: 2}]'}") @TestParameters("{expression: 'dyn([1,2,3,4]).flatten() == [1,2,3,4]'}") public void flattenSingleLevel_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -143,7 +136,7 @@ public void flattenSingleLevel_success(String expression) throws Exception { // The overload with the depth accepts and returns a List(dyn), so the following is permitted. @TestParameters("{expression: '[1].flatten(1) == [1]'}") public void flatten_withDepthValue_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -151,13 +144,17 @@ public void flatten_withDepthValue_success(String expression) throws Exception { @Test public void flatten_negativeDepth_throws() { CelEvaluationException e = - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile("[1,2,3,4].flatten(-1)").getAst()).eval()); - - assertThat(e) - .hasMessageThat() - .contains("evaluation error at :17: Function 'list_flatten_list_int' failed"); + assertThrows(CelEvaluationException.class, () -> eval(cel, "[1,2,3,4].flatten(-1)")); + + if (isParseOnly) { + assertThat(e) + .hasMessageThat() + .contains("evaluation error at :17: Function 'flatten' failed"); + } else { + assertThat(e) + .hasMessageThat() + .contains("evaluation error at :17: Function 'list_flatten_list_int' failed"); + } assertThat(e).hasCauseThat().hasMessageThat().isEqualTo("Level must be non-negative"); } @@ -166,9 +163,11 @@ public void flatten_negativeDepth_throws() { @TestParameters("{expression: '[{1: 2}].flatten()'}") @TestParameters("{expression: '[1,2,3,4].flatten()'}") public void flattenSingleLevel_listIsSingleLevel_throws(String expression) { + // This is a type-checking failure. + Assume.assumeFalse(isParseOnly); // Note: Java lacks the capability of conditionally disabling type guards // due to the lack of full-fledged dynamic dispatch. - assertThrows(CelValidationException.class, () -> CEL.compile(expression).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expression).getAst()); } @Test @@ -176,7 +175,7 @@ public void flattenSingleLevel_listIsSingleLevel_throws(String expression) { @TestParameters("{expression: 'lists.range(0) == []'}") @TestParameters("{expression: 'lists.range(-1) == []'}") public void range_success(String expression) throws Exception { - boolean result = (boolean) CEL.createProgram(CEL.compile(expression).getAst()).eval(); + boolean result = (boolean) eval(cel, expression); assertThat(result).isTrue(); } @@ -204,12 +203,13 @@ public void range_success(String expression) throws Exception { @TestParameters("{expression: 'non_list.distinct()', expected: '[1, 2, 3, 4]'}") public void distinct_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval( - ImmutableMap.of( - "non_list", ImmutableSortedMultiset.of(1L, 2L, 3L, 4L, 4L, 1L, 3L, 2L))); + eval( + cel, + expression, + ImmutableMap.of( + "non_list", ImmutableSortedMultiset.of(1L, 2L, 3L, 4L, 4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -224,10 +224,9 @@ public void distinct_success(String expression, String expected) throws Exceptio @TestParameters("{expression: 'non_list.reverse()', expected: '[4, 3, 2, 1]'}") public void reverse_success(String expression, String expected) throws Exception { Object result = - CEL.createProgram(CEL.compile(expression).getAst()) - .eval(ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); + eval(cel, expression, ImmutableMap.of("non_list", ImmutableSortedSet.of(4L, 1L, 3L, 2L))); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -238,9 +237,9 @@ public void reverse_success(String expression, String expected) throws Exception "{expression: '[\"d\", \"a\", \"b\", \"c\"].sort()', " + "expected: '[\"a\", \"b\", \"c\", \"d\"]'}") public void sort_success(String expression, String expected) throws Exception { - Object result = CEL.createProgram(CEL.compile(expression).getAst()).eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -248,29 +247,20 @@ public void sort_success(String expression, String expected) throws Exception { @TestParameters("{expression: '[4, 3, 2, 1].sort()', expected: '[1, 2, 3, 4]'}") public void sort_success_heterogeneousNumbers(String expression, String expected) throws Exception { - Object result = - CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS - .createProgram(CEL_WITH_HETEROGENEOUS_NUMERIC_COMPARISONS.compile(expression).getAst()) - .eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @TestParameters( "{expression: '[\"d\", 3, 2, \"c\"].sort()', " + "expectedError: 'List elements must have the same type'}") - @TestParameters( - "{expression: '[3.0, 2, 1u].sort()', " - + "expectedError: 'List elements must have the same type'}") @TestParameters( "{expression: '[SimpleTest{name: \"a\"}, SimpleTest{name: \"b\"}].sort()', " + "expectedError: 'List elements must be comparable'}") public void sort_throws(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); @@ -296,9 +286,9 @@ public void sort_throws(String expression, String expectedError) throws Exceptio + " SimpleTest{name: \"baz\"}," + " SimpleTest{name: \"foo\"}]'}") public void sortBy_success(String expression, String expected) throws Exception { - Object result = CEL.createProgram(CEL.compile(expression).getAst()).eval(); + Object result = eval(cel, expression); - assertThat(result).isEqualTo(expectedResult(expected)); + assertThat(result).isEqualTo(eval(cel, expected)); } @Test @@ -313,7 +303,7 @@ public void sortBy_throws_validationException(String expression, String expected assertThat( assertThrows( CelValidationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + () -> cel.createProgram(cel.compile(expression).getAst()).eval())) .hasMessageThat() .contains(expectedError); } @@ -327,17 +317,31 @@ public void sortBy_throws_validationException(String expression, String expected + "expectedError: 'List elements must be comparable'}") public void sortBy_throws_evaluationException(String expression, String expectedError) throws Exception { - assertThat( - assertThrows( - CelEvaluationException.class, - () -> CEL.createProgram(CEL.compile(expression).getAst()).eval())) + assertThat(assertThrows(CelEvaluationException.class, () -> eval(cel, expression))) .hasCauseThat() .hasMessageThat() .contains(expectedError); } - private static Object expectedResult(String expression) - throws CelEvaluationException, CelValidationException { - return CEL.createProgram(CEL.compile(expression).getAst()).eval(); + private static Cel setupEnv(CelBuilder celBuilder) { + return celBuilder + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelExtensions.lists()) + .addRuntimeLibraries(CelExtensions.lists()) + .setContainer(CelContainer.ofName("cel.expr.conformance.test")) + .addMessageTypes(SimpleTest.getDescriptor()) + .addVar("non_list", SimpleType.DYN) + .build(); + } + + + + private Object eval(Cel cel, String expr) throws Exception { + return eval(cel, expr, ImmutableMap.of()); + } + + private Object eval(Cel cel, String expr, Map vars) throws Exception { + CelAbstractSyntaxTree ast = isParseOnly ? cel.parse(expr).getAst() : cel.compile(expr).getAst(); + return cel.createProgram(ast).eval(vars); } }