Skip to content

Commit dadad06

Browse files
committed
feat: Added pynamodb 6 support
2 parents 376792d + b7692cf commit dadad06

File tree

10 files changed

+138
-127
lines changed

10 files changed

+138
-127
lines changed

src/pynamodb_utils/attributes.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
22
from enum import Enum
3-
from typing import Collection, FrozenSet, Union
3+
from typing import Collection, FrozenSet, Optional, Union
44

55
import six
66
from pynamodb.attributes import MapAttribute, NumberAttribute, UnicodeAttribute
7+
from pynamodb.constants import NUMBER
8+
9+
from pynamodb_utils.exceptions import EnumSerializationException
710

811

912
class DynamicMapAttribute(MapAttribute):
@@ -58,15 +61,17 @@ def __str__(self) -> str:
5861

5962

6063
class EnumNumberAttribute(NumberAttribute):
64+
attr_type = NUMBER
65+
6166
def __init__(
6267
self,
6368
enum: Enum,
64-
hash_key=False,
65-
range_key=False,
66-
null=None,
67-
default: Enum = None,
68-
default_for_new: Enum = None,
69-
attr_name=None,
69+
hash_key: bool = False,
70+
range_key: bool = False,
71+
null: Optional[bool] = None,
72+
default: Optional[Enum] = None,
73+
default_for_new: Optional[Enum] = None,
74+
attr_name: Optional[str] = None,
7075
):
7176
if isinstance(enum, Enum):
7277
raise ValueError("enum must be Enum class")
@@ -97,7 +102,7 @@ def serialize(self, value: Union[Enum, str]) -> str:
97102
f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}'
98103
)
99104
except TypeError as e:
100-
raise Exception(value, self.enum) from e
105+
raise EnumSerializationException(f"Error serializing {value} with enum {self.enum}") from e
101106

