Skip to content

Commit

Permalink
Fixed bug where decorated models lose proper type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-oluwa committed Feb 10, 2024
1 parent 2b0b5b7 commit 28271a3
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 26 deletions.
Binary file added dist/django_utz-0.1.9-py3-none-any.whl
Binary file not shown.
Binary file added dist/django_utz-0.1.9.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion django_utz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Author: Daniel T. Afolayan (ti-oluwa.github.io)
"""

__version__ = "0.1.8"
__version__ = "0.1.9"
__author__ = "Daniel T. Afolayan"

alias = "django-user-timezone"
Expand Down
11 changes: 6 additions & 5 deletions django_utz/decorators/models/bases.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from abc import ABC, abstractmethod
import inspect
from typing import Any
from typing import Any, TypeVar
from django.db import models
from django.db import models

from ..bases import UTZDecorator
from .exceptions import ModelConfigurationError

DjangoModel = TypeVar("DjangoModel", bound=models.Model)


class ModelDecorator(UTZDecorator, ABC):
Expand All @@ -17,19 +18,19 @@ class ModelDecorator(UTZDecorator, ABC):
required_configs = ()
__slots__ = ("model",)

def __init__(self, model: type[models.Model]) -> None:
def __init__(self, model: DjangoModel) -> None:
self.model = self.check_model(model)
super().__init__()


def __call__(self) -> type[models.Model]:
def __call__(self) -> DjangoModel:
prepared_model = self.prepare_model()
if not issubclass(prepared_model, models.Model):
raise TypeError("prepare_model method must return a model")
return prepared_model


def check_model(self, model: type[models.Model]) -> type[models.Model]:
def check_model(self, model: DjangoModel) -> DjangoModel:
"""
Check if model and model configuration is valid. Returns the model if it is valid.
Expand All @@ -54,7 +55,7 @@ def check_model(self, model: type[models.Model]) -> type[models.Model]:


@abstractmethod
def prepare_model(self) -> type[models.Model]:
def prepare_model(self) -> DjangoModel:
"""
Prepare the model for use. This where you can customize the model.
Expand Down
21 changes: 11 additions & 10 deletions django_utz/decorators/models/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import zoneinfo
except:
from backports import zoneinfo
from django.contrib.auth.models import AbstractUser, AbstractBaseUser, User
from django.contrib.auth.models import AbstractBaseUser
from typing import TypeVar, List
from django.core.exceptions import FieldDoesNotExist
from typing import Callable
Expand All @@ -17,7 +17,7 @@
from ..utils import is_datetime_field, is_timezone_valid, validate_timezone, transform_utz_decorator
from ...datetime import utzdatetime
from .exceptions import ModelError, ModelConfigurationError
from .bases import ModelDecorator
from .bases import ModelDecorator, DjangoModel
from .utils import get_user, is_user_model, FunctionAttribute


Expand Down Expand Up @@ -72,8 +72,9 @@ def to_local_timezone(self, _datetime: datetime.datetime) -> utzdatetime:
return utz_dt


UserModel = TypeVar("UserModel", AbstractBaseUser, AbstractUser, User)
UTZUserModel = TypeVar("UTZUserModel", AbstractBaseUser, AbstractUser, User, UserModelUTZMixin)

UserModel = TypeVar("UserModel", bound=AbstractBaseUser)
UTZUserModel = TypeVar("UTZUserModel", bound=type[AbstractBaseUser | UserModelUTZMixin])


class UserModelDecorator(ModelDecorator):
Expand All @@ -89,7 +90,7 @@ def check_model(self, model: UserModel) -> UserModel:
"""Ensures that the model in which this mixin is used is the project's user model"""
if not is_user_model(model):
raise ModelError(f"Model '{model.__name__}' is not the project's user model")
return super().check_model(model)
return super().check_model(model)


def validate_timezone_field(self, value: str) -> None:
Expand Down Expand Up @@ -118,7 +119,7 @@ class RegularModelDecorator(ModelDecorator):
all_configs = ("datetime_fields", "attribute_suffix", "use_related_user_timezone", "related_user")
required_configs = ("datetime_fields",)

def check_model(self, model: models.Model) -> models.Model:
def check_model(self, model: DjangoModel) -> DjangoModel:
model = super().check_model(model)

related_user = getattr(model.UTZMeta, "related_user", None)
Expand All @@ -127,7 +128,7 @@ def check_model(self, model: models.Model) -> models.Model:
return model


def prepare_model(self) -> type[models.Model]:
def prepare_model(self) -> DjangoModel:
if self.get_config("datetime_fields") == "__all__":
self.set_config("datetime_fields", self.get_datetime_fields(self.model))

Expand Down Expand Up @@ -160,7 +161,7 @@ def validate_use_related_user_timezone(self, value: bool) -> None:
return None


