Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelContainer;
import dev.cel.common.CelFunctionDecl;
Expand All @@ -35,36 +34,42 @@
import dev.cel.common.CelValidationException;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.StructTypeReference;
import dev.cel.compiler.CelCompiler;
import dev.cel.compiler.CelCompilerFactory;
import dev.cel.expr.conformance.proto2.Proto2ExtensionScopedMessage;
import dev.cel.expr.conformance.proto2.TestAllTypes;
import dev.cel.expr.conformance.proto2.TestAllTypes.NestedEnum;
import dev.cel.expr.conformance.proto2.TestAllTypesExtensions;
import dev.cel.parser.CelMacro;
import dev.cel.parser.CelStandardMacro;
import dev.cel.runtime.CelFunctionBinding;
import dev.cel.runtime.CelRuntime;
import dev.cel.runtime.CelRuntimeFactory;
import dev.cel.testing.CelRuntimeFlavor;
import java.util.Map;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;

@RunWith(TestParameterInjector.class)
public final class CelProtoExtensionsTest {

private static final CelCompiler CEL_COMPILER =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.protos())
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.addFileTypes(TestAllTypesExtensions.getDescriptor())
.addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto2"))
.build();
@TestParameter public CelRuntimeFlavor runtimeFlavor;
@TestParameter public boolean isParseOnly;

private static final CelRuntime CEL_RUNTIME =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addFileTypes(TestAllTypesExtensions.getDescriptor())
.build();
private Cel cel;

@Before
public void setUp() {
// Legacy runtime does not support parsed-only evaluation mode.
Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly);
this.cel =
runtimeFlavor
.builder()
.addCompilerLibraries(CelExtensions.protos())
.setStandardMacros(CelStandardMacro.STANDARD_MACROS)
.addFileTypes(TestAllTypesExtensions.getDescriptor())
.addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"))
.setContainer(CelContainer.ofName("cel.expr.conformance.proto2"))
.build();
}

private static final TestAllTypes PACKAGE_SCOPED_EXT_MSG =
TestAllTypes.newBuilder()
Expand Down Expand Up @@ -106,10 +111,7 @@ public void library() {
"{expr: 'proto.hasExt(msg, cel.expr.conformance.proto2.repeated_test_all_types)'}")
@TestParameters("{expr: '!proto.hasExt(msg, cel.expr.conformance.proto2.test_all_types_ext)'}")
public void hasExt_packageScoped_success(String expr) throws Exception {
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
boolean result =
(boolean)
CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));
boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}
Expand All @@ -128,10 +130,7 @@ public void hasExt_packageScoped_success(String expr) throws Exception {
"{expr: '!proto.hasExt(msg,"
+ " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.nested_enum_ext)'}")
public void hasExt_messageScoped_success(String expr) throws Exception {
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
boolean result =
(boolean)
CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG));
boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}
Expand All @@ -142,9 +141,10 @@ public void hasExt_messageScoped_success(String expr) throws Exception {
public void hasExt_nonProtoNamespace_success(String expr) throws Exception {
StructTypeReference proto2MessageTypeReference =
StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes");
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.protos())
Cel customCel =
runtimeFlavor
.builder()
.addCompilerLibraries(CelExtensions.protos())
.addVar("msg", proto2MessageTypeReference)
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
Expand All @@ -154,37 +154,35 @@ public void hasExt_nonProtoNamespace_success(String expr) throws Exception {
SimpleType.BOOL,
ImmutableList.of(
proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT))))
.build();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addFunctionBindings(
CelFunctionBinding.from(
"msg_hasExt",
ImmutableList.of(TestAllTypes.class, String.class, Long.class),
(arg) -> {
TestAllTypes msg = (TestAllTypes) arg[0];
String extensionField = (String) arg[1];
return msg.getAllFields().keySet().stream()
.anyMatch(fd -> fd.getFullName().equals(extensionField));
}))
CelFunctionBinding.fromOverloads(
"hasExt",
CelFunctionBinding.from(
"msg_hasExt",
ImmutableList.of(TestAllTypes.class, String.class, Long.class),
(arg) -> {
TestAllTypes msg = (TestAllTypes) arg[0];
String extensionField = (String) arg[1];
return msg.getAllFields().keySet().stream()
.anyMatch(fd -> fd.getFullName().equals(extensionField));
})))
.build();

CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst();
boolean result =
(boolean)
celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));
(boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}

@Test
public void hasExt_undefinedField_throwsException() {
// This is a type-checking failure
Assume.assumeFalse(isParseOnly);
CelValidationException exception =
assertThrows(
CelValidationException.class,
() ->
CEL_COMPILER
.compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)")
cel.compile("!proto.hasExt(msg, cel.expr.conformance.proto2.undefined_field)")
.getAst());

