From 0aa835dcb2d2b037121cb75ee6dc439f1d364ad0 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Sat, 28 Feb 2026 14:12:10 -0500 Subject: [PATCH 1/4] Support inferring schemas from Python dataclasses --- sdks/python/apache_beam/coders/coder_impl.py | 6 +- .../typehints/native_type_compatibility.py | 6 ++ sdks/python/apache_beam/typehints/row_type.py | 33 ++++++++-- sdks/python/apache_beam/typehints/schemas.py | 3 +- .../apache_beam/typehints/schemas_test.py | 61 +++++++++++++++++++ 5 files changed, 97 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 3e0b5218b166..b3e45bc7f35c 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -30,6 +30,7 @@ """ # pytype: skip-file +import dataclasses import decimal import enum import itertools @@ -67,11 +68,6 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp -try: - import dataclasses -except ImportError: - dataclasses = None # type: ignore - try: import dill except ImportError: diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 345c04706d6f..bd2ddd1c017e 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -21,6 +21,7 @@ import collections import collections.abc +import dataclasses import logging import sys import types @@ -175,6 +176,10 @@ def match_is_named_tuple(user_type): hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields')) +def match_is_dataclass(user_type): + return dataclasses.is_dataclass(user_type) and isinstance(user_type, type) + + def _match_is_optional(user_type): return _match_is_union(user_type) and sum( tp is type(None) for tp in _get_args(user_type)) == 1 @@ -418,6 +423,7 @@ def convert_to_beam_type(typ): # This MUST appear before the entry for the normal Tuple. _TypeMapEntry( match=match_is_named_tuple, arity=0, beam_type=typehints.Any), + _TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any), _TypeMapEntry( match=_match_is_primitive(tuple), arity=-1, beam_type=typehints.Tuple), diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 08838c84a050..579f51e83971 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -19,6 +19,7 @@ from __future__ import annotations +import dataclasses from typing import Any from typing import Dict from typing import Optional @@ -26,6 +27,7 @@ from typing import Tuple from apache_beam.typehints import typehints +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -56,18 +58,14 @@ def __init__( for guidance on creating PCollections with inferred schemas. Note RowTypeConstraint does not currently store arbitrary functions for - converting to/from the user type. Instead, we only support ``NamedTuple`` - user types and make the follow assumptions: + converting to/from the user type. Instead, we support ``NamedTuple`` and + ``dataclasses`` user types and make the follow assumptions: - The user type can be constructed with field values as arguments in order (i.e. ``constructor(*field_values)``). - Field values can be accessed from instances of the user type by attribute (i.e. with ``getattr(obj, field_name)``). - In the future we will add support for dataclasses - ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy - these assumptions. - The RowTypeConstraint constructor should not be called directly (even internally to Beam). Prefer static methods ``from_user_type`` or ``from_fields``. @@ -127,6 +125,29 @@ def from_user_type( field_options=field_options, field_descriptions=field_descriptions) + if match_is_dataclass(user_type): + fields = [(field.name, field.type) + for field in dataclasses.fields(user_type)] + + field_descriptions = getattr(user_type, '_field_descriptions', None) + + if _user_type_is_generated(user_type): + return RowTypeConstraint.from_fields( + fields, + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) + + # TODO(https://github.com/apache/beam/issues/22125): Add user API for + # specifying schema/field options + return RowTypeConstraint( + fields=fields, + user_type=user_type, + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) + return None @staticmethod diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index e9674fa5bc20..5dd8ff290c48 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -96,6 +96,7 @@ from apache_beam.typehints.native_type_compatibility import _safe_issubclass from apache_beam.typehints.native_type_compatibility import convert_to_python_type from apache_beam.typehints.native_type_compatibility import extract_optional_type +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema: Returns schema as a list of (name, python_type) tuples""" if isinstance(element_type, row_type.RowTypeConstraint): return named_fields_to_schema(element_type._fields) - elif match_is_named_tuple(element_type): + elif match_is_named_tuple(element_type) or match_is_dataclass(element_type): if hasattr(element_type, row_type._BEAM_SCHEMA_ID): # if the named tuple's schema is in registry, we just use it instead of # regenerating one. diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 73db06b9a8d2..5a5d7396ab30 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import dataclasses import itertools import pickle import unittest @@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type): self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) self.assert_namedtuple_equivalent(roundtripped.user_type, user_type) + def test_dataclass_roundtrip(self): + @dataclasses.dataclass + class SimpleDataclass: + id: np.int64 + name: str + + roundtripped = typing_from_runner_api( + typing_to_runner_api( + SimpleDataclass, schema_registry=SchemaTypeRegistry()), + schema_registry=SchemaTypeRegistry()) + + self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) + # The roundtripped user_type is generated as a NamedTuple, so we can't test + # equivalence directly with the dataclass. + # Instead, let's verify annotations. + self.assertEqual( + roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__) + def test_row_type_constraint_to_schema(self): result_type = typing_to_runner_api( row_type.RowTypeConstraint.from_fields([ @@ -646,6 +665,48 @@ def test_trivial_example(self): expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields) + def test_trivial_example_dataclass(self): + @dataclasses.dataclass + class MyCuteDataclass: + name: str + age: Optional[int] + interests: List[str] + height: float + blob: ByteString + + expected = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING), + ), + schema_pb2.Field( + name='age', + type=schema_pb2.FieldType( + nullable=True, atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='interests', + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)))), + schema_pb2.Field( + name='height', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.DOUBLE)), + schema_pb2.Field( + name='blob', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTES)), + ]))) + + self.assertEqual( + expected.row_type.schema.fields, + typing_to_runner_api(MyCuteDataclass).row_type.schema.fields) + def test_user_type_annotated_with_id_after_conversion(self): MyCuteClass = NamedTuple('MyCuteClass', [ ('name', str), From bece3536e2883d43f4027173167e52b69fa28a3d Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 9 Mar 2026 15:36:06 -0400 Subject: [PATCH 2/4] Address comments; Revert native_type_compatibility _TypeMapEntry change --- .../typehints/native_type_compatibility.py | 1 - sdks/python/apache_beam/typehints/row_type.py | 54 ++++++------------- 2 files changed, 17 insertions(+), 38 deletions(-) diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index bd2ddd1c017e..886b1505ffec 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -423,7 +423,6 @@ def convert_to_beam_type(typ): # This MUST appear before the entry for the normal Tuple. _TypeMapEntry( match=match_is_named_tuple, arity=0, beam_type=typehints.Any), - _TypeMapEntry(match=match_is_dataclass, arity=0, beam_type=typehints.Any), _TypeMapEntry( match=_match_is_primitive(tuple), arity=-1, beam_type=typehints.Tuple), diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 579f51e83971..6f96f6f64e32 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -105,50 +105,30 @@ def from_user_type( if match_is_named_tuple(user_type): fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] - - field_descriptions = getattr(user_type, '_field_descriptions', None) - - if _user_type_is_generated(user_type): - return RowTypeConstraint.from_fields( - fields, - schema_id=getattr(user_type, _BEAM_SCHEMA_ID), - schema_options=schema_options, - field_options=field_options, - field_descriptions=field_descriptions) - - # TODO(https://github.com/apache/beam/issues/22125): Add user API for - # specifying schema/field options - return RowTypeConstraint( - fields=fields, - user_type=user_type, - schema_options=schema_options, - field_options=field_options, - field_descriptions=field_descriptions) - - if match_is_dataclass(user_type): + elif match_is_dataclass(user_type): fields = [(field.name, field.type) for field in dataclasses.fields(user_type)] + else: + return None - field_descriptions = getattr(user_type, '_field_descriptions', None) - - if _user_type_is_generated(user_type): - return RowTypeConstraint.from_fields( - fields, - schema_id=getattr(user_type, _BEAM_SCHEMA_ID), - schema_options=schema_options, - field_options=field_options, - field_descriptions=field_descriptions) - - # TODO(https://github.com/apache/beam/issues/22125): Add user API for - # specifying schema/field options - return RowTypeConstraint( - fields=fields, - user_type=user_type, + field_descriptions = getattr(user_type, '_field_descriptions', None) + + if _user_type_is_generated(user_type): + return RowTypeConstraint.from_fields( + fields, + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), schema_options=schema_options, field_options=field_options, field_descriptions=field_descriptions) - return None + # TODO(https://github.com/apache/beam/issues/22125): Add user API for + # specifying schema/field options + return RowTypeConstraint( + fields=fields, + user_type=user_type, + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) @staticmethod def from_fields( From d9921bd4c3ab1f5d91601f37bfcb46f3998a698b Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Mon, 9 Mar 2026 22:20:13 -0400 Subject: [PATCH 3/4] Add unit test for named tuple and dataclasses encoded by RowCoder and passing through GBK --- .../apache_beam/typehints/row_type_test.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 sdks/python/apache_beam/typehints/row_type_test.py diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py new file mode 100644 index 000000000000..49b560f86134 --- /dev/null +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the Beam Row typing functionality.""" + +from dataclasses import dataclass +import typing +import unittest + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.typehints import row_type + + +class RowTypeTest(unittest.TestCase): + @staticmethod + def _check_key_type_and_count(x) -> int: + key_type = type(x[0]) + if not row_type._user_type_is_generated(key_type): + raise RuntimeError("Expect type after GBK to be generated user type") + + return len(x[1]) + + def test_group_by_key_namedtuple(self): + MyNamedTuple = typing.NamedTuple( + "MyNamedTuple", [("id", int), ("name", str)]) + + beam.coders.typecoders.registry.register_coder( + MyNamedTuple, beam.coders.RowCoder) + + def generate(num: int): + for i in range(100): + yield (MyNamedTuple(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(10)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[MyNamedTuple, int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) + assert_that(result, equal_to([10] * 100)) + + def test_group_by_key_dataclass(self): + @dataclass + class MyDataClass: + id: int + name: str + + beam.coders.typecoders.registry.register_coder( + MyDataClass, beam.coders.RowCoder) + + def generate(num: int): + for i in range(100): + yield (MyDataClass(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(10)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[MyDataClass, int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) + assert_that(result, equal_to([10] * 100)) From a8f585ef4287c5d253f47df98515e2f24f7344d4 Mon Sep 17 00:00:00 2001 From: Yi Hu Date: Tue, 10 Mar 2026 11:15:23 -0400 Subject: [PATCH 4/4] Fix lint --- sdks/python/apache_beam/typehints/row_type_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py index 49b560f86134..0c5da45740ac 100644 --- a/sdks/python/apache_beam/typehints/row_type_test.py +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -17,9 +17,9 @@ """Unit tests for the Beam Row typing functionality.""" -from dataclasses import dataclass import typing import unittest +from dataclasses import dataclass import apache_beam as beam from apache_beam.testing.test_pipeline import TestPipeline