diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 181e4d7ca1..a4c4005be0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -497,7 +497,7 @@ jobs: run: | cd integration_tests/idl_tests/swift/idl_package swift test --disable-automatic-resolution --skip-build - - name: Install Java artifacts for IDL tests + - name: Install Java artifacts for gRPC tests run: | cd java mvn -T16 --no-transfer-progress clean install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true @@ -684,6 +684,35 @@ jobs: - name: Run CI run: python ./ci/run_ci.py java --version integration_tests + grpc_tests: + name: Java/Python gRPC Tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - name: Set up JDK 21 + uses: actions/setup-java@v4 + with: + java-version: 21 + distribution: "temurin" + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: 3.11 + cache: "pip" + - name: Cache Maven local repository + uses: actions/cache@v4 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + - name: Install Java artifacts for gRPC tests + run: | + cd java + mvn -T16 --no-transfer-progress clean install -DskipTests -Dmaven.javadoc.skip=true -Dmaven.source.skip=true + - name: Run Java/Python gRPC Tests + run: ./integration_tests/grpc_tests/run_tests.sh + javascript: name: JavaScript CI needs: changes diff --git a/compiler/fory_compiler/cli.py b/compiler/fory_compiler/cli.py index 5eb4885ef0..c1b0f889ac 100644 --- a/compiler/fory_compiler/cli.py +++ b/compiler/fory_compiler/cli.py @@ -122,6 +122,7 @@ def resolve_imports( # Parse the file schema = parse_idl_file(file_path) + annotate_source_package(schema) # Process imports imported_enums = [] @@ -167,6 +168,7 @@ def resolve_imports( enums=imported_enums + schema.enums, messages=imported_messages + schema.messages, unions=imported_unions + schema.unions, + services=schema.services, options=schema.options, source_file=schema.source_file, source_format=schema.source_format, @@ -177,6 +179,13 @@ def resolve_imports( return merged_schema +def annotate_source_package(schema: Schema) -> None: + """Record each parsed top-level type's declaring package for qualified lookups.""" + for type_def in schema.enums + schema.messages + schema.unions: + if type_def.source_package is None: + type_def.source_package = schema.package + + def go_package_info(schema: Schema) -> Tuple[Optional[str], str]: """Return (import_path, package_name) for Go.""" go_package = schema.get_option("go_package") diff --git a/compiler/fory_compiler/frontend/proto/translator.py b/compiler/fory_compiler/frontend/proto/translator.py index f4be61bc2b..7ae7256f72 100644 --- a/compiler/fory_compiler/frontend/proto/translator.py +++ b/compiler/fory_compiler/frontend/proto/translator.py @@ -420,13 +420,11 @@ def _translate_rpc_method(self, proto_method: ProtoRpcMethod) -> RpcMethod: _, options = self._translate_type_options(proto_method.options) return RpcMethod( name=proto_method.name, - request_type=NamedType( - name=proto_method.request_type, - location=self._location(proto_method.line, proto_method.column), + request_type=self._translate_rpc_type( + proto_method.request_type, proto_method ), - response_type=NamedType( - name=proto_method.response_type, - location=self._location(proto_method.line, proto_method.column), + response_type=self._translate_rpc_type( + proto_method.response_type, proto_method ), client_streaming=proto_method.client_streaming, server_streaming=proto_method.server_streaming, @@ -435,3 +433,11 @@ def _translate_rpc_method(self, proto_method: ProtoRpcMethod) -> RpcMethod: column=proto_method.column, location=self._location(proto_method.line, proto_method.column), ) + + def _translate_rpc_type( + self, type_name: str, proto_method: ProtoRpcMethod + ) -> NamedType: + return NamedType( + name=type_name, + location=self._location(proto_method.line, proto_method.column), + ) diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index 2cc8cabfbe..9b9b7906d8 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -21,6 +21,7 @@ from typing import Dict, List, Optional, Set, Tuple, Union as TypingUnion from fory_compiler.generators.base import BaseGenerator, GeneratedFile +from fory_compiler.generators.services.java import JavaServiceGeneratorMixin from fory_compiler.ir.ast import ( Message, Enum, @@ -37,11 +38,67 @@ from fory_compiler.ir.types import PrimitiveKind -class JavaGenerator(BaseGenerator): +class JavaGenerator(JavaServiceGeneratorMixin, BaseGenerator): """Generates Java POJOs with Fory annotations.""" language_name = "java" file_extension = ".java" + JAVA_RESERVED_IDENTIFIERS = { + "abstract", + "assert", + "boolean", + "break", + "byte", + "case", + "catch", + "char", + "class", + "const", + "continue", + "default", + "do", + "double", + "else", + "enum", + "extends", + "false", + "final", + "finally", + "float", + "for", + "goto", + "if", + "implements", + "import", + "instanceof", + "int", + "interface", + "long", + "native", + "new", + "null", + "package", + "private", + "protected", + "public", + "return", + "short", + "static", + "strictfp", + "super", + "switch", + "synchronized", + "this", + "throw", + "throws", + "transient", + "true", + "try", + "void", + "volatile", + "while", + "_", + } def __init__(self, schema: Schema, options): super().__init__(schema, options) diff --git a/compiler/fory_compiler/generators/python.py b/compiler/fory_compiler/generators/python.py index 0226a677ae..a50a7c5301 100644 --- a/compiler/fory_compiler/generators/python.py +++ b/compiler/fory_compiler/generators/python.py @@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Tuple from fory_compiler.generators.base import BaseGenerator, GeneratedFile +from fory_compiler.generators.services.python import PythonServiceGeneratorMixin from fory_compiler.frontend.utils import parse_idl_file from fory_compiler.ir.ast import ( Message, @@ -39,7 +40,7 @@ from fory_compiler.ir.types import PrimitiveKind -class PythonGenerator(BaseGenerator): +class PythonGenerator(PythonServiceGeneratorMixin, BaseGenerator): """Generates Python dataclasses with pyfory type hints.""" language_name = "python" diff --git a/compiler/fory_compiler/generators/services/__init__.py b/compiler/fory_compiler/generators/services/__init__.py new file mode 100644 index 0000000000..4f1e22dbcb --- /dev/null +++ b/compiler/fory_compiler/generators/services/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Service generator helpers.""" diff --git a/compiler/fory_compiler/generators/services/java.py b/compiler/fory_compiler/generators/services/java.py new file mode 100644 index 0000000000..dc68cbe296 --- /dev/null +++ b/compiler/fory_compiler/generators/services/java.py @@ -0,0 +1,787 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Java gRPC service generator helpers.""" + +from typing import List, Optional, Set + +from fory_compiler.generators.base import GeneratedFile +from fory_compiler.ir.ast import NamedType, RpcMethod, Schema, Service + + +class JavaServiceGeneratorMixin: + """Generates Java gRPC service companions.""" + + def generate_services(self) -> List[GeneratedFile]: + """Generate Java gRPC service classes for local service definitions.""" + local_services = [ + service + for service in self.schema.services + if not self.is_imported_type(service) + ] + if not local_services: + return [] + + outer_classname = self.get_java_outer_classname() + if self.get_java_multiple_files(): + outer_classname = None + self.check_java_grpc_service_collisions(local_services, outer_classname) + self.check_java_grpc_method_collisions(local_services) + return [ + self.generate_grpc_service_file(service, outer_classname) + for service in local_services + ] + + def check_java_grpc_service_collisions( + self, services: List[Service], outer_classname: Optional[str] + ) -> None: + generated_type_names = self.same_package_imported_java_file_names() + if outer_classname: + generated_type_names.add(outer_classname) + else: + generated_type_names.update( + enum.name + for enum in self.schema.enums + if not self.is_imported_type(enum) + ) + generated_type_names.update( + union.name + for union in self.schema.unions + if not self.is_imported_type(union) + ) + generated_type_names.update( + message.name + for message in self.schema.messages + if not self.is_imported_type(message) + ) + for service in services: + class_name = f"{service.name}Grpc" + if class_name in generated_type_names: + raise ValueError( + f"Java gRPC service class {class_name} conflicts with a generated type; " + "rename the service or type" + ) + + def same_package_imported_java_file_names(self) -> Set[str]: + java_package = self.get_java_package() + generated_names: Set[str] = set() + for _source, schema in self._schema_graph()[1:]: + if self._java_package_for_schema(schema) != java_package: + continue + generated_names.update(self.java_type_file_names(schema)) + generated_names.update(f"{service.name}Grpc" for service in schema.services) + return generated_names + + def java_type_file_names(self, schema: Schema) -> Set[str]: + outer_classname = schema.get_option("java_outer_classname") + multiple_files = schema.get_option("java_multiple_files") is True + if outer_classname and not multiple_files: + return {outer_classname} + return { + type_def.name + for type_def in [*schema.enums, *schema.unions, *schema.messages] + } + + def check_java_grpc_method_collisions(self, services: List[Service]) -> None: + for service in services: + seen_descriptors = {} + seen_methods = {} + seen_ids = {} + for method in service.methods: + descriptor = f"get{self.to_pascal_case(method.name)}Method" + java_method = self.java_grpc_method_name(method) + method_id = f"METHODID_{self.to_upper_snake_case(method.name)}" + for seen, key, label in ( + (seen_descriptors, descriptor, "method descriptor"), + (seen_methods, java_method, "Java method"), + (seen_ids, method_id, "method id"), + ): + if key in seen: + raise ValueError( + f"Java gRPC {label} name collision in service {service.name}: " + f"{seen[key]} and {method.name} both generate {key}" + ) + seen[key] = method.name + + def generate_grpc_service_file( + self, service: Service, outer_classname: Optional[str] + ) -> GeneratedFile: + """Generate one grpc-java-style service class.""" + java_package = self.get_java_package() + class_name = f"{service.name}Grpc" + lines: List[str] = [] + + lines.append(self.get_license_header()) + lines.append("") + if java_package: + lines.append(f"package {java_package};") + lines.append("") + + imports = [ + "java.io.ByteArrayInputStream", + "java.io.ByteArrayOutputStream", + "java.io.IOException", + "java.io.InputStream", + "java.util.Iterator", + "org.apache.fory.ThreadSafeFory", + ] + for imp in imports: + lines.append(f"import {imp};") + lines.append("") + + lines.append(f"public final class {class_name} {{") + lines.append(f" private {class_name}() {{") + lines.append(" }") + lines.append("") + lines.append( + f' public static final String SERVICE_NAME = "{self.get_grpc_service_name(service)}";' + ) + lines.append("") + lines.append( + f" private static final ThreadSafeFory FORY = {self.get_module_class_name()}.getFory();" + ) + lines.append("") + + for method in service.methods: + lines.extend( + self.generate_java_grpc_method_descriptor( + method, class_name, outer_classname + ) + ) + + lines.extend(self.generate_java_grpc_stub_factories(service)) + lines.extend(self.generate_java_grpc_service_base(service, outer_classname)) + lines.extend(self.generate_java_grpc_async_stub(service, outer_classname)) + lines.extend(self.generate_java_grpc_blocking_stub(service, outer_classname)) + lines.extend(self.generate_java_grpc_future_stub(service, outer_classname)) + lines.extend(self.generate_java_grpc_bind_service(service, outer_classname)) + lines.extend(self.generate_java_grpc_method_handlers(service, outer_classname)) + lines.extend(self.generate_java_grpc_marshaller()) + + lines.append("}") + lines.append("") + + path = self.get_java_package_path() + if path: + path = f"{path}/{class_name}.java" + else: + path = f"{class_name}.java" + return GeneratedFile(path=path, content="\n".join(lines)) + + def generate_java_grpc_method_descriptor( + self, method: RpcMethod, class_name: str, outer_classname: Optional[str] + ) -> List[str]: + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_suffix = self.to_pascal_case(method.name) + method_field = f"get{method_suffix}Method" + method_type = self.grpc_java_method_type(method) + lines = [] + lines.append( + f" private static volatile io.grpc.MethodDescriptor<{request_type}, {response_type}> {method_field};" + ) + lines.append("") + lines.append( + f" public static io.grpc.MethodDescriptor<{request_type}, {response_type}> {method_field}() {{" + ) + lines.append( + f" io.grpc.MethodDescriptor<{request_type}, {response_type}> local = {method_field};" + ) + lines.append(" if (local == null) {") + lines.append(f" synchronized ({class_name}.class) {{") + lines.append(f" local = {method_field};") + lines.append(" if (local == null) {") + lines.append( + " local = io.grpc.MethodDescriptor.<" + f"{request_type}, {response_type}>newBuilder()" + ) + lines.append( + f" .setType(io.grpc.MethodDescriptor.MethodType.{method_type})" + ) + lines.append( + " .setFullMethodName(io.grpc.MethodDescriptor.generateFullMethodName(" + ) + lines.append(f' SERVICE_NAME, "{method.name}"))') + lines.append(" .setSampledToLocalTracing(true)") + lines.append( + f" .setRequestMarshaller(marshaller({request_type}.class))" + ) + lines.append( + f" .setResponseMarshaller(marshaller({response_type}.class))" + ) + lines.append(" .build();") + lines.append(f" {method_field} = local;") + lines.append(" }") + lines.append(" }") + lines.append(" }") + lines.append(" return local;") + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_stub_factories(self, service: Service) -> List[str]: + lines = [] + lines.append( + f" public static {service.name}Stub newStub(io.grpc.Channel channel) {{" + ) + lines.append( + f" io.grpc.stub.AbstractStub.StubFactory<{service.name}Stub> factory =" + ) + lines.append( + f" new io.grpc.stub.AbstractStub.StubFactory<{service.name}Stub>() {{" + ) + lines.append(" @Override") + lines.append( + f" public {service.name}Stub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}Stub(channel, callOptions);" + ) + lines.append(" }") + lines.append(" };") + lines.append(f" return {service.name}Stub.newStub(factory, channel);") + lines.append(" }") + lines.append("") + lines.append( + f" public static {service.name}BlockingStub newBlockingStub(io.grpc.Channel channel) {{" + ) + lines.append( + f" io.grpc.stub.AbstractStub.StubFactory<{service.name}BlockingStub> factory =" + ) + lines.append( + f" new io.grpc.stub.AbstractStub.StubFactory<{service.name}BlockingStub>() {{" + ) + lines.append(" @Override") + lines.append( + f" public {service.name}BlockingStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}BlockingStub(channel, callOptions);" + ) + lines.append(" }") + lines.append(" };") + lines.append( + f" return {service.name}BlockingStub.newStub(factory, channel);" + ) + lines.append(" }") + lines.append("") + lines.append( + f" public static {service.name}FutureStub newFutureStub(io.grpc.Channel channel) {{" + ) + lines.append( + f" io.grpc.stub.AbstractStub.StubFactory<{service.name}FutureStub> factory =" + ) + lines.append( + f" new io.grpc.stub.AbstractStub.StubFactory<{service.name}FutureStub>() {{" + ) + lines.append(" @Override") + lines.append( + f" public {service.name}FutureStub newStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}FutureStub(channel, callOptions);" + ) + lines.append(" }") + lines.append(" };") + lines.append( + f" return {service.name}FutureStub.newStub(factory, channel);" + ) + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_service_base( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + lines.append( + f" public abstract static class {service.name}ImplBase implements io.grpc.BindableService {{" + ) + for method in service.methods: + java_name = self.java_grpc_method_name(method) + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_getter = f"get{self.to_pascal_case(method.name)}Method()" + lines.append("") + if method.client_streaming: + lines.append( + f" public io.grpc.stub.StreamObserver<{request_type}> {java_name}(" + ) + lines.append( + f" io.grpc.stub.StreamObserver<{response_type}> responseObserver) {{" + ) + lines.append( + f" return io.grpc.stub.ServerCalls.asyncUnimplementedStreamingCall({method_getter}, responseObserver);" + ) + lines.append(" }") + else: + lines.append(f" public void {java_name}({request_type} request,") + lines.append( + f" io.grpc.stub.StreamObserver<{response_type}> responseObserver) {{" + ) + lines.append( + f" io.grpc.stub.ServerCalls.asyncUnimplementedUnaryCall({method_getter}, responseObserver);" + ) + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append( + " public final io.grpc.ServerServiceDefinition bindService() {" + ) + lines.append(f" return {service.name}Grpc.bindService(this);") + lines.append(" }") + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_async_stub( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + lines.append( + f" public static final class {service.name}Stub extends io.grpc.stub.AbstractAsyncStub<{service.name}Stub> {{" + ) + lines.append( + f" private {service.name}Stub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append(" super(channel, callOptions);") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append( + f" protected {service.name}Stub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}Stub(channel, callOptions);" + ) + lines.append(" }") + for method in service.methods: + lines.extend( + self.generate_java_grpc_client_method(method, outer_classname, "async") + ) + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_blocking_stub( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + lines.append( + f" public static final class {service.name}BlockingStub extends io.grpc.stub.AbstractBlockingStub<{service.name}BlockingStub> {{" + ) + lines.append( + f" private {service.name}BlockingStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append(" super(channel, callOptions);") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append( + f" protected {service.name}BlockingStub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}BlockingStub(channel, callOptions);" + ) + lines.append(" }") + for method in service.methods: + if method.client_streaming: + continue + lines.extend( + self.generate_java_grpc_client_method( + method, outer_classname, "blocking" + ) + ) + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_future_stub( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + lines.append( + f" public static final class {service.name}FutureStub extends io.grpc.stub.AbstractFutureStub<{service.name}FutureStub> {{" + ) + lines.append( + f" private {service.name}FutureStub(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append(" super(channel, callOptions);") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append( + f" protected {service.name}FutureStub build(io.grpc.Channel channel, io.grpc.CallOptions callOptions) {{" + ) + lines.append( + f" return new {service.name}FutureStub(channel, callOptions);" + ) + lines.append(" }") + for method in service.methods: + if method.client_streaming or method.server_streaming: + continue + lines.extend( + self.generate_java_grpc_client_method(method, outer_classname, "future") + ) + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_client_method( + self, method: RpcMethod, outer_classname: Optional[str], stub_kind: str + ) -> List[str]: + java_name = self.java_grpc_method_name(method) + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_getter = f"get{self.to_pascal_case(method.name)}Method()" + lines = [ + "", + ] + if stub_kind == "async": + if method.client_streaming: + lines.append( + f" public io.grpc.stub.StreamObserver<{request_type}> {java_name}(" + ) + lines.append( + f" io.grpc.stub.StreamObserver<{response_type}> responseObserver) {{" + ) + call = ( + "asyncBidiStreamingCall" + if method.server_streaming + else "asyncClientStreamingCall" + ) + lines.append( + f" return io.grpc.stub.ClientCalls.{call}(" + f"getChannel().newCall({method_getter}, getCallOptions()), responseObserver);" + ) + lines.append(" }") + else: + lines.append(f" public void {java_name}({request_type} request,") + lines.append( + f" io.grpc.stub.StreamObserver<{response_type}> responseObserver) {{" + ) + call = ( + "asyncServerStreamingCall" + if method.server_streaming + else "asyncUnaryCall" + ) + lines.append( + f" io.grpc.stub.ClientCalls.{call}(" + f"getChannel().newCall({method_getter}, getCallOptions()), request, responseObserver);" + ) + lines.append(" }") + elif stub_kind == "blocking": + if method.server_streaming: + lines.append( + f" public Iterator<{response_type}> {java_name}({request_type} request) {{" + ) + lines.append( + f" return io.grpc.stub.ClientCalls.blockingServerStreamingCall(" + f"getChannel(), {method_getter}, getCallOptions(), request);" + ) + else: + lines.append( + f" public {response_type} {java_name}({request_type} request) {{" + ) + lines.append( + f" return io.grpc.stub.ClientCalls.blockingUnaryCall(" + f"getChannel(), {method_getter}, getCallOptions(), request);" + ) + lines.append(" }") + elif stub_kind == "future": + lines.append( + f" public com.google.common.util.concurrent.ListenableFuture<{response_type}> {java_name}(" + ) + lines.append(f" {request_type} request) {{") + lines.append( + f" return io.grpc.stub.ClientCalls.futureUnaryCall(" + f"getChannel().newCall({method_getter}, getCallOptions()), request);" + ) + lines.append(" }") + return lines + + def generate_java_grpc_bind_service( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + lines.append( + f" private static io.grpc.ServerServiceDefinition bindService({service.name}ImplBase serviceImpl) {{" + ) + lines.append( + " io.grpc.ServerServiceDefinition.Builder builder =" + " io.grpc.ServerServiceDefinition.builder(SERVICE_NAME);" + ) + for method in service.methods: + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_getter = f"get{self.to_pascal_case(method.name)}Method()" + method_id = f"METHODID_{self.to_upper_snake_case(method.name)}" + if method.client_streaming and method.server_streaming: + call = "asyncBidiStreamingCall" + elif method.client_streaming: + call = "asyncClientStreamingCall" + elif method.server_streaming: + call = "asyncServerStreamingCall" + else: + call = "asyncUnaryCall" + lines.append( + f" builder.addMethod({method_getter}, io.grpc.stub.ServerCalls.{call}(" + ) + lines.append( + f" new MethodHandlers<{request_type}, {response_type}>(serviceImpl, {method_id})));" + ) + lines.append(" return builder.build();") + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_method_handlers( + self, service: Service, outer_classname: Optional[str] + ) -> List[str]: + lines = [] + for index, method in enumerate(service.methods): + lines.append( + f" private static final int METHODID_{self.to_upper_snake_case(method.name)} = {index};" + ) + lines.append("") + lines.append(' @SuppressWarnings("unchecked")') + lines.append(" private static final class MethodHandlers") + lines.append( + " implements io.grpc.stub.ServerCalls.UnaryMethod," + ) + lines.append( + " io.grpc.stub.ServerCalls.ServerStreamingMethod," + ) + lines.append( + " io.grpc.stub.ServerCalls.ClientStreamingMethod," + ) + lines.append( + " io.grpc.stub.ServerCalls.BidiStreamingMethod {" + ) + lines.append(f" private final {service.name}ImplBase serviceImpl;") + lines.append(" private final int methodId;") + lines.append("") + lines.append( + f" MethodHandlers({service.name}ImplBase serviceImpl, int methodId) {{" + ) + lines.append(" this.serviceImpl = serviceImpl;") + lines.append(" this.methodId = methodId;") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append( + " public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) {" + ) + lines.append(" switch (methodId) {") + for method in service.methods: + if method.client_streaming: + continue + java_name = self.java_grpc_method_name(method) + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_id = f"METHODID_{self.to_upper_snake_case(method.name)}" + lines.append(f" case {method_id}:") + lines.append( + f" serviceImpl.{java_name}(({request_type}) request," + ) + lines.append( + f" (io.grpc.stub.StreamObserver<{response_type}>) responseObserver);" + ) + lines.append(" break;") + lines.append(" default:") + lines.append(" throw new AssertionError();") + lines.append(" }") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append(" public io.grpc.stub.StreamObserver invoke(") + lines.append( + " io.grpc.stub.StreamObserver responseObserver) {" + ) + lines.append(" switch (methodId) {") + for method in service.methods: + if not method.client_streaming: + continue + java_name = self.java_grpc_method_name(method) + request_type = self.generate_java_grpc_type( + method.request_type, outer_classname + ) + response_type = self.generate_java_grpc_type( + method.response_type, outer_classname + ) + method_id = f"METHODID_{self.to_upper_snake_case(method.name)}" + lines.append(f" case {method_id}:") + lines.append( + f" return (io.grpc.stub.StreamObserver) serviceImpl.{java_name}(" + ) + lines.append( + f" (io.grpc.stub.StreamObserver<{response_type}>) responseObserver);" + ) + lines.append(" default:") + lines.append(" throw new AssertionError();") + lines.append(" }") + lines.append(" }") + lines.append(" }") + lines.append("") + return lines + + def generate_java_grpc_marshaller(self) -> List[str]: + lines = [] + lines.append( + " private static io.grpc.MethodDescriptor.Marshaller marshaller(Class type) {" + ) + lines.append(" return new ForyMarshaller(FORY, type);") + lines.append(" }") + lines.append("") + lines.append( + " private static final class ForyMarshaller implements io.grpc.MethodDescriptor.Marshaller {" + ) + lines.append(" private final ThreadSafeFory fory;") + lines.append(" private final Class type;") + lines.append("") + lines.append(" ForyMarshaller(ThreadSafeFory fory, Class type) {") + lines.append(" this.fory = fory;") + lines.append(" this.type = type;") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append(" public InputStream stream(T value) {") + lines.append(" try {") + lines.append( + " return new KnownLengthByteArrayInputStream(fory.serialize(value));" + ) + lines.append(" } catch (RuntimeException e) {") + lines.append( + ' throw io.grpc.Status.INTERNAL.withDescription("Fory serialization failed")' + ) + lines.append(" .withCause(e).asRuntimeException();") + lines.append(" }") + lines.append(" }") + lines.append("") + lines.append(" @Override") + lines.append(" public T parse(InputStream stream) {") + lines.append(" try {") + lines.append( + " return fory.deserialize(readBytes(stream), type);" + ) + lines.append(" } catch (IOException | RuntimeException e) {") + lines.append( + ' throw io.grpc.Status.INTERNAL.withDescription("Fory deserialization failed")' + ) + lines.append(" .withCause(e).asRuntimeException();") + lines.append(" }") + lines.append(" }") + lines.append(" }") + lines.append("") + lines.append( + " private static final class KnownLengthByteArrayInputStream extends ByteArrayInputStream" + ) + lines.append(" implements io.grpc.KnownLength {") + lines.append(" KnownLengthByteArrayInputStream(byte[] bytes) {") + lines.append(" super(bytes);") + lines.append(" }") + lines.append(" }") + lines.append("") + lines.append( + " private static byte[] readBytes(InputStream stream) throws IOException {" + ) + lines.append(" if (stream instanceof io.grpc.KnownLength) {") + lines.append(" int size = stream.available();") + lines.append(" byte[] bytes = new byte[size];") + lines.append(" int offset = 0;") + lines.append(" while (offset < size) {") + lines.append( + " int read = stream.read(bytes, offset, size - offset);" + ) + lines.append(" if (read == -1) {") + lines.append( + ' throw new java.io.EOFException("Expected " + size + " bytes, got " + offset);' + ) + lines.append(" }") + lines.append(" offset += read;") + lines.append(" }") + lines.append(" return bytes;") + lines.append(" }") + lines.append(" ByteArrayOutputStream out = new ByteArrayOutputStream();") + lines.append(" byte[] buffer = new byte[8192];") + lines.append(" int read;") + lines.append(" while ((read = stream.read(buffer)) != -1) {") + lines.append(" out.write(buffer, 0, read);") + lines.append(" }") + lines.append(" return out.toByteArray();") + lines.append(" }") + lines.append("") + return lines + + def grpc_java_method_type(self, method: RpcMethod) -> str: + if method.client_streaming and method.server_streaming: + return "BIDI_STREAMING" + if method.client_streaming: + return "CLIENT_STREAMING" + if method.server_streaming: + return "SERVER_STREAMING" + return "UNARY" + + def safe_java_identifier(self, identifier: str) -> str: + if identifier in self.JAVA_RESERVED_IDENTIFIERS: + return f"{identifier}_" + return identifier + + def java_grpc_method_name(self, method: RpcMethod) -> str: + return self.safe_java_identifier(self.to_camel_case(method.name)) + + def generate_java_grpc_type( + self, named_type: NamedType, outer_classname: Optional[str] + ) -> str: + type_ref = self.schema.resolve_type_name(named_type.name) + type_def = self.schema.get_type(type_ref) + if type_def is not None and self.is_imported_type(type_def): + schema = self._load_schema( + getattr(getattr(type_def, "location", None), "file", None) + ) + if schema is not None: + imported_outer = schema.get_option("java_outer_classname") + multiple_files = schema.get_option("java_multiple_files") is True + if imported_outer and not multiple_files: + type_ref = f"{imported_outer}.{type_ref}" + java_package = self._java_package_for_type(type_def) + if java_package: + return f"{java_package}.{type_ref}" + return type_ref + if outer_classname: + return f"{outer_classname}.{type_ref}" + return type_ref diff --git a/compiler/fory_compiler/generators/services/python.py b/compiler/fory_compiler/generators/services/python.py new file mode 100644 index 0000000000..feb1a259b1 --- /dev/null +++ b/compiler/fory_compiler/generators/services/python.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Python gRPC service generator helpers.""" + +from typing import List + +from fory_compiler.generators.base import GeneratedFile +from fory_compiler.ir.ast import RpcMethod, Service + + +class PythonServiceGeneratorMixin: + """Generates Python gRPC service companions.""" + + def generate_services(self) -> List[GeneratedFile]: + """Generate Python grpc service companion module.""" + local_services = [ + service + for service in self.schema.services + if not self.is_imported_type(service) + ] + if not local_services: + return [] + self.check_python_grpc_service_collisions(local_services) + self.check_python_grpc_method_collisions(local_services) + return [self.generate_grpc_module(local_services)] + + def check_python_grpc_service_collisions(self, services: List[Service]) -> None: + seen_registrations = {} + for service in services: + add_fn = self.python_grpc_add_servicer_name(service) + if add_fn in seen_registrations: + raise ValueError( + f"Python gRPC service registration collision: " + f"{seen_registrations[add_fn]} and {service.name} both generate {add_fn}" + ) + seen_registrations[add_fn] = service.name + + def check_python_grpc_method_collisions(self, services: List[Service]) -> None: + for service in services: + seen_methods = {} + for method in service.methods: + python_method = self.python_grpc_method_name(method) + if python_method in seen_methods: + raise ValueError( + f"Python gRPC method name collision in service {service.name}: " + f"{seen_methods[python_method]} and {method.name} both generate {python_method}" + ) + seen_methods[python_method] = method.name + + def generate_grpc_module(self, services: List[Service]) -> GeneratedFile: + """Generate a grpcio-style companion module for schema services.""" + module_name = self.get_module_name() + lines = [] + lines.append(self.get_license_header("#")) + lines.append("") + lines.append("from __future__ import annotations") + lines.append("") + lines.append("import grpc") + lines.append(f"import {module_name} as _models") + lines.append("") + lines.append("") + lines.append("def _serialize(value):") + lines.append(" return _models._get_fory().serialize(value)") + lines.append("") + lines.append("") + lines.append("def _deserialize(data: bytes):") + lines.append(" return _models._get_fory().deserialize(data)") + lines.append("") + lines.append("") + + for service in services: + lines.extend(self.generate_python_grpc_stub(service)) + lines.append("") + lines.extend(self.generate_python_grpc_servicer(service)) + lines.append("") + lines.extend(self.generate_python_grpc_registration(service)) + lines.append("") + lines.extend(self.generate_python_grpc_add_servicer(services)) + lines.append("") + + return GeneratedFile( + path=f"{module_name}_grpc.py", + content="\n".join(lines), + ) + + def generate_python_grpc_stub(self, service: Service) -> List[str]: + lines = [] + lines.append(f"class {service.name}Stub(object):") + lines.append(f' """Client stub for {service.name}."""') + lines.append("") + lines.append(" def __init__(self, channel):") + for method in service.methods: + channel_call = self.python_grpc_channel_call(method) + python_name = self.python_grpc_method_name(method) + lines.append(f" self.{python_name} = channel.{channel_call}(") + lines.append(f' "{self.get_grpc_method_path(service, method)}",') + lines.append(" request_serializer=_serialize,") + lines.append(" response_deserializer=_deserialize,") + lines.append(" )") + if not service.methods: + lines.append(" pass") + lines.append("") + return lines + + def generate_python_grpc_servicer(self, service: Service) -> List[str]: + lines = [] + lines.append(f"class {service.name}Servicer(object):") + lines.append(f' """Base servicer for {service.name}."""') + if not service.methods: + lines.append(" pass") + return lines + for method in service.methods: + python_name = self.python_grpc_method_name(method) + lines.append("") + if method.client_streaming: + lines.append(f" def {python_name}(self, request_iterator, context):") + else: + lines.append(f" def {python_name}(self, request, context):") + lines.append(" context.set_code(grpc.StatusCode.UNIMPLEMENTED)") + lines.append(' context.set_details("Method not implemented!")') + lines.append(' raise NotImplementedError("Method not implemented!")') + return lines + + def generate_python_grpc_registration(self, service: Service) -> List[str]: + lines = [] + add_fn = self.python_grpc_add_servicer_name(service) + lines.append(f"def {add_fn}(servicer, server):") + lines.append(" rpc_method_handlers = {") + for method in service.methods: + handler = self.python_grpc_handler(method) + python_name = self.python_grpc_method_name(method) + lines.append(f' "{method.name}": grpc.{handler}(') + lines.append(f" servicer.{python_name},") + lines.append(" request_deserializer=_deserialize,") + lines.append(" response_serializer=_serialize,") + lines.append(" ),") + lines.append(" }") + lines.append( + f' generic_handler = grpc.method_handlers_generic_handler("{self.get_grpc_service_name(service)}", rpc_method_handlers)' + ) + lines.append(" server.add_generic_rpc_handlers((generic_handler,))") + return lines + + def generate_python_grpc_add_servicer(self, services: List[Service]) -> List[str]: + lines = [] + lines.append("def add_servicer(servicer, server):") + lines.append(" registered = False") + for service in services: + add_fn = self.python_grpc_add_servicer_name(service) + lines.append(f" if isinstance(servicer, {service.name}Servicer):") + lines.append(f" {add_fn}(servicer, server)") + lines.append(" registered = True") + lines.append(" if not registered:") + lines.append( + ' raise TypeError(f"Unsupported gRPC servicer type: {type(servicer).__name__}")' + ) + return lines + + def python_grpc_channel_call(self, method: RpcMethod) -> str: + if method.client_streaming and method.server_streaming: + return "stream_stream" + if method.client_streaming: + return "stream_unary" + if method.server_streaming: + return "unary_stream" + return "unary_unary" + + def python_grpc_handler(self, method: RpcMethod) -> str: + if method.client_streaming and method.server_streaming: + return "stream_stream_rpc_method_handler" + if method.client_streaming: + return "stream_unary_rpc_method_handler" + if method.server_streaming: + return "unary_stream_rpc_method_handler" + return "unary_unary_rpc_method_handler" + + def python_grpc_method_name(self, method: RpcMethod) -> str: + return self.safe_name(self.to_snake_case(method.name)) + + def python_grpc_add_servicer_name(self, service: Service) -> str: + return f"_add_{self.safe_name(self.to_snake_case(service.name))}_servicer" diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py index e57b03f612..f2718e84fa 100644 --- a/compiler/fory_compiler/ir/ast.py +++ b/compiler/fory_compiler/ir/ast.py @@ -196,6 +196,7 @@ class Message: location: Optional[SourceLocation] = None id_generated: bool = False id_source: Optional[str] = None + source_package: Optional[str] = None def __repr__(self) -> str: id_str = f" [id={self.type_id}]" if self.type_id is not None else "" @@ -236,6 +237,7 @@ class Enum: location: Optional[SourceLocation] = None id_generated: bool = False id_source: Optional[str] = None + source_package: Optional[str] = None def __repr__(self) -> str: id_str = f" [id={self.type_id}]" if self.type_id is not None else "" @@ -256,6 +258,7 @@ class Union: location: Optional[SourceLocation] = None id_generated: bool = False id_source: Optional[str] = None + source_package: Optional[str] = None def __repr__(self) -> str: id_str = f" [id={self.type_id}]" if self.type_id is not None else "" @@ -337,6 +340,60 @@ def get_option(self, name: str, default: Optional[str] = None) -> Optional[str]: def get_type(self, name: str) -> Optional[TypingUnion[Message, Enum, "Union"]]: """Look up a type by name, supporting qualified names like Parent.Child.""" + return self._get_type_by_path(self.resolve_type_name(name)) + + def resolve_type_name(self, name: str) -> str: + """Resolve package-qualified type names to the schema-local type path.""" + absolute = name.startswith(".") + cleaned = name.lstrip(".") + if absolute: + package_resolved = self._resolve_package_qualified_type(cleaned) + if package_resolved is not None: + return package_resolved + return name + if self._get_type_by_path(cleaned) is not None: + return cleaned + package_resolved = self._resolve_package_qualified_type(cleaned) + if package_resolved is not None: + return package_resolved + return cleaned + + def _resolve_package_qualified_type(self, name: str) -> Optional[str]: + if "." not in name: + return None + packages = { + type_def.source_package + for type_def in self.get_all_types() + if type_def.source_package + } + if self.package: + packages.add(self.package) + ordered_packages = sorted( + packages, key=lambda package: (-len(package), package) + ) + for package in ordered_packages: + prefix = f"{package}." + if name.startswith(prefix): + candidate = name[len(prefix) :] + if self._type_path_belongs_to_package(candidate, package): + return candidate + return None + + def _type_path_belongs_to_package(self, name: str, package: str) -> bool: + type_def = self._get_type_by_path(name) + if type_def is None: + return False + top_level = self._get_top_level_type(name.split(".", 1)[0]) + source_package = getattr(type_def, "source_package", None) or getattr( + top_level, "source_package", None + ) + if source_package is None: + return package == self.package + return source_package == package + + def _get_type_by_path( + self, name: str + ) -> Optional[TypingUnion[Message, Enum, "Union"]]: # Handle qualified names (e.g., SearchResponse.Result) if "." in name: parts = name.split(".") diff --git a/compiler/fory_compiler/ir/validator.py b/compiler/fory_compiler/ir/validator.py index e335ae96d8..54ab7c17a6 100644 --- a/compiler/fory_compiler/ir/validator.py +++ b/compiler/fory_compiler/ir/validator.py @@ -747,11 +747,11 @@ def _check_services(self) -> None: resolved = self.schema.get_type(named_type.name) if resolved is None: continue - if not isinstance(resolved, Message): - kind = "enum" if isinstance(resolved, Enum) else "union" + if not isinstance(resolved, (Message, Union)): + kind = "an enum" if isinstance(resolved, Enum) else "unknown" self._error( f"RPC type '{named_type.name}' in service {service.name}" - f" must be a message, not a {kind}", + f" must be a message or union, not {kind}", named_type.location, ) diff --git a/compiler/fory_compiler/tests/test_ir_service_validation.py b/compiler/fory_compiler/tests/test_ir_service_validation.py index 7abf44f907..adb0b61312 100644 --- a/compiler/fory_compiler/tests/test_ir_service_validation.py +++ b/compiler/fory_compiler/tests/test_ir_service_validation.py @@ -138,7 +138,8 @@ def test_rpc_request_type_enum_fails_validation(): """ v = validate(parse_fdl(source)) assert any( - "RPC type 'Status'" in e.message and "not a enum" in e.message for e in v.errors + "RPC type 'Status'" in e.message and "not an enum" in e.message + for e in v.errors ) @@ -155,11 +156,12 @@ def test_rpc_response_type_enum_fails_validation(): """ v = validate(parse_fdl(source)) assert any( - "RPC type 'Status'" in e.message and "not a enum" in e.message for e in v.errors + "RPC type 'Status'" in e.message and "not an enum" in e.message + for e in v.errors ) -def test_rpc_request_type_union_fails_validation(): +def test_rpc_request_type_union_passes_validation(): source = """ package test; @@ -171,13 +173,10 @@ def test_rpc_request_type_union_fails_validation(): } """ v = validate(parse_fdl(source)) - assert any( - "RPC type 'Payload'" in e.message and "not a union" in e.message - for e in v.errors - ) + assert v.errors == [] -def test_rpc_response_type_union_fails_validation(): +def test_rpc_response_type_union_passes_validation(): source = """ package test; @@ -189,10 +188,7 @@ def test_rpc_response_type_union_fails_validation(): } """ v = validate(parse_fdl(source)) - assert any( - "RPC type 'Payload'" in e.message and "not a union" in e.message - for e in v.errors - ) + assert v.errors == [] def test_rpc_message_types_pass_validation(): @@ -226,6 +222,31 @@ def test_proto_rpc_enum_type_fails_validation(): assert any("RPC type 'Status'" in e.message for e in v.errors) +def test_fbs_rpc_union_type_passes_validation(): + source = """ + namespace test; + + table Text { + value: string; + } + + table Count { + value: int; + } + + union Payload { + Text, + Count + } + + rpc_service Svc { + Call(Payload):Payload; + } + """ + v = validate(parse_fbs(source)) + assert v.errors == [] + + def test_fbs_rpc_duplicate_method_fails_validation(): source = """ namespace test; diff --git a/compiler/fory_compiler/tests/test_proto_service.py b/compiler/fory_compiler/tests/test_proto_service.py index 80297beffc..5a56dbf0b7 100644 --- a/compiler/fory_compiler/tests/test_proto_service.py +++ b/compiler/fory_compiler/tests/test_proto_service.py @@ -105,6 +105,73 @@ def test_streaming_rpc(): assert m3.server_streaming is True +def test_fully_qualified_rpc_types_pass_validation(): + source = """ + syntax = "proto3"; + package demo; + + message Request {} + message Response {} + + service Greeter { + rpc SayHello (.demo.Request) returns (.demo.Response); + } + """ + schema = parse_and_translate(source) + method = schema.services[0].methods[0] + assert method.request_type.name == ".demo.Request" + assert method.response_type.name == ".demo.Response" + + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + + +def test_absolute_rpc_type_prefers_package_qualified_type_over_nested_shadow(): + source = """ + syntax = "proto3"; + package demo; + + message demo { + message Request {} + } + message Request {} + message Response {} + + service Greeter { + rpc SayHello (.demo.Request) returns (.demo.Response); + } + """ + schema = parse_and_translate(source) + method = schema.services[0].methods[0] + assert schema.resolve_type_name(method.request_type.name) == "Request" + + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + + +def test_wrong_package_qualified_rpc_type_fails_validation(): + source = """ + syntax = "proto3"; + package demo; + + message other { + message Request {} + } + message Request {} + message Response {} + + service Greeter { + rpc SayHello (.other.Request) returns (.demo.Response); + } + """ + schema = parse_and_translate(source) + validator = SchemaValidator(schema) + assert not validator.validate() + assert any( + "Unknown type '.other.Request'" in err.message for err in validator.errors + ) + + def test_service_options(): source = """ syntax = "proto3"; diff --git a/compiler/fory_compiler/tests/test_service_codegen.py b/compiler/fory_compiler/tests/test_service_codegen.py index e27425d1ff..413fbf0131 100644 --- a/compiler/fory_compiler/tests/test_service_codegen.py +++ b/compiler/fory_compiler/tests/test_service_codegen.py @@ -21,9 +21,15 @@ from textwrap import dedent from typing import Dict, Tuple, Type -from fory_compiler.cli import compile_file +from fory_compiler.cli import compile_file, resolve_imports from fory_compiler.frontend.fdl.lexer import Lexer from fory_compiler.frontend.fdl.parser import Parser +from fory_compiler.frontend.fbs.lexer import Lexer as FbsLexer +from fory_compiler.frontend.fbs.parser import Parser as FbsParser +from fory_compiler.frontend.fbs.translator import FbsTranslator +from fory_compiler.frontend.proto.lexer import Lexer as ProtoLexer +from fory_compiler.frontend.proto.parser import Parser as ProtoParser +from fory_compiler.frontend.proto.translator import ProtoTranslator from fory_compiler.generators.base import BaseGenerator, GeneratorOptions from fory_compiler.generators.cpp import CppGenerator from fory_compiler.generators.csharp import CSharpGenerator @@ -33,6 +39,7 @@ from fory_compiler.generators.rust import RustGenerator from fory_compiler.generators.swift import SwiftGenerator from fory_compiler.ir.ast import Schema +from fory_compiler.ir.validator import SchemaValidator GENERATOR_CLASSES: Tuple[Type[BaseGenerator], ...] = ( @@ -82,6 +89,16 @@ def parse_fdl(source: str) -> Schema: return Parser(Lexer(source).tokenize()).parse() +def parse_proto(source: str) -> Schema: + return ProtoTranslator( + ProtoParser(ProtoLexer(source).tokenize()).parse() + ).translate() + + +def parse_fbs(source: str) -> Schema: + return FbsTranslator(FbsParser(FbsLexer(source).tokenize()).parse()).translate() + + def generate_files( schema: Schema, generator_cls: Type[BaseGenerator] ) -> Dict[str, str]: @@ -90,6 +107,14 @@ def generate_files( return {item.path: item.content for item in generator.generate()} +def generate_service_files( + schema: Schema, generator_cls: Type[BaseGenerator] +) -> Dict[str, str]: + options = GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + generator = generator_cls(schema, options) + return {item.path: item.content for item in generator.generate_services()} + + def test_service_definition_does_not_affect_message_codegen(): schema_with = parse_fdl(_GREETER_WITH_SERVICE) schema_without = parse_fdl(_GREETER_WITHOUT_SERVICE) @@ -101,14 +126,563 @@ def test_service_definition_does_not_affect_message_codegen(): ) -def test_generate_services_returns_empty_list_for_all_generators(): +def test_generate_services_returns_empty_list_for_unsupported_generators(): schema = parse_fdl(_GREETER_WITH_SERVICE) for generator_cls in GENERATOR_CLASSES: + if generator_cls in (JavaGenerator, PythonGenerator): + continue options = GeneratorOptions(output_dir=Path("/tmp")) generator = generator_cls(schema, options) assert generator.generate_services() == [], ( - f"{generator_cls.language_name}: generate_services() should return [] until gRPC is implemented" + f"{generator_cls.language_name}: generate_services() should return []" + ) + + +def test_java_grpc_service_codegen_contains_fory_marshaller(): + schema = parse_fdl(_GREETER_WITH_SERVICE) + files = generate_service_files(schema, JavaGenerator) + assert set(files) == {"demo/greeter/GreeterGrpc.java"} + content = files["demo/greeter/GreeterGrpc.java"] + assert 'SERVICE_NAME = "demo.greeter.Greeter"' in content + assert "io.grpc.MethodDescriptor" in content + assert "implements io.grpc.MethodDescriptor.Marshaller" in content + assert "ThreadSafeFory FORY = GreeterForyModule.getFory()" in content + assert "fory.serialize(value)" in content + assert "fory.deserialize(readBytes(stream), type)" in content + assert "io.grpc.KnownLength" in content + assert "ProtoUtils" not in content + + +def test_python_grpc_service_codegen_uses_byte_callbacks(): + schema = parse_fdl(_GREETER_WITH_SERVICE) + files = generate_service_files(schema, PythonGenerator) + assert set(files) == {"demo_greeter_grpc.py"} + content = files["demo_greeter_grpc.py"] + assert "class GreeterStub(object):" in content + assert "class GreeterServicer(object):" in content + assert "def add_servicer(servicer, server):" in content + assert "add_GreeterServicer_to_server" not in content + assert "self.say_hello = channel.unary_unary(" in content + assert "def say_hello(self, request, context):" in content + assert ' "SayHello": grpc.unary_unary_rpc_method_handler(' in content + assert "servicer.say_hello" in content + assert "return _models._get_fory().serialize(value)" in content + assert "return _models._get_fory().deserialize(data)" in content + assert '"/demo.greeter.Greeter/SayHello"' in content + assert "SerializeToString" not in content + assert "FromString" not in content + + +def test_grpc_streaming_method_shapes(): + schema = parse_fdl( + dedent( + """ + package demo.streams; + + message Req {} + message Res {} + union Payload { Req req = 1; Res res = 2; } + + service Streamer { + rpc Unary (Req) returns (Res); + rpc Server (Req) returns (stream Res); + rpc Client (stream Req) returns (Res); + rpc Bidi (stream Payload) returns (stream Payload); + } + """ + ) + ) + + java = next(iter(generate_service_files(schema, JavaGenerator).values())) + assert "MethodType.UNARY" in java + assert "MethodType.SERVER_STREAMING" in java + assert "MethodType.CLIENT_STREAMING" in java + assert "MethodType.BIDI_STREAMING" in java + assert "asyncServerStreamingCall" in java + assert "asyncClientStreamingCall" in java + assert "asyncBidiStreamingCall" in java + assert "FutureStub" in java + assert "futureUnaryCall" in java + assert "blockingUnaryCall" in java + assert "blockingServerStreamingCall" in java + assert "io.grpc.MethodDescriptor" in java + + python = next(iter(generate_service_files(schema, PythonGenerator).values())) + assert "channel.unary_unary(" in python + assert "channel.unary_stream(" in python + assert "channel.stream_unary(" in python + assert "channel.stream_stream(" in python + assert "grpc.stream_stream_rpc_method_handler(" in python + assert "self.unary = channel.unary_unary(" in python + assert "self.server = channel.unary_stream(" in python + assert "self.client = channel.stream_unary(" in python + assert "self.bidi = channel.stream_stream(" in python + + +def test_java_outer_classname_service_references_nested_model_types(): + schema = parse_fdl( + dedent( + """ + package demo.outer; + option java_outer_classname = "OuterModels"; + + message Req {} + message Res {} + + service OuterService { + rpc Call (Req) returns (Res); + } + """ + ) + ) + + files = generate_service_files(schema, JavaGenerator) + content = files["demo/outer/OuterServiceGrpc.java"] + assert "io.grpc.MethodDescriptor" in content + assert "marshaller(OuterModels.Req.class)" in content + assert "marshaller(OuterModels.Res.class)" in content + + +def test_grpc_services_use_imported_java_type_references(tmp_path: Path): + common = tmp_path / "common.fdl" + common.write_text( + dedent( + """ + package common; + option java_package = "com.example.common"; + option java_outer_classname = "CommonModels"; + + message Shared {} + + service ApiService { + rpc ImportedCall (Shared) returns (Shared); + } + """ + ) + ) + main = tmp_path / "main.fdl" + main.write_text( + dedent( + """ + package api; + option java_package = "com.example.api"; + + import "common.fdl"; + + message Local {} + + service ApiService { + rpc Get (Shared) returns (Local); + } + """ + ) + ) + + schema = resolve_imports(main, [tmp_path]) + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + assert [service.name for service in schema.services] == ["ApiService"] + + java_files = generate_service_files(schema, JavaGenerator) + assert set(java_files) == {"com/example/api/ApiServiceGrpc.java"} + java = java_files["com/example/api/ApiServiceGrpc.java"] + assert ( + "io.grpc.MethodDescriptor" + in java + ) + assert "marshaller(com.example.common.CommonModels.Shared.class)" in java + + python_files = generate_service_files(schema, PythonGenerator) + assert set(python_files) == {"api_grpc.py"} + python = python_files["api_grpc.py"] + assert "class ApiServiceStub" in python + assert "ImportedCall" not in python + + +def test_proto_grpc_services_use_imported_qualified_type_references(tmp_path: Path): + common = tmp_path / "common.proto" + common.write_text( + dedent( + """ + syntax = "proto3"; + package common; + option java_package = "com.example.common"; + option java_outer_classname = "CommonModels"; + + message Shared {} + """ + ) + ) + main = tmp_path / "main.proto" + main.write_text( + dedent( + """ + syntax = "proto3"; + package api; + option java_package = "com.example.api"; + + import "common.proto"; + + message Local {} + + service ApiService { + rpc Get (common.Shared) returns (.api.Local); + } + """ + ) + ) + + schema = resolve_imports(main, [tmp_path]) + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + + java_files = generate_service_files(schema, JavaGenerator) + java = java_files["com/example/api/ApiServiceGrpc.java"] + assert ( + "io.grpc.MethodDescriptor" + in java + ) + assert "marshaller(com.example.common.CommonModels.Shared.class)" in java + + +def test_proto_grpc_absolute_rpc_type_uses_package_type_not_nested_shadow(): + schema = parse_proto( + dedent( + """ + syntax = "proto3"; + package demo; + + message demo { + message Request {} + } + message Request {} + message Response {} + + service ApiService { + rpc Get (.demo.Request) returns (.demo.Response); + } + """ + ) + ) + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + + java_files = generate_service_files(schema, JavaGenerator) + java = java_files["demo/ApiServiceGrpc.java"] + assert "io.grpc.MethodDescriptor" in java + assert "io.grpc.MethodDescriptor" not in java + + +def test_proto_grpc_absolute_rpc_type_prefers_longest_package_prefix(tmp_path: Path): + common = tmp_path / "common.proto" + common.write_text( + dedent( + """ + syntax = "proto3"; + package alpha.beta; + option java_package = "pkg.two"; + + message C {} + """ + ) + ) + main = tmp_path / "main.proto" + main.write_text( + dedent( + """ + syntax = "proto3"; + package alpha; + option java_package = "pkg.one"; + + import "common.proto"; + + message beta { + message C {} + } + + service ApiService { + rpc Get (.alpha.beta.C) returns (.alpha.beta.C); + } + """ + ) + ) + + schema = resolve_imports(main, [tmp_path]) + validator = SchemaValidator(schema) + assert validator.validate(), [issue.message for issue in validator.errors] + + java_files = generate_service_files(schema, JavaGenerator) + java = java_files["pkg/one/ApiServiceGrpc.java"] + assert "io.grpc.MethodDescriptor" in java + assert "marshaller(pkg.two.C.class)" in java + assert "io.grpc.MethodDescriptor" not in java + + +def test_java_grpc_service_class_collision_fails(): + schema = parse_fdl( + dedent( + """ + package demo.collision; + + message GreeterGrpc {} + message Req {} + message Res {} + + service Greeter { + rpc Call (Req) returns (Res); + } + """ + ) + ) + generator = JavaGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + generator.generate_services() + except ValueError as e: + assert "Java gRPC service class GreeterGrpc conflicts" in str(e) + else: + raise AssertionError("Expected Java gRPC service class collision") + + +def test_java_grpc_service_class_collision_with_imported_type_fails(tmp_path: Path): + common = tmp_path / "common.fdl" + common.write_text( + dedent( + """ + package demo.collision; + + message GreeterGrpc {} + """ + ) + ) + main = tmp_path / "main.fdl" + main.write_text( + dedent( + """ + package demo.collision; + + import "common.fdl"; + + message Req {} + message Res {} + + service Greeter { + rpc Call (Req) returns (Res); + } + """ + ) + ) + + schema = resolve_imports(main, [tmp_path]) + generator = JavaGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + generator.generate_services() + except ValueError as e: + assert "Java gRPC service class GreeterGrpc conflicts" in str(e) + else: + raise AssertionError("Expected imported Java gRPC service class collision") + + +def test_java_grpc_service_class_collision_with_imported_outer_fails(tmp_path: Path): + common = tmp_path / "common.fdl" + common.write_text( + dedent( + """ + package demo.collision; + option java_outer_classname = "GreeterGrpc"; + + message Shared {} + """ + ) + ) + main = tmp_path / "main.fdl" + main.write_text( + dedent( + """ + package demo.collision; + + import "common.fdl"; + + message Req {} + message Res {} + + service Greeter { + rpc Call (Req) returns (Res); + } + """ + ) + ) + + schema = resolve_imports(main, [tmp_path]) + generator = JavaGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + generator.generate_services() + except ValueError as e: + assert "Java gRPC service class GreeterGrpc conflicts" in str(e) + else: + raise AssertionError("Expected imported Java outer class collision") + + +def test_grpc_method_name_collisions_fail(): + schema = parse_fdl( + dedent( + """ + package demo.collision; + + message Req {} + message Res {} + + service Greeter { + rpc Foo (Req) returns (Res); + rpc foo (Req) returns (Res); + } + """ + ) + ) + + java_generator = JavaGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + java_generator.generate_services() + except ValueError as e: + assert "Java gRPC" in str(e) and "Foo and foo" in str(e) + else: + raise AssertionError("Expected Java gRPC method name collision") + + python_generator = PythonGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + python_generator.generate_services() + except ValueError as e: + assert "Python gRPC method name collision" in str(e) + else: + raise AssertionError("Expected Python gRPC method name collision") + + +def test_python_grpc_method_keywords_are_safe_names(): + schema = parse_fdl( + dedent( + """ + package demo.keywords; + + message Req {} + message Res {} + + service Greeter { + rpc Class (Req) returns (Res); + } + """ + ) + ) + + java = next(iter(generate_service_files(schema, JavaGenerator).values())) + assert "public void class_(Req request," in java + assert "public Res class_(Req request)" in java + assert "serviceImpl.class_((Req) request," in java + + python = next(iter(generate_service_files(schema, PythonGenerator).values())) + assert "self.class_ = channel.unary_unary(" in python + assert "def class_(self, request, context):" in python + assert "servicer.class_" in python + assert ' "Class": grpc.unary_unary_rpc_method_handler(' in python + + +def test_python_grpc_service_registration_collisions_fail(): + schema = parse_fdl( + dedent( + """ + package demo.collision; + + service FooBar {} + service FooBAR {} + """ + ) + ) + + generator = PythonGenerator( + schema, GeneratorOptions(output_dir=Path("/tmp"), grpc=True) + ) + try: + generator.generate_services() + except ValueError as e: + assert "Python gRPC service registration collision" in str(e) + else: + raise AssertionError("Expected Python gRPC service registration collision") + + +def test_default_package_java_grpc_output_path_and_service_name(): + schema = parse_fdl( + dedent( + """ + message Req {} + message Res {} + + service DefaultService { + rpc Call (Req) returns (Res); + } + """ + ) + ) + + files = generate_service_files(schema, JavaGenerator) + assert set(files) == {"DefaultServiceGrpc.java"} + java = files["DefaultServiceGrpc.java"] + assert "package " not in java + assert 'SERVICE_NAME = "DefaultService"' in java + assert "generateFullMethodName(" in java + assert 'SERVICE_NAME, "Call"' in java + + +def test_proto_and_fbs_grpc_service_codegen(): + proto_schema = parse_proto( + dedent( + """ + syntax = "proto3"; + package demo.proto; + + message Req {} + message Res {} + + service ProtoSvc { + rpc Call (Req) returns (stream Res); + } + """ + ) + ) + proto_java = generate_service_files(proto_schema, JavaGenerator) + proto_python = generate_service_files(proto_schema, PythonGenerator) + assert "demo/proto/ProtoSvcGrpc.java" in proto_java + assert "demo_proto_grpc.py" in proto_python + assert "MethodType.SERVER_STREAMING" in proto_java["demo/proto/ProtoSvcGrpc.java"] + assert "channel.unary_stream(" in proto_python["demo_proto_grpc.py"] + + fbs_schema = parse_fbs( + dedent( + """ + namespace demo.fbs; + + table Req {} + table Res {} + + rpc_service FbsSvc { + Call(Req):Res; + } + """ ) + ) + fbs_java = generate_service_files(fbs_schema, JavaGenerator) + fbs_python = generate_service_files(fbs_schema, PythonGenerator) + assert "demo/fbs/FbsSvcGrpc.java" in fbs_java + assert "demo_fbs_grpc.py" in fbs_python + assert 'SERVICE_NAME = "demo.fbs.FbsSvc"' in fbs_java["demo/fbs/FbsSvcGrpc.java"] + assert '"/demo.fbs.FbsSvc/Call"' in fbs_python["demo_fbs_grpc.py"] def test_service_schema_produces_one_file_per_message_per_language(): @@ -130,6 +704,8 @@ def test_compile_service_schema_with_grpc_flag(tmp_path: Path): for lang, lang_dir in lang_dirs.items(): files = [p for p in lang_dir.rglob("*") if p.is_file()] assert len(files) >= 1, f"{lang}: expected at least one file with grpc=True" + assert (lang_dirs["java"] / "demo" / "greeter" / "GreeterGrpc.java").exists() + assert (lang_dirs["python"] / "demo_greeter_grpc.py").exists() def test_generated_message_contains_key_signatures(): diff --git a/docs/compiler/compiler-guide.md b/docs/compiler/compiler-guide.md index 6fff34d0b3..055e4fd8f3 100644 --- a/docs/compiler/compiler-guide.md +++ b/docs/compiler/compiler-guide.md @@ -72,6 +72,7 @@ Compile options: | `--swift_namespace_style` | Swift namespace style: `enum` or `flatten` | `enum` | | `--emit-fdl` | Emit translated FDL (for non-FDL inputs) | `false` | | `--emit-fdl-path` | Write translated FDL to this path (file or directory) | (stdout) | +| `--grpc` | Generate gRPC service companions for Java and Python | `false` | Schema-level file options are supported for language-specific generation choices. For `go_nested_type_style` and `swift_namespace_style`, the CLI flag overrides @@ -146,6 +147,18 @@ foryc user.fdl order.fdl product.fdl --output ./generated foryc compiler/examples/service.fdl --java_out=./generated/java --python_out=./generated/python ``` +**Generate Java and Python gRPC service companions:** + +```bash +foryc compiler/examples/service.fdl --java_out=./generated/java --python_out=./generated/python --grpc +``` + +The generated gRPC service code uses Fory to serialize request and response +payloads. Java output imports grpc-java APIs and Python output imports `grpc`; +applications that compile or run those generated service files must provide +their own gRPC dependencies. Fory's Java and Python runtime packages do not add a +hard gRPC dependency for this feature. + **Use import search paths:** ```bash diff --git a/docs/compiler/flatbuffers-idl.md b/docs/compiler/flatbuffers-idl.md index 90600640b4..e57d1c8a4d 100644 --- a/docs/compiler/flatbuffers-idl.md +++ b/docs/compiler/flatbuffers-idl.md @@ -122,6 +122,27 @@ message Container { } ``` +### Services + +FlatBuffers `rpc_service` definitions are translated to Fory services. With +`--grpc`, the compiler emits Java and Python gRPC service companions that use +Fory serialization for request and response payloads. + +```fbs +rpc_service SearchService { + Lookup(SearchRequest):SearchResponse; + StreamLookup(SearchRequest):SearchResponse (streaming: "server"); +} +``` + +```bash +foryc api.fbs --java_out=./generated/java --python_out=./generated/python --grpc +``` + +Generated service code imports grpc APIs, so applications must provide grpc-java +or `grpcio` dependencies when they compile or run those files. The Fory runtime +packages do not add gRPC as a hard dependency. + ### Defaults and Metadata - FlatBuffers default values are parsed but not applied as Fory runtime defaults. diff --git a/docs/compiler/generated-code.md b/docs/compiler/generated-code.md index 520dda8d3b..beee9cf15d 100644 --- a/docs/compiler/generated-code.md +++ b/docs/compiler/generated-code.md @@ -238,6 +238,32 @@ byte[] data = person.toBytes(); Person restored = Person.fromBytes(data); ``` +### gRPC Service Companions + +When a schema contains services and the compiler is run with `--grpc`, Java +generation emits one `Grpc.java` file per service next to the model +types. + +```java +public final class AddressBookServiceGrpc { + public static final String SERVICE_NAME = "addressbook.AddressBookService"; + + public static AddressBookServiceStub newStub(io.grpc.Channel channel) { ... } + public static AddressBookServiceBlockingStub newBlockingStub(io.grpc.Channel channel) { ... } + public static AddressBookServiceFutureStub newFutureStub(io.grpc.Channel channel) { ... } + + public abstract static class AddressBookServiceImplBase + implements io.grpc.BindableService { + public void lookup(Person request, io.grpc.stub.StreamObserver responseObserver) { ... } + } +} +``` + +The generated marshaller serializes each request or response with the schema +module's `ThreadSafeFory`. It uses grpc-java's `MethodDescriptor.Marshaller` +API, so applications compiling these files must provide grpc-java dependencies. +Those dependencies are not added to Fory Java runtime artifacts. + ## Python ### Output Layout @@ -337,6 +363,38 @@ data = person.to_bytes() restored = Person.from_bytes(data) ``` +### gRPC Service Companions + +When a schema contains services and the compiler is run with `--grpc`, Python +generation emits a companion module named `_grpc.py`. The module name is +derived from the Fory package by replacing dots with underscores, or `generated` +when the schema has no package. + +```python +class AddressBookServiceStub: + def __init__(self, channel): + self.lookup = channel.unary_unary( + "/addressbook.AddressBookService/Lookup", + request_serializer=_serialize, + response_deserializer=_deserialize, + ) + + +class AddressBookServiceServicer: + def lookup(self, request, context): + raise NotImplementedError("Method not implemented!") + + +def add_servicer(servicer, server): ... +``` + +Python gRPC serializers receive and return complete `bytes` payloads, so the +generated callbacks call the model module's `_get_fory().serialize(...)` and +`_get_fory().deserialize(...)` directly. Applications using the generated +companion module must install `grpcio`; `pyfory` does not add a hard gRPC +dependency. The Python API uses snake_case method names while preserving the +original IDL method names in the gRPC wire paths. + ## Rust ### Output Layout diff --git a/docs/compiler/index.md b/docs/compiler/index.md index 8f516ec7ab..0b227d7e2a 100644 --- a/docs/compiler/index.md +++ b/docs/compiler/index.md @@ -22,7 +22,9 @@ license: | Fory IDL is a schema definition language for Apache Fory that enables type-safe cross-language serialization. Define your data structures once and generate native data structure code for Java, Python, Go, Rust, C++, C#, Swift, -JavaScript, Dart, Scala, and Kotlin. +JavaScript, Dart, Scala, and Kotlin. Fory IDL can also describe RPC services; +for Java and Python, the compiler can generate gRPC service companions that use +Fory serialization for request and response payloads. ## Example Schema @@ -70,8 +72,32 @@ union Animal [id=106] { Dog dog = 1; Cat cat = 2; } + +message LookupRequest [id=107] { + string name = 1; +} + +message LookupResponse [id=108] { + Animal animal = 1; +} + +service AnimalService { + rpc Lookup (LookupRequest) returns (LookupResponse); + rpc Classify (Animal) returns (Animal); +} +``` + +Generate Java and Python models plus gRPC service companions with: + +```bash +foryc animals.fdl --java_out=./generated/java --python_out=./generated/python --grpc ``` +The generated service code uses normal gRPC APIs, but request and response +objects are serialized with Fory. Applications provide their own grpc-java or +`grpcio` dependencies; Fory runtime packages do not add gRPC as a hard +dependency. + ## Why Fory IDL? ### Schema-First Development @@ -170,14 +196,15 @@ data = bytes(person) # or `person.to_bytes()` ## Documentation -| Document | Description | -| ----------------------------------------------- | ------------------------------------------------- | -| [Fory IDL Syntax](schema-idl.md) | Complete language syntax and grammar | -| [Type System](schema-idl.md#type-system) | Primitive types, collections, and type rules | -| [Compiler Guide](compiler-guide.md) | CLI options and build integration | -| [Generated Code](generated-code.md) | Output format for each target language | -| [Protocol Buffers IDL Support](protobuf-idl.md) | Protobuf mapping rules and adoption guidance | -| [FlatBuffers IDL Support](flatbuffers-idl.md) | FlatBuffers mapping rules and codegen differences | +| Document | Description | +| ------------------------------------------------ | ------------------------------------------------- | +| [Fory IDL Syntax](schema-idl.md) | Complete language syntax and grammar | +| [Type System](schema-idl.md#type-system) | Primitive types, collections, and type rules | +| [RPC Services](schema-idl.md#service-definition) | Service and RPC method syntax | +| [Compiler Guide](compiler-guide.md) | CLI options and build integration | +| [Generated Code](generated-code.md) | Output format for each target language | +| [Protocol Buffers IDL Support](protobuf-idl.md) | Protobuf mapping rules and adoption guidance | +| [FlatBuffers IDL Support](flatbuffers-idl.md) | FlatBuffers mapping rules and codegen differences | ## Key Concepts diff --git a/docs/compiler/protobuf-idl.md b/docs/compiler/protobuf-idl.md index 0105a89558..3a13fb8ff4 100644 --- a/docs/compiler/protobuf-idl.md +++ b/docs/compiler/protobuf-idl.md @@ -49,10 +49,13 @@ how protobuf concepts map to Fory, and how to use protobuf-only Fory extension o | Circular refs | Not supported | Supported | | Unknown fields | Preserved | Not preserved | | Generated types | Protobuf-specific model types | Native language constructs | -| gRPC ecosystem | Native | In progress (active development) | +| gRPC ecosystem | Native | Java/Python service codegen | -Fory gRPC support is under active development. For production gRPC -workflows today, protobuf remains the mature/default choice. +Fory can generate Java and Python gRPC service companions with `--grpc`. Those +services use normal gRPC transports but serialize request and response payloads +with Fory rather than protobuf. For broad gRPC ecosystem tooling, schema +reflection, and protobuf-native interceptors, protobuf remains the mature/default +choice. ## Why Use Apache Fory @@ -307,6 +310,19 @@ modifiers (and optional `ref(weak=true)` where needed). Replace protobuf generation steps with the Fory compiler invocation for target languages. +For Java and Python services, add `--grpc` to emit gRPC companion code: + +```bash +foryc api.proto --java_out=./generated/java --python_out=./generated/python --grpc +``` + +Generated Java service files compile against grpc-java, and generated Python +service modules import `grpc`. Add those dependencies in your application build; +Fory runtime packages do not add gRPC as a hard dependency. Protobuf `oneof` +fields are translated to Fory union fields inside request and response messages. +Direct union RPC request or response types are not part of normal protobuf RPC +syntax. + ### Step 5: Run Compatibility Checks For staged transitions, keep both formats in parallel and verify payload-level diff --git a/docs/compiler/schema-idl.md b/docs/compiler/schema-idl.md index 799b8fb0ee..8b5835faaf 100644 --- a/docs/compiler/schema-idl.md +++ b/docs/compiler/schema-idl.md @@ -34,6 +34,7 @@ An Fory IDL file typically consists of: 2. Optional file-level options 3. Optional import statements 4. Type definitions (enums, messages, and unions) +5. Optional service definitions ```protobuf // Optional package declaration @@ -48,8 +49,14 @@ import "common/types.fdl"; // Type definitions enum Color [id=100] { ... } message User [id=101] { ... } -message Order [id=102] { ... } -union Event [id=103] { ... } +message OrderRequest [id=102] { ... } +message Order [id=103] { ... } +union Event [id=104] { ... } + +// Service definitions +service OrderService { + rpc GetOrder (OrderRequest) returns (Order); +} ``` ## Comments @@ -876,6 +883,62 @@ union_def := 'union' IDENTIFIER [type_options] '{' union_field* '}' union_field := field_type IDENTIFIER '=' INTEGER ';' ``` +## Service Definition + +Services define RPC method contracts in Fory IDL. They are optional: schemas +with services still generate the normal data model types, and gRPC service code +is generated only when the compiler is run with `--grpc` for Java or Python. + +```protobuf +message GetPetRequest [id=200] { + string name = 1; +} + +message PetRecord [id=201] { + string name = 1; + Animal animal = 2; +} + +service PetDirectory { + rpc GetPet (GetPetRequest) returns (PetRecord); + rpc Classify (Animal) returns (Animal); +} +``` + +The first method uses message request and response types. The second method uses +a direct union request and response type, which is supported in Fory IDL. + +### Streaming RPCs + +Use `stream` before the request type, the response type, or both: + +```protobuf +service PetDirectory { + rpc GetPet (GetPetRequest) returns (PetRecord); // unary + rpc WatchPets (GetPetRequest) returns (stream PetRecord); // server streaming + rpc ImportPets (stream PetRecord) returns (PetRecord); // client streaming + rpc ChatPets (stream Animal) returns (stream Animal); // bidirectional streaming +} +``` + +### RPC Type Rules + +- Request and response types must reference named message or union types. +- Enum, primitive, collection, map, and array types are not valid direct RPC + request or response types. Wrap those values in a message when they are part + of a service contract. +- The generated Java and Python gRPC companions use Fory serialization for each + RPC payload. Applications that compile or run those companions provide their + own grpc-java or `grpcio` dependency. + +**Grammar:** + +``` +service_def := 'service' IDENTIFIER '{' rpc_method* '}' +rpc_method := 'rpc' IDENTIFIER '(' ['stream'] named_type ')' + 'returns' '(' ['stream'] named_type ')' ';' +``` + ## Field Definition Fields define the properties of a message. @@ -1570,7 +1633,7 @@ For protobuf-specific extension options and `(fory).` syntax, see ## Grammar Summary ``` -file := [package_decl] file_option* import_decl* type_def* +file := [package_decl] file_option* import_decl* definition* package_decl := 'package' package_name ['alias' package_name] ';' package_name := IDENTIFIER ('.' IDENTIFIER)* @@ -1580,6 +1643,7 @@ option_name := IDENTIFIER import_decl := 'import' STRING ';' +definition := type_def | service_def type_def := enum_def | message_def | union_def enum_def := 'enum' IDENTIFIER [type_options] '{' enum_body '}' @@ -1593,6 +1657,10 @@ field_def := [modifiers] field_type IDENTIFIER '=' INTEGER [field_options] '; union_def := 'union' IDENTIFIER [type_options] '{' union_field* '}' union_field := ['repeated'] field_type IDENTIFIER '=' INTEGER [field_options] ';' + +service_def := 'service' IDENTIFIER '{' rpc_method* '}' +rpc_method := 'rpc' IDENTIFIER '(' ['stream'] named_type ')' + 'returns' '(' ['stream'] named_type ')' ';' option_value := 'true' | 'false' | IDENTIFIER | INTEGER | STRING reserved_stmt := 'reserved' reserved_items ';' diff --git a/integration_tests/grpc_tests/generate_grpc.py b/integration_tests/grpc_tests/generate_grpc.py new file mode 100644 index 0000000000..22ea0a4467 --- /dev/null +++ b/integration_tests/grpc_tests/generate_grpc.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] +TEST_DIR = Path(__file__).resolve().parent + +SCHEMAS = [ + TEST_DIR / "idl" / "grpc_fdl.fdl", + TEST_DIR / "idl" / "grpc_pb.proto", + TEST_DIR / "idl" / "grpc_fbs.fbs", +] + +OUTPUTS = { + "java": TEST_DIR / "java/src/main/java/generated", + "python": TEST_DIR / "python/grpc_tests/generated", +} + + +def main() -> int: + env = os.environ.copy() + compiler_path = str(REPO_ROOT / "compiler") + env["PYTHONPATH"] = compiler_path + os.pathsep + env.get("PYTHONPATH", "") + + for root in OUTPUTS.values(): + root.mkdir(parents=True, exist_ok=True) + subprocess.check_call( + [ + sys.executable, + "-m", + "fory_compiler", + "--scan-generated", + "--delete", + "--root", + str(root), + ], + env=env, + ) + + for schema in SCHEMAS: + subprocess.check_call( + [ + sys.executable, + "-m", + "fory_compiler", + "compile", + str(schema), + f"--java_out={OUTPUTS['java']}", + f"--python_out={OUTPUTS['python']}", + "--grpc", + ], + env=env, + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/integration_tests/grpc_tests/idl/grpc_fbs.fbs b/integration_tests/grpc_tests/idl/grpc_fbs.fbs new file mode 100644 index 0000000000..7727499964 --- /dev/null +++ b/integration_tests/grpc_tests/idl/grpc_fbs.fbs @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace grpc_fbs; + +table GrpcFbsRequest { + id: string; + count: int; + payload: string; +} + +table GrpcFbsResponse { + id: string; + count: int; + payload: string; +} + +union GrpcFbsUnion { + GrpcFbsRequest, + GrpcFbsResponse +} + +rpc_service FbsGrpcService { + UnaryMessage(GrpcFbsRequest):GrpcFbsResponse; + ServerStreamMessage(GrpcFbsRequest):GrpcFbsResponse (streaming: "server"); + ClientStreamMessage(GrpcFbsRequest):GrpcFbsResponse (streaming: "client"); + BidiStreamMessage(GrpcFbsRequest):GrpcFbsResponse (streaming: "bidi"); + UnaryUnion(GrpcFbsUnion):GrpcFbsUnion; + ServerStreamUnion(GrpcFbsUnion):GrpcFbsUnion (streaming: "server"); + ClientStreamUnion(GrpcFbsUnion):GrpcFbsUnion (streaming: "client"); + BidiStreamUnion(GrpcFbsUnion):GrpcFbsUnion (streaming: "bidi"); +} diff --git a/integration_tests/grpc_tests/idl/grpc_fdl.fdl b/integration_tests/grpc_tests/idl/grpc_fdl.fdl new file mode 100644 index 0000000000..a552cb1c3a --- /dev/null +++ b/integration_tests/grpc_tests/idl/grpc_fdl.fdl @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package grpc_fdl; + +message GrpcFdlRequest { + string id = 1; + int32 count = 2; + string payload = 3; +} + +message GrpcFdlResponse { + string id = 1; + int32 count = 2; + string payload = 3; +} + +union GrpcFdlUnion { + GrpcFdlRequest request = 1; + GrpcFdlResponse response = 2; +} + +service FdlGrpcService { + rpc UnaryMessage (GrpcFdlRequest) returns (GrpcFdlResponse); + rpc ServerStreamMessage (GrpcFdlRequest) returns (stream GrpcFdlResponse); + rpc ClientStreamMessage (stream GrpcFdlRequest) returns (GrpcFdlResponse); + rpc BidiStreamMessage (stream GrpcFdlRequest) returns (stream GrpcFdlResponse); + rpc UnaryUnion (GrpcFdlUnion) returns (GrpcFdlUnion); + rpc ServerStreamUnion (GrpcFdlUnion) returns (stream GrpcFdlUnion); + rpc ClientStreamUnion (stream GrpcFdlUnion) returns (GrpcFdlUnion); + rpc BidiStreamUnion (stream GrpcFdlUnion) returns (stream GrpcFdlUnion); +} diff --git a/integration_tests/grpc_tests/idl/grpc_pb.proto b/integration_tests/grpc_tests/idl/grpc_pb.proto new file mode 100644 index 0000000000..c6482c5900 --- /dev/null +++ b/integration_tests/grpc_tests/idl/grpc_pb.proto @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +syntax = "proto3"; + +package grpc_pb; + +message GrpcPbRequest { + string id = 1; + int32 count = 2; + oneof payload { + string text = 3; + int32 number = 4; + } +} + +message GrpcPbResponse { + string id = 1; + int32 count = 2; + oneof payload { + string text = 3; + int32 number = 4; + } +} + +service PbGrpcService { + rpc UnaryMessage (GrpcPbRequest) returns (GrpcPbResponse); + rpc ServerStreamMessage (GrpcPbRequest) returns (stream GrpcPbResponse); + rpc ClientStreamMessage (stream GrpcPbRequest) returns (GrpcPbResponse); + rpc BidiStreamMessage (stream GrpcPbRequest) returns (stream GrpcPbResponse); +} diff --git a/integration_tests/grpc_tests/java/pom.xml b/integration_tests/grpc_tests/java/pom.xml new file mode 100644 index 0000000000..7ce2e2baa3 --- /dev/null +++ b/integration_tests/grpc_tests/java/pom.xml @@ -0,0 +1,69 @@ + + + + + org.apache.fory + fory-parent + 0.18.0-SNAPSHOT + ../../../java + + 4.0.0 + + grpc_tests + + + 1.62.2 + 8 + 8 + UTF-8 + + + + + org.apache.fory + fory-core + ${project.version} + + + io.grpc + grpc-api + ${grpc.version} + + + io.grpc + grpc-netty-shaded + ${grpc.version} + + + io.grpc + grpc-stub + ${grpc.version} + + + org.testng + testng + test + + + diff --git a/integration_tests/grpc_tests/java/src/test/java/org/apache/fory/grpc_tests/GrpcInteropTest.java b/integration_tests/grpc_tests/java/src/test/java/org/apache/fory/grpc_tests/GrpcInteropTest.java new file mode 100644 index 0000000000..791c2a6a30 --- /dev/null +++ b/integration_tests/grpc_tests/java/src/test/java/org/apache/fory/grpc_tests/GrpcInteropTest.java @@ -0,0 +1,907 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.grpc_tests; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class GrpcInteropTest { + + @Test + public void testJavaServerPythonClient() throws Exception { + Server server = + ServerBuilder.forPort(0) + .addService(new FdlService()) + .addService(new FbsService()) + .addService(new PbService()) + .build() + .start(); + try { + runPython("python-grpc-client", "client", "--target", "127.0.0.1:" + server.getPort()); + } finally { + server.shutdownNow(); + server.awaitTermination(10, TimeUnit.SECONDS); + } + } + + @Test + public void testJavaClientPythonServer() throws Exception { + Path portFile = Files.createTempFile("fory-grpc-python-", ".port"); + Files.deleteIfExists(portFile); + PeerCommand command = pythonCommand("server", "--port-file", portFile.toString()); + Process process = startPeer(command); + PeerOutputCollector outputCollector = + new PeerOutputCollector(process.getInputStream(), "python-grpc-server"); + outputCollector.start(); + try { + int port = waitForPort(process, outputCollector, portFile); + ManagedChannel channel = + ManagedChannelBuilder.forAddress("127.0.0.1", port).usePlaintext().build(); + try { + exerciseFdl(channel); + exerciseFbs(channel); + exercisePb(channel); + } finally { + channel.shutdownNow(); + channel.awaitTermination(10, TimeUnit.SECONDS); + } + } finally { + process.destroy(); + process.waitFor(10, TimeUnit.SECONDS); + if (process.isAlive()) { + process.destroyForcibly(); + process.waitFor(10, TimeUnit.SECONDS); + } + outputCollector.awaitOutput(); + Files.deleteIfExists(portFile); + } + } + + private void exerciseFdl(ManagedChannel channel) throws InterruptedException { + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceBlockingStub blocking = + grpc_fdl.FdlGrpcServiceGrpc.newBlockingStub(channel); + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceStub async = + grpc_fdl.FdlGrpcServiceGrpc.newStub(channel); + + List messages = + Arrays.asList(fdlRequest("fdl-a", 1, "alpha"), fdlRequest("fdl-b", 2, "beta")); + assertFdlMessages(blocking, async, messages); + + List unions = + Arrays.asList( + grpc_fdl.GrpcFdlUnion.ofRequest(fdlRequest("fdl-u-a", 3, "union-alpha")), + grpc_fdl.GrpcFdlUnion.ofRequest(fdlRequest("fdl-u-b", 4, "union-beta"))); + assertFdlUnions(blocking, async, unions); + } + + private void assertFdlMessages( + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceBlockingStub blocking, + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceStub async, + List requests) + throws InterruptedException { + grpc_fdl.GrpcFdlRequest first = requests.get(0); + Assert.assertEquals(blocking.unaryMessage(first), fdlResponse(first, "unary", 10)); + Assert.assertEquals( + toList(blocking.serverStreamMessage(first)), + Arrays.asList( + fdlResponse(first, "server-0", 0), + fdlResponse(first, "server-1", 1), + fdlResponse(first, "server-2", 2))); + + CollectingObserver clientObserver = new CollectingObserver<>(); + sendAll(async.clientStreamMessage(clientObserver), requests); + Assert.assertEquals(clientObserver.await(), Collections.singletonList(fdlAggregate(requests))); + + CollectingObserver bidiObserver = new CollectingObserver<>(); + sendAll(async.bidiStreamMessage(bidiObserver), requests); + Assert.assertEquals( + bidiObserver.await(), + Arrays.asList( + fdlResponse(requests.get(0), "bidi-0", 0), fdlResponse(requests.get(1), "bidi-1", 1))); + } + + private void assertFdlUnions( + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceBlockingStub blocking, + grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceStub async, + List requests) + throws InterruptedException { + grpc_fdl.GrpcFdlRequest first = fdlRequestFromUnion(requests.get(0)); + Assert.assertEquals(blocking.unaryUnion(requests.get(0)), fdlUnionResponse(first, "unary", 10)); + Assert.assertEquals( + toList(blocking.serverStreamUnion(requests.get(0))), + Arrays.asList( + fdlUnionResponse(first, "server-0", 0), + fdlUnionResponse(first, "server-1", 1), + fdlUnionResponse(first, "server-2", 2))); + + CollectingObserver clientObserver = new CollectingObserver<>(); + sendAll(async.clientStreamUnion(clientObserver), requests); + Assert.assertEquals( + clientObserver.await(), Collections.singletonList(fdlUnionAggregate(requests))); + + CollectingObserver bidiObserver = new CollectingObserver<>(); + sendAll(async.bidiStreamUnion(bidiObserver), requests); + Assert.assertEquals( + bidiObserver.await(), + Arrays.asList( + fdlUnionResponse(fdlRequestFromUnion(requests.get(0)), "bidi-0", 0), + fdlUnionResponse(fdlRequestFromUnion(requests.get(1)), "bidi-1", 1))); + } + + private void exerciseFbs(ManagedChannel channel) throws InterruptedException { + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceBlockingStub blocking = + grpc_fbs.FbsGrpcServiceGrpc.newBlockingStub(channel); + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceStub async = + grpc_fbs.FbsGrpcServiceGrpc.newStub(channel); + + List messages = + Arrays.asList(fbsRequest("fbs-a", 5, "alpha"), fbsRequest("fbs-b", 6, "beta")); + assertFbsMessages(blocking, async, messages); + + List unions = + Arrays.asList( + grpc_fbs.GrpcFbsUnion.ofGrpcFbsRequest(fbsRequest("fbs-u-a", 7, "union-alpha")), + grpc_fbs.GrpcFbsUnion.ofGrpcFbsRequest(fbsRequest("fbs-u-b", 8, "union-beta"))); + assertFbsUnions(blocking, async, unions); + } + + private void assertFbsMessages( + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceBlockingStub blocking, + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceStub async, + List requests) + throws InterruptedException { + grpc_fbs.GrpcFbsRequest first = requests.get(0); + Assert.assertEquals(blocking.unaryMessage(first), fbsResponse(first, "unary", 10)); + Assert.assertEquals( + toList(blocking.serverStreamMessage(first)), + Arrays.asList( + fbsResponse(first, "server-0", 0), + fbsResponse(first, "server-1", 1), + fbsResponse(first, "server-2", 2))); + + CollectingObserver clientObserver = new CollectingObserver<>(); + sendAll(async.clientStreamMessage(clientObserver), requests); + Assert.assertEquals(clientObserver.await(), Collections.singletonList(fbsAggregate(requests))); + + CollectingObserver bidiObserver = new CollectingObserver<>(); + sendAll(async.bidiStreamMessage(bidiObserver), requests); + Assert.assertEquals( + bidiObserver.await(), + Arrays.asList( + fbsResponse(requests.get(0), "bidi-0", 0), fbsResponse(requests.get(1), "bidi-1", 1))); + } + + private void assertFbsUnions( + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceBlockingStub blocking, + grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceStub async, + List requests) + throws InterruptedException { + grpc_fbs.GrpcFbsRequest first = fbsRequestFromUnion(requests.get(0)); + Assert.assertEquals(blocking.unaryUnion(requests.get(0)), fbsUnionResponse(first, "unary", 10)); + Assert.assertEquals( + toList(blocking.serverStreamUnion(requests.get(0))), + Arrays.asList( + fbsUnionResponse(first, "server-0", 0), + fbsUnionResponse(first, "server-1", 1), + fbsUnionResponse(first, "server-2", 2))); + + CollectingObserver clientObserver = new CollectingObserver<>(); + sendAll(async.clientStreamUnion(clientObserver), requests); + Assert.assertEquals( + clientObserver.await(), Collections.singletonList(fbsUnionAggregate(requests))); + + CollectingObserver bidiObserver = new CollectingObserver<>(); + sendAll(async.bidiStreamUnion(bidiObserver), requests); + Assert.assertEquals( + bidiObserver.await(), + Arrays.asList( + fbsUnionResponse(fbsRequestFromUnion(requests.get(0)), "bidi-0", 0), + fbsUnionResponse(fbsRequestFromUnion(requests.get(1)), "bidi-1", 1))); + } + + private void exercisePb(ManagedChannel channel) throws InterruptedException { + grpc_pb.PbGrpcServiceGrpc.PbGrpcServiceBlockingStub blocking = + grpc_pb.PbGrpcServiceGrpc.newBlockingStub(channel); + grpc_pb.PbGrpcServiceGrpc.PbGrpcServiceStub async = grpc_pb.PbGrpcServiceGrpc.newStub(channel); + + List requests = + Arrays.asList( + pbRequest("pb-a", 9, grpc_pb.GrpcPbRequest.Payload.ofText("alpha")), + pbRequest("pb-b", 10, grpc_pb.GrpcPbRequest.Payload.ofNumber(42))); + grpc_pb.GrpcPbRequest first = requests.get(0); + Assert.assertEquals(blocking.unaryMessage(first), pbResponse(first, "unary", 10)); + Assert.assertEquals( + toList(blocking.serverStreamMessage(first)), + Arrays.asList( + pbResponse(first, "server-0", 0), + pbResponse(first, "server-1", 1), + pbResponse(first, "server-2", 2))); + + CollectingObserver clientObserver = new CollectingObserver<>(); + sendAll(async.clientStreamMessage(clientObserver), requests); + Assert.assertEquals(clientObserver.await(), Collections.singletonList(pbAggregate(requests))); + + CollectingObserver bidiObserver = new CollectingObserver<>(); + sendAll(async.bidiStreamMessage(bidiObserver), requests); + Assert.assertEquals( + bidiObserver.await(), + Arrays.asList( + pbResponse(requests.get(0), "bidi-0", 0), pbResponse(requests.get(1), "bidi-1", 1))); + } + + private PeerCommand pythonCommand(String... args) { + Path repoRoot = repoRoot(); + Path grpcRoot = repoRoot.resolve("integration_tests").resolve("grpc_tests"); + Path pythonRoot = grpcRoot.resolve("python"); + String pythonPath = + pythonRoot.resolve("grpc_tests").resolve("generated") + + File.pathSeparator + + pythonRoot + + File.pathSeparator + + repoRoot.resolve("python"); + String existingPythonPath = System.getenv("PYTHONPATH"); + if (existingPythonPath != null && !existingPythonPath.isEmpty()) { + pythonPath = pythonPath + File.pathSeparator + existingPythonPath; + } + List command = new ArrayList<>(); + command.add("python"); + command.add("-m"); + command.add("grpc_tests.grpc_interop"); + command.addAll(Arrays.asList(args)); + PeerCommand peerCommand = new PeerCommand(); + peerCommand.command = command; + peerCommand.workDir = grpcRoot; + peerCommand.environment.put("PYTHONPATH", pythonPath); + peerCommand.environment.put("ENABLE_FORY_CYTHON_SERIALIZATION", "0"); + peerCommand.environment.put("ENABLE_FORY_DEBUG_OUTPUT", "1"); + peerCommand.environment.put("NO_PROXY", "127.0.0.1,localhost"); + peerCommand.environment.put("no_proxy", "127.0.0.1,localhost"); + // Some developer and CI environments set proxy variables that grpcio honors + // even for localhost unless no_proxy is configured correctly. + for (String proxyVar : + Arrays.asList( + "all_proxy", "http_proxy", "https_proxy", "ALL_PROXY", "HTTP_PROXY", "HTTPS_PROXY")) { + peerCommand.environment.put(proxyVar, ""); + } + return peerCommand; + } + + private void runPython(String peer, String... args) throws IOException, InterruptedException { + Process process = startPeer(pythonCommand(args)); + PeerOutputCollector outputCollector = new PeerOutputCollector(process.getInputStream(), peer); + outputCollector.start(); + boolean finished = process.waitFor(180, TimeUnit.SECONDS); + if (!finished) { + process.destroyForcibly(); + process.waitFor(10, TimeUnit.SECONDS); + Assert.fail("Peer process timed out for " + peer + peerOutput(outputCollector)); + } + int exitCode = process.exitValue(); + if (exitCode != 0) { + Assert.fail( + "Peer process failed for " + + peer + + " with exit code " + + exitCode + + peerOutput(outputCollector)); + } + outputCollector.awaitOutput(); + } + + private Process startPeer(PeerCommand command) throws IOException { + ProcessBuilder builder = new ProcessBuilder(command.command); + builder.redirectErrorStream(true); + builder.directory(command.workDir.toFile()); + builder.environment().putAll(command.environment); + return builder.start(); + } + + private int waitForPort(Process process, PeerOutputCollector outputCollector, Path portFile) + throws IOException, InterruptedException { + long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(60); + while (System.nanoTime() < deadline) { + if (!process.isAlive()) { + Assert.fail("Python gRPC server exited early" + peerOutput(outputCollector)); + } + if (Files.exists(portFile)) { + String value = new String(Files.readAllBytes(portFile), StandardCharsets.UTF_8).trim(); + if (!value.isEmpty()) { + return Integer.parseInt(value); + } + } + Thread.sleep(100); + } + process.destroyForcibly(); + process.waitFor(10, TimeUnit.SECONDS); + Assert.fail("Timed out waiting for Python gRPC server port" + peerOutput(outputCollector)); + return -1; + } + + private String peerOutput(PeerOutputCollector outputCollector) + throws IOException, InterruptedException { + String output = outputCollector.awaitOutput(); + return output.isEmpty() ? "" : "\noutput:\n" + output; + } + + private Path repoRoot() { + Path moduleDir = java.nio.file.Paths.get("").toAbsolutePath(); + return moduleDir.getParent().getParent().getParent(); + } + + private static grpc_fdl.GrpcFdlRequest fdlRequest(String id, int count, String payload) { + grpc_fdl.GrpcFdlRequest request = new grpc_fdl.GrpcFdlRequest(); + request.setId(id); + request.setCount(count); + request.setPayload(payload); + return request; + } + + private static grpc_fdl.GrpcFdlResponse fdlResponse( + grpc_fdl.GrpcFdlRequest request, String tag, int offset) { + grpc_fdl.GrpcFdlResponse response = new grpc_fdl.GrpcFdlResponse(); + response.setId(tag + ":" + request.getId()); + response.setCount(request.getCount() + offset); + response.setPayload(tag + ":" + request.getPayload()); + return response; + } + + private static grpc_fdl.GrpcFdlResponse fdlAggregate(List requests) { + grpc_fdl.GrpcFdlResponse response = new grpc_fdl.GrpcFdlResponse(); + response.setId("client:" + joinFdlIds(requests)); + response.setCount(requests.stream().mapToInt(grpc_fdl.GrpcFdlRequest::getCount).sum()); + response.setPayload("client:" + joinFdlPayloads(requests)); + return response; + } + + private static grpc_fdl.GrpcFdlUnion fdlUnionResponse( + grpc_fdl.GrpcFdlRequest request, String tag, int offset) { + return grpc_fdl.GrpcFdlUnion.ofResponse(fdlResponse(request, tag, offset)); + } + + private static grpc_fdl.GrpcFdlUnion fdlUnionAggregate(List unions) { + return grpc_fdl.GrpcFdlUnion.ofResponse( + fdlAggregate(map(unions, GrpcInteropTest::fdlRequestFromUnion))); + } + + private static grpc_fdl.GrpcFdlRequest fdlRequestFromUnion(grpc_fdl.GrpcFdlUnion union) { + Assert.assertTrue(union.hasRequest()); + return union.getRequest(); + } + + private static grpc_fbs.GrpcFbsRequest fbsRequest(String id, int count, String payload) { + grpc_fbs.GrpcFbsRequest request = new grpc_fbs.GrpcFbsRequest(); + request.setId(id); + request.setCount(count); + request.setPayload(payload); + return request; + } + + private static grpc_fbs.GrpcFbsResponse fbsResponse( + grpc_fbs.GrpcFbsRequest request, String tag, int offset) { + grpc_fbs.GrpcFbsResponse response = new grpc_fbs.GrpcFbsResponse(); + response.setId(tag + ":" + request.getId()); + response.setCount(request.getCount() + offset); + response.setPayload(tag + ":" + request.getPayload()); + return response; + } + + private static grpc_fbs.GrpcFbsResponse fbsAggregate(List requests) { + grpc_fbs.GrpcFbsResponse response = new grpc_fbs.GrpcFbsResponse(); + response.setId("client:" + joinFbsIds(requests)); + response.setCount(requests.stream().mapToInt(grpc_fbs.GrpcFbsRequest::getCount).sum()); + response.setPayload("client:" + joinFbsPayloads(requests)); + return response; + } + + private static grpc_fbs.GrpcFbsUnion fbsUnionResponse( + grpc_fbs.GrpcFbsRequest request, String tag, int offset) { + return grpc_fbs.GrpcFbsUnion.ofGrpcFbsResponse(fbsResponse(request, tag, offset)); + } + + private static grpc_fbs.GrpcFbsUnion fbsUnionAggregate(List unions) { + return grpc_fbs.GrpcFbsUnion.ofGrpcFbsResponse( + fbsAggregate(map(unions, GrpcInteropTest::fbsRequestFromUnion))); + } + + private static grpc_fbs.GrpcFbsRequest fbsRequestFromUnion(grpc_fbs.GrpcFbsUnion union) { + Assert.assertTrue(union.hasGrpcFbsRequest()); + return union.getGrpcFbsRequest(); + } + + private static grpc_pb.GrpcPbRequest pbRequest( + String id, int count, grpc_pb.GrpcPbRequest.Payload payload) { + grpc_pb.GrpcPbRequest request = new grpc_pb.GrpcPbRequest(); + request.setId(id); + request.setCount(count); + request.setPayload(payload); + return request; + } + + private static grpc_pb.GrpcPbResponse pbResponse( + grpc_pb.GrpcPbRequest request, String tag, int offset) { + grpc_pb.GrpcPbResponse response = new grpc_pb.GrpcPbResponse(); + response.setId(tag + ":" + request.getId()); + response.setCount(request.getCount() + offset); + response.setPayload(pbResponsePayload(request.getPayload(), tag, offset)); + return response; + } + + private static grpc_pb.GrpcPbResponse pbAggregate(List requests) { + grpc_pb.GrpcPbResponse response = new grpc_pb.GrpcPbResponse(); + response.setId("client:" + joinPbIds(requests)); + response.setCount(requests.stream().mapToLong(grpc_pb.GrpcPbRequest::getCount).sum()); + response.setPayload(grpc_pb.GrpcPbResponse.Payload.ofText("client:" + joinPbIds(requests))); + return response; + } + + private static grpc_pb.GrpcPbResponse.Payload pbResponsePayload( + grpc_pb.GrpcPbRequest.Payload payload, String tag, int offset) { + if (payload == null) { + return null; + } + if (payload.hasText()) { + return grpc_pb.GrpcPbResponse.Payload.ofText(tag + ":" + payload.getText()); + } + Assert.assertTrue(payload.hasNumber()); + return grpc_pb.GrpcPbResponse.Payload.ofNumber(payload.getNumber() + offset); + } + + private static String joinFdlIds(List requests) { + List ids = new ArrayList<>(); + for (grpc_fdl.GrpcFdlRequest request : requests) { + ids.add(request.getId()); + } + return String.join("+", ids); + } + + private static String joinFdlPayloads(List requests) { + List payloads = new ArrayList<>(); + for (grpc_fdl.GrpcFdlRequest request : requests) { + payloads.add(request.getPayload()); + } + return String.join("+", payloads); + } + + private static String joinFbsIds(List requests) { + List ids = new ArrayList<>(); + for (grpc_fbs.GrpcFbsRequest request : requests) { + ids.add(request.getId()); + } + return String.join("+", ids); + } + + private static String joinFbsPayloads(List requests) { + List payloads = new ArrayList<>(); + for (grpc_fbs.GrpcFbsRequest request : requests) { + payloads.add(request.getPayload()); + } + return String.join("+", payloads); + } + + private static String joinPbIds(List requests) { + List ids = new ArrayList<>(); + for (grpc_pb.GrpcPbRequest request : requests) { + ids.add(request.getId()); + } + return String.join("+", ids); + } + + private static void sendAll(StreamObserver observer, List values) { + for (T value : values) { + observer.onNext(value); + } + observer.onCompleted(); + } + + private static List toList(Iterator iterator) { + List result = new ArrayList<>(); + while (iterator.hasNext()) { + result.add(iterator.next()); + } + return result; + } + + private static List map(List values, Function mapper) { + List result = new ArrayList<>(); + for (T value : values) { + result.add(mapper.apply(value)); + } + return result; + } + + private static void sendOne(StreamObserver observer, T value) { + observer.onNext(value); + observer.onCompleted(); + } + + private static void sendMany(StreamObserver observer, List values) { + for (T value : values) { + observer.onNext(value); + } + observer.onCompleted(); + } + + private static StreamObserver collectAndRespond( + StreamObserver responseObserver, Function, Resp> responseFactory) { + return new StreamObserver() { + private final List requests = new ArrayList<>(); + + @Override + public void onNext(Req value) { + requests.add(value); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + sendOne(responseObserver, responseFactory.apply(requests)); + } + }; + } + + private static final class FdlService extends grpc_fdl.FdlGrpcServiceGrpc.FdlGrpcServiceImplBase { + @Override + public void unaryMessage( + grpc_fdl.GrpcFdlRequest request, + StreamObserver responseObserver) { + sendOne(responseObserver, fdlResponse(request, "unary", 10)); + } + + @Override + public void serverStreamMessage( + grpc_fdl.GrpcFdlRequest request, + StreamObserver responseObserver) { + sendMany( + responseObserver, + Arrays.asList( + fdlResponse(request, "server-0", 0), + fdlResponse(request, "server-1", 1), + fdlResponse(request, "server-2", 2))); + } + + @Override + public StreamObserver clientStreamMessage( + StreamObserver responseObserver) { + return collectAndRespond(responseObserver, GrpcInteropTest::fdlAggregate); + } + + @Override + public StreamObserver bidiStreamMessage( + StreamObserver responseObserver) { + return new StreamObserver() { + private int index; + + @Override + public void onNext(grpc_fdl.GrpcFdlRequest value) { + responseObserver.onNext(fdlResponse(value, "bidi-" + index, index)); + index++; + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + @Override + public void unaryUnion( + grpc_fdl.GrpcFdlUnion request, StreamObserver responseObserver) { + sendOne(responseObserver, fdlUnionResponse(fdlRequestFromUnion(request), "unary", 10)); + } + + @Override + public void serverStreamUnion( + grpc_fdl.GrpcFdlUnion request, StreamObserver responseObserver) { + grpc_fdl.GrpcFdlRequest value = fdlRequestFromUnion(request); + sendMany( + responseObserver, + Arrays.asList( + fdlUnionResponse(value, "server-0", 0), + fdlUnionResponse(value, "server-1", 1), + fdlUnionResponse(value, "server-2", 2))); + } + + @Override + public StreamObserver clientStreamUnion( + StreamObserver responseObserver) { + return collectAndRespond(responseObserver, GrpcInteropTest::fdlUnionAggregate); + } + + @Override + public StreamObserver bidiStreamUnion( + StreamObserver responseObserver) { + return new StreamObserver() { + private int index; + + @Override + public void onNext(grpc_fdl.GrpcFdlUnion value) { + responseObserver.onNext( + fdlUnionResponse(fdlRequestFromUnion(value), "bidi-" + index, index)); + index++; + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + } + + private static final class FbsService extends grpc_fbs.FbsGrpcServiceGrpc.FbsGrpcServiceImplBase { + @Override + public void unaryMessage( + grpc_fbs.GrpcFbsRequest request, + StreamObserver responseObserver) { + sendOne(responseObserver, fbsResponse(request, "unary", 10)); + } + + @Override + public void serverStreamMessage( + grpc_fbs.GrpcFbsRequest request, + StreamObserver responseObserver) { + sendMany( + responseObserver, + Arrays.asList( + fbsResponse(request, "server-0", 0), + fbsResponse(request, "server-1", 1), + fbsResponse(request, "server-2", 2))); + } + + @Override + public StreamObserver clientStreamMessage( + StreamObserver responseObserver) { + return collectAndRespond(responseObserver, GrpcInteropTest::fbsAggregate); + } + + @Override + public StreamObserver bidiStreamMessage( + StreamObserver responseObserver) { + return new StreamObserver() { + private int index; + + @Override + public void onNext(grpc_fbs.GrpcFbsRequest value) { + responseObserver.onNext(fbsResponse(value, "bidi-" + index, index)); + index++; + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + + @Override + public void unaryUnion( + grpc_fbs.GrpcFbsUnion request, StreamObserver responseObserver) { + sendOne(responseObserver, fbsUnionResponse(fbsRequestFromUnion(request), "unary", 10)); + } + + @Override + public void serverStreamUnion( + grpc_fbs.GrpcFbsUnion request, StreamObserver responseObserver) { + grpc_fbs.GrpcFbsRequest value = fbsRequestFromUnion(request); + sendMany( + responseObserver, + Arrays.asList( + fbsUnionResponse(value, "server-0", 0), + fbsUnionResponse(value, "server-1", 1), + fbsUnionResponse(value, "server-2", 2))); + } + + @Override + public StreamObserver clientStreamUnion( + StreamObserver responseObserver) { + return collectAndRespond(responseObserver, GrpcInteropTest::fbsUnionAggregate); + } + + @Override + public StreamObserver bidiStreamUnion( + StreamObserver responseObserver) { + return new StreamObserver() { + private int index; + + @Override + public void onNext(grpc_fbs.GrpcFbsUnion value) { + responseObserver.onNext( + fbsUnionResponse(fbsRequestFromUnion(value), "bidi-" + index, index)); + index++; + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + } + + private static final class PbService extends grpc_pb.PbGrpcServiceGrpc.PbGrpcServiceImplBase { + @Override + public void unaryMessage( + grpc_pb.GrpcPbRequest request, StreamObserver responseObserver) { + sendOne(responseObserver, pbResponse(request, "unary", 10)); + } + + @Override + public void serverStreamMessage( + grpc_pb.GrpcPbRequest request, StreamObserver responseObserver) { + sendMany( + responseObserver, + Arrays.asList( + pbResponse(request, "server-0", 0), + pbResponse(request, "server-1", 1), + pbResponse(request, "server-2", 2))); + } + + @Override + public StreamObserver clientStreamMessage( + StreamObserver responseObserver) { + return collectAndRespond(responseObserver, GrpcInteropTest::pbAggregate); + } + + @Override + public StreamObserver bidiStreamMessage( + StreamObserver responseObserver) { + return new StreamObserver() { + private int index; + + @Override + public void onNext(grpc_pb.GrpcPbRequest value) { + responseObserver.onNext(pbResponse(value, "bidi-" + index, index)); + index++; + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + } + + private static final class CollectingObserver implements StreamObserver { + private final List values = new ArrayList<>(); + private final CountDownLatch done = new CountDownLatch(1); + private Throwable failure; + + @Override + public void onNext(T value) { + values.add(value); + } + + @Override + public void onError(Throwable t) { + failure = t; + done.countDown(); + } + + @Override + public void onCompleted() { + done.countDown(); + } + + private List await() throws InterruptedException { + if (!done.await(30, TimeUnit.SECONDS)) { + Assert.fail("Timed out waiting for gRPC responses"); + } + if (failure != null) { + Assert.fail("gRPC call failed", failure); + } + return values; + } + } + + private static final class PeerCommand { + private List command; + private Path workDir; + private final java.util.Map environment = new java.util.HashMap<>(); + } + + private static final class PeerOutputCollector extends Thread { + private final InputStream inputStream; + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + private IOException readFailure; + + private PeerOutputCollector(InputStream inputStream, String peer) { + super("idl-grpc-peer-output-" + peer); + setDaemon(true); + this.inputStream = inputStream; + } + + @Override + public void run() { + byte[] buffer = new byte[4096]; + int bytesRead; + try { + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + } catch (IOException e) { + readFailure = e; + } finally { + try { + inputStream.close(); + } catch (IOException ignored) { + } + } + } + + private String awaitOutput() throws IOException, InterruptedException { + join(); + if (readFailure != null) { + throw readFailure; + } + return new String(outputStream.toByteArray(), StandardCharsets.UTF_8); + } + } +} diff --git a/integration_tests/grpc_tests/python/grpc_tests/__init__.py b/integration_tests/grpc_tests/python/grpc_tests/__init__.py new file mode 100644 index 0000000000..13a83393a9 --- /dev/null +++ b/integration_tests/grpc_tests/python/grpc_tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/integration_tests/grpc_tests/python/grpc_tests/grpc_interop.py b/integration_tests/grpc_tests/python/grpc_tests/grpc_interop.py new file mode 100644 index 0000000000..e8a3b46947 --- /dev/null +++ b/integration_tests/grpc_tests/python/grpc_tests/grpc_interop.py @@ -0,0 +1,400 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import argparse +from concurrent import futures +from pathlib import Path +from typing import Iterable, Optional, Sequence + +import grpc +import grpc_fbs +import grpc_fbs_grpc +import grpc_fdl +import grpc_fdl_grpc +import grpc_pb +import grpc_pb_grpc + + +def _fdl_request(id_value: str, count: int, payload: str) -> grpc_fdl.GrpcFdlRequest: + return grpc_fdl.GrpcFdlRequest(id=id_value, count=count, payload=payload) + + +def _fdl_response( + request: grpc_fdl.GrpcFdlRequest, tag: str, offset: int +) -> grpc_fdl.GrpcFdlResponse: + return grpc_fdl.GrpcFdlResponse( + id=f"{tag}:{request.id}", + count=request.count + offset, + payload=f"{tag}:{request.payload}", + ) + + +def _fdl_aggregate( + requests: Sequence[grpc_fdl.GrpcFdlRequest], +) -> grpc_fdl.GrpcFdlResponse: + return grpc_fdl.GrpcFdlResponse( + id="client:" + "+".join(request.id for request in requests), + count=sum(request.count for request in requests), + payload="client:" + "+".join(request.payload for request in requests), + ) + + +def _fdl_union_request( + request: grpc_fdl.GrpcFdlRequest, +) -> grpc_fdl.GrpcFdlUnion: + return grpc_fdl.GrpcFdlUnion.request(request) + + +def _fdl_union_response( + request: grpc_fdl.GrpcFdlRequest, tag: str, offset: int +) -> grpc_fdl.GrpcFdlUnion: + return grpc_fdl.GrpcFdlUnion.response(_fdl_response(request, tag, offset)) + + +def _fdl_union_aggregate( + requests: Sequence[grpc_fdl.GrpcFdlRequest], +) -> grpc_fdl.GrpcFdlUnion: + return grpc_fdl.GrpcFdlUnion.response(_fdl_aggregate(requests)) + + +def _fdl_request_from_union( + union: grpc_fdl.GrpcFdlUnion, +) -> grpc_fdl.GrpcFdlRequest: + assert union.is_request() + return union.request_value() + + +def _fbs_request(id_value: str, count: int, payload: str) -> grpc_fbs.GrpcFbsRequest: + return grpc_fbs.GrpcFbsRequest(id=id_value, count=count, payload=payload) + + +def _fbs_response( + request: grpc_fbs.GrpcFbsRequest, tag: str, offset: int +) -> grpc_fbs.GrpcFbsResponse: + return grpc_fbs.GrpcFbsResponse( + id=f"{tag}:{request.id}", + count=request.count + offset, + payload=f"{tag}:{request.payload}", + ) + + +def _fbs_aggregate( + requests: Sequence[grpc_fbs.GrpcFbsRequest], +) -> grpc_fbs.GrpcFbsResponse: + return grpc_fbs.GrpcFbsResponse( + id="client:" + "+".join(request.id for request in requests), + count=sum(request.count for request in requests), + payload="client:" + "+".join(request.payload for request in requests), + ) + + +def _fbs_union_request( + request: grpc_fbs.GrpcFbsRequest, +) -> grpc_fbs.GrpcFbsUnion: + return grpc_fbs.GrpcFbsUnion.grpc_fbs_request(request) + + +def _fbs_union_response( + request: grpc_fbs.GrpcFbsRequest, tag: str, offset: int +) -> grpc_fbs.GrpcFbsUnion: + return grpc_fbs.GrpcFbsUnion.grpc_fbs_response(_fbs_response(request, tag, offset)) + + +def _fbs_union_aggregate( + requests: Sequence[grpc_fbs.GrpcFbsRequest], +) -> grpc_fbs.GrpcFbsUnion: + return grpc_fbs.GrpcFbsUnion.grpc_fbs_response(_fbs_aggregate(requests)) + + +def _fbs_request_from_union( + union: grpc_fbs.GrpcFbsUnion, +) -> grpc_fbs.GrpcFbsRequest: + assert union.is_grpc_fbs_request() + return union.grpc_fbs_request_value() + + +def _pb_payload_text(value: str) -> grpc_pb.GrpcPbRequest.Payload: + return grpc_pb.GrpcPbRequest.Payload.text(value) + + +def _pb_response_payload( + payload: Optional[grpc_pb.GrpcPbRequest.Payload], + tag: str, + offset: int, +) -> Optional[grpc_pb.GrpcPbResponse.Payload]: + if payload is None: + return None + if payload.is_text(): + return grpc_pb.GrpcPbResponse.Payload.text(f"{tag}:{payload.text_value()}") + assert payload.is_number() + return grpc_pb.GrpcPbResponse.Payload.number(payload.number_value() + offset) + + +def _pb_request( + id_value: str, count: int, payload: grpc_pb.GrpcPbRequest.Payload +) -> grpc_pb.GrpcPbRequest: + return grpc_pb.GrpcPbRequest(id=id_value, count=count, payload=payload) + + +def _pb_response( + request: grpc_pb.GrpcPbRequest, tag: str, offset: int +) -> grpc_pb.GrpcPbResponse: + return grpc_pb.GrpcPbResponse( + id=f"{tag}:{request.id}", + count=request.count + offset, + payload=_pb_response_payload(request.payload, tag, offset), + ) + + +def _pb_aggregate( + requests: Sequence[grpc_pb.GrpcPbRequest], +) -> grpc_pb.GrpcPbResponse: + return grpc_pb.GrpcPbResponse( + id="client:" + "+".join(request.id for request in requests), + count=sum(request.count for request in requests), + payload=grpc_pb.GrpcPbResponse.Payload.text( + "client:" + "+".join(request.id for request in requests) + ), + ) + + +class FdlService(grpc_fdl_grpc.FdlGrpcServiceServicer): + def unary_message(self, request, context): + return _fdl_response(request, "unary", 10) + + def server_stream_message(self, request, context): + for index in range(3): + yield _fdl_response(request, f"server-{index}", index) + + def client_stream_message(self, request_iterator, context): + return _fdl_aggregate(list(request_iterator)) + + def bidi_stream_message(self, request_iterator, context): + for index, request in enumerate(request_iterator): + yield _fdl_response(request, f"bidi-{index}", index) + + def unary_union(self, request, context): + return _fdl_union_response(_fdl_request_from_union(request), "unary", 10) + + def server_stream_union(self, request, context): + item = _fdl_request_from_union(request) + for index in range(3): + yield _fdl_union_response(item, f"server-{index}", index) + + def client_stream_union(self, request_iterator, context): + requests = [_fdl_request_from_union(item) for item in request_iterator] + return _fdl_union_aggregate(requests) + + def bidi_stream_union(self, request_iterator, context): + for index, item in enumerate(request_iterator): + yield _fdl_union_response( + _fdl_request_from_union(item), f"bidi-{index}", index + ) + + +class FbsService(grpc_fbs_grpc.FbsGrpcServiceServicer): + def unary_message(self, request, context): + return _fbs_response(request, "unary", 10) + + def server_stream_message(self, request, context): + for index in range(3): + yield _fbs_response(request, f"server-{index}", index) + + def client_stream_message(self, request_iterator, context): + return _fbs_aggregate(list(request_iterator)) + + def bidi_stream_message(self, request_iterator, context): + for index, request in enumerate(request_iterator): + yield _fbs_response(request, f"bidi-{index}", index) + + def unary_union(self, request, context): + return _fbs_union_response(_fbs_request_from_union(request), "unary", 10) + + def server_stream_union(self, request, context): + item = _fbs_request_from_union(request) + for index in range(3): + yield _fbs_union_response(item, f"server-{index}", index) + + def client_stream_union(self, request_iterator, context): + requests = [_fbs_request_from_union(item) for item in request_iterator] + return _fbs_union_aggregate(requests) + + def bidi_stream_union(self, request_iterator, context): + for index, item in enumerate(request_iterator): + yield _fbs_union_response( + _fbs_request_from_union(item), f"bidi-{index}", index + ) + + +class PbService(grpc_pb_grpc.PbGrpcServiceServicer): + def unary_message(self, request, context): + return _pb_response(request, "unary", 10) + + def server_stream_message(self, request, context): + for index in range(3): + yield _pb_response(request, f"server-{index}", index) + + def client_stream_message(self, request_iterator, context): + return _pb_aggregate(list(request_iterator)) + + def bidi_stream_message(self, request_iterator, context): + for index, request in enumerate(request_iterator): + yield _pb_response(request, f"bidi-{index}", index) + + +def _assert_iterable_equal( + actual: Iterable[object], expected: Sequence[object] +) -> None: + assert list(actual) == list(expected) + + +def _exercise_message_stub( + stub, + requests: Sequence[object], + response_fn, + aggregate_fn, +) -> None: + first = requests[0] + assert stub.unary_message(first) == response_fn(first, "unary", 10) + _assert_iterable_equal( + stub.server_stream_message(first), + [response_fn(first, f"server-{index}", index) for index in range(3)], + ) + assert stub.client_stream_message(iter(requests)) == aggregate_fn(requests) + _assert_iterable_equal( + stub.bidi_stream_message(iter(requests)), + [ + response_fn(request, f"bidi-{index}", index) + for index, request in enumerate(requests) + ], + ) + + +def _exercise_union_stub( + stub, + requests: Sequence[object], + union_request_fn, + union_response_fn, + union_aggregate_fn, +) -> None: + union_requests = [union_request_fn(request) for request in requests] + first = union_requests[0] + first_request = requests[0] + assert stub.unary_union(first) == union_response_fn(first_request, "unary", 10) + _assert_iterable_equal( + stub.server_stream_union(first), + [ + union_response_fn(first_request, f"server-{index}", index) + for index in range(3) + ], + ) + assert stub.client_stream_union(iter(union_requests)) == union_aggregate_fn( + requests + ) + _assert_iterable_equal( + stub.bidi_stream_union(iter(union_requests)), + [ + union_response_fn(request, f"bidi-{index}", index) + for index, request in enumerate(requests) + ], + ) + + +def run_client(target: str) -> None: + with grpc.insecure_channel(target) as channel: + _exercise_message_stub( + grpc_fdl_grpc.FdlGrpcServiceStub(channel), + [ + _fdl_request("fdl-a", 1, "alpha"), + _fdl_request("fdl-b", 2, "beta"), + ], + _fdl_response, + _fdl_aggregate, + ) + _exercise_union_stub( + grpc_fdl_grpc.FdlGrpcServiceStub(channel), + [ + _fdl_request("fdl-u-a", 3, "union-alpha"), + _fdl_request("fdl-u-b", 4, "union-beta"), + ], + _fdl_union_request, + _fdl_union_response, + _fdl_union_aggregate, + ) + + _exercise_message_stub( + grpc_fbs_grpc.FbsGrpcServiceStub(channel), + [ + _fbs_request("fbs-a", 5, "alpha"), + _fbs_request("fbs-b", 6, "beta"), + ], + _fbs_response, + _fbs_aggregate, + ) + _exercise_union_stub( + grpc_fbs_grpc.FbsGrpcServiceStub(channel), + [ + _fbs_request("fbs-u-a", 7, "union-alpha"), + _fbs_request("fbs-u-b", 8, "union-beta"), + ], + _fbs_union_request, + _fbs_union_response, + _fbs_union_aggregate, + ) + + _exercise_message_stub( + grpc_pb_grpc.PbGrpcServiceStub(channel), + [ + _pb_request("pb-a", 9, _pb_payload_text("alpha")), + _pb_request("pb-b", 10, grpc_pb.GrpcPbRequest.Payload.number(42)), + ], + _pb_response, + _pb_aggregate, + ) + + +def run_server(port_file: Path) -> None: + server = grpc.server(futures.ThreadPoolExecutor(max_workers=8)) + grpc_fdl_grpc.add_servicer(FdlService(), server) + grpc_fbs_grpc.add_servicer(FbsService(), server) + grpc_pb_grpc.add_servicer(PbService(), server) + port = server.add_insecure_port("127.0.0.1:0") + server.start() + port_file.write_text(str(port)) + server.wait_for_termination() + + +def main() -> int: + parser = argparse.ArgumentParser(description="Java/Python Fory gRPC interop peer") + subparsers = parser.add_subparsers(dest="command", required=True) + client_parser = subparsers.add_parser("client") + client_parser.add_argument("--target", required=True) + server_parser = subparsers.add_parser("server") + server_parser.add_argument("--port-file", type=Path, required=True) + args = parser.parse_args() + + if args.command == "client": + run_client(args.target) + else: + run_server(args.port_file) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/integration_tests/grpc_tests/python/pyproject.toml b/integration_tests/grpc_tests/python/pyproject.toml new file mode 100644 index 0000000000..8803016e72 --- /dev/null +++ b/integration_tests/grpc_tests/python/pyproject.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "fory-grpc-tests" +version = "0.18.0.dev0" +description = "gRPC compiler integration tests for Apache Fory" +requires-python = ">=3.8" +license = {text = "Apache-2.0"} +dependencies = ["grpcio>=1.62.2,<1.71", "pyfory"] + +[tool.setuptools.packages.find] +where = ["."] +include = ["grpc_tests"] diff --git a/integration_tests/grpc_tests/run_tests.sh b/integration_tests/grpc_tests/run_tests.sh new file mode 100755 index 0000000000..7cb2a4ad61 --- /dev/null +++ b/integration_tests/grpc_tests/run_tests.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +python -m pip install "grpcio>=1.62.2,<1.71" +python -m pip install -v -e "${ROOT_DIR}/python" + +python "${SCRIPT_DIR}/generate_grpc.py" + +cd "${ROOT_DIR}/integration_tests/grpc_tests/java" +ENABLE_FORY_DEBUG_OUTPUT=1 mvn -T16 --no-transfer-progress \ + -Dtest=GrpcInteropTest \ + test