diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index ab9058e77..e877f4572 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -22,6 +22,8 @@ import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -276,6 +278,8 @@ public String description() { public abstract Map options(); + public abstract Optional> metadata(); + public List requiredArguments() { return requiredArgsSupplier.get(); } @@ -381,10 +385,12 @@ public abstract static class ScalarFunction { @Nullable public abstract String description(); + public abstract Optional> metadata(); + public abstract List impls(); public Stream resolve(String urn) { - return impls().stream().map(f -> f.resolve(urn, name(), description())); + return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata())); } } @@ -392,7 +398,8 @@ public Stream resolve(String urn) { @JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class) @Value.Immutable public abstract static class ScalarFunctionVariant extends Function { - public ScalarFunctionVariant resolve(String urn, String name, String description) { + public ScalarFunctionVariant resolve( + String urn, String name, String description, Optional> metadata) { return ImmutableSimpleExtension.ScalarFunctionVariant.builder() .urn(urn) .name(name) @@ -400,6 +407,7 @@ public ScalarFunctionVariant resolve(String urn, String name, String description .nullability(nullability()) .args(args()) .options(options()) + .metadata(metadata) .ordered(ordered()) .variadic(variadic()) .returnType(returnType()) @@ -417,10 +425,12 @@ public abstract static class AggregateFunction { @Nullable public abstract String description(); + public abstract Optional> metadata(); + public abstract List impls(); public Stream resolve(String urn) { - return impls().stream().map(f -> f.resolve(urn, name(), description())); + return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata())); } } @@ -434,10 +444,12 @@ public abstract static class WindowFunction { @Nullable public abstract String description(); + public abstract Optional> metadata(); + public abstract List impls(); public Stream resolve(String urn) { - return impls().stream().map(f -> f.resolve(urn, name(), description())); + return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata())); } public static ImmutableSimpleExtension.WindowFunction.Builder builder() { @@ -463,7 +475,8 @@ public String toString() { @Nullable public abstract TypeExpression intermediate(); - AggregateFunctionVariant resolve(String urn, String name, String description) { + AggregateFunctionVariant resolve( + String urn, String name, String description, Optional> metadata) { return ImmutableSimpleExtension.AggregateFunctionVariant.builder() .urn(urn) .name(name) @@ -471,6 +484,7 @@ AggregateFunctionVariant resolve(String urn, String name, String description) { .nullability(nullability()) .args(args()) .options(options()) + .metadata(metadata) .ordered(ordered()) .variadic(variadic()) .decomposability(decomposability()) @@ -505,7 +519,8 @@ public String toString() { return super.toString(); } - WindowFunctionVariant resolve(String urn, String name, String description) { + WindowFunctionVariant resolve( + String urn, String name, String description, Optional> metadata) { return ImmutableSimpleExtension.WindowFunctionVariant.builder() .urn(urn) .name(name) @@ -513,6 +528,7 @@ WindowFunctionVariant resolve(String urn, String name, String description) { .nullability(nullability()) .args(args()) .options(options()) + .metadata(metadata) .ordered(ordered()) .variadic(variadic()) .decomposability(decomposability()) @@ -549,6 +565,8 @@ public abstract static class Type { protected abstract Optional variadic(); + public abstract Optional> metadata(); + public TypeAnchor getAnchor() { return anchorSupplier.get(); } @@ -574,6 +592,9 @@ public abstract static class ExtensionSignatures { @JsonProperty("window_functions") public abstract List windows(); + @JsonProperty("metadata") + public abstract Optional> metadata(); + public int size() { return (types() == null ? 0 : types().size()) + (scalars() == null ? 0 : scalars().size()) @@ -643,6 +664,11 @@ BidiMap uriUrnMap() { return new BidiMap<>(); } + @Value.Default + public Map> extensionMetadata() { + return Collections.emptyMap(); + } + public abstract List types(); public abstract List scalarFunctions(); @@ -655,6 +681,16 @@ public static ImmutableSimpleExtension.ExtensionCollection.Builder builder() { return ImmutableSimpleExtension.ExtensionCollection.builder(); } + /** + * Gets the top-level metadata for a specific extension by URN. + * + * @param urn The URN of the extension + * @return The metadata map if present, empty Optional otherwise + */ + public Optional> getExtensionMetadata(String urn) { + return Optional.ofNullable(extensionMetadata().get(urn)); + } + public Type getType(TypeAnchor anchor) { Type type = typeLookup.get().get(anchor); if (type != null) { @@ -744,6 +780,10 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { mergedUriUrnMap.merge(uriUrnMap()); mergedUriUrnMap.merge(extensionCollection.uriUrnMap()); + Map> mergedExtensionMetadata = new HashMap<>(); + mergedExtensionMetadata.putAll(extensionMetadata()); + mergedExtensionMetadata.putAll(extensionCollection.extensionMetadata()); + return ImmutableSimpleExtension.ExtensionCollection.builder() .addAllAggregateFunctions(aggregateFunctions()) .addAllAggregateFunctions(extensionCollection.aggregateFunctions()) @@ -754,6 +794,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) { .addAllTypes(types()) .addAllTypes(extensionCollection.types()) .uriUrnMap(mergedUriUrnMap) + .extensionMetadata(mergedExtensionMetadata) .build(); } } @@ -859,6 +900,9 @@ public static ExtensionCollection buildExtensionCollection( BidiMap uriUrnMap = new BidiMap<>(); uriUrnMap.put(uri, urn); + Map> extMetadata = new HashMap<>(); + extensionSignatures.metadata().ifPresent(m -> extMetadata.put(urn, m)); + ImmutableSimpleExtension.ExtensionCollection collection = ImmutableSimpleExtension.ExtensionCollection.builder() .scalarFunctions(scalarFunctionVariants) @@ -866,6 +910,7 @@ public static ExtensionCollection buildExtensionCollection( .windowFunctions(allWindowFunctionVariants) .addAllTypes(extensionSignatures.types()) .uriUrnMap(uriUrnMap) + .extensionMetadata(extMetadata) .build(); LOGGER.atDebug().log( diff --git a/core/src/test/java/io/substrait/extension/MetadataExtensionTest.java b/core/src/test/java/io/substrait/extension/MetadataExtensionTest.java new file mode 100644 index 000000000..924c73837 --- /dev/null +++ b/core/src/test/java/io/substrait/extension/MetadataExtensionTest.java @@ -0,0 +1,137 @@ +package io.substrait.extension; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.Test; + +/** + * Verifies that metadata can be read from extension YAML files at multiple levels: + * + *
    + *
  • Extension-level metadata (top-level) + *
  • Type-level metadata + *
  • Function-level metadata (scalar, aggregate, window) + *
+ */ +class MetadataExtensionTest extends TestBase { + + static final String URN = "extension:test:metadata_extensions"; + static final SimpleExtension.ExtensionCollection METADATA_EXTENSION; + + static { + try { + String extensionStr = asString("extensions/metadata_extensions.yaml"); + METADATA_EXTENSION = SimpleExtension.load(URN, extensionStr); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + MetadataExtensionTest() { + super(METADATA_EXTENSION); + } + + @Test + void testExtensionLevelMetadata() { + Optional> metadata = extensions.getExtensionMetadata(URN); + assertTrue(metadata.isPresent(), "Extension metadata should be present"); + + Map meta = metadata.get(); + assertEquals("1.0", meta.get("version")); + assertEquals("test-team", meta.get("author")); + + // Test nested metadata + @SuppressWarnings("unchecked") + Map customData = (Map) meta.get("custom_data"); + assertEquals(true, customData.get("nested_value")); + assertEquals(42, customData.get("numeric_value")); + } + + @Test + void testExtensionLevelMetadataMissing() { + Optional> metadata = + extensions.getExtensionMetadata("extension:nonexistent:urn"); + assertFalse(metadata.isPresent(), "Metadata for non-existent URN should be empty"); + } + + @Test + void testTypeMetadata() { + SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(URN, "metadataType"); + SimpleExtension.Type type = extensions.getType(anchor); + + Optional> metadata = type.metadata(); + assertTrue(metadata.isPresent(), "Type metadata should be present"); + + Map meta = metadata.get(); + assertEquals("custom-type-metadata", meta.get("type_info")); + assertEquals("user-defined", meta.get("category")); + } + + @Test + void testScalarFunctionMetadata() { + SimpleExtension.FunctionAnchor anchor = + SimpleExtension.FunctionAnchor.of(URN, "metadataScalar:i64"); + SimpleExtension.ScalarFunctionVariant fn = extensions.getScalarFunction(anchor); + + Optional> metadata = fn.metadata(); + assertTrue(metadata.isPresent(), "Scalar function metadata should be present"); + + Map meta = metadata.get(); + assertEquals("vectorized", meta.get("perf_hint")); + assertEquals(1, meta.get("cost")); + } + + @Test + void testAggregateFunctionMetadata() { + SimpleExtension.FunctionAnchor anchor = + SimpleExtension.FunctionAnchor.of(URN, "metadataAggregate:i64"); + SimpleExtension.AggregateFunctionVariant fn = extensions.getAggregateFunction(anchor); + + Optional> metadata = fn.metadata(); + assertTrue(metadata.isPresent(), "Aggregate function metadata should be present"); + + Map meta = metadata.get(); + assertEquals("incremental", meta.get("agg_info")); + } + + @Test + void testWindowFunctionMetadata() { + SimpleExtension.FunctionAnchor anchor = + SimpleExtension.FunctionAnchor.of(URN, "metadataWindow:i64"); + SimpleExtension.WindowFunctionVariant fn = extensions.getWindowFunction(anchor); + + Optional> metadata = fn.metadata(); + assertTrue(metadata.isPresent(), "Window function metadata should be present"); + + Map meta = metadata.get(); + assertEquals("partitioned", meta.get("window_info")); + } + + @Test + void testMergePreservesMetadata() throws IOException { + // Load a second extension without metadata + String customExtensionStr = asString("extensions/custom_extensions.yaml"); + SimpleExtension.ExtensionCollection customExtension = + SimpleExtension.load("extension:test:custom_extensions", customExtensionStr); + + // Merge the two collections + SimpleExtension.ExtensionCollection merged = METADATA_EXTENSION.merge(customExtension); + + // Verify metadata from first extension is still accessible + Optional> metadata = merged.getExtensionMetadata(URN); + assertTrue(metadata.isPresent(), "Metadata should be preserved after merge"); + assertEquals("1.0", metadata.get().get("version")); + + // Verify the second extension has no metadata + Optional> customMetadata = + merged.getExtensionMetadata("extension:test:custom_extensions"); + assertFalse(customMetadata.isPresent(), "Custom extension should have no metadata"); + } +} diff --git a/core/src/test/resources/extensions/metadata_extensions.yaml b/core/src/test/resources/extensions/metadata_extensions.yaml new file mode 100644 index 000000000..596b99e45 --- /dev/null +++ b/core/src/test/resources/extensions/metadata_extensions.yaml @@ -0,0 +1,39 @@ +%YAML 1.2 +--- +urn: extension:test:metadata_extensions +metadata: + version: "1.0" + author: "test-team" + custom_data: + nested_value: true + numeric_value: 42 +types: + - name: "metadataType" + metadata: + type_info: "custom-type-metadata" + category: "user-defined" +scalar_functions: + - name: "metadataScalar" + metadata: + perf_hint: "vectorized" + cost: 1 + impls: + - args: + - value: i64 + return: i64 +aggregate_functions: + - name: "metadataAggregate" + metadata: + agg_info: "incremental" + impls: + - args: + - value: i64 + return: i64 +window_functions: + - name: "metadataWindow" + metadata: + window_info: "partitioned" + impls: + - args: + - value: i64 + return: i64 diff --git a/substrait b/substrait index 92d2e757a..8cf616e7d 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit 92d2e757a330f9c973bb566817dc92afd1badcb2 +Subproject commit 8cf616e7d034282b8f5dc9cc1cc5f2410d1fac1b