102107
def deserialize(self, value: str) -> str:
103108
return self.enum(int(value)).name
@@ -107,12 +112,12 @@ class EnumUnicodeAttribute(UnicodeAttribute):
107112
def __init__(
108113
self,
109114
enum: Enum,
110-
hash_key=False,
111-
range_key=False,
112-
null=None,
113-
default: Enum = None,
114-
default_for_new: Enum = None,
115-
attr_name=None,
115+
hash_key: bool = False,
116+
range_key: bool = False,
117+
null: Optional[bool] = None,
118+
default: Optional[Enum] = None,
119+
default_for_new: Optional[Enum] = None,
120+
attr_name: Optional[str] = None,
116121
):
117122
if isinstance(enum, Enum):
118123
raise ValueError("enum must be Enum class")
@@ -135,9 +140,8 @@ def __init__(
135140
def serialize(self, value: Union[Enum, str]) -> str:
136141
if isinstance(value, self.enum):
137142
return str(value.value)
138-
elif isinstance(value, str):
139-
if value in self.enum.__members__.keys():
140-
return getattr(self.enum, value).value
143+
elif isinstance(value, str) and value in self.enum.__members__.keys():
144+
return getattr(self.enum, value).value
141145
raise ValueError(
142146
f'Value Error: {value} must be in {", ".join([item for item in self.enum.__members__.keys()])}'
143147
)

src/pynamodb_utils/conditions.py

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import operator
22
from functools import reduce
3-
from typing import Any, Callable, Dict, List, Set
3+
from typing import Any, Callable, Dict, List, Optional
44

55
from pynamodb.attributes import Attribute
66
from pynamodb.expressions.condition import Condition
@@ -12,13 +12,30 @@
1212
from pynamodb_utils.utils import get_attribute, get_available_attributes_list
1313

1414

15+
def _is_available(field_path: str, available_attributes: List, raise_exception: bool):
16+
if "." in field_path:
17+
_field_path = field_path.split(".", 1)[0] + ".*"
18+
is_available = _field_path in available_attributes
19+
else:
20+
is_available = field_path in available_attributes
21+
if not is_available and raise_exception:
22+
raise FilterError(
23+
message={
24+
field_path: [
25+
f"Parameter {field_path} does not exist."
26+
f" Choose some of available: {', '.join(available_attributes)}"
27+
]
28+
}
29+
)
30+
31+
1532
def create_model_condition(
16-
model: Model,
17-
args: Dict[str, Any],
18-
_operator: Callable = operator.and_,
19-
raise_exception: bool = True,
20-
unavailable_attributes: List[str] = []
21-
) -> Condition:
33+
model: Model,
34+
args: Dict[str, Any],
35+
_operator: Callable = operator.and_,
36+
raise_exception: bool = True,
37+
unavailable_attributes: Optional[List[str]] = None
38+
) -> Optional[Condition]:
2239
"""
2340
Function creates pynamodb conditions based on input dictionary (args)
2441
Parameters:
@@ -31,52 +48,29 @@ def create_model_condition(
3148
condition (Condition): computed pynamodb condition
3249
"""
3350
conditions_list: List[Condition] = []
34-
35-
available_attributes: Set[str] = get_available_attributes_list(
51+
available_attributes: List[str] = get_available_attributes_list(
3652
model=model,
37-
unavaiable_attrs=unavailable_attributes
53+
unavailable_attrs=unavailable_attributes
3854
)
39-
40-
key: str
41-
value: Any
4255
for key, value in args.items():
43-
array: List[str] = key.rsplit('__', 1)
56+
array: List[str] = key.rsplit("__", 1)
4457
field_path: str = array[0]
45-
operator_name: str = array[1] if len(array) > 1 and array[1] != 'not' else ''
46-
47-
if "." in field_path:
48-
_field_path = field_path.split(".", 1)[0] + ".*"
49-
is_available = _field_path in available_attributes
50-
else:
51-
52-
is_available = field_path in available_attributes
53-
54-
if operator_name.replace('not_', '') not in OPERATORS_MAPPING:
55-
raise FilterError(
56-
message={key: [f'Operator {operator_name} does not exist.'
57-
f' Choose some of available: {", ".join(OPERATORS_MAPPING.keys())}']}
58-
)
59-
if not is_available and raise_exception:
58+
operator_name: str = array[1] if len(array) > 1 and array[1] != "not" else ""
59+
if operator_name.replace("not_", "") not in OPERATORS_MAPPING:
6060
raise FilterError(
61-
message={
62-
field_path: [
63-
f"Parameter {field_path} does not exist."
64-
f' Choose some of available: {", ".join(available_attributes)}'
65-
]
66-
}
61+
message={key: [f"Operator {operator_name} does not exist."
62+
f" Choose some of available: {', '.join(OPERATORS_MAPPING.keys())}"]}
6763
)
68-
64+
_is_available(field_path, available_attributes, raise_exception)
6965
attr: Attribute = get_attribute(model, field_path)
70-
7166
if isinstance(attr, (Attribute, Path)):
7267
if 'not_' in operator_name:
73-
operator_name = operator_name.replace('not_', '')
68+
operator_name = operator_name.replace("not_", "")
7469
operator_handler = OPERATORS_MAPPING[operator_name]
7570
conditions_list.append(~operator_handler(model, field_path, attr, value))
7671
else:
7772
operator_handler = OPERATORS_MAPPING[operator_name]
7873
conditions_list.append(operator_handler(model, field_path, attr, value))
79-
if not conditions_list:
80-
return None
81-
else:
74+
if conditions_list:
8275
return reduce(_operator, conditions_list)
76+
return None

src/pynamodb_utils/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,11 @@ class FilterError(Error):
1010

1111
class SerializerError(Error):
1212
pass
13+
14+
15+
class IndexNotFoundError(Exception):
16+
pass
17+
18+
19+
class EnumSerializationException(Exception):
20+
pass

src/pynamodb_utils/models.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import timezone
2-
from typing import Any, List
2+
from typing import Any, List, Optional
33

44
from pynamodb.attributes import UTCDateTimeAttribute
55
from pynamodb.expressions.condition import Condition
@@ -20,12 +20,14 @@ def get_conditions_from_json(cls, query: dict, raise_exception: bool = True) ->
2020
2121
Parameters:
2222
query (dict): A decimal integer
23+
raise_exception (bool): Throwing an exception in case of an error
2324
2425
Returns:
2526
condition (Condition): computed pynamodb condition
2627
"""
2728
query_unavailable_attributes: List[str] = getattr(cls.Meta, "query_unavailable_attributes", [])
28-
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query, raise_exception=raise_exception)
29+
return ConditionsSerializer(cls, query_unavailable_attributes).load(data=query,
30+
raise_exception=raise_exception)
2931

3032
@classmethod
3133
def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -> ResultIterator[Model]:
@@ -34,6 +36,7 @@ def make_index_query(cls, query: dict, raise_exception: bool = True, **kwargs) -
3436
3537
Parameters:
3638
query (dict): A decimal integer
39+
raise_exception (bool): Throwing an exception in case of an error
3740
3841
Returns:
3942
result_iterator (result_iterator): result iterator for optimized query
@@ -69,6 +72,9 @@ def _pop_path(obj: dict, path: str) -> Any:
6972
obj = obj[key]
7073

7174

75+
TZ_INFO = "TZINFO"
76+
77+
7278
class TimestampedModel(Model):
7379
created_at = UTCDateTimeAttribute(default=get_timestamp)
7480
updated_at = UTCDateTimeAttribute(default=get_timestamp)
@@ -77,17 +83,17 @@ class TimestampedModel(Model):
7783
class Meta:
7884
abstract = True
7985

80-
def save(self, condition=None):
81-
tz_info = getattr(self.Meta, "TZINFO", None)
86+
def save(self, condition: Optional[Condition] = None, *, add_version_condition: bool = True):
87+
tz_info = getattr(self.Meta, TZ_INFO, None)
8288
self.created_at = self.created_at.astimezone(tz=tz_info or timezone.utc)
83-
self.updated_at = get_timestamp(tzinfo=tz_info)
84-
super().save(condition=condition)
89+
self.updated_at = get_timestamp(tz=tz_info)
90+
super().save(condition=condition, add_version_condition=add_version_condition)
8591

8692
def save_without_timestamp_update(self, condition=None):
8793
super().save(condition=condition)
8894

8995
def soft_delete(self, condition=None):
9096
""" Puts delete_at timestamp """
91-
tz_info = getattr(self.Meta, "TZINFO", None)
97+
tz_info = getattr(self.Meta, TZ_INFO, None)
9298
self.deleted_at = get_timestamp(tz_info)
9399
super().save(condition=condition)

src/pynamodb_utils/parsers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def default_list_parser(value: List[Any], field_name: str, model: Model) -> List
5858
raise FilterError(message={field_name: [f"{value} is not valid type of {field_name}."]})
5959

6060

61-
def default_dict_parser(value: Dict, field_name: str, *args) -> Dict[Any, Any]:
61+
def default_dict_parser(value: Dict, field_name: str, *args) -> Union[Dict[Any, Any], str]:
6262
if isinstance(value, (dict, NoneType)):
6363
return value
6464
elif isinstance(value, str):
6565
try:
6666
return json.dumps(value, default=str)
67-
except (ValueError, json.JSONDecodeError):
67+
except ValueError:
6868
pass
6969
raise FilterError(
7070
message={field_name: [f"{value} is not valid type of {field_name}."]}

src/pynamodb_utils/utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from datetime import datetime, timezone
2-
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
from pynamodb.attributes import Attribute, MapAttribute
55
from pynamodb.indexes import GlobalSecondaryIndex, LocalSecondaryIndex
66
from pynamodb.models import Model
77

88
from pynamodb_utils.attributes import DynamicMapAttribute
9+
from pynamodb_utils.exceptions import IndexNotFoundError
910

1011
NoneType = type(None)
1112

@@ -31,7 +32,7 @@ def create_index_map(
3132
).get("AttributeName")
3233
idx_map[(hash_key, range_key)] = getattr(model, k)
3334
except StopIteration as e:
34-
raise Exception("Could not find index keys") from e
35+
raise IndexNotFoundError("Could not find index keys") from e
3536

3637
return idx_map
3738

@@ -53,12 +54,14 @@ def pick_index_keys(
5354
return keys
5455

5556

56-
def parse_attr(attr: Attribute) -> Union[Dict, datetime]:
57+
def parse_attr(attr: Attribute) -> Union[Attribute, Dict, List, datetime, str]:
5758
"""
5859
Function parses attribute to corresponding values
5960
"""
6061
if isinstance(attr, DynamicMapAttribute):
6162
return attr.as_dict()
63+
elif isinstance(attr, List):
64+
return [parse_attr(el) for el in attr]
6265
elif isinstance(attr, MapAttribute):
6366
return parse_attrs_to_dict(attr)
6467
elif isinstance(attr, datetime):
@@ -85,12 +88,12 @@ def get_attributes_list(model: Model, depth: int = 0) -> List[str]:
8588
return attrs
8689

8790

88-
def get_available_attributes_list(model: Model, unavaiable_attrs: List[str] = []) -> Set[str]:
91+
def get_available_attributes_list(model: Model, unavailable_attrs: Optional[List[str]] = None) -> List[str]:
8992
attrs: List[str] = get_attributes_list(model)
90-
return [attr for attr in attrs if attr not in unavaiable_attrs]
93+
return sorted(set(attr for attr in attrs if attr not in unavailable_attrs))
9194

9295

93-
def get_attribute(model: Model, attr_string: str) -> Attribute:
96+
def get_attribute(model: Model, attr_string: str) -> Optional[Attribute]:
9497
"""
9598
Function gets nested attribute based on path (attr_string)
9699
"""
@@ -106,5 +109,5 @@ def get_attribute(model: Model, attr_string: str) -> Attribute:
106109
return result
107110

108111

109-
def get_timestamp(tzinfo: timezone = None) -> datetime:
110-
return datetime.now(tzinfo or timezone.utc)
112+
def get_timestamp(tz: timezone = None) -> datetime:
113+
return datetime.now(tz or timezone.utc)

src/tests/conftest.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,30 @@
33
from datetime import timezone
44

55
import pytest
6-
from moto import mock_dynamodb
6+
from moto import mock_aws
77
from pynamodb.attributes import UnicodeAttribute, UTCDateTimeAttribute
88
from pynamodb.indexes import AllProjection, GlobalSecondaryIndex
99

1010
from pynamodb_utils import AsDictModel, DynamicMapAttribute, EnumAttribute, JSONQueryModel, TimestampedModel
1111

1212

13-
@pytest.fixture(scope="session")
13+
@pytest.fixture
1414
def aws_environ():
15-
vars = {
15+
env_vars = {
1616
"AWS_DEFAULT_REGION": "us-east-1"
1717
}
18-
for k, v in vars.items():
19-
os.environ[k] = v
20-
21-
yield
22-
23-
for k in vars:
24-
del os.environ[k]
18+
with mock_aws():
19+
for k, v in env_vars.items():
20+
os.environ[k] = v
2521

26-
27-
@pytest.fixture
28-
def dynamodb(aws_environ):
29-
with mock_dynamodb():
3022
yield
3123

24+
for k in env_vars:
25+
del os.environ[k]
26+
3227

3328
@pytest.fixture
34-
def post_table(dynamodb):
29+
def post_table(aws_environ):
3530
class CategoryEnum(enum.Enum):
3631
finance = enum.auto()
3732
politics = enum.auto()
@@ -49,7 +44,7 @@ class Post(AsDictModel, JSONQueryModel, TimestampedModel):
4944
sub_name = UnicodeAttribute(range_key=True)
5045
category = EnumAttribute(enum=CategoryEnum, default=CategoryEnum.finance)
5146
content = UnicodeAttribute()
52-
tags = DynamicMapAttribute(default={})
47+
tags = DynamicMapAttribute(default=None)
5348
category_created_at_gsi = PostCategoryCreatedAtGSI()
5449
secret_parameter = UnicodeAttribute(default="secret")
5550

0 commit comments

Comments
 (0)