diff --git a/app/solomon/infrastructure/database.py b/app/solomon/infrastructure/database.py index 207b21b..5a3fdd8 100644 --- a/app/solomon/infrastructure/database.py +++ b/app/solomon/infrastructure/database.py @@ -14,6 +14,34 @@ class CustomQuery(Query): + OPERATORS = { + "eq": lambda field, value: field == value, + "in": lambda field, value: field.in_(value), + "like": lambda field, value: field.like(value), + "ilike": lambda field, value: field.ilike(value), + "gt": lambda field, value: field > value, + "lt": lambda field, value: field < value, + "gte": lambda field, value: field >= value, + "lte": lambda field, value: field <= value, + } + + def apply_filters(self, model, filters): + for attribute, value in filters.items(): + if "__" in attribute: + field_name, operator_name = attribute.split("__") + operator = self.OPERATORS[operator_name] + if operator is None: + raise ValueError( + f"Invalid operator '{operator_name}' for field '{field_name}'" + ) + else: + raise ValueError(f"No operator specified for field '{attribute}'") + + field = getattr(model, field_name) + self = self.filter(operator(field, value)) + + return self + def paginate(self, params: Optional[AbstractParams]): return paginate(self, params) diff --git a/app/solomon/transactions/application/services.py b/app/solomon/transactions/application/services.py index e528cb4..57cedb1 100644 --- a/app/solomon/transactions/application/services.py +++ b/app/solomon/transactions/application/services.py @@ -23,6 +23,7 @@ PaginatedTransactionResponseMapper, Transaction, TransactionCreate, + TransactionFilters, TransactionResponseMapper, ) @@ -217,7 +218,10 @@ def get_transaction( return TransactionResponseMapper.create(transaction=transaction) def get_transactions( - self, user_id: str, params: Params = None + self, + user_id: str, + pagination_params: Params = None, + filters: TransactionFilters = None, ) -> PaginatedTransactionResponseMapper: """ Retrieve all transactions. @@ -232,8 +236,9 @@ def get_transactions( List[Transaction] A list of all transactions. """ + filters_dict = filters.model_dump(exclude_none=True) paginated_transaction = self.transaction_repository.get_all( - user_id=user_id, params=params + user_id=user_id, pagination_params=pagination_params, filters=filters_dict ) return PaginatedTransactionResponseMapper.create( diff --git a/app/solomon/transactions/domain/models.py b/app/solomon/transactions/domain/models.py index 70e4942..837d099 100644 --- a/app/solomon/transactions/domain/models.py +++ b/app/solomon/transactions/domain/models.py @@ -45,7 +45,9 @@ class Transaction(BaseModel): recurring_day = Column(Integer, nullable=True) kind = Column(String(20), nullable=False) - installments = relationship("Installment", back_populates="transaction") + installments = relationship( + "Installment", back_populates="transaction", lazy="noload" + ) user = relationship("User", back_populates="transactions") credit_card = relationship("CreditCard", back_populates="transactions") category = relationship("Category", back_populates="transactions") diff --git a/app/solomon/transactions/infrastructure/repositories.py b/app/solomon/transactions/infrastructure/repositories.py index c85e4db..a62a339 100644 --- a/app/solomon/transactions/infrastructure/repositories.py +++ b/app/solomon/transactions/infrastructure/repositories.py @@ -1,6 +1,7 @@ from typing import List from fastapi_pagination import Params +from sqlalchemy.orm import joinedload from app.solomon.common.models import PaginatedResponse from app.solomon.transactions.domain.models import ( @@ -32,9 +33,9 @@ class CreditCardRepository: def __init__(self, session): self.session = session - def get_all(self, user_id: str) -> List[CreditCard]: + def get_all(self, user_id: str, **kwargs: dict) -> List[CreditCard]: """Get all Credit Cards.""" - return self.session.query(CreditCard).filter_by(user_id=user_id).all() + return self.session.query(CreditCard).filter_by(user_id=user_id, **kwargs).all() def get_by_id(self, id: str, user_id: str) -> CreditCard | None: """Get a Credit Card by id.""" @@ -83,18 +84,27 @@ def rollback(self): """Rollback the current transaction.""" self.session.rollback() - def get_all(self, user_id: str, params: Params = None) -> PaginatedResponse: + def get_all( + self, + user_id: str, + pagination_params: Params, + filters: dict, + ) -> PaginatedResponse: """Get all transactions based on specified filters.""" - return ( - self.session.query(Transaction) - .filter(Transaction.user_id == user_id) - .paginate(params) + transactions_query = self.session.query(Transaction).filter( + Transaction.user_id == user_id ) + if filters: + transactions_query = transactions_query.apply_filters(Transaction, filters) + + return transactions_query.paginate(pagination_params) + def get_by_id(self, transaction_id: str, user_id: str) -> Transaction | None: """Get a Transaction by id.""" return ( self.session.query(Transaction) + .options(joinedload(Transaction.installments)) .filter(Transaction.id == transaction_id, Transaction.user_id == user_id) .first() ) @@ -114,4 +124,9 @@ def create_with_installments( self.session.add(transaction) self.session.commit() - return transaction + + return ( + self.session.query(Transaction) + .options(joinedload(Transaction.installments)) + .get(transaction.id) + ) diff --git a/app/solomon/transactions/presentation/models.py b/app/solomon/transactions/presentation/models.py index 0979333..37504d5 100644 --- a/app/solomon/transactions/presentation/models.py +++ b/app/solomon/transactions/presentation/models.py @@ -196,3 +196,12 @@ def create( data=[TransactionMapper.create(transaction) for transaction in items], meta=PaginationMeta(page=page, pages=pages, size=size, total=total), ) + + +class TransactionFilters(BaseModel): + date__gt: Optional[datetime.date] = None + date__lt: Optional[datetime.date] = None + category_id__eq: Optional[str] = None + kind__eq: Optional[str] = None + is_fixed__eq: Optional[bool] = None + is_revenue__eq: Optional[bool] = None diff --git a/app/solomon/transactions/presentation/transactions_resources.py b/app/solomon/transactions/presentation/transactions_resources.py index 809dc42..a47f46f 100644 --- a/app/solomon/transactions/presentation/transactions_resources.py +++ b/app/solomon/transactions/presentation/transactions_resources.py @@ -14,6 +14,7 @@ from app.solomon.transactions.presentation.models import ( PaginatedTransactionResponseMapper, TransactionCreate, + TransactionFilters, TransactionResponseMapper, ) @@ -95,6 +96,7 @@ async def get_transactions( transaction_service: TransactionService = Depends(get_transaction_service), current_user: UserTokenAuthenticated = Depends(get_current_user), pagination: Params = Depends(), + filters: TransactionFilters = Depends(), ) -> PaginatedTransactionResponseMapper: """ Retrieve all transactions. @@ -115,4 +117,4 @@ async def get_transactions( List[Transaction] The retrieved transactions. """ - return transaction_service.get_transactions(current_user.id, pagination) + return transaction_service.get_transactions(current_user.id, pagination, filters) diff --git a/app/tests/solomon/conftest.py b/app/tests/solomon/conftest.py index ac7a17e..c7ea4d6 100644 --- a/app/tests/solomon/conftest.py +++ b/app/tests/solomon/conftest.py @@ -13,6 +13,7 @@ from app.solomon.users.domain.models import User from app.tests.solomon.factories.category_factory import CategoryFactory from app.tests.solomon.factories.credit_card_factory import CreditCardFactory +from app.tests.solomon.factories.installment_factory import InstallmentFactory from app.tests.solomon.factories.transaction_factory import ( TransactionCreateFactory, TransactionFactory, @@ -102,3 +103,4 @@ def current_user(current_user_token) -> User: register(CategoryFactory) register(TransactionCreateFactory) register(TransactionFactory) +register(InstallmentFactory) diff --git a/app/tests/solomon/factories/installment_factory.py b/app/tests/solomon/factories/installment_factory.py index a30e4a8..7aa1f46 100644 --- a/app/tests/solomon/factories/installment_factory.py +++ b/app/tests/solomon/factories/installment_factory.py @@ -1,5 +1,6 @@ import factory +from app.solomon.transactions.domain.models import Installment from app.solomon.transactions.presentation.models import InstallmentCreate from app.tests.solomon.factories.base_factory import BaseFactory @@ -19,3 +20,12 @@ class Meta: model = InstallmentCreate installment_number = factory.LazyAttribute(lambda x: next(incrementing_numbers)) + + +class InstallmentFactory(BaseFactory): + class Meta: + model = Installment + + date = factory.Faker("date") + amount = factory.Faker("pydecimal", left_digits=4, right_digits=2, positive=True) + installment_number = factory.LazyAttribute(lambda x: next(incrementing_numbers)) diff --git a/app/tests/solomon/transactions/application/test_transactions_services.py b/app/tests/solomon/transactions/application/test_transactions_services.py index 4c1a2ee..fdda953 100644 --- a/app/tests/solomon/transactions/application/test_transactions_services.py +++ b/app/tests/solomon/transactions/application/test_transactions_services.py @@ -14,6 +14,7 @@ from app.solomon.transactions.domain.options import Kinds from app.solomon.transactions.presentation.models import ( PaginatedTransactionResponseMapper, + TransactionFilters, ) from app.tests.solomon.factories.credit_card_factory import CreditCardFactory from app.tests.solomon.factories.installment_factory import InstallmentCreateFactory @@ -356,6 +357,7 @@ def test_get_invalid_transaction(self, transaction_service, mock_repository): def test_get_transactions(self, transaction_service, mock_repository): user_id = "123" pagination_params = Params(page=1, size=5) + filters = TransactionFilters() mock_transactions = [ TransactionFactory.build(id=str(uuid4), category_id=str(uuid4)), TransactionFactory.build(id=str(uuid4), category_id=str(uuid4)), @@ -374,9 +376,53 @@ def test_get_transactions(self, transaction_service, mock_repository): total=3, ) - result = transaction_service.get_transactions(user_id, pagination_params) + result = transaction_service.get_transactions( + user_id, pagination_params, filters + ) + + assert result == expected_result + mock_repository.get_all.assert_called_once_with( + user_id=user_id, pagination_params=pagination_params, filters={} + ) + + def test_get_transactions_with_filters(self, transaction_service, mock_repository): + user_id = "123" + pagination_params = Params(page=1, size=5) + category_id = str(uuid4()) + filters = TransactionFilters( + is_fixed__eq=False, + kind__eq=Kinds.CREDIT.value, + category_id__eq=category_id, + ) + mock_transactions = [ + TransactionFactory.build(id=str(uuid4), category_id=str(uuid4)), + TransactionFactory.build(id=str(uuid4), category_id=str(uuid4)), + TransactionFactory.build(id=str(uuid4), category_id=str(uuid4)), + ] + + mock_repository.get_all.return_value = PaginatedResponse( + items=mock_transactions, page=1, pages=1, size=5, total=3 + ) + + expected_result = PaginatedTransactionResponseMapper.create( + items=mock_transactions, + page=1, + pages=1, + size=5, + total=3, + ) + + result = transaction_service.get_transactions( + user_id, pagination_params, filters + ) assert result == expected_result mock_repository.get_all.assert_called_once_with( - user_id=user_id, params=pagination_params + user_id=user_id, + pagination_params=pagination_params, + filters={ + "is_fixed__eq": False, + "kind__eq": Kinds.CREDIT.value, + "category_id__eq": category_id, + }, ) diff --git a/app/tests/solomon/transactions/presentation/test_transactions_resources.py b/app/tests/solomon/transactions/presentation/test_transactions_resources.py index b6895f3..8014622 100644 --- a/app/tests/solomon/transactions/presentation/test_transactions_resources.py +++ b/app/tests/solomon/transactions/presentation/test_transactions_resources.py @@ -1,10 +1,12 @@ import datetime from unittest.mock import patch +from urllib.parse import urlencode from uuid import uuid4 from fastapi.encoders import jsonable_encoder from fastapi_sqlalchemy import db +from app.solomon.transactions.domain.models import Transaction from app.solomon.transactions.domain.options import Kinds @@ -101,12 +103,18 @@ def test_create_invalid_credit_card_variable_transaction( ) assert response.status_code == 500 - def test_get_transaction(self, auth_client, current_user, transaction_factory): + def test_get_transaction( + self, auth_client, current_user, transaction_factory, installment_factory + ): with db(): - transaction = transaction_factory.create(user=current_user) + transaction = transaction_factory.create( + user=current_user, kind=Kinds.CREDIT.value, amount=300.00 + ) + installment_factory.create_batch(3, transaction=transaction) response = auth_client.get(f"/transactions/{transaction.id}/") result = response.json()["data"] + installments = result["installments"] assert response.status_code == 200 assert result["id"] == transaction.id @@ -114,7 +122,7 @@ def test_get_transaction(self, auth_client, current_user, transaction_factory): assert result["kind"] == transaction.kind assert result["is_fixed"] == transaction.is_fixed assert result["amount"] == transaction.amount - assert result["installments"] == [] + assert len(installments) == 3 def test_get_invalid_transaction(self, auth_client): with db(): @@ -152,3 +160,59 @@ def test_get_transactions_with_pagination( assert meta["total"] == 15 assert len(data) == 5 assert isinstance(data, list) + + def test_get_transactions_with_filters( + self, auth_client, category_factory, transaction_factory, current_user + ): + with db(): + food_category = category_factory.create(description="Food") + home_category = category_factory.create(description="Home") + + transaction_factory.create_batch( + 45, + user=current_user, + is_fixed=False, + kind=Kinds.CREDIT.value, + category=food_category, + date=datetime.date(2023, 5, 1), + is_revenue=False, + ) + transaction_factory.create_batch( + 23, + user=current_user, + is_fixed=True, + kind=Kinds.DEBIT.value, + category=home_category, + date=datetime.date(2023, 8, 15), + is_revenue=False, + ) + base_url = "/transactions/" + params = { + "kind__eq": "debit", + "is_fixed__eq": "true", + "category_id__eq": home_category.id, + "date__gt": "2023-08-01", + "date__lt": "2023-08-30", + "is_revenue__eq": "false", + } + url = f"{base_url}?{urlencode(params)}" + total_transactions = db.session.query(Transaction).count() + + response = auth_client.get(url) + + data = response.json()["data"] + + assert response.status_code == 200 + assert total_transactions == 68 + assert len(data) == 23 + for item in data: + assert item["kind"] == "debit" + assert item["is_fixed"] is True + assert item["category_id"] == home_category.id + assert datetime.date.fromisoformat(item["date"]) > datetime.date( + 2023, 8, 1 + ) + assert datetime.date.fromisoformat(item["date"]) < datetime.date( + 2023, 8, 30 + ) + assert item["is_revenue"] is False