diff --git a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java index 15f6df5be..2e55619db 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelProtoExtensionsTest.java @@ -26,7 +26,6 @@ import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.CelFunctionDecl; @@ -35,8 +34,6 @@ import dev.cel.common.CelValidationException; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto2.Proto2ExtensionScopedMessage; import dev.cel.expr.conformance.proto2.TestAllTypes; import dev.cel.expr.conformance.proto2.TestAllTypes.NestedEnum; @@ -44,27 +41,35 @@ import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; +import java.util.Map; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelProtoExtensionsTest { - private static final CelCompiler CEL_COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addFileTypes(TestAllTypesExtensions.getDescriptor()) - .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) - .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; - private static final CelRuntime CEL_RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addFileTypes(TestAllTypesExtensions.getDescriptor()) - .build(); + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + this.cel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addFileTypes(TestAllTypesExtensions.getDescriptor()) + .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) + .setContainer(CelContainer.ofName("cel.expr.conformance.proto2")) + .build(); + } private static final TestAllTypes PACKAGE_SCOPED_EXT_MSG = TestAllTypes.newBuilder() @@ -106,10 +111,7 @@ public void library() { "{expr: 'proto.hasExt(msg, cel.expr.conformance.proto2.repeated_test_all_types)'}") @TestParameters("{expr: '!proto.hasExt(msg, cel.expr.conformance.proto2.test_all_types_ext)'}") public void hasExt_packageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -128,10 +130,7 @@ public void hasExt_packageScoped_success(String expr) throws Exception { "{expr: '!proto.hasExt(msg," + " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.nested_enum_ext)'}") public void hasExt_messageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -142,9 +141,10 @@ public void hasExt_messageScoped_success(String expr) throws Exception { public void hasExt_nonProtoNamespace_success(String expr) throws Exception { StructTypeReference proto2MessageTypeReference = StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) + Cel customCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) .addVar("msg", proto2MessageTypeReference) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( @@ -154,37 +154,35 @@ public void hasExt_nonProtoNamespace_success(String expr) throws Exception { SimpleType.BOOL, ImmutableList.of( proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT)))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "msg_hasExt", - ImmutableList.of(TestAllTypes.class, String.class, Long.class), - (arg) -> { - TestAllTypes msg = (TestAllTypes) arg[0]; - String extensionField = (String) arg[1]; - return msg.getAllFields().keySet().stream() - .anyMatch(fd -> fd.getFullName().equals(extensionField)); - })) + CelFunctionBinding.fromOverloads( + "hasExt", + CelFunctionBinding.from( + "msg_hasExt", + ImmutableList.of(TestAllTypes.class, String.class, Long.class), + (arg) -> { + TestAllTypes msg = (TestAllTypes) arg[0]; + String extensionField = (String) arg[1]; + return msg.getAllFields().keySet().stream() + .anyMatch(fd -> fd.getFullName().equals(extensionField)); + }))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); boolean result = - (boolean) - celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + (boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @Test public void hasExt_undefinedField_throwsException() { + // This is a type-checking failure + Assume.assumeFalse(isParseOnly); CelValidationException exception = assertThrows( CelValidationException.class, () -> - CEL_COMPILER - .compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)") + cel.compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)") .getAst()); assertThat(exception) @@ -204,10 +202,7 @@ public void hasExt_undefinedField_throwsException() { "{expr: 'proto.getExt(msg, cel.expr.conformance.proto2.repeated_test_all_types) ==" + " [TestAllTypes{single_string: ''A''}, TestAllTypes{single_string: ''B''}]'}") public void getExt_packageScoped_success(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -221,22 +216,20 @@ public void getExt_packageScoped_success(String expr) throws Exception { "{expr: 'proto.getExt(msg," + " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.int64_ext) == 1'}") public void getExt_messageScopedSuccess(String expr) throws Exception { - CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst(); - boolean result = - (boolean) - CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); + boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @Test public void getExt_undefinedField_throwsException() { + // This is a type-checking failure + Assume.assumeFalse(isParseOnly); CelValidationException exception = assertThrows( CelValidationException.class, () -> - CEL_COMPILER - .compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)") + cel.compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)") .getAst()); assertThat(exception) @@ -250,9 +243,10 @@ public void getExt_undefinedField_throwsException() { public void getExt_nonProtoNamespace_success(String expr) throws Exception { StructTypeReference proto2MessageTypeReference = StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"); - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.protos()) + Cel customCel = + runtimeFlavor + .builder() + .addCompilerLibraries(CelExtensions.protos()) .addVar("msg", proto2MessageTypeReference) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( @@ -262,29 +256,26 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { SimpleType.DYN, ImmutableList.of( proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT)))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "msg_getExt", - ImmutableList.of(TestAllTypes.class, String.class, Long.class), - (arg) -> { - TestAllTypes msg = (TestAllTypes) arg[0]; - String extensionField = (String) arg[1]; - FieldDescriptor extensionDescriptor = - msg.getAllFields().keySet().stream() - .filter(fd -> fd.getFullName().equals(extensionField)) - .findAny() - .get(); - return msg.getField(extensionDescriptor); - })) + CelFunctionBinding.fromOverloads( + "getExt", + CelFunctionBinding.from( + "msg_getExt", + ImmutableList.of(TestAllTypes.class, String.class, Long.class), + (arg) -> { + TestAllTypes msg = (TestAllTypes) arg[0]; + String extensionField = (String) arg[1]; + FieldDescriptor extensionDescriptor = + msg.getAllFields().keySet().stream() + .filter(fd -> fd.getFullName().equals(extensionField)) + .findAny() + .get(); + return msg.getField(extensionDescriptor); + }))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); boolean result = - (boolean) - celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); + (boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG)); assertThat(result).isTrue(); } @@ -293,21 +284,24 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception { public void getExt_onAnyPackedExtensionField_success() throws Exception { ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance(); TestAllTypesExtensions.registerAllExtensions(extensionRegistry); - Cel cel = - CelFactory.standardCelBuilder() + Cel customCel = + runtimeFlavor + .builder() // CEL-Internal-2 .addCompilerLibraries(CelExtensions.protos()) .addFileTypes(TestAllTypesExtensions.getDescriptor()) .setExtensionRegistry(extensionRegistry) .addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes")) .build(); - CelAbstractSyntaxTree ast = - cel.compile("proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)").getAst(); Any anyMsg = Any.pack( TestAllTypes.newBuilder().setExtension(TestAllTypesExtensions.int32Ext, 1).build()); - - Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg)); + Long result = + (Long) + eval( + customCel, + "proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)", + ImmutableMap.of("msg", anyMsg)); assertThat(result).isEqualTo(1); } @@ -343,9 +337,18 @@ private enum ParseErrorTestCase { @Test public void parseErrors(@TestParameter ParseErrorTestCase testcase) { CelValidationException e = - assertThrows( - CelValidationException.class, () -> CEL_COMPILER.compile(testcase.expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.parse(testcase.expr).getAst()); assertThat(e).hasMessageThat().isEqualTo(testcase.error); } + + private Object eval(String expression, Map variables) throws Exception { + return eval(this.cel, expression, variables); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast = + this.isParseOnly ? cel.parse(expression).getAst() : cel.compile(expression).getAst(); + return cel.createProgram(ast).eval(variables); + } }