diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json index bb5da04014ec..83346d34aee0 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 15 + "modification": 16 } diff --git a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json index 83346d34aee0..c5309eebb070 100644 --- a/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json +++ b/.github/trigger_files/beam_PostCommit_Python_Xlang_Gcp_Direct.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 16 + "modification": 17 } diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 9683e00f0c2a..76bab83e4c8e 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -124,6 +124,7 @@ def _normalize_typehint_type(typehint_type): def register_coder( self, typehint_type: Any, typehint_coder_class: Type[coders.Coder]) -> None: + "Register a user type with a coder" if not isinstance(typehint_coder_class, type): raise TypeError( 'Coder registration requires a coder class object. ' @@ -133,6 +134,21 @@ def register_coder( self._register_coder_internal( self._normalize_typehint_type(typehint_type), typehint_coder_class) + def register_row(self, typehint_type: Any) -> None: + """ + Register a user type with a Beam Row. + + This registers the type with a RowCoder and register its schema. + """ + from apache_beam.coders import RowCoder + from apache_beam.typehints.schemas import typing_to_runner_api + + # Register with row coder + self.register_coder(typehint_type, RowCoder) + # This call generated a schema id for the type and register it with + # schema registry + typing_to_runner_api(typehint_type) + def get_coder(self, typehint: Any) -> coders.Coder: if typehint and typehint.__module__ == '__main__': # See https://github.com/apache/beam/issues/21541 diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index acdcc46cd40d..cea4f01f803c 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -256,20 +256,27 @@ def dump_session(file_path): # dump supported Beam Registries (currently only logical type registry) from apache_beam.coders import typecoders from apache_beam.typehints import schemas + from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY with _pickle_lock, open(file_path, 'wb') as file: coder_reg = typecoders.registry.get_custom_type_coder_tuples() logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom() + schema_reg = SCHEMA_REGISTRY.get_registered_typings() pickler = cloudpickle.CloudPickler(file) # TODO(https://github.com/apache/beam/issues/18500) add file system registry # once implemented - pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg}) + pickler.dump({ + "coder": coder_reg, + "logical_type": logical_type_reg, + "schema": schema_reg + }) def load_session(file_path): from apache_beam.coders import typecoders from apache_beam.typehints import schemas + from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY with _pickle_lock, open(file_path, 'rb') as file: registries = cloudpickle.load(file) @@ -284,3 +291,7 @@ def load_session(file_path): schemas.LogicalType._known_logical_types.load(registries["logical_type"]) else: _LOGGER.warning('No logical type registry found in saved session') + if "schema" in registries: + SCHEMA_REGISTRY.load_registered_typings(registries["schema"]) + else: + _LOGGER.warning('No schema registry found in saved session') diff --git a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py index 069f13e11bfb..0e967f1beec3 100644 --- a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py @@ -30,6 +30,7 @@ import apache_beam as beam from apache_beam import coders +from apache_beam.io import jdbc from apache_beam.io.jdbc import ReadFromJdbc from apache_beam.io.jdbc import WriteToJdbc from apache_beam.options.pipeline_options import StandardOptions @@ -64,7 +65,7 @@ ("f_timestamp", Timestamp), ("f_decimal", Decimal), ("f_date", datetime.date), ("f_time", datetime.time)], ) -coders.registry.register_coder(JdbcTestRow, coders.RowCoder) +coders.registry.register_row(JdbcTestRow) CustomSchemaRow = typing.NamedTuple( "CustomSchemaRow", @@ -82,11 +83,17 @@ ("renamed_time", datetime.time), ], ) -coders.registry.register_coder(CustomSchemaRow, coders.RowCoder) + +# Need to put inside enforce_millis_instant_for_timestamp context to align +# with the same setup in ReadFromJdbc.__init__. Remove once Beam moved to +# micros instant for timestamp +# Alternatively, use coders.registry.register_coder(CustomSchemaRow, RowCoder) +with jdbc.enforce_millis_instant_for_timestamp(): + coders.registry.register_row(CustomSchemaRow) SimpleRow = typing.NamedTuple( "SimpleRow", [("id", int), ("name", str), ("value", float)]) -coders.registry.register_coder(SimpleRow, coders.RowCoder) +coders.registry.register_row(SimpleRow) @pytest.mark.uses_gcp_java_expansion_service diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py index 97012d9561d7..54e64caf6fa7 100644 --- a/sdks/python/apache_beam/typehints/row_type_test.py +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -33,8 +33,10 @@ class RowTypeTest(unittest.TestCase): @staticmethod def _check_key_type_and_count(x) -> int: key_type = type(x[0]) - if not row_type._user_type_is_generated(key_type): - raise RuntimeError("Expect type after GBK to be generated user type") + if row_type._user_type_is_generated(key_type): + raise RuntimeError("Type after GBK not preserved, get generated type") + if not hasattr(key_type, row_type._BEAM_SCHEMA_ID): + raise RuntimeError("Type after GBK missing Beam schema ID") return len(x[1]) @@ -42,8 +44,7 @@ def test_group_by_key_namedtuple(self): MyNamedTuple = typing.NamedTuple( "MyNamedTuple", [("id", int), ("name", str)]) - beam.coders.typecoders.registry.register_coder( - MyNamedTuple, beam.coders.RowCoder) + beam.coders.typecoders.registry.register_row(MyNamedTuple) def generate(num: int): for i in range(100): @@ -67,8 +68,7 @@ class MyDataClass: id: int name: str - beam.coders.typecoders.registry.register_coder( - MyDataClass, beam.coders.RowCoder) + beam.coders.typecoders.registry.register_row(MyDataClass) def generate(num: int): for i in range(100): @@ -120,10 +120,8 @@ class DataClassInt: class DataClassStr(DataClassInt): name: str - beam.coders.typecoders.registry.register_coder( - DataClassInt, beam.coders.RowCoder) - beam.coders.typecoders.registry.register_coder( - DataClassStr, beam.coders.RowCoder) + beam.coders.typecoders.registry.register_row(DataClassInt) + beam.coders.typecoders.registry.register_row(DataClassStr) def generate(num: int): for i in range(10): diff --git a/sdks/python/apache_beam/typehints/schema_registry.py b/sdks/python/apache_beam/typehints/schema_registry.py index 7d8cdcf57d3f..684bf8734a5f 100644 --- a/sdks/python/apache_beam/typehints/schema_registry.py +++ b/sdks/python/apache_beam/typehints/schema_registry.py @@ -26,7 +26,7 @@ class SchemaTypeRegistry(object): def __init__(self): self.by_id = {} - self.by_typing = {} + self.by_typing = {} # currently not used def generate_new_id(self): for _ in range(100): @@ -43,6 +43,15 @@ def add(self, typing, schema): if schema.id: self.by_id[schema.id] = (typing, schema) + def load_registered_typings(self, by_id): + for id, typing in by_id.items(): + if id not in self.by_id: + self.by_id[id] = (typing, None) + + def get_registered_typings(self): + # Used by save_main_session, as pb2.schema isn't picklable + return {k: v[0] for k, v in self.by_id.items()} + def get_typing_by_id(self, unique_id): if not unique_id: return None diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index d2c4db8cabca..108e75aac9c0 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -587,19 +587,22 @@ def named_tuple_from_schema(self, schema: schema_pb2.Schema) -> type: descriptions[field.name] = field.description subfields.append((field.name, field_py_type)) - user_type = NamedTuple(type_name, subfields) - - # Define a reduce function, otherwise these types can't be pickled - # (See BEAM-9574) - setattr( - user_type, - '__reduce__', - _named_tuple_reduce_method(schema.SerializeToString())) - setattr(user_type, "_field_descriptions", descriptions) - setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id) - - self.schema_registry.add(user_type, schema) - coders.registry.register_coder(user_type, coders.RowCoder) + if schema.id in self.schema_registry.by_id: + user_type = self.schema_registry.by_id[schema.id][0] + else: + user_type = NamedTuple(type_name, subfields) + + # Define a reduce function, otherwise these types can't be pickled + # (See BEAM-9574) + setattr( + user_type, + '__reduce__', + _named_tuple_reduce_method(schema.SerializeToString())) + setattr(user_type, "_field_descriptions", descriptions) + setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id) + + self.schema_registry.add(user_type, schema) + coders.registry.register_coder(user_type, coders.RowCoder) return user_type