Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,8 @@ def with_exception_handling(
on_failure_callback,
allow_unsafe_userstate_in_process,
self.get_resource_hints(),
self.get_type_hints())
self.get_type_hints(),
)

def with_error_handler(self, error_handler, **exception_handling_kwargs):
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
Expand Down Expand Up @@ -1984,7 +1985,8 @@ def with_exception_handling(self, main_tag=None, **kwargs):
if main_tag is None:
main_tag = self._main_tag or 'good'
named = self._do_transform.with_exception_handling(
main_tag=main_tag, **kwargs)
main_tag=main_tag, **kwargs
)
# named is _NamedPTransform wrapping _ExceptionHandlingWrapper
named.transform._extra_tags = self._tags
return named
Expand Down Expand Up @@ -2331,7 +2333,8 @@ def __init__(
on_failure_callback,
allow_unsafe_userstate_in_process,
resource_hints,
pardo_type_hints=None):
pardo_type_hints=None,
):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
Expand Down Expand Up @@ -2431,42 +2434,44 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):
return result

def expand_2_72_0(self, pcoll):
"""Pre-2.73.0 behavior: manual element_type override, no with_output_types.
"""
"""Pre-2.73.0 behavior: manual element_type override, no with_output_types."""
pardo = self._build_pardo(pcoll)
result = pcoll | pardo.with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True
)
# TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)
pcoll.element_type
)

return self._post_process_result(pcoll, result)

def expand(self, pcoll):
if pcoll.pipeline.options.is_compat_version_prior_to("2.73.0"):
if pcoll.pipeline.options.is_compat_version_prior_to('2.73.0'):
return self.expand_2_72_0(pcoll)

pardo = self._build_pardo(pcoll)

if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()):
if self._pardo_type_hints and self._pardo_type_hints._has_output_types():
main_output_type = self._pardo_type_hints.simple_output_type(self.label)
tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types())
else:
main_output_type = self._fn.infer_output_type(pcoll.element_type)
tagged_type_hints = dict(self._fn.get_type_hints().tagged_output_types())

# Dead letter format: Tuple[element, Tuple[exception_type, repr, traceback]]
dead_letter_type = typehints.Tuple[pcoll.element_type,
typehints.Tuple[type,
str,
typehints.List[str]]]
dead_letter_type = typehints.Tuple[
pcoll.element_type,
typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]],
]

tagged_type_hints[self._dead_letter_tag] = dead_letter_type
pardo = pardo.with_output_types(main_output_type, **tagged_type_hints)

all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag})
result = pcoll | pardo.with_outputs(
*all_tags, main=self._main_tag, allow_unknown_tags=True)
*all_tags, main=self._main_tag, allow_unknown_tags=True
)

return self._post_process_result(pcoll, result)

Expand Down
120 changes: 80 additions & 40 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
import logging
import os
import tempfile
import unittest
import typing
from google3.testing.pybase import googletest
googletest.ThisTestIsUsefulWithoutCallingMain()
unittest = googletest

from typing import Iterable
from typing import Literal
from typing import TypeVar
Expand Down Expand Up @@ -520,14 +524,19 @@ class TagHint(ResourceHint):

class ExceptionHandlingWithOutputsTest(unittest.TestCase):
"""Tests for combining with_exception_handling() and with_outputs()."""

def _create_dofn_with_tagged_outputs(self):
"""A DoFn that yields tagged outputs and can raise on even numbers."""

class DoWithFailures(beam.DoFn):

def process(
self, element: int
) -> Iterable[int
| beam.pvalue.TaggedOutput[Literal['threes'], int]
| beam.pvalue.TaggedOutput[Literal['fives'], str]]:
) -> Iterable[
int
| beam.pvalue.TaggedOutput[Literal['threes'], int]
| beam.pvalue.TaggedOutput[Literal['fives'], str]
]:
if element % 2 == 0:
raise ValueError(f'Even numbers not allowed {element}')
if element % 3 == 0:
Expand All @@ -546,9 +555,10 @@ def test_with_exception_handling_then_with_outputs(self):
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
with_exception_handling().with_outputs(
'threes', 'fives', main='main'))
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling()
.with_outputs('threes', 'fives', main='main')
)

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand All @@ -561,7 +571,8 @@ def test_with_exception_handling_then_with_outputs(self):
self.assertEqual(results.fives.element_type, str)
self.assertEqual(
results.bad.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]],
)

def test_with_outputs_then_with_exception_handling(self):
"""Direction 2: .with_outputs().with_exception_handling()"""
Expand All @@ -570,8 +581,10 @@ def test_with_outputs_then_with_exception_handling(self):
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes', 'fives', main='main').with_exception_handling())
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_outputs('threes', 'fives', main='main')
.with_exception_handling()
)

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand All @@ -584,19 +597,22 @@ def test_with_outputs_then_with_exception_handling(self):
self.assertEqual(results.fives.element_type, str)
self.assertEqual(
results.bad.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]],
)

