Skip to content

Commit

Permalink
bump version
Browse files Browse the repository at this point in the history
  • Loading branch information
kheina committed Jul 13, 2024
1 parent 9f4da30 commit 12bed7c
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 92 deletions.
4 changes: 2 additions & 2 deletions avrofastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@


# this needs to be imported last
from fastapi.applications import routing # isort:skip
from fastapi.applications import routing # type: ignore # isort:skip


__version__: str = '0.0.4'
__version__: str = '0.0.5'


class AvroFastAPI(FastAPI) :
Expand Down
7 changes: 5 additions & 2 deletions avrofastapi/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from fastapi.datastructures import Default, DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import solve_dependencies
from fastapi.encoders import DictIntStrAny, SetIntStr
from fastapi.exceptions import RequestValidationError
from fastapi.responses import Response
from fastapi.routing import APIRoute, APIRouter, run_endpoint_function, serialize_response
Expand All @@ -27,7 +26,7 @@
from pydantic.fields import ModelField, Undefined
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.types import ASGIApp, Receive, Scope, Send

Expand All @@ -38,6 +37,10 @@
from avrofastapi.serialization import AvroDeserializer, AvroSerializer, avro_frame, read_avro_frames


SetIntStr = Set[Union[int, str]]
DictIntStrAny = Dict[Union[int, str], Any]


class CalcDict(dict) :

def __init__(self, default: Callable[[Hashable], Any]) -> None :
Expand Down
119 changes: 66 additions & 53 deletions avrofastapi/schema.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,52 @@
try :
from re import _pattern_type as Pattern
from re import compile as re_compile
except ImportError :
from re import Pattern, compile as re_compile

from datetime import date, datetime, time
from decimal import Decimal
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
from re import Pattern
from re import compile as re_compile
from types import UnionType
from typing import Any, Callable, Dict, List, Mapping, Optional, Self, Sequence, Set, Type, Union
from uuid import UUID

from avro.errors import AvroException
from pydantic import BaseModel, ConstrainedBytes, ConstrainedDecimal, ConstrainedInt, conint
from pydantic import BaseModel, conint
from pydantic.types import ConstrainedBytes, ConstrainedDecimal


AvroInt: ConstrainedInt = conint(ge=-2147483648, le=2147483647)
AvroInt: type[int] = conint(ge=-2147483648, le=2147483647)
_avro_name_format: Pattern = re_compile(r'^[A-Za-z_][A-Za-z0-9_]*$')


class AvroFloat(float) :
pass


AvroSchema: Type = Union[str, Dict[str, Union['AvroSchema', int]], List['AvroSchema']]
AvroSchema = Union[str, Mapping[str, Union['AvroSchema', list[str], int]], list['AvroSchema']]


def convert_schema(model: Type[BaseModel], error: bool = False, conversions: Dict[type, Callable[[type], AvroSchema]] = { }) -> AvroSchema :
def convert_schema(model: Type[BaseModel], error: bool = False, conversions: dict[type, Union[Callable[['AvroSchemaGenerator', type], AvroSchema], AvroSchema]] = { }) -> AvroSchema :
generator: AvroSchemaGenerator = AvroSchemaGenerator(model, error, conversions)
return generator.schema()


def _validate_avro_name(name: str) :
def _validate_avro_name(name: str) -> str :
if _avro_name_format.match(name) is None :
raise AvroException(f'{name} does not match the avro name format: names must start with [A-Za-z_] and subsequently contain only [A-Za-z0-9_]')

return name


def get_name(model: Type) -> str :
origin: Optional[Type] = getattr(model, '__origin__', None) # for types from typing library
name: str = None
name: Optional[str] = None

if origin :
if origin or (model.__class__ and model.__class__.__module__ == 'types') :
name = str(origin)
if name.startswith('typing.') :
name = name[7:]

if name.startswith('types.') :
name = name[6:]

name += '_' + '_'.join(list(map(get_name, model.__args__)))

elif issubclass(model, ConstrainedBytes) and model.__name__ == 'ConstrainedBytesValue' :
Expand All @@ -58,7 +61,7 @@ def get_name(model: Type) -> str :
return name


def _validate_avro_namespace(namespace: str, parent_namespace: str = None) :
def _validate_avro_namespace(namespace: str, parent_namespace: Optional[str] = None) :
if not all(map(_avro_name_format.match, namespace.split('.'))) :
raise AvroException(f'{namespace} does not match the avro namespace format: A namespace is a dot-separated sequence of names. names must start with [A-Za-z_] and subsequently contain only [A-Za-z0-9_]')

Expand All @@ -74,7 +77,7 @@ def _validate_avro_namespace(namespace: str, parent_namespace: str = None) :