def get_datetime_fields(self, model: type[models.Model]) -> List[str]:
def get_datetime_fields(self, model: DjangoModel) -> List[str]:
"""Returns the datetime fields in the given model."""
return [field.name for field in model._meta.fields if isinstance(field, models.DateTimeField)]

Expand All @@ -185,7 +186,7 @@ def func(model_instance: models.Model) -> utzdatetime:
return func


def update_model_attrs(self, model: type[models.Model]) -> type[models.Model]:
def update_model_attrs(self, model: DjangoModel) -> DjangoModel:
"""
Updates the model with the read-only attributes for the datetime fields.
Expand All @@ -208,7 +209,7 @@ def update_model_attrs(self, model: type[models.Model]) -> type[models.Model]:

# Function-type decorator for django models

def model(model: type[models.Model]) -> type[models.Model]:
def model(model: DjangoModel) -> DjangoModel:
"""
#### `django_utz` decorator for django models.
Expand Down
17 changes: 10 additions & 7 deletions django_utz/decorators/serializers/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rest_framework import serializers
from django.db import models
from typing import Any, Dict
from typing import Any, Dict, TypeVar
import inspect
import datetime

Expand All @@ -9,7 +9,9 @@
from ..models.exceptions import ModelError
from .exceptions import SerializerConfigurationError
from ...serializer_fields import UTZDateTimeField
from ..models.bases import DjangoModel

DRFModelSerializer = TypeVar("DRFModelSerializer", bound=serializers.ModelSerializer)


class ModelSerializerDecorator(UTZDecorator):
Expand All @@ -18,27 +20,27 @@ class ModelSerializerDecorator(UTZDecorator):
required_configs = ("auto_add_fields",)
__slots__ = ("serializer",)

def __init__(self, serializer: type[serializers.ModelSerializer]) -> None:
def __init__(self, serializer: DRFModelSerializer) -> None:
self.serializer = self.check_model_serializer(serializer)
super().__init__()


@property
def serializer_model(self) -> type[models.Model]:
def serializer_model(self) -> DjangoModel:
"""
Returns the serializer's model.
"""
return self.serializer.Meta.model


def __call__(self) -> type[models.Model]:
def __call__(self) -> DRFModelSerializer:
prepared_serializer = self.prepare_serializer()
if not issubclass(prepared_serializer, serializers.ModelSerializer):
raise TypeError("prepare_serializer method must return a model serializer")
return prepared_serializer


def check_model_serializer(self, serializer: type[serializers.ModelSerializer]) -> type[serializers.ModelSerializer]:
def check_model_serializer(self, serializer: DRFModelSerializer) -> DRFModelSerializer:
"""
Check if the model serializer is properly setup.
Expand Down Expand Up @@ -85,7 +87,7 @@ def validate_datetime_format(self, value: Any) -> None:
return None


def prepare_serializer(self) -> type[serializers.ModelSerializer]:
def prepare_serializer(self) -> DRFModelSerializer:
"""
Prepare the serializer for use. This where you can customize the serializer.
Expand Down Expand Up @@ -171,9 +173,10 @@ def get_config(self, attr: str, default: Any = None) -> Any | None:
return val



# Funtion-type decorator for `rest_framework.serializers.ModelSerializer` classes

def modelserializer(serializer: type[serializers.ModelSerializer]) -> type[serializers.ModelSerializer]:
def modelserializer(serializer: DRFModelSerializer) -> DRFModelSerializer:
"""
#### `django_utz` decorator for `reest_framework.serializers.ModelSerializer` classes.
Expand Down
9 changes: 6 additions & 3 deletions django_utz/decorators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import zoneinfo
except:
from backports import zoneinfo
from typing import Any, Callable
from typing import Any, Callable, TypeVar
import pytz
import datetime
from django.db import models
Expand Down Expand Up @@ -125,7 +125,10 @@ def is_date_field(model: type[models.Model], field_name: str) -> bool:



def transform_utz_decorator(decorator: type[UTZDecorator]) -> Callable[[type[object]], type[object]]:
Class = TypeVar("Class", bound=type[object])


def transform_utz_decorator(decorator: type[UTZDecorator]) -> Callable[[Class], Class]:
"""
Transforms class type utz decorator to a function type decorator.
Expand All @@ -137,7 +140,7 @@ def transform_utz_decorator(decorator: type[UTZDecorator]) -> Callable[[type[obj
and returns the modified class.
"""
@functools.wraps(decorator)
def decorator_wrapper(cls: type[object]) -> type[object]:
def decorator_wrapper(cls: Class) -> Class:
"""Wrapper function that applies the utz decorator to the decorated class."""
return decorator(cls)()

Expand Down

0 comments on commit 28271a3

Please sign in to comment.