diff --git a/api/src/main/java/io/grpc/Grpc.java b/api/src/main/java/io/grpc/Grpc.java index baa9f5f0ab6..cd25f79041b 100644 --- a/api/src/main/java/io/grpc/Grpc.java +++ b/api/src/main/java/io/grpc/Grpc.java @@ -101,6 +101,35 @@ public static ManagedChannelBuilder newChannelBuilder( return ManagedChannelRegistry.getDefaultRegistry().newChannelBuilder(target, creds); } + /** + * Creates a channel builder with a target string, credentials, and a specific + * name resolver registry. + * + *

This method uses the {@link ManagedChannelRegistry#getDefaultRegistry()} to + * find an appropriate underlying transport provider based on the target and credentials. + * The provided {@code nameResolverRegistry} is used to resolve the target address + * into physical addresses (e.g., DNS or custom schemes). + * + * @param target the target URI for the channel, such as {@code "localhost:8080"} + * or {@code "dns:///example.com"} + * @param creds the channel credentials to use for secure communication + * @param nameResolverRegistry the registry used to look up {@link NameResolver} + * providers for the target + * @return a {@link ManagedChannelBuilder} instance configured with the given parameters + * @throws IllegalArgumentException if no provider is available for the given target + * or credentials + * @since 1.79.0 + */ + public static ManagedChannelBuilder newChannelBuilder( + String target, + ChannelCredentials creds, + NameResolverRegistry nameResolverRegistry) { + return ManagedChannelRegistry.getDefaultRegistry().newChannelBuilder( + nameResolverRegistry, + target, + creds); + } + /** * Creates a channel builder from a host, port, and credentials. The host and port are combined to * form an authority string and then passed to {@link #newChannelBuilder(String, diff --git a/api/src/main/java/io/grpc/ManagedChannelProvider.java b/api/src/main/java/io/grpc/ManagedChannelProvider.java index 42941dfc809..18a4329a146 100644 --- a/api/src/main/java/io/grpc/ManagedChannelProvider.java +++ b/api/src/main/java/io/grpc/ManagedChannelProvider.java @@ -81,6 +81,31 @@ protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCreden return NewChannelBuilderResult.error("ChannelCredentials are unsupported"); } + /** + * Creates a channel builder using the provided target, credentials, and resolution + * components. + * + *

This method allows for fine-grained control over name resolution by providing + * both a {@link NameResolverRegistry} and a specific {@link NameResolverProvider}. + * Unlike the public factory methods, this returns a {@link NewChannelBuilderResult}, + * which may contain an error string if the provided credentials or target are + * not supported by this provider. + * + * @param target the target URI for the channel + * @param creds the channel credentials to use + * @param nameResolverRegistry the registry used for looking up name resolvers + * @param nameResolverProvider a specific provider to use, or {@code null} to + * search the registry + * @return a {@link NewChannelBuilderResult} containing either the builder or an + * error description + * @since 1.79.0 + */ + protected NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { + return newChannelBuilder(target, creds); + } + /** * Returns the {@link SocketAddress} types this ManagedChannelProvider supports. */ diff --git a/api/src/main/java/io/grpc/ManagedChannelRegistry.java b/api/src/main/java/io/grpc/ManagedChannelRegistry.java index ec47b325ffc..c70b6812651 100644 --- a/api/src/main/java/io/grpc/ManagedChannelRegistry.java +++ b/api/src/main/java/io/grpc/ManagedChannelRegistry.java @@ -158,7 +158,6 @@ ManagedChannelBuilder newChannelBuilder(String target, ChannelCredentials cre return newChannelBuilder(NameResolverRegistry.getDefaultRegistry(), target, creds); } - @VisibleForTesting ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegistry, String target, ChannelCredentials creds) { NameResolverProvider nameResolverProvider = null; @@ -198,7 +197,7 @@ ManagedChannelBuilder newChannelBuilder(NameResolverRegistry nameResolverRegi continue; } ManagedChannelProvider.NewChannelBuilderResult result - = provider.newChannelBuilder(target, creds); + = provider.newChannelBuilder(target, creds, nameResolverRegistry, nameResolverProvider); if (result.getChannelBuilder() != null) { return result.getChannelBuilder(); } diff --git a/api/src/main/java/io/grpc/NameResolver.java b/api/src/main/java/io/grpc/NameResolver.java index 53dbc5d6888..9810d65bab3 100644 --- a/api/src/main/java/io/grpc/NameResolver.java +++ b/api/src/main/java/io/grpc/NameResolver.java @@ -191,7 +191,7 @@ public abstract static class Factory { */ public NameResolver newNameResolver(Uri targetUri, final Args args) { // Not every io.grpc.Uri can be converted but in the ordinary ManagedChannel creation flow, - // any IllegalArgumentException thrown here would happened anyway, just earlier. That's + // any IllegalArgumentException thrown here would have happened anyway, just earlier. That's // because parse/toString is transparent so java.net.URI#create here sees the original target // string just like it did before the io.grpc.Uri migration. // diff --git a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java index 2479e339791..7b02459290e 100644 --- a/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java +++ b/api/src/test/java/io/grpc/ManagedChannelRegistryTest.java @@ -243,6 +243,53 @@ public NewChannelBuilderResult newChannelBuilder( mcb); } + @Test + public void newChannelBuilder_propagatesRegistry() { + final NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); + class SocketAddress1 extends SocketAddress { + } + + ManagedChannelRegistry registry = new ManagedChannelRegistry(); + class MockChannelBuilder extends ForwardingChannelBuilder2 { + @Override + public ManagedChannelBuilder delegate() { + throw new UnsupportedOperationException(); + } + } + + final ManagedChannelBuilder mcb = new MockChannelBuilder(); + registry.register(new BaseProvider(true, 4) { + @Override + protected Collection> getSupportedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + + @Override + public NewChannelBuilderResult newChannelBuilder( + String passedTarget, ChannelCredentials passedCreds, + NameResolverRegistry passedRegistry, NameResolverProvider passedProvider) { + assertThat(passedRegistry).isSameInstanceAs(nameResolverRegistry); + return NewChannelBuilderResult.channelBuilder(mcb); + } + }); + + // ManagedChannelRegistry.newChannelBuilder(NameResolverRegistry, String, ChannelCredentials) + // gets the scheme from target. Then it gets NameResolverProvider from registry for that scheme. + // Then it gets producedSocketAddressTypes from that provider. + // Then it finds a ManagedChannelProvider that supports those types. + // So we need a registered NameResolverProvider for the scheme. + nameResolverRegistry.register(new BaseNameResolverProvider(true, 5, "sc1") { + @Override + public Collection> getProducedSocketAddressTypes() { + return Collections.singleton(SocketAddress1.class); + } + }); + + assertThat( + registry.newChannelBuilder(nameResolverRegistry, "sc1:" + target, creds)).isSameInstanceAs( + mcb); + } + @Test public void newChannelBuilder_unsupportedSocketAddressTypes() { NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index 128c929ec0e..ed043ce47f1 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -156,6 +156,9 @@ public static ManagedChannelBuilder forTarget(String target) { private final List interceptors = new ArrayList<>(); NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); + @Nullable + NameResolverProvider nameResolverProvider; + final List transportFilters = new ArrayList<>(); final String target; @@ -307,6 +310,49 @@ public ManagedChannelImplBuilder( InternalConfiguratorRegistry.configureChannelBuilder(this); } + /** + * Creates a new managed channel builder with a target string, which can be + * either a valid {@link io.grpc.NameResolver}-compliant URI, or an authority + * string. Transport + * implementors must provide client transport factory builder, and may set + * custom channel default + * port provider. + * + * @param channelCreds The ChannelCredentials provided by the user. + * These may be used when + * creating derivative channels. + * @param nameResolverRegistry the registry used to look up name resolvers. + * @param nameResolverProvider the provider used to look up name resolvers. + */ + public ManagedChannelImplBuilder( + String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, + ClientTransportFactoryBuilder clientTransportFactoryBuilder, + @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider, + @Nullable NameResolverRegistry nameResolverRegistry, + @Nullable NameResolverProvider nameResolverProvider) { + this.target = checkNotNull(target, "target"); + this.channelCredentials = channelCreds; + this.callCredentials = callCreds; + this.clientTransportFactoryBuilder = checkNotNull(clientTransportFactoryBuilder, + "clientTransportFactoryBuilder"); + this.directServerAddress = null; + + if (channelBuilderDefaultPortProvider != null) { + this.channelBuilderDefaultPortProvider = channelBuilderDefaultPortProvider; + } else { + this.channelBuilderDefaultPortProvider = new ManagedChannelDefaultPortProvider(); + } + if (nameResolverRegistry != null) { + this.nameResolverRegistry = nameResolverRegistry; + } + if (nameResolverProvider != null) { + this.nameResolverProvider = nameResolverProvider; + } + + // TODO(dnvindhya): Move configurator to all the individual builders + InternalConfiguratorRegistry.configureChannelBuilder(this); + } + /** * Returns a target string for the SocketAddress. It is only used as a placeholder, because * DirectAddressNameResolverProvider will not actually try to use it. However, it must be a valid @@ -422,6 +468,7 @@ public ManagedChannelImplBuilder nameResolverFactory(NameResolver.Factory resolv Preconditions.checkState(directServerAddress == null, "directServerAddress is set (%s), which forbids the use of NameResolverFactory", directServerAddress); + if (resolverFactory != null) { NameResolverRegistry reg = new NameResolverRegistry(); if (resolverFactory instanceof NameResolverProvider) { @@ -724,7 +771,7 @@ public ManagedChannel build() { ResolvedNameResolver resolvedResolver = InternalFeatureFlags.getRfc3986UrisEnabled() ? getNameResolverProviderRfc3986(target, nameResolverRegistry) - : getNameResolverProvider(target, nameResolverRegistry); + : getNameResolverProvider(target, nameResolverRegistry, nameResolverProvider); resolvedResolver.checkAddressTypes(clientTransportFactory.getSupportedSocketAddressTypes()); return new ManagedChannelOrphanWrapper(new ManagedChannelImpl( this, @@ -845,7 +892,8 @@ void checkAddressTypes( @VisibleForTesting static ResolvedNameResolver getNameResolverProvider( - String target, NameResolverRegistry nameResolverRegistry) { + String target, NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { // Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending // "dns:///". NameResolverProvider provider = null; @@ -860,19 +908,34 @@ static ResolvedNameResolver getNameResolverProvider( if (targetUri != null) { // For "localhost:8080" this would likely cause provider to be null, because "localhost" is // parsed as the scheme. Will hit the next case and try "dns:///localhost:8080". - provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + provider = nameResolverProvider; + if (provider == null) { + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } } - if (provider == null && !URI_PATTERN.matcher(target).matches()) { - // It doesn't look like a URI target. Maybe it's an authority string. Try with the default - // scheme from the registry. + if (!URI_PATTERN.matcher(target).matches()) { + // It doesn't look like a URI target. Maybe it's an authority string. Try with + // the default scheme from the registry (if provider is not specified) or + // the provider's default scheme (if provider is specified). + String scheme = (provider != null) + ? provider.getDefaultScheme() + : (nameResolverProvider != null + ? nameResolverProvider.getDefaultScheme() + : nameResolverRegistry.getDefaultScheme()); try { - targetUri = new URI(nameResolverRegistry.getDefaultScheme(), "", "/" + target, null); + targetUri = new URI(scheme, "", "/" + target, null); } catch (URISyntaxException e) { - // Should not be possible. + // Should not happen because we just validated the URI. throw new IllegalArgumentException(e); } - provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + if (provider == null) { + if (nameResolverProvider != null) { + provider = nameResolverProvider; + } else { + provider = nameResolverRegistry.getProviderForScheme(targetUri.getScheme()); + } + } } if (provider == null) { diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index b0939239477..61c47d8b35b 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -48,11 +48,13 @@ import io.grpc.MethodDescriptor; import io.grpc.MetricSink; import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; import io.grpc.NameResolverRegistry; import io.grpc.StaticTestingClassLoader; import io.grpc.internal.ManagedChannelImplBuilder.ChannelBuilderDefaultPortProvider; import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; import io.grpc.internal.ManagedChannelImplBuilder.FixedPortProvider; +import io.grpc.internal.ManagedChannelImplBuilder.ResolvedNameResolver; import io.grpc.internal.ManagedChannelImplBuilder.UnsupportedClientTransportFactoryBuilder; import io.grpc.testing.GrpcCleanupRule; import java.net.InetSocketAddress; @@ -385,7 +387,7 @@ public void transportDoesNotSupportAddressTypes() { builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); try { - ManagedChannel unused = grpcCleanupRule.register(builder.build()); + grpcCleanupRule.register(builder.build()); fail("Should fail"); } catch (IllegalArgumentException e) { assertThat(e) @@ -408,7 +410,7 @@ public void transportAddressTypeCompatibilityCheckSkipped() { builder = new ManagedChannelImplBuilder(DUMMY_AUTHORITY_VALID, mockClientTransportFactoryBuilder, new FixedPortProvider(DUMMY_PORT)); // should not fail - ManagedChannel unused = grpcCleanupRule.register(builder.build()); + grpcCleanupRule.register(builder.build()); } @Test @@ -800,4 +802,109 @@ public void uriPattern() { } private static class CustomSocketAddress extends SocketAddress {} + + @Test + public void getNameResolverProvider_explicitProviderWithIpTarget() { + String target = "127.0.0.1:8080"; + NameResolverRegistry registry = new NameResolverRegistry(); + NameResolverProvider explicitProvider = mock(NameResolverProvider.class); + when(explicitProvider.getDefaultScheme()).thenReturn("dns"); + + ManagedChannelImplBuilder.ResolvedNameResolver resolved; + resolved = ManagedChannelImplBuilder + .getNameResolverProvider(target, registry, explicitProvider); + + assertThat(resolved.provider).isSameInstanceAs(explicitProvider); + assertThat(resolved.targetUri.toString()).isEqualTo("dns:///127.0.0.1:8080"); + } + + @Test + public void getNameResolverProvider_explicitProviderWithInvalidUri() { + String target = "::1"; + NameResolverRegistry registry = new NameResolverRegistry(); + NameResolverProvider explicitProvider = mock(NameResolverProvider.class); + when(explicitProvider.getDefaultScheme()).thenReturn("dns"); + + ManagedChannelImplBuilder.ResolvedNameResolver resolved; + resolved = ManagedChannelImplBuilder + .getNameResolverProvider(target, registry, explicitProvider); + + assertThat(resolved.provider).isSameInstanceAs(explicitProvider); + assertThat(resolved.targetUri.toString()).isEqualTo("dns:///::1"); + } + + @Test + public void getNameResolverProvider_explicitProviderWithValidUri() { + String target = "dns:///localhost"; + NameResolverRegistry registry = new NameResolverRegistry(); + NameResolverProvider explicitProvider = new NameResolverProvider() { + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return "dns"; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + }; + + ResolvedNameResolver resolved = ManagedChannelImplBuilder.getNameResolverProvider( + target, registry, explicitProvider); + + // Should prefer explicit provider if scheme matches? + // Current logic: provider passed to getNameResolverProvider is prioritized for + // SCHEME determination for fallback + // BUT for valid URI, it logic matches URI scheme. + // If explicit provider is passed, it is used if target not valid URI. + // If target IS valid URI, it checks if provider != null. + // Wait, the code: + // if (provider == null) { ... } + // If explicit 'provider' arg is NOT null, logic uses it? + // Let's re-read ManagedChannelImplBuilder.getNameResolverProvider + assertThat(resolved.provider).isSameInstanceAs(explicitProvider); + } + + @Test + public void getNameResolverProvider_registryFallback() { + String target = "dns:///localhost"; + final NameResolverProvider registryProvider = new NameResolverProvider() { + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return null; + } + + @Override + public String getDefaultScheme() { + return "dns"; + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + }; + NameResolverRegistry registry = new NameResolverRegistry(); + registry.register(registryProvider); + + ResolvedNameResolver resolved = ManagedChannelImplBuilder.getNameResolverProvider( + target, registry, null); + + assertThat(resolved.provider).isSameInstanceAs(registryProvider); + } } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java index 792f4daca4e..4bfc66beb12 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplGetNameResolverTest.java @@ -118,7 +118,8 @@ public void validTargetNoProvider() { NameResolverRegistry nameResolverRegistry = new NameResolverRegistry(); try { ManagedChannelImplBuilder.getNameResolverProvider( - "foo.googleapis.com:8080", nameResolverRegistry); + "foo.googleapis.com:8080", nameResolverRegistry, + null); fail("Should fail"); } catch (IllegalArgumentException e) { // expected @@ -130,7 +131,7 @@ public void validTargetProviderAddrTypesNotSupported() { NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme"); try { ManagedChannelImplBuilder.getNameResolverProvider( - "testscheme:///foo.googleapis.com:8080", nameResolverRegistry) + "testscheme:///foo.googleapis.com:8080", nameResolverRegistry, null) .checkAddressTypes(Collections.singleton(CustomSocketAddress.class)); fail("Should fail"); } catch (IllegalArgumentException e) { @@ -143,7 +144,7 @@ public void validTargetProviderAddrTypesNotSupported() { private void testValidTarget(String target, String expectedUriString, URI expectedUri) { NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme()); ManagedChannelImplBuilder.ResolvedNameResolver resolved = - ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry, null); assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class); assertThat(resolved.targetUri).isEqualTo(wrap(expectedUri)); assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString); @@ -154,7 +155,7 @@ private void testInvalidTarget(String target) { try { ManagedChannelImplBuilder.ResolvedNameResolver resolved = - ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry); + ManagedChannelImplBuilder.getNameResolverProvider(target, nameResolverRegistry, null); FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider; fail("Should have failed, but got resolver provider " + nameResolverProvider); } catch (IllegalArgumentException e) { diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 258aa15b005..9aacaa6059c 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -38,6 +38,8 @@ import io.grpc.HttpConnectProxiedSocketAddress; import io.grpc.Internal; import io.grpc.ManagedChannelBuilder; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.internal.AtomicBackoff; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ConnectionClientTransport; @@ -207,10 +209,20 @@ public int getDefaultPort() { NettyChannelBuilder( String target, ChannelCredentials channelCreds, CallCredentials callCreds, ProtocolNegotiator.ClientFactory negotiator) { + this(target, channelCreds, callCreds, negotiator, null, null); + } + + NettyChannelBuilder( + String target, ChannelCredentials channelCreds, CallCredentials callCreds, + ProtocolNegotiator.ClientFactory negotiator, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { managedChannelImplBuilder = new ManagedChannelImplBuilder( target, channelCreds, callCreds, new NettyChannelTransportFactoryBuilder(), - new NettyChannelDefaultPortProvider()); + new NettyChannelDefaultPortProvider(), + nameResolverRegistry, + nameResolverProvider); this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator"); this.freezeProtocolNegotiatorFactory = true; } @@ -708,6 +720,8 @@ NettyChannelBuilder setTransportTracerFactory(TransportTracer.Factory transportT return this; } + + static Collection> getSupportedSocketAddressTypes() { return Collections.singleton(InetSocketAddress.class); } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java index 1b22a95a44b..7b2b747a1c4 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelProvider.java @@ -19,6 +19,8 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import java.net.SocketAddress; import java.util.Collection; @@ -55,6 +57,19 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia new NettyChannelBuilder(target, creds, result.callCredentials, result.negotiator)); } + @Override + public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { + ProtocolNegotiators.FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); + if (result.error != null) { + return NewChannelBuilderResult.error(result.error); + } + NettyChannelBuilder builder = new NettyChannelBuilder(target, creds, + result.callCredentials, result.negotiator, nameResolverRegistry, nameResolverProvider); + return NewChannelBuilderResult.channelBuilder(builder); + } + @Override protected Collection> getSupportedSocketAddressTypes() { return NettyChannelBuilder.getSupportedSocketAddressTypes(); diff --git a/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java index 4e9895da0a8..23a09889de0 100644 --- a/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java +++ b/netty/src/main/java/io/grpc/netty/UdsNettyChannelProvider.java @@ -20,6 +20,8 @@ import io.grpc.ChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelProvider; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.internal.SharedResourcePool; import io.netty.channel.unix.DomainSocketAddress; import java.net.SocketAddress; @@ -62,6 +64,21 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia return result; } + @Override + public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { + Preconditions.checkState(isAvailable()); + NewChannelBuilderResult result = new NettyChannelProvider().newChannelBuilder( + target, creds, nameResolverRegistry, nameResolverProvider); + if (result.getChannelBuilder() != null) { + ((NettyChannelBuilder) result.getChannelBuilder()) + .eventLoopGroupPool(SharedResourcePool.forResource(Utils.DEFAULT_WORKER_EVENT_LOOP_GROUP)) + .channelType(Utils.EPOLL_DOMAIN_CLIENT_CHANNEL_TYPE, DomainSocketAddress.class); + } + return result; + } + @Override protected Collection> getSupportedSocketAddressTypes() { return Collections.singleton(DomainSocketAddress.class); diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java index 86c1389f002..fe5a45848e1 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelProviderTest.java @@ -87,4 +87,164 @@ public void newChannelBuilder_fail() { TlsChannelCredentials.newBuilder().requireFakeFeature().build()); assertThat(result.getError()).contains("FAKE"); } + + @Test + public void newChannelBuilder_withRegistry() { + io.grpc.NameResolverRegistry registry = new io.grpc.NameResolverRegistry(); + NewChannelBuilderResult result = provider.newChannelBuilder( + "localhost:443", TlsChannelCredentials.create(), registry, null); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + } + + @Test + public void newChannelBuilder_withProvider() { + io.grpc.NameResolverProvider resolverProvider = new io.grpc.NameResolverProvider() { + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public String getDefaultScheme() { + return "dns"; + } + + @Override + public io.grpc.NameResolver newNameResolver(java.net.URI targetUri, + io.grpc.NameResolver.Args args) { + return null; + } + }; + NewChannelBuilderResult result = provider.newChannelBuilder( + "localhost:443", TlsChannelCredentials.create(), null, + resolverProvider); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + } + + @Test + public void newChannelBuilder_registryPropagation_e2e() { + String scheme = "testscheme"; + final io.grpc.NameResolverRegistry registry = new io.grpc.NameResolverRegistry(); + final java.util.concurrent.atomic.AtomicReference + capturedRegistry = new java.util.concurrent.atomic.AtomicReference<>(); + + final io.grpc.NameResolverProvider resolverProvider = new io.grpc.NameResolverProvider() { + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public String getDefaultScheme() { + return scheme; + } + + @Override + public io.grpc.NameResolver newNameResolver(java.net.URI targetUri, + io.grpc.NameResolver.Args args) { + capturedRegistry.set(args.getNameResolverRegistry()); + return new io.grpc.NameResolver() { + @Override + public String getServiceAuthority() { + return "authority"; + } + + @Override + public void start(Listener2 listener) { + } + + @Override + public void shutdown() { + } + }; + } + }; + registry.register(resolverProvider); + + NewChannelBuilderResult result = provider.newChannelBuilder( + scheme + ":///target", TlsChannelCredentials.create(), registry, + null); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + // Verify build() succeeds + result.getChannelBuilder().build(); + + // Verify the registry passed to args is the exact same instance + assertSame("Registry should be propagated to NameResolver.Args", registry, + capturedRegistry.get()); + + // Verify default registry (empty) fails + NewChannelBuilderResult defaultResult = provider.newChannelBuilder( + scheme + ":///target", TlsChannelCredentials.create(), + new io.grpc.NameResolverRegistry(), null); + // The provider might still return a builder, but build() should fail if it + // can't find the resolver. + // However, NettyChannelProvider just delegates to NettyChannelBuilder. + // NettyChannelBuilder delegates to ManagedChannelImplBuilder. + // ManagedChannelImplBuilder.build() calls getNameResolverProvider(), which + // throws if not found. + try { + defaultResult.getChannelBuilder().build(); + fail("Should have failed to build() without correct registry"); + } catch (IllegalArgumentException e) { + // Expected + } + } + + @Test + public void newChannelBuilder_providerPropagation_e2e() { + String scheme = "otherscheme"; + final io.grpc.NameResolverProvider resolverProvider = new io.grpc.NameResolverProvider() { + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + public String getDefaultScheme() { + return scheme; + } + + @Override + public io.grpc.NameResolver newNameResolver(java.net.URI targetUri, + io.grpc.NameResolver.Args args) { + return new io.grpc.NameResolver() { + @Override + public String getServiceAuthority() { + return "authority"; + } + + @Override + public void start(Listener2 listener) { + } + + @Override + public void shutdown() { + } + }; + } + }; + + // Pass explicit provider, null registry + NewChannelBuilderResult result = provider.newChannelBuilder( + scheme + ":///target", TlsChannelCredentials.create(), + null, resolverProvider); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + // Should succeed because we passed the specific provider + result.getChannelBuilder().build(); + } } diff --git a/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java b/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java index e0c3d5a8525..ad26caf1713 100644 --- a/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java +++ b/netty/src/test/java/io/grpc/netty/UdsNettyChannelProviderTest.java @@ -108,6 +108,16 @@ public void newChannelBuilder_success() { assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); } + @Test + public void newChannelBuilder_withRegistry_success() { + Assume.assumeTrue(Utils.isEpollAvailable()); + NewChannelBuilderResult result = provider.newChannelBuilder("unix:sock.sock", + TlsChannelCredentials.create(), + io.grpc.NameResolverRegistry.getDefaultRegistry(), + new io.grpc.internal.DnsNameResolverProvider()); + assertThat(result.getChannelBuilder()).isInstanceOf(NettyChannelBuilder.class); + } + @Test public void managedChannelRegistry_newChannelBuilder() { Assume.assumeTrue(Utils.isEpollAvailable()); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 98f764132fe..3e1d7a774c1 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -35,6 +35,8 @@ import io.grpc.InsecureChannelCredentials; import io.grpc.Internal; import io.grpc.ManagedChannelBuilder; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import io.grpc.TlsChannelCredentials; import io.grpc.internal.AtomicBackoff; import io.grpc.internal.ClientTransportFactory; @@ -214,10 +216,20 @@ private OkHttpChannelBuilder(String target) { OkHttpChannelBuilder( String target, ChannelCredentials channelCreds, CallCredentials callCreds, SSLSocketFactory factory) { + this(target, channelCreds, callCreds, factory, null, null); + } + + OkHttpChannelBuilder( + String target, ChannelCredentials channelCreds, CallCredentials callCreds, + SSLSocketFactory factory, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { managedChannelImplBuilder = new ManagedChannelImplBuilder( target, channelCreds, callCreds, new OkHttpChannelTransportFactoryBuilder(), - new OkHttpChannelDefaultPortProvider()); + new OkHttpChannelDefaultPortProvider(), + nameResolverRegistry, + nameResolverProvider); this.sslSocketFactory = factory; this.negotiationType = factory == null ? NegotiationType.PLAINTEXT : NegotiationType.TLS; this.freezeSecurityConfiguration = true; @@ -588,6 +600,8 @@ SSLSocketFactory createSslSocketFactory() { } } + + private static final EnumSet understoodTlsFeatures = EnumSet.of( TlsChannelCredentials.Feature.MTLS, TlsChannelCredentials.Feature.CUSTOM_MANAGERS); diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java index bf2a9be6fee..f4485d237a3 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelProvider.java @@ -20,6 +20,8 @@ import io.grpc.Internal; import io.grpc.InternalServiceProviders; import io.grpc.ManagedChannelProvider; +import io.grpc.NameResolverProvider; +import io.grpc.NameResolverRegistry; import java.net.SocketAddress; import java.util.Collection; @@ -60,6 +62,21 @@ public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentia target, creds, result.callCredentials, result.factory)); } + @Override + public NewChannelBuilderResult newChannelBuilder(String target, ChannelCredentials creds, + NameResolverRegistry nameResolverRegistry, + NameResolverProvider nameResolverProvider) { + OkHttpChannelBuilder.SslSocketFactoryResult result = + OkHttpChannelBuilder.sslSocketFactoryFrom(creds); + if (result.error != null) { + return NewChannelBuilderResult.error(result.error); + } + OkHttpChannelBuilder builder = new OkHttpChannelBuilder( + target, creds, result.callCredentials, result.factory, + nameResolverRegistry, nameResolverProvider); + return NewChannelBuilderResult.channelBuilder(builder); + } + @Override protected Collection> getSupportedSocketAddressTypes() { return OkHttpChannelBuilder.getSupportedSocketAddressTypes(); diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java index 363f11e71eb..40cf0a41449 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpChannelProviderTest.java @@ -85,4 +85,13 @@ public void newChannelBuilder_fail() { TlsChannelCredentials.newBuilder().requireFakeFeature().build()); assertThat(result.getError()).contains("FAKE"); } + + @Test + public void newChannelBuilder_withRegistry_success() { + NewChannelBuilderResult result = provider.newChannelBuilder("localhost:443", + TlsChannelCredentials.create(), + io.grpc.NameResolverRegistry.getDefaultRegistry(), + new io.grpc.internal.DnsNameResolverProvider()); + assertThat(result.getChannelBuilder()).isInstanceOf(OkHttpChannelBuilder.class); + } }