class AvroSchemaGenerator :

def __init__(self, model: Type[BaseModel], error: bool = False, conversions: Dict[type, Callable[[type], AvroSchema]] = { }) -> None :
def __init__(self, model: Type[BaseModel], error: bool = False, conversions: dict[type, Union[Callable[[Self, type], AvroSchema], AvroSchema]] = { }) -> None :
"""
:param model: the pydantic model to generate a schema for
:param error: whether or not the model is an error
Expand All @@ -86,45 +89,43 @@ def __init__(self, model: Type[BaseModel], error: bool = False, conversions: Dic
self.namespace: str = getattr(model, '__namespace__', self.name)
_validate_avro_namespace(self.namespace)
self.error: bool = error or self.name.lower().endswith('error')
self.refs: Optional[Set[str]] = None
self._conversions: Dict[type, Callable[[type], AvroSchema]] = {
**self._conversions_,
self.refs: Set[str] = set()
self._conversions: dict[type, Union[Callable[[Self, type], AvroSchema], AvroSchema]] = {
**AvroSchemaGenerator._conversions_,
**conversions,
}


def schema(self: 'AvroSchemaGenerator') -> AvroSchema :
def schema(self: Self) -> AvroSchema :
self.refs = set()
schema: AvroSchema = self._get_type(self.model)

if isinstance(schema, dict) :
schema['namespace'] = self.namespace
schema['name'] = self.name

if schema.get('type') == 'record' and self.error :
if schema.get('type') == 'record' and self.error:
schema['type'] = 'error'

return schema


def _convert_array(self: 'AvroSchemaGenerator', model: Type[Iterable[Any]]) -> Dict[str, AvroSchema] :
object_type: AvroSchema = self._get_type(model.__args__[0])
def _convert_array(self: Self, model: Type[Sequence[Any]]) -> AvroSchema :
object_type: AvroSchema = self._get_type(model.__args__[0]) # type: ignore

# TODO: does this do anything?
if (
isinstance(object_type, dict)
and isinstance(object_type.get('type'), dict)
and object_type['type'].get('logicalType') is not None
):
object_type = object_type['type']
and object_type['type'].get('logicalType') is not None # type: ignore
) :
object_type = object_type['type'] # type: ignore

return {
'type': 'array',
'items': object_type,
}


def _convert_object(self: 'AvroSchemaGenerator', model: Type[BaseModel]) -> Dict[str, Union[str, List[AvroSchema]]] :
def _convert_object(self: Self, model: Type[BaseModel]) -> AvroSchema :
sub_namespace: Optional[str] = getattr(model, '__namespace__', None)
parent_namespace: str = self.namespace
if sub_namespace :
Expand All @@ -134,15 +135,15 @@ def _convert_object(self: 'AvroSchemaGenerator', model: Type[BaseModel]) -> Dict
fields: List[AvroSchema] = []

for name, field in model.__fields__.items() :
_validate_avro_name(name)
f: AvroSchema = { }
f['name'] = _validate_avro_name(name)
submodel = model.__annotations__[name]
f: AvroSchema = { 'name': name }

if getattr(submodel, '__origin__', None) is Union and len(submodel.__args__) == 2 and type(None) in submodel.__args__ and field.default is None :
# this is a special case where the field is nullable and the default value is null, but the actual value can be omitted from the schema
# we rearrange Optional[Type] and Union[Type, None] to Union[None, Type] so that null becomes the default type and the 'default' key is unnecessary
type_index: int = 0 if submodel.__args__.index(type(None)) else 1
f['type'] = self._get_type(Union[None, submodel.__args__[type_index]])
f['type'] = self._get_type(Union[None, submodel.__args__[type_index]]) # type: ignore

else :
f['type'] = self._get_type(submodel)
Expand Down Expand Up @@ -171,11 +172,11 @@ def _convert_object(self: 'AvroSchemaGenerator', model: Type[BaseModel]) -> Dict
return schema


def _convert_union(self: 'AvroSchemaGenerator', model: Type[Union[Any, Any]]) -> List[AvroSchema] :
def _convert_union(self: Self, model: Type[Union[Any, Any]]) -> List[AvroSchema] :
return list(map(self._get_type, model.__args__))


def _convert_enum(self: 'AvroSchemaGenerator', model: Type[Enum]) -> Dict[str, Union[str, List[str]]] :
def _convert_enum(self: Self, model: Type[Enum]) -> AvroSchema :
name: str = get_name(model)
_validate_avro_name(name)

Expand All @@ -198,14 +199,13 @@ def _convert_enum(self: 'AvroSchemaGenerator', model: Type[Enum]) -> Dict[str, U
return schema


def _convert_bytes(self: 'AvroSchemaGenerator', model: Type[ConstrainedBytes]) -> Dict[str, Union[str, int]] :
def _convert_bytes(self: Self, model: Type[ConstrainedBytes]) -> AvroSchema :
if model.min_length == model.max_length and model.max_length :
schema: Dict[str, Union[str, int]] = {
'type': 'fixed',
'name': get_name(model),
'name': _validate_avro_name(get_name(model)),
'size': model.max_length,
}
_validate_avro_name(schema['name'])

self_namespace: Optional[str] = getattr(model, '__namespace__', None)
if self_namespace :
Expand All @@ -217,24 +217,24 @@ def _convert_bytes(self: 'AvroSchemaGenerator', model: Type[ConstrainedBytes]) -
return 'bytes'


def _convert_map(self: 'AvroSchemaGenerator', model: Type[Dict[str, Any]]) -> Dict[str, AvroSchema] :
def _convert_map(self: Self, model: Type[dict[str, Any]]) -> AvroSchema :
if not hasattr(model, '__args__') :
raise AvroException('typing.Dict must be used to determine key/value type, not dict')

if model.__args__[0] != str :
if model.__args__[0] is not str : # type: ignore
raise AvroException('maps must have string keys')

return {
'type': 'map',
'values': self._get_type(model.__args__[1]),
'values': self._get_type(model.__args__[1]), # type: ignore
}


def _convert_decimal(self: 'AvroSchemaGenerator', model: Type[Decimal]) -> None :
def _convert_decimal(self: Self, _: Type[Decimal]) -> AvroSchema :
raise AvroException('Support for unconstrained decimals is not possible due to the nature of avro decimals. please use pydantic.condecimal(max_digits=int, decimal_places=int)')


def _convert_condecimal(self: 'AvroSchemaGenerator', model: Type[ConstrainedDecimal]) -> Dict[str, Union[str, int]] :
def _convert_condecimal(self: Self, model: Type[ConstrainedDecimal]) -> AvroSchema :
if not model.max_digits or model.decimal_places is None :
raise AvroException('Decimal attributes max_digits and decimal_places must be provided in order to map to avro decimals')

Expand All @@ -245,10 +245,8 @@ def _convert_condecimal(self: 'AvroSchemaGenerator', model: Type[ConstrainedDeci
'scale': model.decimal_places,
}


_conversions_ = {
_conversions_: dict[type, Union[Callable[[Self, type], AvroSchema], AvroSchema]] = {
BaseModel: _convert_object,
Union: _convert_union,
list: _convert_array,
Enum: _convert_enum,
ConstrainedBytes: _convert_bytes,
Expand Down Expand Up @@ -282,29 +280,44 @@ def _convert_condecimal(self: 'AvroSchemaGenerator', model: Type[ConstrainedDeci
},
}


def _get_type(self: 'AvroSchemaGenerator', model: Type[BaseModel]) -> AvroSchema :
def _get_type(self: Self, model: type) -> AvroSchema :
name: str = get_name(model)

if name in self.refs :
return name

origin: Optional[Type] = getattr(model, '__origin__', None)

if origin in self._conversions :
if origin and origin in self._conversions :
# none of these can be converted without funcs
schema: AvroSchema = self._conversions[origin](self, model)
schema: AvroSchema = self._conversions[origin](self, model) # type: ignore
if isinstance(schema, dict) and 'name' in schema :
assert isinstance(schema['name'], str)
self.refs.add(schema['name'])
return schema

clss: Optional[Type] = getattr(model, '__class__', None)

if clss and clss in self._conversions :
# none of these can be converted without funcs
schema: AvroSchema = self._conversions[clss](self, model) # type: ignore
if isinstance(schema, dict) and 'name' in schema :
assert isinstance(schema['name'], str)
self.refs.add(schema['name'])
return schema

for cls in getattr(model, '__mro__', []) :
if cls in self._conversions :
if isinstance(self._conversions[cls], Callable) :
schema: AvroSchema = self._conversions[cls](self, model)
if callable(self._conversions[cls]) :
schema: AvroSchema = self._conversions[cls](self, model) # type: ignore
if 'name' in schema :
self.refs.add(schema['name'])
self.refs.add(schema['name']) # type: ignore
return schema
return self._conversions[cls]
return self._conversions[cls] # type: ignore

raise NotImplementedError(f'{model} missing from conversion map.')


# I didn't want to ignore the whole definition above, so assign it down here
AvroSchemaGenerator._conversions_[Union] = AvroSchemaGenerator._convert_union # type: ignore
AvroSchemaGenerator._conversions_[UnionType] = AvroSchemaGenerator._convert_union # type: ignore
Loading

0 comments on commit 12bed7c

Please sign in to comment.