assertThat(exception)
Expand All @@ -204,10 +202,7 @@ public void hasExt_undefinedField_throwsException() {
"{expr: 'proto.getExt(msg, cel.expr.conformance.proto2.repeated_test_all_types) =="
+ " [TestAllTypes{single_string: ''A''}, TestAllTypes{single_string: ''B''}]'}")
public void getExt_packageScoped_success(String expr) throws Exception {
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
boolean result =
(boolean)
CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));
boolean result = (boolean) eval(expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}
Expand All @@ -221,22 +216,20 @@ public void getExt_packageScoped_success(String expr) throws Exception {
"{expr: 'proto.getExt(msg,"
+ " cel.expr.conformance.proto2.Proto2ExtensionScopedMessage.int64_ext) == 1'}")
public void getExt_messageScopedSuccess(String expr) throws Exception {
CelAbstractSyntaxTree ast = CEL_COMPILER.compile(expr).getAst();
boolean result =
(boolean)
CEL_RUNTIME.createProgram(ast).eval(ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG));
boolean result = (boolean) eval(expr, ImmutableMap.of("msg", MESSAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}

@Test
public void getExt_undefinedField_throwsException() {
// This is a type-checking failure
Assume.assumeFalse(isParseOnly);
CelValidationException exception =
assertThrows(
CelValidationException.class,
() ->
CEL_COMPILER
.compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)")
cel.compile("!proto.getExt(msg, cel.expr.conformance.proto2.undefined_field)")
.getAst());

assertThat(exception)
Expand All @@ -250,9 +243,10 @@ public void getExt_undefinedField_throwsException() {
public void getExt_nonProtoNamespace_success(String expr) throws Exception {
StructTypeReference proto2MessageTypeReference =
StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes");
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addLibraries(CelExtensions.protos())
Cel customCel =
runtimeFlavor
.builder()
.addCompilerLibraries(CelExtensions.protos())
.addVar("msg", proto2MessageTypeReference)
.addFunctionDeclarations(
CelFunctionDecl.newFunctionDeclaration(
Expand All @@ -262,29 +256,26 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception {
SimpleType.DYN,
ImmutableList.of(
proto2MessageTypeReference, SimpleType.STRING, SimpleType.INT))))
.build();
CelRuntime celRuntime =
CelRuntimeFactory.standardCelRuntimeBuilder()
.addFunctionBindings(
CelFunctionBinding.from(
"msg_getExt",
ImmutableList.of(TestAllTypes.class, String.class, Long.class),
(arg) -> {
TestAllTypes msg = (TestAllTypes) arg[0];
String extensionField = (String) arg[1];
FieldDescriptor extensionDescriptor =
msg.getAllFields().keySet().stream()
.filter(fd -> fd.getFullName().equals(extensionField))
.findAny()
.get();
return msg.getField(extensionDescriptor);
}))
CelFunctionBinding.fromOverloads(
"getExt",
CelFunctionBinding.from(
"msg_getExt",
ImmutableList.of(TestAllTypes.class, String.class, Long.class),
(arg) -> {
TestAllTypes msg = (TestAllTypes) arg[0];
String extensionField = (String) arg[1];
FieldDescriptor extensionDescriptor =
msg.getAllFields().keySet().stream()
.filter(fd -> fd.getFullName().equals(extensionField))
.findAny()
.get();
return msg.getField(extensionDescriptor);
})))
.build();

CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst();
boolean result =
(boolean)
celRuntime.createProgram(ast).eval(ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));
(boolean) eval(customCel, expr, ImmutableMap.of("msg", PACKAGE_SCOPED_EXT_MSG));

assertThat(result).isTrue();
}
Expand All @@ -293,21 +284,24 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception {
public void getExt_onAnyPackedExtensionField_success() throws Exception {
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
TestAllTypesExtensions.registerAllExtensions(extensionRegistry);
Cel cel =
CelFactory.standardCelBuilder()
Cel customCel =
runtimeFlavor
.builder()
// CEL-Internal-2
.addCompilerLibraries(CelExtensions.protos())
.addFileTypes(TestAllTypesExtensions.getDescriptor())
.setExtensionRegistry(extensionRegistry)
.addVar("msg", StructTypeReference.create("cel.expr.conformance.proto2.TestAllTypes"))
.build();
CelAbstractSyntaxTree ast =
cel.compile("proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)").getAst();
Any anyMsg =
Any.pack(
TestAllTypes.newBuilder().setExtension(TestAllTypesExtensions.int32Ext, 1).build());

Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg));
Long result =
(Long)
eval(
customCel,
"proto.getExt(msg, cel.expr.conformance.proto2.int32_ext)",
ImmutableMap.of("msg", anyMsg));

assertThat(result).isEqualTo(1);
}
Expand Down Expand Up @@ -343,9 +337,18 @@ private enum ParseErrorTestCase {
@Test
public void parseErrors(@TestParameter ParseErrorTestCase testcase) {
CelValidationException e =
assertThrows(
CelValidationException.class, () -> CEL_COMPILER.compile(testcase.expr).getAst());
assertThrows(CelValidationException.class, () -> cel.parse(testcase.expr).getAst());

assertThat(e).hasMessageThat().isEqualTo(testcase.error);
}

private Object eval(String expression, Map<String, ?> variables) throws Exception {
return eval(this.cel, expression, variables);
}

private Object eval(Cel cel, String expression, Map<String, ?> variables) throws Exception {
CelAbstractSyntaxTree ast =
this.isParseOnly ? cel.parse(expression).getAst() : cel.compile(expression).getAst();
return cel.createProgram(ast).eval(variables);
}
}
Loading