diff --git a/extensions/BUILD.bazel b/extensions/BUILD.bazel index c6a029106..28ad1ecb8 100644 --- a/extensions/BUILD.bazel +++ b/extensions/BUILD.bazel @@ -37,6 +37,11 @@ java_library( exports = ["//extensions/src/main/java/dev/cel/extensions:math"], ) +java_library( + name = "network", + exports = ["//extensions/src/main/java/dev/cel/extensions:network"], +) + java_library( name = "optional_library", exports = ["//extensions/src/main/java/dev/cel/extensions:optional_library"], diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 9b897cf84..61f603576 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -34,6 +34,7 @@ java_library( ":encoders", ":lists", ":math", + ":network", ":optional_library", ":protos", ":regex", @@ -135,6 +136,25 @@ java_library( ], ) +java_library( + name = "network", + srcs = ["CelNetworkExtensions.java"], + deps = [ + ":extension_library", + "//checker:checker_builder", + "//common:compiler_common", + "//common/types", + "//common/types:type_providers", + "//compiler:compiler_builder", + "//java/com/google/common/net", + "//java/com/google/net/base:base-core", + "//runtime", + "//runtime:function_binding", + "@maven//:com_google_errorprone_error_prone_annotations", + "@maven//:com_google_guava_guava", + ], +) + java_library( name = "bindings", srcs = ["CelBindingsExtensions.java"], diff --git a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java index 2d14ed118..f1e9bf519 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelExtensions.java @@ -36,6 +36,7 @@ public final class CelExtensions { private static final CelRegexExtensions REGEX_EXTENSIONS = new CelRegexExtensions(); private static final CelComprehensionsExtensions COMPREHENSIONS_EXTENSIONS = new CelComprehensionsExtensions(); + private static final CelNetworkExtensions NETWORK_EXTENSIONS = new CelNetworkExtensions(); /** * Implementation of optional values. @@ -319,6 +320,18 @@ public static CelComprehensionsExtensions comprehensions() { return COMPREHENSIONS_EXTENSIONS; } + /** + * Extended functions for Network manipulation. + * + *

Refer to README.md for available functions. + * + *

This will include all functions denoted in {@link CelNetworkExtensions.Function}, including + * any future additions. + */ + public static CelNetworkExtensions network() { + return NETWORK_EXTENSIONS; + } + /** * Retrieves all function names used by every extension libraries. * @@ -339,6 +352,8 @@ public static ImmutableSet getAllFunctionNames() { .map(CelListsExtensions.Function::getFunction), stream(CelRegexExtensions.Function.values()) .map(CelRegexExtensions.Function::getFunction), + stream(CelNetworkExtensions.Function.values()) + .map(CelNetworkExtensions.Function::getFunction), stream(CelComprehensionsExtensions.Function.values()) .map(CelComprehensionsExtensions.Function::getFunction)) .collect(toImmutableSet()); @@ -346,31 +361,21 @@ public static ImmutableSet getAllFunctionNames() { public static CelExtensionLibrary getExtensionLibrary( String name, CelOptions options) { - switch (name) { - case "bindings": - return CelBindingsExtensions.library(); - case "encoders": - return CelEncoderExtensions.library(options); - case "lists": - return CelListsExtensions.library(); - case "math": - return CelMathExtensions.library(options); - case "optional": - return CelOptionalLibrary.library(); - case "protos": - return CelProtoExtensions.library(); - case "regex": - return CelRegexExtensions.library(); - case "sets": - return CelSetsExtensions.library(options); - case "strings": - return CelStringExtensions.library(); - case "comprehensions": - return CelComprehensionsExtensions.library(); + return switch (name) { + case "bindings" -> CelBindingsExtensions.library(); + case "encoders" -> CelEncoderExtensions.library(options); + case "lists" -> CelListsExtensions.library(); + case "math" -> CelMathExtensions.library(options); + case "network" -> CelNetworkExtensions.library(); + case "optional" -> CelOptionalLibrary.library(); + case "protos" -> CelProtoExtensions.library(); + case "regex" -> CelRegexExtensions.library(); + case "sets" -> CelSetsExtensions.library(options); + case "strings" -> CelStringExtensions.library(); + case "comprehensions" -> CelComprehensionsExtensions.library(); // TODO: add support for remaining standard extensions - default: - throw new IllegalArgumentException("Unknown standard extension '" + name + "'"); - } + default -> throw new IllegalArgumentException("Unknown standard extension '" + name + "'"); + }; } private CelExtensions() {} diff --git a/extensions/src/main/java/dev/cel/extensions/CelNetworkExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelNetworkExtensions.java new file mode 100644 index 000000000..3ff56ecf0 --- /dev/null +++ b/extensions/src/main/java/dev/cel/extensions/CelNetworkExtensions.java @@ -0,0 +1,635 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.extensions; + +import com.google.common.collect.ImmutableCollection; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.net.InetAddresses; +import com.google.errorprone.annotations.Immutable; +import com.google.net.base.CidrAddressBlock; +import dev.cel.checker.CelCheckerBuilder; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.types.CelType; +import dev.cel.common.types.CelTypeProvider; +import dev.cel.common.types.OpaqueType; +import dev.cel.common.types.SimpleType; +import dev.cel.compiler.CelCompilerLibrary; +import dev.cel.runtime.CelFunctionBinding; +import dev.cel.runtime.CelRuntimeBuilder; +import dev.cel.runtime.CelRuntimeLibrary; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * CEL Extension for Network functions (IP and CIDR). + * + *

Provides functions for creating, inspecting, and manipulating IP addresses and CIDR blocks, + * maintaining consistency with the CEL Go and C++ network extensions. + */ +@Immutable +public final class CelNetworkExtensions + implements CelCompilerLibrary, CelRuntimeLibrary, CelExtensionLibrary.FeatureSet { + + // Opaque Type Definitions + public static final CelType IP_TYPE = OpaqueType.create("net.IP"); + public static final CelType CIDR_TYPE = OpaqueType.create("net.CIDR"); + + // Package-private constructor + CelNetworkExtensions() { + this.functions = ImmutableSet.copyOf(Function.values()); + } + + // Constructor for creating subsets + CelNetworkExtensions(Set functions) { + this.functions = ImmutableSet.copyOf(functions); + } + + /** Wrapper for InetAddress to represent the net.IP opaque type in CEL. */ + @Immutable + public static class IpAddress { + private final InetAddress address; + + private IpAddress(InetAddress address) { + this.address = address; + } + + public static IpAddress create(String val) { + InetAddress addr = parseStrictIp(val); + return new IpAddress(addr); + } + + public static IpAddress create(InetAddress addr) { + return new IpAddress(addr); + } + + public InetAddress getAddress() { + return address; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IpAddress ipAddress)) { + return false; + } + return address.equals(ipAddress.address); + } + + @Override + public int hashCode() { + return address.hashCode(); + } + + @Override + public String toString() { + return InetAddresses.toAddrString(address); + } + } + + /** Wrapper for CidrAddressBlock to represent the net.CIDR opaque type in CEL. */ + @Immutable + public static class CidrAddress { + private final CidrAddressBlock block; + private final InetAddress originalHost; // To preserve the non-truncated IP + private final int prefixLength; + + private CidrAddress(CidrAddressBlock block, InetAddress originalHost, int prefixLength) { + this.block = block; + this.originalHost = originalHost; + this.prefixLength = prefixLength; + } + + public static CidrAddress create(String val) { + String[] parts = val.split("/", 2); + if (parts.length != 2) { + throw new IllegalArgumentException("Invalid CIDR string format: " + val); + } + InetAddress host = parseStrictIp(parts[0]); + int prefixLength; + try { + prefixLength = Integer.parseInt(parts[1]); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid prefix length: " + parts[1], e); + } + + if ((host instanceof Inet4Address && (prefixLength < 0 || prefixLength > 32)) + || (host instanceof Inet6Address && (prefixLength < 0 || prefixLength > 128))) { + throw new IllegalArgumentException("Invalid prefix length for IP type: " + prefixLength); + } + + return new CidrAddress(CidrAddressBlock.create(host, prefixLength), host, prefixLength); + } + + public CidrAddressBlock getBlock() { + return block; + } + + public InetAddress getOriginalHost() { + return originalHost; + } + + public int getPrefixLength() { + return prefixLength; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CidrAddress that)) { + return false; + } + return prefixLength == that.prefixLength && originalHost.equals(that.originalHost); + } + + @Override + public int hashCode() { + return Objects.hash(originalHost, prefixLength); + } + + @Override + public String toString() { + // Use InetAddresses.toAddrString to ensure canonical IPv6 formatting + return InetAddresses.toAddrString(originalHost) + "/" + prefixLength; + } + } + + // -------------------------------------------------------------------------- + // Strict Parsing Helpers + // -------------------------------------------------------------------------- + private static InetAddress parseStrictIp(String val) { + if (val == null || val.isEmpty()) { + throw new IllegalArgumentException("IP address string cannot be null or empty"); + } + if (val.contains("%")) { + throw new IllegalArgumentException("IP address string must not include a zone index: " + val); + } + String ipStr = val; + if (ipStr.startsWith("[") && ipStr.endsWith("]")) { // Pure IPv6 in brackets + ipStr = ipStr.substring(1, ipStr.length() - 1); + } else if (ipStr.contains(":") && ipStr.lastIndexOf(':') != ipStr.indexOf(':')) { // IPv6 + // Handled by InetAddresses.forString + } else if (ipStr.contains(":")) { // Potentially IPv4 with port or invalid + throw new IllegalArgumentException("Invalid IP address format: " + val); + } + + InetAddress addr; + try { + addr = InetAddresses.forString(ipStr); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid IP address string: " + val, e); + } + + if (addr instanceof Inet6Address) { + if (InetAddresses.isMappedIPv4Address(addr.getHostAddress())) { + throw new IllegalArgumentException("IPv4-mapped IPv6 addresses are not allowed: " + val); + } + } + return addr; + } + + // -------------------------------------------------------------------------- + // Function Enum & Declarations + // -------------------------------------------------------------------------- + /** Enum of all functions in this extension. */ + public enum Function { + IS_IP( + CelFunctionDecl.newFunctionDeclaration( + "isIP", + CelOverloadDecl.newGlobalOverload( + "is_ip_string", + "Checks if a string is a valid IP address", + SimpleType.BOOL, + ImmutableList.of(SimpleType.STRING))), + CelFunctionBinding.from("is_ip_string", String.class, CelNetworkExtensions::isIp)), + STRING_TO_IP( + CelFunctionDecl.newFunctionDeclaration( + "ip", + CelOverloadDecl.newGlobalOverload( + "string_to_ip", + "Converts a string to an IP address object", + IP_TYPE, + ImmutableList.of(SimpleType.STRING))), + CelFunctionBinding.from("string_to_ip", String.class, CelNetworkExtensions::stringToIp)), + IS_CIDR( + CelFunctionDecl.newFunctionDeclaration( + "isCIDR", + CelOverloadDecl.newGlobalOverload( + "is_cidr_string", + "Checks if a string is a valid CIDR notation", + SimpleType.BOOL, + ImmutableList.of(SimpleType.STRING))), + CelFunctionBinding.from("is_cidr_string", String.class, CelNetworkExtensions::isCidr)), + STRING_TO_CIDR( + CelFunctionDecl.newFunctionDeclaration( + "cidr", + CelOverloadDecl.newGlobalOverload( + "string_to_cidr", + "Converts a string to a CIDR object", + CIDR_TYPE, + ImmutableList.of(SimpleType.STRING))), + CelFunctionBinding.from( + "string_to_cidr", String.class, CelNetworkExtensions::stringToCidr)), + IP_IS_CANONICAL( + CelFunctionDecl.newFunctionDeclaration( + "isCanonical", + CelOverloadDecl.newMemberOverload( + "ip_is_canonical_string", + "Checks if a string is a canonical representation of an IP address", + SimpleType.BOOL, + ImmutableList.of(SimpleType.STRING)), + CelOverloadDecl.newMemberOverload( + "ip_is_canonical_ip", + "Checks if an IP address object is a canonical representation", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_canonical_string", String.class, CelNetworkExtensions::ipIsCanonicalString), + CelFunctionBinding.from( + "ip_is_canonical_ip", IpAddress.class, CelNetworkExtensions::ipIsCanonical)), + IP_FAMILY( + CelFunctionDecl.newFunctionDeclaration( + "family", + CelOverloadDecl.newMemberOverload( + "ip_family", + "Returns the IP family (4 or 6)", + SimpleType.INT, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from("ip_family", IpAddress.class, CelNetworkExtensions::ipFamily)), + IP_IS_LOOPBACK( + CelFunctionDecl.newFunctionDeclaration( + "isLoopback", + CelOverloadDecl.newMemberOverload( + "ip_is_loopback", + "Checks if the IP is a loopback address", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_loopback", IpAddress.class, CelNetworkExtensions::ipIsLoopback)), + IP_IS_GLOBAL_UNICAST( + CelFunctionDecl.newFunctionDeclaration( + "isGlobalUnicast", + CelOverloadDecl.newMemberOverload( + "ip_is_global_unicast", + "Checks if the IP is a global unicast address", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_global_unicast", IpAddress.class, CelNetworkExtensions::ipIsGlobalUnicast)), + IP_IS_LINK_LOCAL_MULTICAST( + CelFunctionDecl.newFunctionDeclaration( + "isLinkLocalMulticast", + CelOverloadDecl.newMemberOverload( + "ip_is_link_local_multicast", + "Checks if the IP is a link-local multicast address", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_link_local_multicast", + IpAddress.class, + CelNetworkExtensions::ipIsLinkLocalMulticast)), + IP_IS_LINK_LOCAL_UNICAST( + CelFunctionDecl.newFunctionDeclaration( + "isLinkLocalUnicast", + CelOverloadDecl.newMemberOverload( + "ip_is_link_local_unicast", + "Checks if the IP is a link-local unicast address", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_link_local_unicast", + IpAddress.class, + CelNetworkExtensions::ipIsLinkLocalUnicast)), + IP_IS_UNSPECIFIED( + CelFunctionDecl.newFunctionDeclaration( + "isUnspecified", + CelOverloadDecl.newMemberOverload( + "ip_is_unspecified", + "Checks if the IP is an unspecified address", + SimpleType.BOOL, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from( + "ip_is_unspecified", IpAddress.class, CelNetworkExtensions::ipIsUnspecified)), + IP_TO_STRING( + CelFunctionDecl.newFunctionDeclaration( + "string", + CelOverloadDecl.newMemberOverload( + "ip_to_string", + "Converts the IP address to its string representation", + SimpleType.STRING, + ImmutableList.of(IP_TYPE))), + CelFunctionBinding.from("ip_to_string", IpAddress.class, IpAddress::toString)), + CIDR_IP( + CelFunctionDecl.newFunctionDeclaration( + "ip", + CelOverloadDecl.newMemberOverload( + "cidr_ip", + "Returns the base IP address of the CIDR block", + IP_TYPE, + ImmutableList.of(CIDR_TYPE))), + CelFunctionBinding.from("cidr_ip", CidrAddress.class, CelNetworkExtensions::cidrIp)), + CIDR_CONTAINS_IP( + CelFunctionDecl.newFunctionDeclaration( + "containsIP", + CelOverloadDecl.newMemberOverload( + "cidr_contains_ip_ip", + "Checks if the CIDR block contains the given IP address object", + SimpleType.BOOL, + ImmutableList.of(CIDR_TYPE, IP_TYPE)), + CelOverloadDecl.newMemberOverload( + "cidr_contains_ip_string", + "Checks if the CIDR block contains the given IP address string", + SimpleType.BOOL, + ImmutableList.of(CIDR_TYPE, SimpleType.STRING))), + CelFunctionBinding.from( + "cidr_contains_ip_ip", + CidrAddress.class, + IpAddress.class, + CelNetworkExtensions::cidrContainsIp), + CelFunctionBinding.from( + "cidr_contains_ip_string", + CidrAddress.class, + String.class, + CelNetworkExtensions::cidrContainsIpString)), + CIDR_CONTAINS_CIDR( + CelFunctionDecl.newFunctionDeclaration( + "containsCIDR", + CelOverloadDecl.newMemberOverload( + "cidr_contains_cidr_cidr", + "Checks if the CIDR block contains the other CIDR block object", + SimpleType.BOOL, + ImmutableList.of(CIDR_TYPE, CIDR_TYPE)), + CelOverloadDecl.newMemberOverload( + "cidr_contains_cidr_string", + "Checks if the CIDR block contains the other CIDR block string", + SimpleType.BOOL, + ImmutableList.of(CIDR_TYPE, SimpleType.STRING))), + CelFunctionBinding.from( + "cidr_contains_cidr_cidr", + CidrAddress.class, + CidrAddress.class, + CelNetworkExtensions::cidrContainsCidr), + CelFunctionBinding.from( + "cidr_contains_cidr_string", + CidrAddress.class, + String.class, + CelNetworkExtensions::cidrContainsCidrString)), + CIDR_MASKED( + CelFunctionDecl.newFunctionDeclaration( + "masked", + CelOverloadDecl.newMemberOverload( + "cidr_masked", + "Returns the network address (masked IP) of the CIDR block", + CIDR_TYPE, + ImmutableList.of(CIDR_TYPE))), + CelFunctionBinding.from( + "cidr_masked", CidrAddress.class, CelNetworkExtensions::cidrMasked)), + CIDR_PREFIX_LENGTH( + CelFunctionDecl.newFunctionDeclaration( + "prefixLength", + CelOverloadDecl.newMemberOverload( + "cidr_prefix_length", + "Returns the prefix length of the CIDR block", + SimpleType.INT, + ImmutableList.of(CIDR_TYPE))), + CelFunctionBinding.from( + "cidr_prefix_length", CidrAddress.class, CelNetworkExtensions::cidrPrefixLength)), + CIDR_TO_STRING( + CelFunctionDecl.newFunctionDeclaration( + "string", + CelOverloadDecl.newMemberOverload( + "cidr_to_string", + "Converts the CIDR block to its string representation", + SimpleType.STRING, + ImmutableList.of(CIDR_TYPE))), + CelFunctionBinding.from("cidr_to_string", CidrAddress.class, CidrAddress::toString)); + + private final CelFunctionDecl functionDecl; + private final ImmutableSet bindings; + + public String getFunction() { + return functionDecl.name(); + } + + Function(CelFunctionDecl functionDecl, CelFunctionBinding... bindings) { + this.functionDecl = functionDecl; + this.bindings = ImmutableSet.copyOf(bindings); + } + + public CelFunctionDecl getFunctionDecl() { + return functionDecl; + } + + public ImmutableSet getBindings() { + return bindings; + } + } + + private final ImmutableSet functions; + + // -------------------------------------------------------------------------- + // Library Registration + // -------------------------------------------------------------------------- + private static final CelExtensionLibrary LIBRARY = + new CelExtensionLibrary() { + private final CelNetworkExtensions version0 = new CelNetworkExtensions(); + + @Override + public String name() { + return "network"; + } + + @Override + public ImmutableSet versions() { + return ImmutableSet.of(version0); + } + }; + + public static CelExtensionLibrary library() { + return LIBRARY; + } + + @Override + public int version() { + return 0; + } + + @Override + public ImmutableSet functions() { + return functions.stream().map(Function::getFunctionDecl).collect(ImmutableSet.toImmutableSet()); + } + + // -------------------------------------------------------------------------- + // CelCheckerLibrary Implementation + // -------------------------------------------------------------------------- + @Override + public void setCheckerOptions(CelCheckerBuilder builder) { + for (Function func : functions) { + builder.addFunctionDeclarations(func.getFunctionDecl()); + } + builder.setTypeProvider(new NetworkTypeProvider()); + } + + // -------------------------------------------------------------------------- + // CelRuntimeLibrary Implementation + // -------------------------------------------------------------------------- + @Override + public void setRuntimeOptions(CelRuntimeBuilder builder) { + for (Function func : functions) { + builder.addFunctionBindings(func.getBindings()); + } + } + + // -------------------------------------------------------------------------- + // Function Implementations + // -------------------------------------------------------------------------- + private static boolean isIp(String val) { + try { + parseStrictIp(val); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } + + private static IpAddress stringToIp(String val) { + return IpAddress.create(val); + } + + private static boolean isCidr(String val) { + try { + CidrAddress.create(val); + return true; + } catch (IllegalArgumentException e) { + return false; + } + } + + private static CidrAddress stringToCidr(String val) { + return CidrAddress.create(val); + } + + private static boolean ipIsCanonicalString(String val) { + return ipIsCanonical(IpAddress.create(val)); + } + + private static boolean ipIsCanonical(IpAddress ip) { + InetAddress addr = ip.getAddress(); + // InetAddresses.toAddrString() returns the canonical string form. + // We check if the input string matches this canonical form. + return InetAddresses.toAddrString(addr).equals(ip.toString()); + } + + private static long ipFamily(IpAddress ip) { + return (ip.getAddress() instanceof Inet4Address) ? 4L : 6L; + } + + private static boolean ipIsLoopback(IpAddress ip) { + return ip.getAddress().isLoopbackAddress(); + } + + private static boolean ipIsGlobalUnicast(IpAddress ip) { + InetAddress addr = ip.getAddress(); + return !addr.isAnyLocalAddress() + && !addr.isLoopbackAddress() + && !addr.isLinkLocalAddress() + && !addr.isSiteLocalAddress() + && !addr.isMulticastAddress(); + } + + private static boolean ipIsLinkLocalMulticast(IpAddress ip) { + return ip.getAddress().isMCLinkLocal(); + } + + private static boolean ipIsLinkLocalUnicast(IpAddress ip) { + return ip.getAddress().isLinkLocalAddress() && !ip.getAddress().isMulticastAddress(); + } + + private static boolean ipIsUnspecified(IpAddress ip) { + return ip.getAddress().isAnyLocalAddress(); + } + + private static IpAddress cidrIp(CidrAddress cidr) { + return IpAddress.create(cidr.getOriginalHost()); + } + + private static boolean cidrContainsIp(CidrAddress cidr, IpAddress ip) { + return cidr.getBlock().contains(ip.getAddress()); + } + + private static boolean cidrContainsIpString(CidrAddress cidr, String ipStr) { + try { + return cidr.getBlock().contains(parseStrictIp(ipStr)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid IP string in containsIP", e); + } + } + + private static boolean cidrContainsCidr(CidrAddress parent, CidrAddress child) { + return parent.getBlock().contains(child.getBlock()); + } + + private static boolean cidrContainsCidrString(CidrAddress parent, String childStr) { + try { + CidrAddress child = CidrAddress.create(childStr); + return parent.getBlock().contains(child.getBlock()); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Invalid CIDR string in containsCIDR", e); + } + } + + private static CidrAddress cidrMasked(CidrAddress cidr) { + CidrAddressBlock maskedBlock = cidr.getBlock(); + return new CidrAddress(maskedBlock, maskedBlock.getInetAddress(), cidr.getPrefixLength()); + } + + private static long cidrPrefixLength(CidrAddress cidr) { + return (long) cidr.getPrefixLength(); + } + + // -------------------------------------------------------------------------- + // Custom Type Provider + // -------------------------------------------------------------------------- + @Immutable + private static class NetworkTypeProvider implements CelTypeProvider { + private static final ImmutableSet SUPPORTED_TYPES = + ImmutableSet.of(IP_TYPE, CIDR_TYPE); + + @Override + public Optional findType(String typeName) { + if (typeName.equals(IP_TYPE.name())) { + return Optional.of(IP_TYPE); + } + if (typeName.equals(CIDR_TYPE.name())) { + return Optional.of(CIDR_TYPE); + } + return Optional.empty(); + } + + @Override + public ImmutableCollection types() { + return SUPPORTED_TYPES; + } + } +} diff --git a/extensions/src/main/java/dev/cel/extensions/README.md b/extensions/src/main/java/dev/cel/extensions/README.md index 10c5217e8..82f905158 100644 --- a/extensions/src/main/java/dev/cel/extensions/README.md +++ b/extensions/src/main/java/dev/cel/extensions/README.md @@ -351,6 +351,228 @@ Examples: math.sqrt(4) // returns 2.0 math.sqrt(-4) // returns NaN +## Network + +The Network extension provides types and functions for working with IP addresses +and CIDR ranges. It introduces two opaque types: `net.IP` and `net.CIDR`. + +**Types** + +* `net.IP`: Represents an IP address (either IPv4 or IPv6). +* `net.CIDR`: Represents a CIDR range, retaining the original host and prefix length. + +**Functions** + +### isIP + +Checks if a string is a valid IP address (IPv4 or IPv6). Excludes addresses with +ports or zone indices. + + isIP() -> + +Examples: + + isIP("192.168.0.1") // returns true + isIP("2001:db8::1") // returns true + isIP("192.168.0.256") // returns false + isIP("1.2.3.4:80") // returns false + +### ip + +Converts a string to a `net.IP` object. Throws an error if the string is not a +valid IP address. + + ip() -> + +Examples: + + ip("127.0.0.1") // returns net.IP object + ip("2001:db8::1") // returns net.IP object + ip("invalid") // error + +### isCIDR + +Checks if a string is a valid CIDR notation (e.g., "192.168.0.0/24"). + + isCIDR() -> + +Examples: + + isCIDR("192.168.0.0/24") // returns true + isCIDR("2001:db8::/32") // returns true + isCIDR("192.168.0.0/33") // returns false + isCIDR("192.168.0.0") // returns false + +### cidr + +Converts a string in CIDR notation to a `net.CIDR` object. Throws an error if +the string is not valid CIDR notation. + + cidr() -> + +Examples: + + cidr("192.168.1.0/24") // returns net.CIDR object + cidr("2001:db8::/48") // returns net.CIDR object + cidr("192.168.1.0/33") // error + +### ip.isCanonical + +Checks if a string is the canonical representation of an IP address. + + ip.isCanonical() -> + +Examples: + + ip.isCanonical("192.168.0.1") // returns true + ip.isCanonical("2001:db8::1") // returns true + ip.isCanonical("2001:db8:0:0:0:0:0:1") // returns false (not canonical) + ip.isCanonical("127.00.0.1") // returns false (not canonical) + +### family + +Returns the IP family of a `net.IP` object as an integer (4 for IPv4, 6 for +IPv6). + + .family() -> + +Examples: + + ip("192.168.0.1").family() // returns 4 + ip("2001:db8::1").family() // returns 6 + +### isLoopback + +Checks if the `net.IP` object is a loopback address. + + .isLoopback() -> + +Examples: + + ip("127.0.0.1").isLoopback() // returns true + ip("::1").isLoopback() // returns true + ip("8.8.8.8").isLoopback() // returns false + +### isGlobalUnicast + +Checks if the `net.IP` object is a global unicast address. + + .isGlobalUnicast() -> + +Examples: + + ip("8.8.8.8").isGlobalUnicast() // returns true + ip("192.168.0.1").isGlobalUnicast() // returns false (private) + ip("127.0.0.1").isGlobalUnicast() // returns false (loopback) + +### isLinkLocalMulticast + +Checks if the `net.IP` object is a link-local multicast address. + + .isLinkLocalMulticast() -> + +Examples: + + ip("ff02::1").isLinkLocalMulticast() // returns true + ip("224.0.0.1").isLinkLocalMulticast() // returns false + +### isLinkLocalUnicast + +Checks if the `net.IP` object is a link-local unicast address. + + .isLinkLocalUnicast() -> + +Examples: + + ip("169.254.0.1").isLinkLocalUnicast() // returns true + ip("fe80::1").isLinkLocalUnicast() // returns true + ip("192.168.0.1").isLinkLocalUnicast() // returns false + +### isUnspecified + +Checks if the `net.IP` object is an unspecified address +(e.g., "0.0.0.0" or "::"). + + .isUnspecified() -> + +Examples: + + ip("0.0.0.0").isUnspecified() // returns true + ip("::").isUnspecified() // returns true + ip("1.2.3.4").isUnspecified() // returns false + +### string + +Converts a `net.IP` or `net.CIDR` object to its string representation. + + .string() -> + .string() -> + +Examples: + + ip("1.2.3.4").string() // returns "1.2.3.4" + cidr("10.0.0.0/8").string() // returns "10.0.0.0/8" + cidr("10.0.0.1/8").string() // returns "10.0.0.1/8" + +### ip (CIDR member) + +Returns the original base `net.IP` object from a `net.CIDR` object. + + .ip() -> + +Example: + + cidr("192.168.1.5/24").ip() // returns ip("192.168.1.5") + +### containsIP + +Checks if a `net.CIDR` range contains the given IP address (either as a `net.IP` +object or a string). + + .containsIP() -> + .containsIP() -> + +Examples: + + cidr("10.0.0.0/8").containsIP(ip("10.1.2.3")) // returns true + cidr("10.0.0.0/8").containsIP("10.1.2.3") // returns true + cidr("10.0.0.0/8").containsIP("11.0.0.1") // returns false + +### containsCIDR + +Checks if a `net.CIDR` range completely contains another CIDR range (either as a +`net.CIDR` object or a string). + + .containsCIDR() -> + .containsCIDR() -> + +Examples: + + cidr("10.0.0.0/8").containsCIDR(cidr("10.1.0.0/16")) // returns true + cidr("10.0.0.0/8").containsCIDR("10.1.0.0/16") // returns true + cidr("10.1.0.0/16").containsCIDR("10.0.0.0/8") // returns false + +### masked + +Returns a new `net.CIDR` object representing the network range with the host +bits masked off. + + .masked() -> + +Example: + + cidr("192.168.1.5/24").masked() // returns cidr("192.168.1.0/24") + +### prefixLength + +Returns the prefix length of the `net.CIDR` object. + + .prefixLength() -> + +Example: + + cidr("192.168.1.0/24").prefixLength() // returns 24 + ## Protos Extended macros and functions for proto manipulation. @@ -392,8 +614,8 @@ zero-based. ### CharAt -Returns the character at the given position. If the position is negative, or greater than -the length of the string, the function will produce an error. +Returns the character at the given position. If the position is negative, or +greater than the length of the string, the function will produce an error. .charAt() -> @@ -405,11 +627,12 @@ Examples: ### IndexOf -Returns the integer index of the first occurrence of the search string. If the search string is -not found the function returns -1. +Returns the integer index of the first occurrence of the search string. If the +search string is not found the function returns -1. -The function also accepts an optional offset from which to begin the substring search. If the -substring is the empty string, the index where the search starts is returned (zero or custom). +The function also accepts an optional offset from which to begin the substring +search. If the substring is the empty string, the index where the search starts +is returned (zero or custom). .indexOf() -> .indexOf(, ) -> @@ -427,7 +650,8 @@ Examples: Returns a new string where the elements of string list are concatenated. -The function also accepts an optional separator which is placed between elements in the resulting string. +The function also accepts an optional separator which is placed between elements +in the resulting string. >.join() -> >.join() -> @@ -495,8 +719,8 @@ Examples: ### Split -Returns a mutable list of strings split from the input by the given separator. The -function accepts an optional argument specifying a limit on the number of +Returns a mutable list of strings split from the input by the given separator. +The function accepts an optional argument specifying a limit on the number of substrings produced by the split. When the split limit is 0, the result is an empty list. When the limit is 1, @@ -682,7 +906,8 @@ Examples: Introduced at version: 1 -Flattens a list by one level, or to the specified level. Providing a negative level will error. +Flattens a list by one level, or to the specified level. Providing a negative +level will error. Examples: diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index 45d48aeca..b1175f821 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -14,7 +14,6 @@ java_library( "//common:compiler_common", "//common:container", "//common:options", - "//common/internal:proto_time_utils", "//common/types", "//common/types:type_providers", "//common/values", @@ -25,6 +24,7 @@ java_library( "//extensions:extension_library", "//extensions:lite_extensions", "//extensions:math", + "//extensions:network", "//extensions:optional_library", "//extensions:sets", "//extensions:sets_function", diff --git a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java index 61922f70f..599c5f481 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java @@ -188,6 +188,22 @@ public void getAllFunctionNames() { "regex.replace", "regex.extract", "regex.extractAll", + "isIP", + "ip", + "isCIDR", + "cidr", + "isCanonical", + "family", + "isLoopback", + "isGlobalUnicast", + "isLinkLocalMulticast", + "isLinkLocalUnicast", + "isUnspecified", + "string", + "containsIP", + "containsCIDR", + "masked", + "prefixLength", "cel.@mapInsert"); } } diff --git a/extensions/src/test/java/dev/cel/extensions/CelNetworkExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelNetworkExtensionsTest.java new file mode 100644 index 000000000..516a7212f --- /dev/null +++ b/extensions/src/test/java/dev/cel/extensions/CelNetworkExtensionsTest.java @@ -0,0 +1,264 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.extensions; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelRuntime; +import dev.cel.runtime.CelRuntimeFactory; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public final class CelNetworkExtensionsTest { + + private static final CelCompiler COMPILER = + CelCompilerFactory.standardCelCompilerBuilder() + .addLibraries(new CelNetworkExtensions()) + .build(); + + private static final CelRuntime RUNTIME = + CelRuntimeFactory.standardCelRuntimeBuilder() + .addLibraries(new CelNetworkExtensions()) + .build(); + + private Object eval(String expression) throws CelEvaluationException, CelValidationException { + CelAbstractSyntaxTree ast = COMPILER.compile(expression).getAst(); + return RUNTIME.createProgram(ast).eval(); + } + + // --- Global Checks (isIP, isCIDR) --- + @Test + @TestParameters({ + "{expr: 'isIP(\"1.2.3.4\")', expected: true}", + "{expr: 'isIP(\"2001:db8::1\")', expected: true}", + "{expr: 'isIP(\"not.an.ip\")', expected: false}", + "{expr: 'isIP(\"127.0.0.1:80\")', expected: false}", + "{expr: 'isIP(\"[2001:db8::1]:80\")', expected: false}", + "{expr: 'isIP(\"1.2.3.4%\")', expected: false}", + }) + public void isIP_testCases(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + @Test + @TestParameters({ + "{expr: 'isCIDR(\"10.0.0.0/8\")', expected: true}", + "{expr: 'isCIDR(\"10.0.0.1/8\")', expected: true}", + "{expr: 'isCIDR(\"2001:db8::/32\")', expected: true}", + "{expr: 'isCIDR(\"10.0.0.0/33\")', expected: false}", + "{expr: 'isCIDR(\"10.0.0.0/999\")', expected: false}", + "{expr: 'isCIDR(\"10.0.0.0\")', expected: false}", + "{expr: 'isCIDR(\"invalid\")', expected: false}", + }) + public void isCIDR_testCases(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- IP Constructors & Equality --- + @Test + @TestParameters({ + "{expr: 'ip(\"127.0.0.1\") == ip(\"127.0.0.1\")', expected: true}", + "{expr: 'ip(\"127.0.0.1\") == ip(\"1.2.3.4\")', expected: false}", + "{expr: 'ip(\"2001:db8::1\") == ip(\"2001:DB8::1\")', expected: true}", + "{expr: 'ip(\"2001:db8::1\") == ip(\"2001:db8::2\")', expected: false}", + }) + public void ip_equality(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- String Conversion --- + @Test + @TestParameters({ + "{expr: 'ip(\"1.2.3.4\").string()', expected: \"1.2.3.4\"}", + "{expr: 'ip(\"2001:db8::1\").string()', expected: \"2001:db8::1\"}", + "{expr: 'cidr(\"10.0.0.0/8\").string()', expected: \"10.0.0.0/8\"}", + "{expr: 'cidr(\"10.0.0.1/8\").string()', expected: \"10.0.0.1/8\"}", + "{expr: 'cidr(\"::1/128\").string()', expected: \"::1/128\"}", + }) + public void string_conversion(String expr, String expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- Family --- + @Test + @TestParameters({ + "{expr: 'ip(\"127.0.0.1\").family()', expected: 4}", + "{expr: 'ip(\"::1\").family()', expected: 6}", + }) + public void ip_family(String expr, long expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- Canonicalization --- + @Test + @TestParameters({ + "{expr: 'ip(\"127.0.0.1\").isCanonical()', expected: true}", + "{expr: 'ip(\"2001:db8::1\").isCanonical()', expected: true}", + "{expr: 'ip(ip(\"2001:DB8::1\").string()).isCanonical()', expected: true}", + "{expr: 'ip(ip(\"2001:db8:0:0:0:0:0:1\").string()).isCanonical()', expected:" + " true}", + }) + public void ip_isCanonical(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- IP Types (Loopback, Unspecified, etc) --- + @Test + @TestParameters({ + "{expr: 'ip(\"127.0.0.1\").isLoopback()', expected: true}", + "{expr: 'ip(\"::1\").isLoopback()', expected: true}", + "{expr: 'ip(\"192.168.0.1\").isLoopback()', expected: false}", + "{expr: 'ip(\"0.0.0.0\").isUnspecified()', expected: true}", + "{expr: 'ip(\"::\").isUnspecified()', expected: true}", + "{expr: 'ip(\"1.2.3.4\").isUnspecified()', expected: false}", + "{expr: 'ip(\"8.8.8.8\").isGlobalUnicast()', expected: true}", + "{expr: 'ip(\"192.168.0.1\").isGlobalUnicast()', expected: false}", // Private + "{expr: 'ip(\"127.0.0.1\").isGlobalUnicast()', expected: false}", // Loopback + "{expr: 'ip(\"ff02::1\").isLinkLocalMulticast()', expected: true}", + "{expr: 'ip(\"224.0.0.1\").isLinkLocalMulticast()', expected: true}", + "{expr: 'ip(\"224.0.1.1\").isLinkLocalMulticast()', expected: false}", + "{expr: 'ip(\"fe80::1\").isLinkLocalUnicast()', expected: true}", + "{expr: 'ip(\"169.254.0.1\").isLinkLocalUnicast()', expected: true}", + }) + public void ip_types(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- CIDR Accessors --- + @Test + @TestParameters({ + "{expr: 'cidr(\"192.168.0.0/24\").prefixLength()', expected: 24}", + "{expr: 'cidr(\"2001:db8::/32\").prefixLength()', expected: 32}", + }) + public void cidr_prefixLength(String expr, long expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + @Test + @TestParameters({ + "{expr: 'cidr(\"192.168.0.0/24\").ip() == ip(\"192.168.0.0\")', expected: true}", + "{expr: 'cidr(\"192.168.1.5/24\").ip() == ip(\"192.168.1.5\")', expected: true}", + "{expr: 'cidr(\"2001:db8::1/128\").ip() == ip(\"2001:db8::1\")', expected: true}", + }) + public void cidr_ip_extraction(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + @Test + @TestParameters({ + "{expr: 'cidr(\"192.168.1.5/24\").masked().string()', expected: \"192.168.1.0/24\"}", + "{expr: 'cidr(\"192.168.1.0/24\").masked().string()', expected: \"192.168.1.0/24\"}", + "{expr: 'cidr(\"2001:db8:abcd:1234::1/64\").masked().string()', expected:" + + " \"2001:db8:abcd:1234::/64\"}", + }) + public void cidr_masked(String expr, String expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- Containment (IP in CIDR) --- + @Test + @TestParameters({ + "{expr: 'cidr(\"10.0.0.0/8\").containsIP(ip(\"10.1.2.3\"))', expected: true}", + "{expr: 'cidr(\"10.0.0.0/8\").containsIP(ip(\"11.0.0.0\"))', expected: false}", + "{expr: 'cidr(\"10.0.0.0/8\").containsIP(\"10.255.255.255\")', expected: true}", + "{expr: 'cidr(\"2001:db8::/32\").containsIP(\"2001:db8:ffff::1\")', expected: true}", + "{expr: 'cidr(\"2001:db8::/32\").containsIP(\"2001:db9::\")', expected: false}", + }) + public void cidr_containsIP(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- Containment (CIDR in CIDR) --- + @Test + @TestParameters({ + "{expr: 'cidr(\"10.0.0.0/8\").containsCIDR(cidr(\"10.1.0.0/16\"))', expected: true}", + "{expr: 'cidr(\"10.1.0.0/16\").containsCIDR(cidr(\"10.0.0.0/8\"))', expected: false}", + "{expr: 'cidr(\"10.0.0.0/8\").containsCIDR(\"10.0.0.0/8\")', expected: true}", + "{expr: 'cidr(\"2001:db8::/32\").containsCIDR(\"2001:db8:abcd::/48\")', expected: true}", + "{expr: 'cidr(\"10.0.0.0/8\").containsCIDR(\"11.0.0.0/8\")', expected: false}", + "{expr: 'cidr(\"192.168.1.0/24\").containsCIDR(\"192.168.1.128/25\")', expected: true}", + "{expr: 'cidr(\"192.168.1.128/25\").containsCIDR(\"192.168.1.0/24\")', expected: false}", + }) + public void cidr_containsCIDR(String expr, boolean expected) throws Exception { + assertThat(eval(expr)).isEqualTo(expected); + } + + // --- Runtime Errors --- + @Test + public void err_ip_invalid() { + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> eval("ip('999.999.999.999')")); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Invalid IP address string"); + } + + @Test + public void err_cidr_invalidFormat() { + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> eval("cidr('1.2.3.4')")); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Invalid CIDR string format"); + } + + @Test + public void err_cidr_invalidMask() { + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> eval("cidr('10.0.0.0/999')")); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Invalid prefix length"); + } + + @Test + public void err_containsIP_stringInvalid() { + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, () -> eval("cidr('10.0.0.0/8').containsIP('not-an-ip')")); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Invalid IP string in containsIP"); + } + + @Test + public void err_containsCIDR_stringInvalid() { + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, + () -> eval("cidr('10.0.0.0/8').containsCIDR('not-a-cidr')")); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Invalid CIDR string in containsCIDR"); + } + + @Test + @TestParameters({ + "{ip: '192.168.1.1', expected: false}", + "{ip: '127.0.0.1', expected: false}", + "{ip: '0.0.0.0', expected: false}", + "{ip: '169.254.0.1', expected: true}", + "{ip: '::1', expected: false}", + "{ip: '2001:db8::1', expected: false}", + "{ip: 'fe80::1', expected: true}", + "{ip: 'ff02::1', expected: false}", // Multicast + }) + public void ip_isLinkLocalUnicast_testCases(String ip, boolean expected) throws Exception { + assertThat(eval(String.format("ip('%s').isLinkLocalUnicast()", ip))).isEqualTo(expected); + } +}