Skip to content

Commit c35b7e4

Browse files
VladIftimeFlix6x
andauthored
Fixed unit test gen (#62)
* Fixed unit test gen to generate only if the tests do not already exists * fix: remove leftover breakpoint Signed-off-by: F.N. Claessen <felix@seita.nl> * refactor: simplify if statements Signed-off-by: F.N. Claessen <felix@seita.nl> --------- Signed-off-by: F.N. Claessen <felix@seita.nl> Co-authored-by: F.N. Claessen <felix@seita.nl>
1 parent ad6f255 commit c35b7e4

File tree

1 file changed

+25
-12
lines changed

1 file changed

+25
-12
lines changed

development_utilities/gen_unit_test_template.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import datetime
22
import json
3+
import os
34
from enum import Enum
45
import inspect
56
import pprint
@@ -17,6 +18,7 @@
1718
import uuid
1819

1920
import pydantic
21+
from pydantic.types import AwareDatetime
2022

2123
from s2python import frbc
2224
from s2python.common import Duration, PowerRange, NumberRange
@@ -64,7 +66,7 @@ def get_list_arg(field_type):
6466

6567

6668
def is_enum(field_type):
67-
return issubclass(field_type, Enum)
69+
return inspect.isclass(field_type) and issubclass(field_type, Enum)
6870

6971

7072
def snake_case(camelcased: str) -> str:
@@ -111,17 +113,16 @@ def generate_json_test_data_for_field(field_type: Type):
111113
value = bool(random.randint(0, 1))
112114
elif field_type is float:
113115
value = random.random() * 9000.0
114-
elif field_type is datetime.datetime:
116+
elif field_type in (AwareDatetime, datetime.datetime):
117+
# Generate a timezone-aware datetime
115118
value = datetime.datetime(
116119
year=random.randint(2020, 2023),
117120
month=random.randint(1, 12),
118121
day=random.randint(1, 28),
119122
hour=random.randint(0, 23),
120123
minute=random.randint(0, 59),
121124
second=random.randint(0, 59),
122-
tzinfo=datetime.timezone(
123-
offset=datetime.timedelta(hours=random.randint(0, 2))
124-
),
125+
tzinfo=datetime.timezone(datetime.timedelta(hours=random.randint(-12, 14))),
125126
)
126127
elif field_type is uuid.UUID:
127128
value = uuid.uuid4()
@@ -167,12 +168,19 @@ def dump_test_data_as_constructor_field_for(test_data, field_type: Type) -> str:
167168
value = str(test_data)
168169
elif field_type is float:
169170
value = str(test_data)
170-
elif field_type is datetime.datetime:
171+
elif field_type is AwareDatetime or field_type is datetime.datetime:
171172
test_data: datetime.datetime
172173
offset: datetime.timedelta = test_data.tzinfo.utcoffset(None)
173-
value = f"datetime(year={test_data.year}, month={test_data.month}, day={test_data.day}, hour={test_data.hour}, minute={test_data.minute}, second={test_data.second}, tzinfo=offset(offset=timedelta(seconds={offset.total_seconds()})))"
174+
value = (
175+
f"datetime("
176+
f"year={test_data.year}, month={test_data.month}, day={test_data.day}, "
177+
f"hour={test_data.hour}, minute={test_data.minute}, second={test_data.second}, "
178+
f"tzinfo=offset(offset=timedelta(seconds={offset.total_seconds()})))"
179+
)
174180
elif field_type is uuid.UUID:
175181
value = f'uuid.UUID("{test_data}")'
182+
elif type(field_type).__name__ == "_LiteralGenericAlias":
183+
value = field_type.__args__[0]
176184
else:
177185
raise RuntimeError(
178186
f"Please implement dump test data for field type {field_type}"
@@ -217,11 +225,13 @@ def dump_test_data_as_json_field_for(test_data, field_type: Type):
217225
value = test_data
218226
elif field_type is float:
219227
value = test_data
220-
elif field_type is datetime.datetime:
228+
elif field_type in (AwareDatetime, datetime.datetime):
221229
test_data: datetime.datetime
222230
value = test_data.isoformat()
223231
elif field_type is uuid.UUID:
224232
value = str(test_data)
233+
elif type(field_type).__name__ == "_LiteralGenericAlias":
234+
value = test_data
225235
else:
226236
raise RuntimeError(
227237
f"Please implement dump test data to json for field type {field_type}"
@@ -294,7 +304,10 @@ def test__to_json__happy_path_full(self):
294304
print()
295305
print()
296306

297-
with open(
298-
f"tests/unit/frbc/{snake_case(class_name)}_test.py", "w+"
299-
) as unit_test_file:
300-
unit_test_file.write(template)
307+
# Check if the file already exists
308+
if not os.path.exists(f"tests/unit/frbc/{snake_case(class_name)}_test.py"):
309+
with open(
310+
f"tests/unit/frbc/{snake_case(class_name)}_test.py", "w+"
311+
) as unit_test_file:
312+
unit_test_file.write(template)
313+
print(f"Created tests/unit/frbc/{snake_case(class_name)}_test.py")

0 commit comments

Comments
 (0)