From f841ef82816b5b698eb662d598241bdd90b9d250 Mon Sep 17 00:00:00 2001 From: apanich Date: Thu, 9 Apr 2026 17:32:34 -0400 Subject: [PATCH 1/4] Update core.py hints types follow up on go/beam-commit/b090e22. make the hits more typed --- sdks/python/apache_beam/transforms/core.py | 35 ++++++++++++---------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b5c3178210d9..b151d6c56269 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1688,7 +1688,8 @@ def with_exception_handling( on_failure_callback, allow_unsafe_userstate_in_process, self.get_resource_hints(), - self.get_type_hints()) + self.get_type_hints(), + ) def with_error_handler(self, error_handler, **exception_handling_kwargs): """An alias for `with_exception_handling(error_handler=error_handler, ...)` @@ -1984,7 +1985,8 @@ def with_exception_handling(self, main_tag=None, **kwargs): if main_tag is None: main_tag = self._main_tag or 'good' named = self._do_transform.with_exception_handling( - main_tag=main_tag, **kwargs) + main_tag=main_tag, **kwargs + ) # named is _NamedPTransform wrapping _ExceptionHandlingWrapper named.transform._extra_tags = self._tags return named @@ -2331,7 +2333,8 @@ def __init__( on_failure_callback, allow_unsafe_userstate_in_process, resource_hints, - pardo_type_hints=None): + pardo_type_hints=None, + ): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') self._fn = fn @@ -2431,24 +2434,25 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): return result def expand_2_72_0(self, pcoll): - """Pre-2.73.0 behavior: manual element_type override, no with_output_types. - """ + """Pre-2.73.0 behavior: manual element_type override, no with_output_types.""" pardo = self._build_pardo(pcoll) result = pcoll | pardo.with_outputs( - self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) - #TODO(BEAM-18957): Fix when type inference supports tagged outputs. + self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True + ) + # TODO(BEAM-18957): Fix when type inference supports tagged outputs. result[self._main_tag].element_type = self._fn.infer_output_type( - pcoll.element_type) + pcoll.element_type + ) return self._post_process_result(pcoll, result) def expand(self, pcoll): - if pcoll.pipeline.options.is_compat_version_prior_to("2.73.0"): + if pcoll.pipeline.options.is_compat_version_prior_to('2.73.0'): return self.expand_2_72_0(pcoll) pardo = self._build_pardo(pcoll) - if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()): + if self._pardo_type_hints and self._pardo_type_hints._has_output_types(): main_output_type = self._pardo_type_hints.simple_output_type(self.label) tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types()) else: @@ -2456,17 +2460,18 @@ def expand(self, pcoll): tagged_type_hints = dict(self._fn.get_type_hints().tagged_output_types()) # Dead letter format: Tuple[element, Tuple[exception_type, repr, traceback]] - dead_letter_type = typehints.Tuple[pcoll.element_type, - typehints.Tuple[type, - str, - typehints.List[str]]] + dead_letter_type = typehints.Tuple[ + pcoll.element_type, + typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]], + ] tagged_type_hints[self._dead_letter_tag] = dead_letter_type pardo = pardo.with_output_types(main_output_type, **tagged_type_hints) all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag}) result = pcoll | pardo.with_outputs( - *all_tags, main=self._main_tag, allow_unknown_tags=True) + *all_tags, main=self._main_tag, allow_unknown_tags=True + ) return self._post_process_result(pcoll, result) From eec00fe0e21b8c88bded874feaa9807184205baa Mon Sep 17 00:00:00 2001 From: apanich Date: Thu, 9 Apr 2026 17:36:03 -0400 Subject: [PATCH 2/4] Update core_test.py --- .../apache_beam/transforms/core_test.py | 120 ++++++++++++------ 1 file changed, 80 insertions(+), 40 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index d80a03bdf53b..8814c7ac1e95 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -21,7 +21,11 @@ import logging import os import tempfile -import unittest +import typing +from google3.testing.pybase import googletest +googletest.ThisTestIsUsefulWithoutCallingMain() +unittest = googletest + from typing import Iterable from typing import Literal from typing import TypeVar @@ -520,14 +524,19 @@ class TagHint(ResourceHint): class ExceptionHandlingWithOutputsTest(unittest.TestCase): """Tests for combining with_exception_handling() and with_outputs().""" + def _create_dofn_with_tagged_outputs(self): """A DoFn that yields tagged outputs and can raise on even numbers.""" + class DoWithFailures(beam.DoFn): + def process( self, element: int - ) -> Iterable[int - | beam.pvalue.TaggedOutput[Literal['threes'], int] - | beam.pvalue.TaggedOutput[Literal['fives'], str]]: + ) -> Iterable[ + int + | beam.pvalue.TaggedOutput[Literal['threes'], int] + | beam.pvalue.TaggedOutput[Literal['fives'], str] + ]: if element % 2 == 0: raise ValueError(f'Even numbers not allowed {element}') if element % 3 == 0: @@ -546,9 +555,10 @@ def test_with_exception_handling_then_with_outputs(self): results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()). - with_exception_handling().with_outputs( - 'threes', 'fives', main='main')) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling() + .with_outputs('threes', 'fives', main='main') + ) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') @@ -561,7 +571,8 @@ def test_with_exception_handling_then_with_outputs(self): self.assertEqual(results.fives.element_type, str) self.assertEqual( results.bad.element_type, - typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], + ) def test_with_outputs_then_with_exception_handling(self): """Direction 2: .with_outputs().with_exception_handling()""" @@ -570,8 +581,10 @@ def test_with_outputs_then_with_exception_handling(self): results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( - 'threes', 'fives', main='main').with_exception_handling()) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_outputs('threes', 'fives', main='main') + .with_exception_handling() + ) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') @@ -584,19 +597,22 @@ def test_with_outputs_then_with_exception_handling(self): self.assertEqual(results.fives.element_type, str) self.assertEqual( results.bad.element_type, - typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], + ) def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( - self): + self, + ): """Direction 2 with custom dead_letter_tag.""" with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( - 'threes', - main='main').with_exception_handling(dead_letter_tag='errors')) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_outputs('threes', main='main') + .with_exception_handling(dead_letter_tag='errors') + ) assert_that(results.main, equal_to([1]), 'main') assert_that(results.threes, equal_to([3]), 'threes') @@ -605,19 +621,22 @@ def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( self.assertEqual(results.threes.element_type, int) self.assertEqual( results.errors.element_type, - typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], + ) def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( - self): + self, + ): """Direction 1 with custom dead_letter_tag.""" with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo( - self._create_dofn_with_tagged_outputs()).with_exception_handling( - dead_letter_tag='errors').with_outputs('threes', main='main')) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling(dead_letter_tag='errors') + .with_outputs('threes', main='main') + ) assert_that(results.main, equal_to([1]), 'main') assert_that(results.threes, equal_to([3]), 'threes') @@ -626,7 +645,8 @@ def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( self.assertEqual(results.threes.element_type, int) self.assertEqual( results.errors.element_type, - typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], + ) def test_exception_handling_no_with_outputs_backward_compat(self): """Without with_outputs(), behavior is unchanged.""" @@ -635,8 +655,10 @@ def test_exception_handling_no_with_outputs_backward_compat(self): good, bad = ( p | beam.Create([1, 2, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_exception_handling()) + | beam.ParDo( + self._create_dofn_with_tagged_outputs() + ).with_exception_handling() + ) assert_that(good, equal_to([1, 7]), 'good') bad_elements = bad | beam.Keys() @@ -644,13 +666,15 @@ def test_exception_handling_no_with_outputs_backward_compat(self): def test_exception_handling_compat_version_uses_old_behavior(self): """With compat version < 2.73.0, old expand path is used.""" - options = PipelineOptions(update_compatibility_version="2.72.0") + options = PipelineOptions(update_compatibility_version='2.72.0') with beam.Pipeline(options=options) as p: good, bad = ( p | beam.Create([1, 2, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_exception_handling()) + | beam.ParDo( + self._create_dofn_with_tagged_outputs() + ).with_exception_handling() + ) assert_that(good, equal_to([1, 7]), 'good') bad_elements = bad | beam.Keys() @@ -658,15 +682,19 @@ def test_exception_handling_compat_version_uses_old_behavior(self): def test_exception_handling_compat_version_element_type_set_manually(self): """With compat version < 2.73.0, element_type is set via manual override - (the old behavior) rather than via with_output_types.""" - options = PipelineOptions(update_compatibility_version="2.72.0") + (the old behavior) rather than via with_output_types. + """ + + options = PipelineOptions(update_compatibility_version='2.72.0') with beam.Pipeline(options=options) as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()). - with_exception_handling().with_outputs('threes', main='main')) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling() + .with_outputs('threes', main='main') + ) # In old path, dead letter type is Any (no with_output_types call) self.assertEqual(results.bad.element_type, typehints.Any) @@ -682,17 +710,23 @@ def test_with_outputs_then_exception_handling_with_map(self): results = ( p | beam.Create([1, 2, 3, 4, 5]) - | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs( - main='main').with_exception_handling()) + | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0) + .with_outputs(main='main') + .with_exception_handling() + ) assert_that(results.main, equal_to([1, 3, 5]), 'main') bad_elements = results.bad | beam.Keys() assert_that(bad_elements, equal_to([2, 4]), 'bad') def test_with_output_types_chained_on_pardo(self): """When type hints are chained on the ParDo (not annotations on the DoFn), + tagged output types should still be propagated through - with_exception_handling().with_outputs().""" + with_exception_handling().with_outputs(). + """ + class DoWithFailuresNoAnnotations(beam.DoFn): + def process(self, element): if element % 2 == 0: raise ValueError(f'Even numbers not allowed {element}') @@ -705,9 +739,11 @@ def process(self, element): results = ( p | beam.Create([1, 2, 3, 7]) - | beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types( - int, threes=int).with_exception_handling().with_outputs( - 'threes', main='main')) + | beam.ParDo(DoWithFailuresNoAnnotations()) + .with_output_types(int, threes=int) + .with_exception_handling() + .with_outputs('threes', main='main') + ) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') @@ -718,16 +754,20 @@ def process(self, element): def test_with_outputs_and_error_handler(self): """with_outputs() + error_handler should return DoOutputsTuple, not a - bare PCollection.""" + + bare PCollection. + """ from apache_beam.transforms.error_handling import ErrorHandler + with beam.Pipeline() as p: with ErrorHandler(beam.Map(lambda x: x)) as handler: results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( - 'threes', 'fives', - main='main').with_exception_handling(error_handler=handler)) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_outputs('threes', 'fives', main='main') + .with_exception_handling(error_handler=handler) + ) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') From e9077445a40867e24b458a9d35f2134caa5ac7ce Mon Sep 17 00:00:00 2001 From: apanich Date: Fri, 10 Apr 2026 14:27:50 -0400 Subject: [PATCH 3/4] Update core_test.py --- .../apache_beam/transforms/core_test.py | 273 +++--------------- 1 file changed, 41 insertions(+), 232 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 8814c7ac1e95..357971b9c0b0 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -1,37 +1,13 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - """Unit tests for the core python file.""" -# pytype: skip-file - import logging import os import tempfile +import unittest import typing -from google3.testing.pybase import googletest -googletest.ThisTestIsUsefulWithoutCallingMain() -unittest = googletest - from typing import Iterable from typing import Literal from typing import TypeVar - import pytest - import apache_beam as beam from apache_beam.coders import coders from apache_beam.options.pipeline_options import PipelineOptions @@ -46,67 +22,47 @@ from apache_beam.typehints import TypeCheckError from apache_beam.typehints import row_type from apache_beam.typehints import typehints - RETURN_NONE_PARTIAL_WARNING = "Process method returned None" - - class TestDoFn0(beam.DoFn): """Returning without a value is allowed""" def process(self, element): if not element: return yield element - - class TestDoFn1(beam.DoFn): def process(self, element): yield element - - class TestDoFn2(beam.DoFn): def process(self, element): def inner_func(x): yield x - return inner_func(element) - - class TestDoFn3(beam.DoFn): """mixing return and yield is not allowed""" def process(self, element): if not element: return -1 yield element - - class TestDoFn4(beam.DoFn): """test the variable name containing return""" def process(self, element): my_return = element yield my_return - - class TestDoFn5(beam.DoFn): """test the variable name containing yield""" def process(self, element): my_yield = element return my_yield - - class TestDoFn6(beam.DoFn): """test the variable name containing return""" def process(self, element): return_test = element yield return_test - - class TestDoFn7(beam.DoFn): """test the variable name containing yield""" def process(self, element): yield_test = element return yield_test - - class TestDoFn8(beam.DoFn): """test the code containing yield and yield from""" def process(self, element): @@ -114,35 +70,25 @@ def process(self, element): yield from [1, 2, 3] else: yield element - - class TestDoFn9(beam.DoFn): def process(self, element): if len(element) > 3: raise ValueError('Not allowed to have long elements') yield element - - class TestDoFn10(beam.DoFn): """test process returning None explicitly""" def process(self, element): return None - - class TestDoFn11(beam.DoFn): """test process returning None (no return and no yield)""" def process(self, element): pass - - class TestDoFn12(beam.DoFn): """test process returning None in a filter pattern""" def process(self, element): if element == 0: return return element - - class TestDoFnStateful(beam.DoFn): STATE_SPEC = ReadModifyWriteStateSpec('num_elements', coders.VarIntCoder()) """test process with a stateful dofn""" @@ -152,8 +98,6 @@ def process(self, element, state=beam.DoFn.StateParam(STATE_SPEC)): current_value = state.read() or 1 state.write(current_value + 1) yield current_value - - class TestDoFnWithTimer(beam.DoFn): ALL_ELEMENTS = BagStateSpec('buffer', coders.VarIntCoder()) TIMER = TimerSpec('timer', beam.TimeDomain.WATERMARK) @@ -168,25 +112,18 @@ def process( raise ValueError('Not allowed to have large numbers') state.add(element[1]) timer.set(t) - return [] - @on_timer(TIMER) def expiry_callback(self, state=beam.DoFn.StateParam(ALL_ELEMENTS)): unique_elements = list(state.read()) state.clear() - return unique_elements - - class CreateTest(unittest.TestCase): @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog - def test_dofn_with_yield_and_return(self): warning_text = 'Using yield and return' - with self._caplog.at_level(logging.WARNING): assert beam.ParDo(sum) assert beam.ParDo(TestDoFn0()) @@ -198,28 +135,22 @@ def test_dofn_with_yield_and_return(self): assert beam.ParDo(TestDoFn7()) assert beam.ParDo(TestDoFn8()) assert warning_text not in self._caplog.text - with self._caplog.at_level(logging.WARNING): beam.ParDo(TestDoFn3()) assert warning_text in self._caplog.text - def test_dofn_with_explicit_return_none(self): with self._caplog.at_level(logging.WARNING): beam.ParDo(TestDoFn10()) assert RETURN_NONE_PARTIAL_WARNING in self._caplog.text assert str(TestDoFn10) in self._caplog.text - def test_dofn_with_implicit_return_none_missing_return_and_yield(self): with self._caplog.at_level(logging.WARNING): beam.ParDo(TestDoFn11()) assert RETURN_NONE_PARTIAL_WARNING not in self._caplog.text - def test_dofn_with_implicit_return_none_and_value(self): with self._caplog.at_level(logging.WARNING): beam.ParDo(TestDoFn12()) assert RETURN_NONE_PARTIAL_WARNING not in self._caplog.text - - class PartitionTest(unittest.TestCase): def test_partition_with_bools(self): with pytest.raises( @@ -234,89 +165,69 @@ def test_partition_with_bools(self): _ = ( p | beam.Create([input_value]) | beam.Partition(lambda x, _: x, 2)) - def test_partition_with_numpy_integers(self): # Test that numpy integer types are correctly accepted by the # ApplyPartitionFnFn class import numpy as np - # Create an instance of the ApplyPartitionFnFn class apply_partition_fn = beam.Partition.ApplyPartitionFnFn() - # Define a simple partition function class SimplePartitionFn(beam.PartitionFn): def partition_for(self, element, num_partitions): return element % num_partitions - partition_fn = SimplePartitionFn() - # Test with numpy.int32 # This should not raise an exception outputs = list(apply_partition_fn.process(np.int32(1), partition_fn, 3)) self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].tag, '1') # 1 % 3 = 1 - # Test with numpy.int64 # This should not raise an exception outputs = list(apply_partition_fn.process(np.int64(2), partition_fn, 3)) self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].tag, '2') # 2 % 3 = 2 - def test_partition_fn_returning_numpy_integers(self): # Test that partition functions can return numpy integer types import numpy as np - # Create an instance of the ApplyPartitionFnFn class apply_partition_fn = beam.Partition.ApplyPartitionFnFn() - # Define partition functions that return numpy integer types class Int32PartitionFn(beam.PartitionFn): def partition_for(self, element, num_partitions): return np.int32(element % num_partitions) - class Int64PartitionFn(beam.PartitionFn): def partition_for(self, element, num_partitions): return np.int64(element % num_partitions) - # Test with partition function returning numpy.int32 # This should not raise an exception outputs = list(apply_partition_fn.process(1, Int32PartitionFn(), 3)) self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].tag, '1') # 1 % 3 = 1 - # Test with partition function returning numpy.int64 # This should not raise an exception outputs = list(apply_partition_fn.process(2, Int64PartitionFn(), 3)) self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].tag, '2') # 2 % 3 = 2 - def test_partition_boundedness(self): def partition_fn(val, num_partitions): return val % num_partitions - class UnboundedDoFn(beam.DoFn): @beam.DoFn.unbounded_per_element() def process(self, element): yield element - with beam.testing.test_pipeline.TestPipeline() as p: source = p | beam.Create([1, 2, 3, 4, 5]) p1, p2, p3 = source | "bounded" >> beam.Partition(partition_fn, 3) - self.assertEqual(source.is_bounded, True) self.assertEqual(p1.is_bounded, True) self.assertEqual(p2.is_bounded, True) self.assertEqual(p3.is_bounded, True) - unbounded = source | beam.ParDo(UnboundedDoFn()) p4, p5, p6 = unbounded | "unbounded" >> beam.Partition(partition_fn, 3) - self.assertEqual(unbounded.is_bounded, False) self.assertEqual(p4.is_bounded, False) self.assertEqual(p5.is_bounded, False) self.assertEqual(p6.is_bounded, False) - - class FlattenTest(unittest.TestCase): def test_flatten_identical_windows(self): with beam.testing.test_pipeline.TestPipeline() as p: @@ -328,7 +239,6 @@ def test_flatten_identical_windows(self): FixedWindows(100)) out = (source1, source2, source3) | "flatten" >> beam.Flatten() assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) - def test_flatten_no_windows(self): with beam.testing.test_pipeline.TestPipeline() as p: source1 = p | "c1" >> beam.Create([1, 2, 3, 4, 5]) @@ -336,7 +246,6 @@ def test_flatten_no_windows(self): source3 = p | "c3" >> beam.Create([9, 10]) out = (source1, source2, source3) | "flatten" >> beam.Flatten() assert_that(out, equal_to([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) - def test_flatten_mismatched_windows(self): with beam.testing.test_pipeline.TestPipeline() as p: source1 = p | "c1" >> beam.Create( @@ -346,8 +255,6 @@ def test_flatten_mismatched_windows(self): source3 = p | "c3" >> beam.Create([9, 10]) | "w3" >> beam.WindowInto( FixedWindows(100)) _ = (source1, source2, source3) | "flatten" >> beam.Flatten() - - class ExceptionHandlingTest(unittest.TestCase): def test_routes_failures(self): with beam.Pipeline() as pipeline: @@ -358,12 +265,10 @@ def test_routes_failures(self): bad_elements = bad | beam.Keys() assert_that(good, equal_to(['abc', 'foo', 'bar']), 'good') assert_that(bad_elements, equal_to(['long_word', 'foobar']), 'bad') - def test_handles_callbacks(self): with tempfile.TemporaryDirectory() as tmp_dirname: tmp_path = os.path.join(tmp_dirname, 'tmp_filename') file_contents = 'random content' - def failure_callback(e, el): if type(e) is not ValueError: raise Exception(f'Failed to pass in correct exception, received {e}') @@ -373,7 +278,6 @@ def failure_callback(e, el): logging.warning(tmp_path) f.write(file_contents) f.close() - with beam.Pipeline() as pipeline: good, bad = ( pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar', 'foobar']) @@ -386,18 +290,15 @@ def failure_callback(e, el): with open(tmp_path) as f: s = f.read() self.assertEqual(s, file_contents) - def test_handles_no_callback_triggered(self): with tempfile.TemporaryDirectory() as tmp_dirname: tmp_path = os.path.join(tmp_dirname, 'tmp_filename') file_contents = 'random content' - def failure_callback(e, el): f = open(tmp_path, "a") logging.warning(tmp_path) f.write(file_contents) f.close() - with beam.Pipeline() as pipeline: good, bad = ( pipeline | beam.Create(['abc', 'bcd', 'foo', 'bar']) @@ -408,7 +309,6 @@ def failure_callback(e, el): assert_that(good, equal_to(['abc', 'bcd', 'foo', 'bar']), 'good') assert_that(bad_elements, equal_to([]), 'bad') self.assertFalse(os.path.isfile(tmp_path)) - def test_stateful_exception_handling(self): with beam.Pipeline() as pipeline: good, bad = ( @@ -421,7 +321,6 @@ def test_stateful_exception_handling(self): assert_that(good, equal_to([1, 2, 3]), 'good') assert_that( bad_elements, equal_to([(1, 'long_word'), (1, 'foobar')]), 'bad') - def test_timer_exception_handling(self): with beam.Pipeline() as pipeline: good, bad = ( @@ -432,11 +331,9 @@ def test_timer_exception_handling(self): bad_elements = bad | beam.Keys() assert_that(good, equal_to([0, 1, 2]), 'good') assert_that(bad_elements, equal_to([(1, 5), (1, 10)]), 'bad') - def test_tags_with_exception_handling_then_resource_hint(self): class TagHint(ResourceHint): urn = 'beam:resources:tags:v1' - ResourceHint.register_resource_hint('tags', TagHint) with beam.Pipeline() as pipeline: ok, unused_errors = ( @@ -454,11 +351,9 @@ class TagHint(ResourceHint): pd.get_resource_hints(), {'beam:resources:tags:v1': b'test_tag'}, ) - def test_tags_with_exception_handling_timeout_then_resource_hint(self): class TagHint(ResourceHint): urn = 'beam:resources:tags:v1' - ResourceHint.register_resource_hint('tags', TagHint) with beam.Pipeline() as pipeline: ok, unused_errors = ( @@ -476,11 +371,9 @@ class TagHint(ResourceHint): pd.get_resource_hints(), {'beam:resources:tags:v1': b'test_tag'}, ) - def test_tags_with_resource_hint_then_exception_handling(self): class TagHint(ResourceHint): urn = 'beam:resources:tags:v1' - ResourceHint.register_resource_hint('tags', TagHint) with beam.Pipeline() as pipeline: ok, unused_errors = ( @@ -498,11 +391,9 @@ class TagHint(ResourceHint): pd.get_resource_hints(), {'beam:resources:tags:v1': b'test_tag'}, ) - def test_tags_with_resource_hint_then_exception_handling_timeout(self): class TagHint(ResourceHint): urn = 'beam:resources:tags:v1' - ResourceHint.register_resource_hint('tags', TagHint) with beam.Pipeline() as pipeline: ok, unused_errors = ( @@ -520,23 +411,16 @@ class TagHint(ResourceHint): pd.get_resource_hints(), {'beam:resources:tags:v1': b'test_tag'}, ) - - class ExceptionHandlingWithOutputsTest(unittest.TestCase): """Tests for combining with_exception_handling() and with_outputs().""" - def _create_dofn_with_tagged_outputs(self): """A DoFn that yields tagged outputs and can raise on even numbers.""" - class DoWithFailures(beam.DoFn): - def process( self, element: int - ) -> Iterable[ - int - | beam.pvalue.TaggedOutput[Literal['threes'], int] - | beam.pvalue.TaggedOutput[Literal['fives'], str] - ]: + ) -> Iterable[int + | beam.pvalue.TaggedOutput[Literal['threes'], int] + | beam.pvalue.TaggedOutput[Literal['fives'], str]]: if element % 2 == 0: raise ValueError(f'Even numbers not allowed {element}') if element % 3 == 0: @@ -545,21 +429,16 @@ def process( yield beam.pvalue.TaggedOutput('fives', str(element)) # type: ignore[misc] else: yield element - return DoWithFailures() - def test_with_exception_handling_then_with_outputs(self): """Direction 1: .with_exception_handling().with_outputs()""" - with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_exception_handling() - .with_outputs('threes', 'fives', main='main') - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()). + with_exception_handling().with_outputs( + 'threes', 'fives', main='main')) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') assert_that(results.fives, equal_to(['5']), 'fives') @@ -571,21 +450,15 @@ def test_with_exception_handling_then_with_outputs(self): self.assertEqual(results.fives.element_type, str) self.assertEqual( results.bad.element_type, - typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], - ) - + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]]) def test_with_outputs_then_with_exception_handling(self): """Direction 2: .with_outputs().with_exception_handling()""" - with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_outputs('threes', 'fives', main='main') - .with_exception_handling() - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', 'fives', main='main').with_exception_handling()) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') assert_that(results.fives, equal_to(['5']), 'fives') @@ -597,23 +470,17 @@ def test_with_outputs_then_with_exception_handling(self): self.assertEqual(results.fives.element_type, str) self.assertEqual( results.bad.element_type, - typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], - ) - + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]]) def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( - self, - ): + self): """Direction 2 with custom dead_letter_tag.""" - with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_outputs('threes', main='main') - .with_exception_handling(dead_letter_tag='errors') - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', + main='main').with_exception_handling(dead_letter_tag='errors')) assert_that(results.main, equal_to([1]), 'main') assert_that(results.threes, equal_to([3]), 'threes') bad_elements = results.errors | beam.Keys() @@ -621,23 +488,17 @@ def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( self.assertEqual(results.threes.element_type, int) self.assertEqual( results.errors.element_type, - typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], - ) - + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]]) def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( - self, - ): + self): """Direction 1 with custom dead_letter_tag.""" - with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_exception_handling(dead_letter_tag='errors') - .with_outputs('threes', main='main') - ) - + | beam.ParDo( + self._create_dofn_with_tagged_outputs()).with_exception_handling( + dead_letter_tag='errors').with_outputs('threes', main='main')) assert_that(results.main, equal_to([1]), 'main') assert_that(results.threes, equal_to([3]), 'threes') bad_elements = results.errors | beam.Keys() @@ -645,57 +506,40 @@ def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( self.assertEqual(results.threes.element_type, int) self.assertEqual( results.errors.element_type, - typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]], - ) - + typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]]) def test_exception_handling_no_with_outputs_backward_compat(self): """Without with_outputs(), behavior is unchanged.""" - with beam.Pipeline() as p: good, bad = ( p | beam.Create([1, 2, 7]) - | beam.ParDo( - self._create_dofn_with_tagged_outputs() - ).with_exception_handling() - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling()) assert_that(good, equal_to([1, 7]), 'good') bad_elements = bad | beam.Keys() assert_that(bad_elements, equal_to([2]), 'bad') - def test_exception_handling_compat_version_uses_old_behavior(self): """With compat version < 2.73.0, old expand path is used.""" - options = PipelineOptions(update_compatibility_version='2.72.0') + options = PipelineOptions(update_compatibility_version="2.72.0") with beam.Pipeline(options=options) as p: good, bad = ( p | beam.Create([1, 2, 7]) - | beam.ParDo( - self._create_dofn_with_tagged_outputs() - ).with_exception_handling() - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling()) assert_that(good, equal_to([1, 7]), 'good') bad_elements = bad | beam.Keys() assert_that(bad_elements, equal_to([2]), 'bad') - def test_exception_handling_compat_version_element_type_set_manually(self): """With compat version < 2.73.0, element_type is set via manual override - - (the old behavior) rather than via with_output_types. - """ - - options = PipelineOptions(update_compatibility_version='2.72.0') + (the old behavior) rather than via with_output_types.""" + options = PipelineOptions(update_compatibility_version="2.72.0") with beam.Pipeline(options=options) as p: results = ( p | beam.Create([1, 2, 3]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_exception_handling() - .with_outputs('threes', main='main') - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()). + with_exception_handling().with_outputs('threes', main='main')) # In old path, dead letter type is Any (no with_output_types call) self.assertEqual(results.bad.element_type, typehints.Any) # Tagged outputs still get types from DoFn Literal annotations @@ -703,30 +547,22 @@ def test_exception_handling_compat_version_element_type_set_manually(self): self.assertEqual(results.threes.element_type, int) # Main output type should still be inferred via manual override assert_that(results.main, equal_to([1]), 'main') - def test_with_outputs_then_exception_handling_with_map(self): """with_outputs().with_exception_handling() also works on Map.""" with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3, 4, 5]) - | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0) - .with_outputs(main='main') - .with_exception_handling() - ) + | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs( + main='main').with_exception_handling()) assert_that(results.main, equal_to([1, 3, 5]), 'main') bad_elements = results.bad | beam.Keys() assert_that(bad_elements, equal_to([2, 4]), 'bad') - def test_with_output_types_chained_on_pardo(self): """When type hints are chained on the ParDo (not annotations on the DoFn), - tagged output types should still be propagated through - with_exception_handling().with_outputs(). - """ - + with_exception_handling().with_outputs().""" class DoWithFailuresNoAnnotations(beam.DoFn): - def process(self, element): if element % 2 == 0: raise ValueError(f'Even numbers not allowed {element}') @@ -734,66 +570,48 @@ def process(self, element): yield beam.pvalue.TaggedOutput('threes', element) else: yield element - with beam.Pipeline() as p: results = ( p | beam.Create([1, 2, 3, 7]) - | beam.ParDo(DoWithFailuresNoAnnotations()) - .with_output_types(int, threes=int) - .with_exception_handling() - .with_outputs('threes', main='main') - ) - + | beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types( + int, threes=int).with_exception_handling().with_outputs( + 'threes', main='main')) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') bad_elements = results.bad | beam.Keys() assert_that(bad_elements, equal_to([2]), 'bad') self.assertEqual(results.main.element_type, int) self.assertEqual(results.threes.element_type, int) - def test_with_outputs_and_error_handler(self): """with_outputs() + error_handler should return DoOutputsTuple, not a - - bare PCollection. - """ + bare PCollection.""" from apache_beam.transforms.error_handling import ErrorHandler - with beam.Pipeline() as p: with ErrorHandler(beam.Map(lambda x: x)) as handler: results = ( p | beam.Create([1, 2, 3, 4, 5, 6, 7]) - | beam.ParDo(self._create_dofn_with_tagged_outputs()) - .with_outputs('threes', 'fives', main='main') - .with_exception_handling(error_handler=handler) - ) - + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', 'fives', + main='main').with_exception_handling(error_handler=handler)) assert_that(results.main, equal_to([1, 7]), 'main') assert_that(results.threes, equal_to([3]), 'threes') assert_that(results.fives, equal_to(['5']), 'fives') - - def test_callablewrapper_typehint(): T = TypeVar("T") - def identity(x: T) -> T: return x - dofn = beam.core.CallableWrapperDoFn(identity) assert dofn.get_type_hints().strip_iterable()[1][0][0] == typehints.Any - - class FlatMapTest(unittest.TestCase): def test_default(self): - with beam.Pipeline() as pipeline: letters = ( pipeline | beam.Create(['abc', 'def'], reshuffle=False) | beam.FlatMap()) assert_that(letters, equal_to(['a', 'b', 'c', 'd', 'e', 'f'])) - def test_default_identity_function_with_typehint(self): with beam.Pipeline() as pipeline: letters = ( @@ -801,19 +619,15 @@ def test_default_identity_function_with_typehint(self): | beam.Create([["abc"]], reshuffle=False) | beam.FlatMap() | beam.Map(lambda s: s.upper()).with_input_types(str)) - assert_that(letters, equal_to(["ABC"])) - def test_typecheck_with_default(self): with pytest.raises(TypeCheckError): with beam.Pipeline() as pipeline: _ = ( - pipeline +pipeline | beam.Create([[1, 2, 3]], reshuffle=False) | beam.FlatMap() | beam.Map(lambda s: s.upper()).with_input_types(str)) - - class CreateInferOutputSchemaTest(unittest.TestCase): def test_multiple_types_for_field(self): output_type = beam.Create([beam.Row(a=1), @@ -823,13 +637,11 @@ def test_multiple_types_for_field(self): row_type.RowTypeConstraint.from_fields([ ('a', typehints.Union[int, str]) ])) - def test_single_type_for_field(self): output_type = beam.Create([beam.Row(a=1), beam.Row(a=2)]).infer_output_type(None) self.assertEqual( output_type, row_type.RowTypeConstraint.from_fields([('a', int)])) - def test_optional_type_for_field(self): output_type = beam.Create([beam.Row(a=1), beam.Row(a=None)]).infer_output_type(None) @@ -837,13 +649,10 @@ def test_optional_type_for_field(self): output_type, row_type.RowTypeConstraint.from_fields([('a', typehints.Optional[int]) ])) - def test_none_type_for_field_raises_error(self): with self.assertRaisesRegex(TypeError, "('No types found for field %s', 'a')"): beam.Create([beam.Row(a=None), beam.Row(a=None)]).infer_output_type(None) - - if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() From dbbf734d211145609f8592716453e5c21c05aff3 Mon Sep 17 00:00:00 2001 From: apanich Date: Fri, 10 Apr 2026 14:45:43 -0400 Subject: [PATCH 4/4] Update core.py --- sdks/python/apache_beam/transforms/core.py | 27 +++++++++------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index b151d6c56269..893b4cb58f4f 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1688,8 +1688,7 @@ def with_exception_handling( on_failure_callback, allow_unsafe_userstate_in_process, self.get_resource_hints(), - self.get_type_hints(), - ) + self.get_type_hints()) def with_error_handler(self, error_handler, **exception_handling_kwargs): """An alias for `with_exception_handling(error_handler=error_handler, ...)` @@ -1985,8 +1984,7 @@ def with_exception_handling(self, main_tag=None, **kwargs): if main_tag is None: main_tag = self._main_tag or 'good' named = self._do_transform.with_exception_handling( - main_tag=main_tag, **kwargs - ) + main_tag=main_tag, **kwargs) # named is _NamedPTransform wrapping _ExceptionHandlingWrapper named.transform._extra_tags = self._tags return named @@ -2333,8 +2331,7 @@ def __init__( on_failure_callback, allow_unsafe_userstate_in_process, resource_hints, - pardo_type_hints=None, - ): + pardo_type_hints=None): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') self._fn = fn @@ -2434,25 +2431,24 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): return result def expand_2_72_0(self, pcoll): - """Pre-2.73.0 behavior: manual element_type override, no with_output_types.""" + """Pre-2.73.0 behavior: manual element_type override, no with_output_types. + """ pardo = self._build_pardo(pcoll) result = pcoll | pardo.with_outputs( - self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True - ) - # TODO(BEAM-18957): Fix when type inference supports tagged outputs. + self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) + #TODO(BEAM-18957): Fix when type inference supports tagged outputs. result[self._main_tag].element_type = self._fn.infer_output_type( - pcoll.element_type - ) + pcoll.element_type) return self._post_process_result(pcoll, result) def expand(self, pcoll): - if pcoll.pipeline.options.is_compat_version_prior_to('2.73.0'): + if pcoll.pipeline.options.is_compat_version_prior_to("2.73.0"): return self.expand_2_72_0(pcoll) pardo = self._build_pardo(pcoll) - if self._pardo_type_hints and self._pardo_type_hints._has_output_types(): + if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()): main_output_type = self._pardo_type_hints.simple_output_type(self.label) tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types()) else: @@ -2470,8 +2466,7 @@ def expand(self, pcoll): all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag}) result = pcoll | pardo.with_outputs( - *all_tags, main=self._main_tag, allow_unknown_tags=True - ) + *all_tags, main=self._main_tag, allow_unknown_tags=True) return self._post_process_result(pcoll, result)