Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions core/src/main/java/io/substrait/extension/SimpleExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -276,6 +278,8 @@ public String description() {

public abstract Map<String, Option> options();

public abstract Optional<Map<String, Object>> metadata();

public List<Argument> requiredArguments() {
return requiredArgsSupplier.get();
}
Expand Down Expand Up @@ -381,25 +385,29 @@ public abstract static class ScalarFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<ScalarFunctionVariant> impls();

public Stream<ScalarFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}
}

@JsonDeserialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@Value.Immutable
public abstract static class ScalarFunctionVariant extends Function {
public ScalarFunctionVariant resolve(String urn, String name, String description) {
public ScalarFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.ScalarFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.returnType(returnType())
Expand All @@ -417,10 +425,12 @@ public abstract static class AggregateFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<AggregateFunctionVariant> impls();

public Stream<AggregateFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}
}

Expand All @@ -434,10 +444,12 @@ public abstract static class WindowFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<WindowFunctionVariant> impls();

public Stream<WindowFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}

public static ImmutableSimpleExtension.WindowFunction.Builder builder() {
Expand All @@ -463,14 +475,16 @@ public String toString() {
@Nullable
public abstract TypeExpression intermediate();

AggregateFunctionVariant resolve(String urn, String name, String description) {
AggregateFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.AggregateFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
Expand Down Expand Up @@ -505,14 +519,16 @@ public String toString() {
return super.toString();
}

WindowFunctionVariant resolve(String urn, String name, String description) {
WindowFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.WindowFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
Expand Down Expand Up @@ -549,6 +565,8 @@ public abstract static class Type {

protected abstract Optional<Boolean> variadic();

public abstract Optional<Map<String, Object>> metadata();

public TypeAnchor getAnchor() {
return anchorSupplier.get();
}
Expand All @@ -574,6 +592,9 @@ public abstract static class ExtensionSignatures {
@JsonProperty("window_functions")
public abstract List<WindowFunction> windows();

@JsonProperty("metadata")
public abstract Optional<Map<String, Object>> metadata();

public int size() {
return (types() == null ? 0 : types().size())
+ (scalars() == null ? 0 : scalars().size())
Expand Down Expand Up @@ -643,6 +664,11 @@ BidiMap<String, String> uriUrnMap() {
return new BidiMap<>();
}

@Value.Default
public Map<String, Map<String, Object>> extensionMetadata() {
return Collections.emptyMap();
}

public abstract List<Type> types();

public abstract List<ScalarFunctionVariant> scalarFunctions();
Expand All @@ -655,6 +681,16 @@ public static ImmutableSimpleExtension.ExtensionCollection.Builder builder() {
return ImmutableSimpleExtension.ExtensionCollection.builder();
}

/**
* Gets the top-level metadata for a specific extension by URN.
*
* @param urn The URN of the extension
* @return The metadata map if present, empty Optional otherwise
*/
public Optional<Map<String, Object>> getExtensionMetadata(String urn) {
return Optional.ofNullable(extensionMetadata().get(urn));
}

public Type getType(TypeAnchor anchor) {
Type type = typeLookup.get().get(anchor);
if (type != null) {
Expand Down Expand Up @@ -744,6 +780,10 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
mergedUriUrnMap.merge(uriUrnMap());
mergedUriUrnMap.merge(extensionCollection.uriUrnMap());

Map<String, Map<String, Object>> mergedExtensionMetadata = new HashMap<>();
mergedExtensionMetadata.putAll(extensionMetadata());
mergedExtensionMetadata.putAll(extensionCollection.extensionMetadata());

return ImmutableSimpleExtension.ExtensionCollection.builder()
.addAllAggregateFunctions(aggregateFunctions())
.addAllAggregateFunctions(extensionCollection.aggregateFunctions())
Expand All @@ -754,6 +794,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
.addAllTypes(types())
.addAllTypes(extensionCollection.types())
.uriUrnMap(mergedUriUrnMap)
.extensionMetadata(mergedExtensionMetadata)
.build();
}
}
Expand Down Expand Up @@ -859,13 +900,17 @@ public static ExtensionCollection buildExtensionCollection(
BidiMap<String, String> uriUrnMap = new BidiMap<>();
uriUrnMap.put(uri, urn);

Map<String, Map<String, Object>> extMetadata = new HashMap<>();
extensionSignatures.metadata().ifPresent(m -> extMetadata.put(urn, m));

ImmutableSimpleExtension.ExtensionCollection collection =
ImmutableSimpleExtension.ExtensionCollection.builder()
.scalarFunctions(scalarFunctionVariants)
.aggregateFunctions(aggregateFunctionVariants)
.windowFunctions(allWindowFunctionVariants)
.addAllTypes(extensionSignatures.types())
.uriUrnMap(uriUrnMap)
.extensionMetadata(extMetadata)
.build();

LOGGER.atDebug().log(
Expand Down
137 changes: 137 additions & 0 deletions core/src/test/java/io/substrait/extension/MetadataExtensionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package io.substrait.extension;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.substrait.TestBase;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.Optional;
import org.junit.jupiter.api.Test;

/**
* Verifies that metadata can be read from extension YAML files at multiple levels:
*
* <ul>
* <li>Extension-level metadata (top-level)
* <li>Type-level metadata
* <li>Function-level metadata (scalar, aggregate, window)
* </ul>
*/
class MetadataExtensionTest extends TestBase {

static final String URN = "extension:test:metadata_extensions";
static final SimpleExtension.ExtensionCollection METADATA_EXTENSION;

static {
try {
String extensionStr = asString("extensions/metadata_extensions.yaml");
METADATA_EXTENSION = SimpleExtension.load(URN, extensionStr);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

MetadataExtensionTest() {
super(METADATA_EXTENSION);
}

@Test
void testExtensionLevelMetadata() {
Optional<Map<String, Object>> metadata = extensions.getExtensionMetadata(URN);
assertTrue(metadata.isPresent(), "Extension metadata should be present");

Map<String, Object> meta = metadata.get();
assertEquals("1.0", meta.get("version"));
assertEquals("test-team", meta.get("author"));

// Test nested metadata
@SuppressWarnings("unchecked")
Map<String, Object> customData = (Map<String, Object>) meta.get("custom_data");
assertEquals(true, customData.get("nested_value"));
assertEquals(42, customData.get("numeric_value"));
}

@Test
void testExtensionLevelMetadataMissing() {
Optional<Map<String, Object>> metadata =
extensions.getExtensionMetadata("extension:nonexistent:urn");
assertFalse(metadata.isPresent(), "Metadata for non-existent URN should be empty");
}

@Test
void testTypeMetadata() {
SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(URN, "metadataType");
SimpleExtension.Type type = extensions.getType(anchor);

Optional<Map<String, Object>> metadata = type.metadata();
assertTrue(metadata.isPresent(), "Type metadata should be present");

Map<String, Object> meta = metadata.get();
assertEquals("custom-type-metadata", meta.get("type_info"));
assertEquals("user-defined", meta.get("category"));
}

@Test
void testScalarFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataScalar:i64");
SimpleExtension.ScalarFunctionVariant fn = extensions.getScalarFunction(anchor);

Optional<Map<String, Object>> metadata = fn.metadata();
assertTrue(metadata.isPresent(), "Scalar function metadata should be present");

Map<String, Object> meta = metadata.get();
assertEquals("vectorized", meta.get("perf_hint"));
assertEquals(1, meta.get("cost"));
}

@Test
void testAggregateFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataAggregate:i64");
SimpleExtension.AggregateFunctionVariant fn = extensions.getAggregateFunction(anchor);

Optional<Map<String, Object>> metadata = fn.metadata();
assertTrue(metadata.isPresent(), "Aggregate function metadata should be present");

Map<String, Object> meta = metadata.get();
assertEquals("incremental", meta.get("agg_info"));
}

@Test
void testWindowFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataWindow:i64");
SimpleExtension.WindowFunctionVariant fn = extensions.getWindowFunction(anchor);

Optional<Map<String, Object>> metadata = fn.metadata();
assertTrue(metadata.isPresent(), "Window function metadata should be present");

Map<String, Object> meta = metadata.get();
assertEquals("partitioned", meta.get("window_info"));
}

@Test
void testMergePreservesMetadata() throws IOException {
// Load a second extension without metadata
String customExtensionStr = asString("extensions/custom_extensions.yaml");
SimpleExtension.ExtensionCollection customExtension =
SimpleExtension.load("extension:test:custom_extensions", customExtensionStr);

// Merge the two collections
SimpleExtension.ExtensionCollection merged = METADATA_EXTENSION.merge(customExtension);

// Verify metadata from first extension is still accessible
Optional<Map<String, Object>> metadata = merged.getExtensionMetadata(URN);
assertTrue(metadata.isPresent(), "Metadata should be preserved after merge");
assertEquals("1.0", metadata.get().get("version"));

// Verify the second extension has no metadata
Optional<Map<String, Object>> customMetadata =
merged.getExtensionMetadata("extension:test:custom_extensions");
assertFalse(customMetadata.isPresent(), "Custom extension should have no metadata");
}
}
39 changes: 39 additions & 0 deletions core/src/test/resources/extensions/metadata_extensions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%YAML 1.2
---
urn: extension:test:metadata_extensions
metadata:
version: "1.0"
author: "test-team"
custom_data:
nested_value: true
numeric_value: 42
types:
- name: "metadataType"
metadata:
type_info: "custom-type-metadata"
category: "user-defined"
scalar_functions:
- name: "metadataScalar"
metadata:
perf_hint: "vectorized"
cost: 1
impls:
- args:
- value: i64
return: i64
aggregate_functions:
- name: "metadataAggregate"
metadata:
agg_info: "incremental"
impls:
- args:
- value: i64
return: i64
window_functions:
- name: "metadataWindow"
metadata:
window_info: "partitioned"
impls:
- args:
- value: i64
return: i64