Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 15
"modification": 16
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run",
"modification": 16
"modification": 17
}
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/coders/typecoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _normalize_typehint_type(typehint_type):
def register_coder(
self, typehint_type: Any,
typehint_coder_class: Type[coders.Coder]) -> None:
"Register a user type with a coder"
if not isinstance(typehint_coder_class, type):
raise TypeError(
'Coder registration requires a coder class object. '
Expand All @@ -133,6 +134,21 @@ def register_coder(
self._register_coder_internal(
self._normalize_typehint_type(typehint_type), typehint_coder_class)

def register_row(self, typehint_type: Any) -> None:
"""
Register a user type with a Beam Row.

This registers the type with a RowCoder and register its schema.
"""
from apache_beam.coders import RowCoder
from apache_beam.typehints.schemas import typing_to_runner_api

# Register with row coder
self.register_coder(typehint_type, RowCoder)
# This call generated a schema id for the type and register it with
# schema registry
typing_to_runner_api(typehint_type)

def get_coder(self, typehint: Any) -> coders.Coder:
if typehint and typehint.__module__ == '__main__':
# See https://github.com/apache/beam/issues/21541
Expand Down
13 changes: 12 additions & 1 deletion sdks/python/apache_beam/internal/cloudpickle_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,27 @@ def dump_session(file_path):
# dump supported Beam Registries (currently only logical type registry)
from apache_beam.coders import typecoders
from apache_beam.typehints import schemas
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY

with _pickle_lock, open(file_path, 'wb') as file:
coder_reg = typecoders.registry.get_custom_type_coder_tuples()
logical_type_reg = schemas.LogicalType._known_logical_types.copy_custom()
schema_reg = SCHEMA_REGISTRY.get_registered_typings()

pickler = cloudpickle.CloudPickler(file)
# TODO(https://github.com/apache/beam/issues/18500) add file system registry
# once implemented
pickler.dump({"coder": coder_reg, "logical_type": logical_type_reg})
pickler.dump({
"coder": coder_reg,
"logical_type": logical_type_reg,
"schema": schema_reg
})


def load_session(file_path):
from apache_beam.coders import typecoders
from apache_beam.typehints import schemas
from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY

with _pickle_lock, open(file_path, 'rb') as file:
registries = cloudpickle.load(file)
Expand All @@ -284,3 +291,7 @@ def load_session(file_path):
schemas.LogicalType._known_logical_types.load(registries["logical_type"])
else:
_LOGGER.warning('No logical type registry found in saved session')
if "schema" in registries:
SCHEMA_REGISTRY.load_registered_typings(registries["schema"])
else:
_LOGGER.warning('No schema registry found in saved session')
6 changes: 3 additions & 3 deletions sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
("f_timestamp", Timestamp), ("f_decimal", Decimal),
("f_date", datetime.date), ("f_time", datetime.time)],
)
coders.registry.register_coder(JdbcTestRow, coders.RowCoder)
coders.registry.register_row(JdbcTestRow)

CustomSchemaRow = typing.NamedTuple(
"CustomSchemaRow",
Expand All @@ -82,11 +82,11 @@
("renamed_time", datetime.time),
],
)
coders.registry.register_coder(CustomSchemaRow, coders.RowCoder)
coders.registry.register_row(CustomSchemaRow)

SimpleRow = typing.NamedTuple(
"SimpleRow", [("id", int), ("name", str), ("value", float)])
coders.registry.register_coder(SimpleRow, coders.RowCoder)
coders.registry.register_row(SimpleRow)


@pytest.mark.uses_gcp_java_expansion_service
Expand Down
18 changes: 8 additions & 10 deletions sdks/python/apache_beam/typehints/row_type_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@ class RowTypeTest(unittest.TestCase):
@staticmethod
def _check_key_type_and_count(x) -> int:
key_type = type(x[0])
if not row_type._user_type_is_generated(key_type):
raise RuntimeError("Expect type after GBK to be generated user type")
if row_type._user_type_is_generated(key_type):
raise RuntimeError("Type after GBK not preserved, get generated type")
if not hasattr(key_type, row_type._BEAM_SCHEMA_ID):
raise RuntimeError("Type after GBK missing Beam schema ID")

return len(x[1])

def test_group_by_key_namedtuple(self):
MyNamedTuple = typing.NamedTuple(
"MyNamedTuple", [("id", int), ("name", str)])

beam.coders.typecoders.registry.register_coder(
MyNamedTuple, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_row(MyNamedTuple)

def generate(num: int):
for i in range(100):
Expand All @@ -67,8 +68,7 @@ class MyDataClass:
id: int
name: str

beam.coders.typecoders.registry.register_coder(
MyDataClass, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_row(MyDataClass)

def generate(num: int):
for i in range(100):
Expand Down Expand Up @@ -120,10 +120,8 @@ class DataClassInt:
class DataClassStr(DataClassInt):
name: str

beam.coders.typecoders.registry.register_coder(
DataClassInt, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_coder(
DataClassStr, beam.coders.RowCoder)
beam.coders.typecoders.registry.register_row(DataClassInt)
beam.coders.typecoders.registry.register_row(DataClassStr)

def generate(num: int):
for i in range(10):
Expand Down
11 changes: 10 additions & 1 deletion sdks/python/apache_beam/typehints/schema_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
class SchemaTypeRegistry(object):
def __init__(self):
self.by_id = {}
self.by_typing = {}
self.by_typing = {} # currently not used

def generate_new_id(self):
for _ in range(100):
Expand All @@ -43,6 +43,15 @@ def add(self, typing, schema):
if schema.id:
self.by_id[schema.id] = (typing, schema)

def load_registered_typings(self, by_id):
for id, typing in by_id.items():
if id not in self.by_id:
self.by_id[id] = (typing, None)

def get_registered_typings(self):
# Used by save_main_session, as pb2.schema isn't picklable
return {k: v[0] for k, v in self.by_id.items()}

def get_typing_by_id(self, unique_id):
if not unique_id:
return None
Expand Down
29 changes: 16 additions & 13 deletions sdks/python/apache_beam/typehints/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,19 +587,22 @@ def named_tuple_from_schema(self, schema: schema_pb2.Schema) -> type:
descriptions[field.name] = field.description
subfields.append((field.name, field_py_type))

user_type = NamedTuple(type_name, subfields)

# Define a reduce function, otherwise these types can't be pickled
# (See BEAM-9574)
setattr(
user_type,
'__reduce__',
_named_tuple_reduce_method(schema.SerializeToString()))
setattr(user_type, "_field_descriptions", descriptions)
setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id)

self.schema_registry.add(user_type, schema)
coders.registry.register_coder(user_type, coders.RowCoder)
if schema.id in self.schema_registry.by_id:
user_type = self.schema_registry.by_id[schema.id][0]
else:
user_type = NamedTuple(type_name, subfields)

# Define a reduce function, otherwise these types can't be pickled
# (See BEAM-9574)
setattr(
user_type,
'__reduce__',
_named_tuple_reduce_method(schema.SerializeToString()))
setattr(user_type, "_field_descriptions", descriptions)
setattr(user_type, row_type._BEAM_SCHEMA_ID, schema.id)

self.schema_registry.add(user_type, schema)
coders.registry.register_coder(user_type, coders.RowCoder)

return user_type

Expand Down
Loading