def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag(
self):
self,
):
"""Direction 2 with custom dead_letter_tag."""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes',
main='main').with_exception_handling(dead_letter_tag='errors'))
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_outputs('threes', main='main')
.with_exception_handling(dead_letter_tag='errors')
)

assert_that(results.main, equal_to([1]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand All @@ -605,19 +621,22 @@ def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag(
self.assertEqual(results.threes.element_type, int)
self.assertEqual(
results.errors.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]],
)

def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag(
self):
self,
):
"""Direction 1 with custom dead_letter_tag."""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(
self._create_dofn_with_tagged_outputs()).with_exception_handling(
dead_letter_tag='errors').with_outputs('threes', main='main'))
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling(dead_letter_tag='errors')
.with_outputs('threes', main='main')
)

assert_that(results.main, equal_to([1]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand All @@ -626,7 +645,8 @@ def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag(
self.assertEqual(results.threes.element_type, int)
self.assertEqual(
results.errors.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])
typehints.Tuple[int, typehints.Tuple[type[typing.Any], str, typehints.Sequence[str]]],
)

def test_exception_handling_no_with_outputs_backward_compat(self):
"""Without with_outputs(), behavior is unchanged."""
Expand All @@ -635,38 +655,46 @@ def test_exception_handling_no_with_outputs_backward_compat(self):
good, bad = (
p
| beam.Create([1, 2, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling())
| beam.ParDo(
self._create_dofn_with_tagged_outputs()
).with_exception_handling()
)

assert_that(good, equal_to([1, 7]), 'good')
bad_elements = bad | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'bad')

def test_exception_handling_compat_version_uses_old_behavior(self):
"""With compat version < 2.73.0, old expand path is used."""
options = PipelineOptions(update_compatibility_version="2.72.0")
options = PipelineOptions(update_compatibility_version='2.72.0')
with beam.Pipeline(options=options) as p:
good, bad = (
p
| beam.Create([1, 2, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling())
| beam.ParDo(
self._create_dofn_with_tagged_outputs()
).with_exception_handling()
)

assert_that(good, equal_to([1, 7]), 'good')
bad_elements = bad | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'bad')

def test_exception_handling_compat_version_element_type_set_manually(self):
"""With compat version < 2.73.0, element_type is set via manual override
(the old behavior) rather than via with_output_types."""

options = PipelineOptions(update_compatibility_version="2.72.0")
(the old behavior) rather than via with_output_types.
"""

options = PipelineOptions(update_compatibility_version='2.72.0')
with beam.Pipeline(options=options) as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
with_exception_handling().with_outputs('threes', main='main'))
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling()
.with_outputs('threes', main='main')
)

# In old path, dead letter type is Any (no with_output_types call)
self.assertEqual(results.bad.element_type, typehints.Any)
Expand All @@ -682,17 +710,23 @@ def test_with_outputs_then_exception_handling_with_map(self):
results = (
p
| beam.Create([1, 2, 3, 4, 5])
| beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs(
main='main').with_exception_handling())
| beam.Map(lambda x: x if x % 2 != 0 else 1 / 0)
.with_outputs(main='main')
.with_exception_handling()
)
assert_that(results.main, equal_to([1, 3, 5]), 'main')
bad_elements = results.bad | beam.Keys()
assert_that(bad_elements, equal_to([2, 4]), 'bad')

def test_with_output_types_chained_on_pardo(self):
"""When type hints are chained on the ParDo (not annotations on the DoFn),

tagged output types should still be propagated through
with_exception_handling().with_outputs()."""
with_exception_handling().with_outputs().
"""

class DoWithFailuresNoAnnotations(beam.DoFn):

def process(self, element):
if element % 2 == 0:
raise ValueError(f'Even numbers not allowed {element}')
Expand All @@ -705,9 +739,11 @@ def process(self, element):
results = (
p
| beam.Create([1, 2, 3, 7])
| beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types(
int, threes=int).with_exception_handling().with_outputs(
'threes', main='main'))
| beam.ParDo(DoWithFailuresNoAnnotations())
.with_output_types(int, threes=int)
.with_exception_handling()
.with_outputs('threes', main='main')
)

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand All @@ -718,16 +754,20 @@ def process(self, element):

def test_with_outputs_and_error_handler(self):
"""with_outputs() + error_handler should return DoOutputsTuple, not a
bare PCollection."""

bare PCollection.
"""
from apache_beam.transforms.error_handling import ErrorHandler

with beam.Pipeline() as p:
with ErrorHandler(beam.Map(lambda x: x)) as handler:
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes', 'fives',
main='main').with_exception_handling(error_handler=handler))
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_outputs('threes', 'fives', main='main')
.with_exception_handling(error_handler=handler)
)

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
Expand Down
Loading