diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 3e0b5218b166..b3e45bc7f35c 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -30,6 +30,7 @@ """ # pytype: skip-file +import dataclasses import decimal import enum import itertools @@ -67,11 +68,6 @@ from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.timestamp import Timestamp -try: - import dataclasses -except ImportError: - dataclasses = None # type: ignore - try: import dill except ImportError: diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 345c04706d6f..886b1505ffec 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -21,6 +21,7 @@ import collections import collections.abc +import dataclasses import logging import sys import types @@ -175,6 +176,10 @@ def match_is_named_tuple(user_type): hasattr(user_type, '__annotations__') and hasattr(user_type, '_fields')) +def match_is_dataclass(user_type): + return dataclasses.is_dataclass(user_type) and isinstance(user_type, type) + + def _match_is_optional(user_type): return _match_is_union(user_type) and sum( tp is type(None) for tp in _get_args(user_type)) == 1 diff --git a/sdks/python/apache_beam/typehints/row_type.py b/sdks/python/apache_beam/typehints/row_type.py index 08838c84a050..6f96f6f64e32 100644 --- a/sdks/python/apache_beam/typehints/row_type.py +++ b/sdks/python/apache_beam/typehints/row_type.py @@ -19,6 +19,7 @@ from __future__ import annotations +import dataclasses from typing import Any from typing import Dict from typing import Optional @@ -26,6 +27,7 @@ from typing import Tuple from apache_beam.typehints import typehints +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -56,18 +58,14 @@ def __init__( for guidance on creating PCollections with inferred schemas. Note RowTypeConstraint does not currently store arbitrary functions for - converting to/from the user type. Instead, we only support ``NamedTuple`` - user types and make the follow assumptions: + converting to/from the user type. Instead, we support ``NamedTuple`` and + ``dataclasses`` user types and make the follow assumptions: - The user type can be constructed with field values as arguments in order (i.e. ``constructor(*field_values)``). - Field values can be accessed from instances of the user type by attribute (i.e. with ``getattr(obj, field_name)``). - In the future we will add support for dataclasses - ([#22085](https://github.com/apache/beam/issues/22085)) which also satisfy - these assumptions. - The RowTypeConstraint constructor should not be called directly (even internally to Beam). Prefer static methods ``from_user_type`` or ``from_fields``. @@ -107,27 +105,30 @@ def from_user_type( if match_is_named_tuple(user_type): fields = [(name, user_type.__annotations__[name]) for name in user_type._fields] - - field_descriptions = getattr(user_type, '_field_descriptions', None) - - if _user_type_is_generated(user_type): - return RowTypeConstraint.from_fields( - fields, - schema_id=getattr(user_type, _BEAM_SCHEMA_ID), - schema_options=schema_options, - field_options=field_options, - field_descriptions=field_descriptions) - - # TODO(https://github.com/apache/beam/issues/22125): Add user API for - # specifying schema/field options - return RowTypeConstraint( - fields=fields, - user_type=user_type, + elif match_is_dataclass(user_type): + fields = [(field.name, field.type) + for field in dataclasses.fields(user_type)] + else: + return None + + field_descriptions = getattr(user_type, '_field_descriptions', None) + + if _user_type_is_generated(user_type): + return RowTypeConstraint.from_fields( + fields, + schema_id=getattr(user_type, _BEAM_SCHEMA_ID), schema_options=schema_options, field_options=field_options, field_descriptions=field_descriptions) - return None + # TODO(https://github.com/apache/beam/issues/22125): Add user API for + # specifying schema/field options + return RowTypeConstraint( + fields=fields, + user_type=user_type, + schema_options=schema_options, + field_options=field_options, + field_descriptions=field_descriptions) @staticmethod def from_fields( diff --git a/sdks/python/apache_beam/typehints/row_type_test.py b/sdks/python/apache_beam/typehints/row_type_test.py new file mode 100644 index 000000000000..0c5da45740ac --- /dev/null +++ b/sdks/python/apache_beam/typehints/row_type_test.py @@ -0,0 +1,86 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the Beam Row typing functionality.""" + +import typing +import unittest +from dataclasses import dataclass + +import apache_beam as beam +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.typehints import row_type + + +class RowTypeTest(unittest.TestCase): + @staticmethod + def _check_key_type_and_count(x) -> int: + key_type = type(x[0]) + if not row_type._user_type_is_generated(key_type): + raise RuntimeError("Expect type after GBK to be generated user type") + + return len(x[1]) + + def test_group_by_key_namedtuple(self): + MyNamedTuple = typing.NamedTuple( + "MyNamedTuple", [("id", int), ("name", str)]) + + beam.coders.typecoders.registry.register_coder( + MyNamedTuple, beam.coders.RowCoder) + + def generate(num: int): + for i in range(100): + yield (MyNamedTuple(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(10)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[MyNamedTuple, int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) + assert_that(result, equal_to([10] * 100)) + + def test_group_by_key_dataclass(self): + @dataclass + class MyDataClass: + id: int + name: str + + beam.coders.typecoders.registry.register_coder( + MyDataClass, beam.coders.RowCoder) + + def generate(num: int): + for i in range(100): + yield (MyDataClass(i, 'a'), num) + + pipeline = TestPipeline(is_integration_test=False) + + with pipeline as p: + result = ( + p + | 'Create' >> beam.Create([i for i in range(10)]) + | 'Generate' >> beam.ParDo(generate).with_output_types( + tuple[MyDataClass, int]) + | 'GBK' >> beam.GroupByKey() + | 'Count Elements' >> beam.Map(self._check_key_type_and_count)) + assert_that(result, equal_to([10] * 100)) diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py index e9674fa5bc20..5dd8ff290c48 100644 --- a/sdks/python/apache_beam/typehints/schemas.py +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -96,6 +96,7 @@ from apache_beam.typehints.native_type_compatibility import _safe_issubclass from apache_beam.typehints.native_type_compatibility import convert_to_python_type from apache_beam.typehints.native_type_compatibility import extract_optional_type +from apache_beam.typehints.native_type_compatibility import match_is_dataclass from apache_beam.typehints.native_type_compatibility import match_is_named_tuple from apache_beam.typehints.schema_registry import SCHEMA_REGISTRY from apache_beam.typehints.schema_registry import SchemaTypeRegistry @@ -629,7 +630,7 @@ def schema_from_element_type(element_type: type) -> schema_pb2.Schema: Returns schema as a list of (name, python_type) tuples""" if isinstance(element_type, row_type.RowTypeConstraint): return named_fields_to_schema(element_type._fields) - elif match_is_named_tuple(element_type): + elif match_is_named_tuple(element_type) or match_is_dataclass(element_type): if hasattr(element_type, row_type._BEAM_SCHEMA_ID): # if the named tuple's schema is in registry, we just use it instead of # regenerating one. diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py index 73db06b9a8d2..5a5d7396ab30 100644 --- a/sdks/python/apache_beam/typehints/schemas_test.py +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import dataclasses import itertools import pickle import unittest @@ -388,6 +389,24 @@ def test_namedtuple_roundtrip(self, user_type): self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) self.assert_namedtuple_equivalent(roundtripped.user_type, user_type) + def test_dataclass_roundtrip(self): + @dataclasses.dataclass + class SimpleDataclass: + id: np.int64 + name: str + + roundtripped = typing_from_runner_api( + typing_to_runner_api( + SimpleDataclass, schema_registry=SchemaTypeRegistry()), + schema_registry=SchemaTypeRegistry()) + + self.assertIsInstance(roundtripped, row_type.RowTypeConstraint) + # The roundtripped user_type is generated as a NamedTuple, so we can't test + # equivalence directly with the dataclass. + # Instead, let's verify annotations. + self.assertEqual( + roundtripped.user_type.__annotations__, SimpleDataclass.__annotations__) + def test_row_type_constraint_to_schema(self): result_type = typing_to_runner_api( row_type.RowTypeConstraint.from_fields([ @@ -646,6 +665,48 @@ def test_trivial_example(self): expected.row_type.schema.fields, typing_to_runner_api(MyCuteClass).row_type.schema.fields) + def test_trivial_example_dataclass(self): + @dataclasses.dataclass + class MyCuteDataclass: + name: str + age: Optional[int] + interests: List[str] + height: float + blob: ByteString + + expected = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING), + ), + schema_pb2.Field( + name='age', + type=schema_pb2.FieldType( + nullable=True, atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='interests', + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)))), + schema_pb2.Field( + name='height', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.DOUBLE)), + schema_pb2.Field( + name='blob', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTES)), + ]))) + + self.assertEqual( + expected.row_type.schema.fields, + typing_to_runner_api(MyCuteDataclass).row_type.schema.fields) + def test_user_type_annotated_with_id_after_conversion(self): MyCuteClass = NamedTuple('MyCuteClass', [ ('name', str),