diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..993b4237 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.8.0 # Use the latest stable version of Black + hooks: + - id: black + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: [--profile, black] + - repo: https://github.com/hhatto/autopep8 + rev: v2.3.1 # Use the latest stable version of autopep8 + hooks: + - id: autopep8 diff --git a/backend/app/auth/__init__.py b/backend/app/auth/__init__.py index fada7103..6f00fa6e 100644 --- a/backend/app/auth/__init__.py +++ b/backend/app/auth/__init__.py @@ -1,7 +1,13 @@ # Auth module from .routes import router -from .service import auth_service -from .security import verify_token, create_access_token from .schemas import UserResponse +from .security import create_access_token, verify_token +from .service import auth_service -__all__ = ["router", "auth_service", "verify_token", "create_access_token", "UserResponse"] +__all__ = [ + "router", + "auth_service", + "verify_token", + "create_access_token", + "UserResponse", +] diff --git a/backend/app/auth/routes.py b/backend/app/auth/routes.py index 8dbb7ee7..433c0a04 100644 --- a/backend/app/auth/routes.py +++ b/backend/app/auth/routes.py @@ -1,19 +1,33 @@ -from fastapi import APIRouter, HTTPException, status, Depends +from datetime import timedelta + from app.auth.schemas import ( - EmailSignupRequest, EmailLoginRequest, GoogleLoginRequest, - RefreshTokenRequest, PasswordResetRequest, PasswordResetConfirm, - TokenVerifyRequest, AuthResponse, TokenResponse, SuccessResponse, - UserResponse, ErrorResponse + AuthResponse, + EmailLoginRequest, + EmailSignupRequest, + ErrorResponse, + GoogleLoginRequest, + PasswordResetConfirm, + PasswordResetRequest, + RefreshTokenRequest, + SuccessResponse, + TokenResponse, + TokenVerifyRequest, + UserResponse, ) +from app.auth.security import create_access_token, oauth2_scheme # Import oauth2_scheme from app.auth.service import auth_service -from app.auth.security import create_access_token, oauth2_scheme # Import oauth2_scheme -from fastapi.security import OAuth2PasswordRequestForm # Import OAuth2PasswordRequestForm -from datetime import timedelta from app.config import settings +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import ( # Import OAuth2PasswordRequestForm + OAuth2PasswordRequestForm, +) router = APIRouter(prefix="/auth", tags=["Authentication"]) -@router.post("/token", response_model=TokenResponse, include_in_schema=False) # include_in_schema=False to hide from docs if desired, or True to show + +@router.post( + "/token", response_model=TokenResponse, include_in_schema=False +) # include_in_schema=False to hide from docs if desired, or True to show async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): """ OAuth2 compatible token login, get an access token for future requests. @@ -24,13 +38,14 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( # Note: OAuth2PasswordRequestForm uses 'username' field for the user identifier. # We'll treat it as email here. result = await auth_service.authenticate_user_with_email( - email=form_data.username, # form_data.username is the email - password=form_data.password + email=form_data.username, # form_data.username is the email + password=form_data.password, ) access_token = create_access_token( data={"sub": str(result["user"]["_id"])}, - expires_delta=timedelta(minutes=settings.access_token_expire_minutes) + expires_delta=timedelta( + minutes=settings.access_token_expire_minutes), ) return TokenResponse(access_token=access_token, token_type="bearer") @@ -40,233 +55,236 @@ async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends( # It's good practice to log the exception here raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Authentication failed: {str(e)}" + detail=f"Authentication failed: {str(e)}", ) + @router.post("/signup/email", response_model=AuthResponse) async def signup_with_email(request: EmailSignupRequest): """ Registers a new user using email, password, and name, and returns authentication tokens and user information. - + Args: request: Contains the user's email, password, and name for registration. - + Returns: An AuthResponse with access token, refresh token, and user details. - + Raises: HTTPException: If registration fails or an unexpected error occurs. """ try: result = await auth_service.create_user_with_email( - email=request.email, - password=request.password, - name=request.name + email=request.email, password=request.password, name=request.name ) - + # Create access token access_token = create_access_token( data={"sub": str(result["user"]["_id"])}, - expires_delta=timedelta(minutes=settings.access_token_expire_minutes) + expires_delta=timedelta( + minutes=settings.access_token_expire_minutes), ) - + # Convert ObjectId to string for response result["user"]["_id"] = str(result["user"]["_id"]) - + return AuthResponse( access_token=access_token, refresh_token=result["refresh_token"], - user=UserResponse(**result["user"]) + user=UserResponse(**result["user"]), ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Registration failed: {str(e)}" + detail=f"Registration failed: {str(e)}", ) + @router.post("/login/email", response_model=AuthResponse) async def login_with_email(request: EmailLoginRequest): """ Authenticates a user using email and password credentials. - + On successful authentication, returns an access token, refresh token, and user information. Raises an HTTP 500 error if authentication fails due to an unexpected error. """ try: result = await auth_service.authenticate_user_with_email( - email=request.email, - password=request.password + email=request.email, password=request.password ) - + # Create access token access_token = create_access_token( data={"sub": str(result["user"]["_id"])}, - expires_delta=timedelta(minutes=settings.access_token_expire_minutes) + expires_delta=timedelta( + minutes=settings.access_token_expire_minutes), ) - + # Convert ObjectId to string for response result["user"]["_id"] = str(result["user"]["_id"]) - + return AuthResponse( access_token=access_token, refresh_token=result["refresh_token"], - user=UserResponse(**result["user"]) + user=UserResponse(**result["user"]), ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Login failed: {str(e)}" + detail=f"Login failed: {str(e)}", ) + @router.post("/login/google", response_model=AuthResponse) async def login_with_google(request: GoogleLoginRequest): """ Authenticates or registers a user using a Google OAuth ID token. - + On success, returns an access token, refresh token, and user information. Raises an HTTP 500 error if Google authentication fails. """ try: result = await auth_service.authenticate_with_google(request.id_token) - + # Create access token access_token = create_access_token( data={"sub": str(result["user"]["_id"])}, - expires_delta=timedelta(minutes=settings.access_token_expire_minutes) + expires_delta=timedelta( + minutes=settings.access_token_expire_minutes), ) - + # Convert ObjectId to string for response result["user"]["_id"] = str(result["user"]["_id"]) - + return AuthResponse( access_token=access_token, refresh_token=result["refresh_token"], - user=UserResponse(**result["user"]) + user=UserResponse(**result["user"]), ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Google authentication failed: {str(e)}" + detail=f"Google authentication failed: {str(e)}", ) + @router.post("/refresh", response_model=TokenResponse) async def refresh_token(request: RefreshTokenRequest): """ Refreshes JWT tokens using a valid refresh token. - + Validates the provided refresh token, issues a new access token and refresh token if valid, and returns them. Raises a 401 error if the refresh token is invalid or revoked. - + Returns: - A TokenResponse containing the new access and refresh tokens. + A TokenResponse containing the new access and refresh tokens. """ try: - new_refresh_token = await auth_service.refresh_access_token(request.refresh_token) - + new_refresh_token = await auth_service.refresh_access_token( + request.refresh_token + ) + # Get user from the new refresh token to create access token from app.database import get_database + db = get_database() - token_record = await db.refresh_tokens.find_one({ - "token": new_refresh_token, - "revoked": False - }) - + token_record = await db.refresh_tokens.find_one( + {"token": new_refresh_token, "revoked": False} + ) + if not token_record: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Failed to create new tokens" + detail="Failed to create new tokens", ) - # Create new access token + # Create new access token access_token = create_access_token( data={"sub": str(token_record["user_id"])}, - expires_delta=timedelta(minutes=settings.access_token_expire_minutes) - ) - - return TokenResponse( - access_token=access_token, - refresh_token=new_refresh_token + expires_delta=timedelta( + minutes=settings.access_token_expire_minutes), ) + + return TokenResponse(access_token=access_token, refresh_token=new_refresh_token) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Token refresh failed: {str(e)}" + detail=f"Token refresh failed: {str(e)}", ) + @router.post("/token/verify", response_model=UserResponse) async def verify_token(request: TokenVerifyRequest): """ Verifies an access token and returns the associated user information. - + Raises: HTTPException: If the token is invalid or expired, returns a 401 Unauthorized error. """ try: user = await auth_service.verify_access_token(request.access_token) - + # Convert ObjectId to string for response user["_id"] = str(user["_id"]) - + return UserResponse(**user) except HTTPException: raise except Exception as e: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token" ) + @router.post("/password/reset/request", response_model=SuccessResponse) async def request_password_reset(request: PasswordResetRequest): """ Initiates a password reset process by sending a reset link to the provided email address. - + Returns: SuccessResponse: Indicates whether the password reset email was sent if the email exists. """ try: await auth_service.request_password_reset(request.email) return SuccessResponse( - success=True, - message="If the email exists, a reset link has been sent" + success=True, message="If the email exists, a reset link has been sent" ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Password reset request failed: {str(e)}" + detail=f"Password reset request failed: {str(e)}", ) + @router.post("/password/reset/confirm", response_model=SuccessResponse) async def confirm_password_reset(request: PasswordResetConfirm): """ Resets a user's password using a valid password reset token. - + Args: request: Contains the password reset token and the new password. - + Returns: SuccessResponse indicating the password has been reset successfully. - + Raises: HTTPException: If the reset token is invalid or an error occurs during the reset process. """ try: await auth_service.confirm_password_reset( - reset_token=request.reset_token, - new_password=request.new_password + reset_token=request.reset_token, new_password=request.new_password ) return SuccessResponse( - success=True, - message="Password has been reset successfully" + success=True, message="Password has been reset successfully" ) except HTTPException: raise except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Password reset failed: {str(e)}" + detail=f"Password reset failed: {str(e)}", ) diff --git a/backend/app/auth/schemas.py b/backend/app/auth/schemas.py index b2e8e595..3c76dc3e 100644 --- a/backend/app/auth/schemas.py +++ b/backend/app/auth/schemas.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel, EmailStr, Field, ConfigDict -from typing import Optional from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, ConfigDict, EmailStr, Field + # Request Models class EmailSignupRequest(BaseModel): @@ -8,26 +10,33 @@ class EmailSignupRequest(BaseModel): password: str = Field(..., min_length=6) name: str = Field(..., min_length=1) + class EmailLoginRequest(BaseModel): email: EmailStr password: str + class GoogleLoginRequest(BaseModel): id_token: str + class RefreshTokenRequest(BaseModel): refresh_token: str + class PasswordResetRequest(BaseModel): email: EmailStr + class PasswordResetConfirm(BaseModel): reset_token: str new_password: str = Field(..., min_length=6) + class TokenVerifyRequest(BaseModel): access_token: str + # Response Models class UserResponse(BaseModel): id: str = Field(alias="_id") @@ -39,18 +48,22 @@ class UserResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) + class AuthResponse(BaseModel): access_token: str refresh_token: str user: UserResponse + class TokenResponse(BaseModel): access_token: str refresh_token: Optional[str] = None + class SuccessResponse(BaseModel): success: bool = True message: Optional[str] = None + class ErrorResponse(BaseModel): error: str diff --git a/backend/app/auth/security.py b/backend/app/auth/security.py index 53f1475e..85bfb749f 100644 --- a/backend/app/auth/security.py +++ b/backend/app/auth/security.py @@ -1,56 +1,64 @@ +import secrets from datetime import datetime, timedelta, timezone -from typing import Optional, Dict, Any +from typing import Any, Dict, Optional + +from app.config import settings +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt from passlib.context import CryptContext -from fastapi import HTTPException, status, Depends -from fastapi.security import OAuth2PasswordBearer -from app.config import settings -import secrets # Password hashing with better bcrypt configuration try: pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") except Exception: # Fallback for bcrypt version compatibility issues - pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=12) + pwd_context = CryptContext( + schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=12) + +oauth2_scheme = OAuth2PasswordBearer( + tokenUrl="/auth/token") # Updated tokenUrl -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token") # Updated tokenUrl def verify_password(plain_password: str, hashed_password: str) -> bool: """ Verifies whether a plaintext password matches a given hashed password. - + Args: plain_password: The plaintext password to verify. hashed_password: The hashed password to compare against. - + Returns: True if the plaintext password matches the hash, otherwise False. """ return pwd_context.verify(plain_password, hashed_password) + def get_password_hash(password: str) -> str: """ Hashes a plaintext password using bcrypt. - + Args: password: The plaintext password to hash. - + Returns: The bcrypt-hashed password as a string. """ return pwd_context.hash(password) -def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str: + +def create_access_token( + data: Dict[str, Any], expires_delta: Optional[timedelta] = None +) -> str: """ Creates a JWT access token embedding the provided data and an expiration time. - + If `expires_delta` is not specified, the token expires after the default duration from settings. The payload includes an expiration timestamp and a type field set to "access". The token is signed using the configured secret key and algorithm. - + Args: data: The payload to include in the token. expires_delta: Optional timedelta specifying how long the token is valid. - + Returns: A signed JWT access token as a string. """ @@ -58,30 +66,38 @@ def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: - expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes) - + expire = datetime.now(timezone.utc) + timedelta( + minutes=settings.access_token_expire_minutes + ) + to_encode.update({"exp": expire, "type": "access"}) - encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm) + encoded_jwt = jwt.encode( + to_encode, settings.secret_key, algorithm=settings.algorithm + ) return encoded_jwt + def create_refresh_token() -> str: """ Generates a secure random refresh token as a URL-safe string. - + Returns: A cryptographically secure, URL-safe refresh token string. """ return secrets.token_urlsafe(32) + def verify_token(token: str) -> Dict[str, Any]: """ Verifies and decodes a JWT token. - + If the token is invalid or cannot be verified, raises an HTTP 401 Unauthorized exception. Returns the decoded token payload as a dictionary. """ try: - payload = jwt.decode(token, settings.secret_key, algorithms=[settings.algorithm]) + payload = jwt.decode( + token, settings.secret_key, algorithms=[settings.algorithm] + ) return payload except JWTError: raise HTTPException( @@ -90,30 +106,35 @@ def verify_token(token: str) -> Dict[str, Any]: headers={"WWW-Authenticate": "Bearer"}, ) + def generate_reset_token() -> str: """ Generates a secure, URL-safe token for password reset operations. - + Returns: A random 32-byte URL-safe string suitable for use as a password reset token. """ return secrets.token_urlsafe(32) + def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]: """ Retrieves the current user based on the provided JWT token using centralized verification. - + Args: token: The JWT token from which to extract the user information. - + Returns: A dictionary containing the current user's information. - + Raises: HTTPException: If the token is invalid or user information cannot be extracted. """ - payload = verify_token(token) # Centralized JWT validation and error handling + payload = verify_token( + token) # Centralized JWT validation and error handling user_id = payload.get("sub") if user_id is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload" + ) return {"_id": user_id} diff --git a/backend/app/auth/service.py b/backend/app/auth/service.py index 9d3fe558..fcd46993 100644 --- a/backend/app/auth/service.py +++ b/backend/app/auth/service.py @@ -1,93 +1,115 @@ +import json +import os from datetime import datetime, timedelta, timezone -from typing import Optional, Dict, Any -from pymongo.errors import DuplicateKeyError +from typing import Any, Dict, Optional + +import firebase_admin +from app.auth.schemas import UserResponse +from app.auth.security import ( + create_refresh_token, + generate_reset_token, + get_password_hash, + verify_password, +) +from app.config import logger, settings +from app.database import get_database from bson import ObjectId from fastapi import HTTPException, status -from app.database import get_database -from app.auth.security import get_password_hash, verify_password, create_refresh_token, generate_reset_token -from app.auth.schemas import UserResponse -import firebase_admin -from firebase_admin import auth as firebase_auth, credentials -from app.config import settings,logger -import os -import json +from firebase_admin import auth as firebase_auth +from firebase_admin import credentials +from pymongo.errors import DuplicateKeyError # Initialize Firebase Admin SDK if not firebase_admin._apps: # First, check if we have credentials in environment variables - if all([ - settings.firebase_type, - settings.firebase_project_id, - settings.firebase_private_key_id, - settings.firebase_private_key, - settings.firebase_client_email - ]): + if all( + [ + settings.firebase_type, + settings.firebase_project_id, + settings.firebase_private_key_id, + settings.firebase_private_key, + settings.firebase_client_email, + ] + ): # Create a credential dictionary from environment variables cred_dict = { "type": settings.firebase_type, "project_id": settings.firebase_project_id, "private_key_id": settings.firebase_private_key_id, - "private_key": settings.firebase_private_key.replace("\\n", "\n"), # Replace escaped newlines + "private_key": settings.firebase_private_key.replace( + "\\n", "\n" + ), # Replace escaped newlines "client_email": settings.firebase_client_email, "client_id": settings.firebase_client_id, "auth_uri": settings.firebase_auth_uri, "token_uri": settings.firebase_token_uri, "auth_provider_x509_cert_url": settings.firebase_auth_provider_x509_cert_url, - "client_x509_cert_url": settings.firebase_client_x509_cert_url + "client_x509_cert_url": settings.firebase_client_x509_cert_url, } cred = credentials.Certificate(cred_dict) - firebase_admin.initialize_app(cred, { - 'projectId': settings.firebase_project_id, - }) - logger.info("Firebase initialized with credentials from environment variables") + firebase_admin.initialize_app( + cred, + { + "projectId": settings.firebase_project_id, + }, + ) + logger.info( + "Firebase initialized with credentials from environment variables") # Fall back to service account JSON file if env vars are not available elif os.path.exists(settings.firebase_service_account_path): cred = credentials.Certificate(settings.firebase_service_account_path) - firebase_admin.initialize_app(cred, { - 'projectId': settings.firebase_project_id, - }) + firebase_admin.initialize_app( + cred, + { + "projectId": settings.firebase_project_id, + }, + ) logger.info("Firebase initialized with service account file") else: - logger.warning("Firebase service account not found. Google auth will not work.") + logger.warning( + "Firebase service account not found. Google auth will not work.") + class AuthService: def __init__(self): # Initializes the AuthService instance. pass - + def get_db(self): """ Returns a database connection instance from the application's database module. """ return get_database() - async def create_user_with_email(self, email: str, password: str, name: str) -> Dict[str, Any]: + async def create_user_with_email( + self, email: str, password: str, name: str + ) -> Dict[str, Any]: """ Creates a new user account with the provided email, password, and name. - + Checks for existing users with the same email and raises an error if found. Stores the user with a hashed password and default profile fields, then generates and returns a refresh token along with the user data. - + Args: email: The user's email address. password: The user's plaintext password. name: The user's display name. - + Returns: A dictionary containing the created user document and a refresh token. - + Raises: HTTPException: If a user with the given email already exists. """ db = self.get_db() - + # Check if user already exists existing_user = await db.users.find_one({"email": email}) if existing_user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User with this email already exists" + detail="User with this email already exists", ) - + # Create user document user_doc = { "email": email, @@ -97,86 +119,84 @@ async def create_user_with_email(self, email: str, password: str, name: str) -> "currency": "USD", "created_at": datetime.now(timezone.utc), "auth_provider": "email", - "firebase_uid": None + "firebase_uid": None, } - + try: result = await db.users.insert_one(user_doc) user_doc["_id"] = str(result.inserted_id) - + # Create refresh token - refresh_token = await self._create_refresh_token_record(str(result.inserted_id)) - - return { - "user": user_doc, - "refresh_token": refresh_token - } + refresh_token = await self._create_refresh_token_record( + str(result.inserted_id) + ) + + return {"user": user_doc, "refresh_token": refresh_token} except DuplicateKeyError: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="User with this email already exists" + detail="User with this email already exists", ) - async def authenticate_user_with_email(self, email: str, password: str) -> Dict[str, Any]: + async def authenticate_user_with_email( + self, email: str, password: str + ) -> Dict[str, Any]: """ Authenticates a user using email and password credentials. - + Verifies the provided email and password against stored user data. If authentication succeeds, returns the user information and a new refresh token. Raises an HTTP 401 error if credentials are invalid. - + Returns: A dictionary containing the authenticated user and a new refresh token. """ db = self.get_db() - + user = await db.users.find_one({"email": email}) if not user or not verify_password(password, user.get("hashed_password", "")): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect email or password" + detail="Incorrect email or password", ) - + # Create new refresh token refresh_token = await self._create_refresh_token_record(str(user["_id"])) - - return { - "user": user, - "refresh_token": refresh_token - } + + return {"user": user, "refresh_token": refresh_token} async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: """ Authenticates a user using a Google OAuth ID token, creating a new user if necessary. - + Verifies the provided Firebase ID token, retrieves or creates the corresponding user in the database, updates user information if needed, and issues a new refresh token. Raises an HTTP 400 error if the email is missing or if authentication fails, and HTTP 401 if the token is invalid. - + Args: id_token: The Firebase ID token obtained from Google OAuth. - + Returns: A dictionary containing the user data and a new refresh token. """ try: # Verify the Firebase ID token decoded_token = firebase_auth.verify_id_token(id_token) - firebase_uid = decoded_token['uid'] - email = decoded_token.get('email') - name = decoded_token.get('name', email.split('@')[0] if email else 'User') - picture = decoded_token.get('picture') - + firebase_uid = decoded_token["uid"] + email = decoded_token.get("email") + name = decoded_token.get( + "name", email.split("@")[0] if email else "User") + picture = decoded_token.get("picture") + if not email: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Email not provided by Google" + detail="Email not provided by Google", ) - + db = self.get_db() - + # Check if user exists - user = await db.users.find_one({"$or": [ - {"email": email}, - {"firebase_uid": firebase_uid} - ]}) - + user = await db.users.find_one( + {"$or": [{"email": email}, {"firebase_uid": firebase_uid}]} + ) + if user: # Update user info if needed update_data = {} @@ -184,11 +204,10 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: update_data["firebase_uid"] = firebase_uid if user.get("avatar") != picture and picture: update_data["avatar"] = picture - + if update_data: await db.users.update_one( - {"_id": user["_id"]}, - {"$set": update_data} + {"_id": user["_id"]}, {"$set": update_data} ) user.update(update_data) else: @@ -201,222 +220,229 @@ async def authenticate_with_google(self, id_token: str) -> Dict[str, Any]: "created_at": datetime.now(timezone.utc), "auth_provider": "google", "firebase_uid": firebase_uid, - "hashed_password": None + "hashed_password": None, } - + result = await db.users.insert_one(user_doc) user_doc["_id"] = result.inserted_id user = user_doc - + # Create refresh token refresh_token = await self._create_refresh_token_record(str(user["_id"])) - - return { - "user": user, - "refresh_token": refresh_token - } - + + return {"user": user, "refresh_token": refresh_token} + except firebase_auth.InvalidIdTokenError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid Google ID token" + detail="Invalid Google ID token", ) except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Google authentication failed: {str(e)}" + detail=f"Google authentication failed: {str(e)}", ) async def refresh_access_token(self, refresh_token: str) -> str: """ Refreshes an access token by validating and rotating the provided refresh token. - + If the refresh token is valid and not expired, issues a new refresh token and revokes the old one. Raises an HTTP 401 error if the token is invalid, expired, or the associated user does not exist. - + Args: refresh_token: The refresh token string to validate and rotate. - + Returns: A new refresh token string. """ db = self.get_db() - + # Find and validate refresh token - token_record = await db.refresh_tokens.find_one({ - "token": refresh_token, - "revoked": False, - "expires_at": {"$gt": datetime.now(timezone.utc)} - }) - + token_record = await db.refresh_tokens.find_one( + { + "token": refresh_token, + "revoked": False, + "expires_at": {"$gt": datetime.now(timezone.utc)}, + } + ) + if not token_record: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired refresh token" + detail="Invalid or expired refresh token", ) - + # Get user user = await db.users.find_one({"_id": token_record["user_id"]}) if not user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not found" + status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" ) - + # Create new refresh token (token rotation) new_refresh_token = await self._create_refresh_token_record(str(user["_id"])) - + # Revoke old token await db.refresh_tokens.update_one( - {"_id": token_record["_id"]}, - {"$set": {"revoked": True}} + {"_id": token_record["_id"]}, {"$set": {"revoked": True}} ) - - return new_refresh_token + + return new_refresh_token + async def verify_access_token(self, token: str) -> Dict[str, Any]: """ Verifies an access token and retrieves the associated user. - + Args: token: The JWT access token to verify. - + Returns: The user document corresponding to the token's subject. - + Raises: HTTPException: If the token is invalid or the user does not exist. """ from app.auth.security import verify_token - + payload = verify_token(token) user_id = payload.get("sub") - + if not user_id: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid token" + status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token" ) - + db = self.get_db() user = await db.users.find_one({"_id": user_id}) - + if not user: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not found" + status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" ) - + return user async def request_password_reset(self, email: str) -> bool: """ Initiates a password reset process for the specified email address. - + If the user exists, generates a password reset token with a 1-hour expiration and stores it in the database. The reset token and link are logged for development purposes. Always returns True to avoid revealing whether the email is registered. """ db = self.get_db() - + user = await db.users.find_one({"email": email}) if not user: # Don't reveal if email exists or not return True - + # Generate reset token reset_token = generate_reset_token() - reset_expires = datetime.now(timezone.utc) + timedelta(hours=1) # 1 hour expiry - + reset_expires = datetime.now( + timezone.utc) + timedelta(hours=1) # 1 hour expiry + # Store reset token - await db.password_resets.insert_one({ - "user_id": user["_id"], - "token": reset_token, - "expires_at": reset_expires, - "used": False, - "created_at": datetime.utcnow() - }) - + await db.password_resets.insert_one( + { + "user_id": user["_id"], + "token": reset_token, + "expires_at": reset_expires, + "used": False, + "created_at": datetime.utcnow(), + } + ) + # For development/free tier: just log the reset token # In production, you would send this via email logger.info(f"Password reset token for {email}: {reset_token[:6]}") - logger.info(f"Reset link: https://yourapp.com/reset-password?token={reset_token}") - + logger.info( + f"Reset link: https://yourapp.com/reset-password?token={reset_token}" + ) + return True async def confirm_password_reset(self, reset_token: str, new_password: str) -> bool: """ Confirms a password reset using a valid reset token and updates the user's password. - + Validates the reset token, updates the user's password, marks the token as used, and revokes all existing refresh tokens for the user to require re-authentication. - + Args: reset_token: The password reset token to validate. new_password: The new password to set for the user. - + Returns: True if the password reset is successful. - + Raises: HTTPException: If the reset token is invalid or expired. """ db = self.get_db() - + # Find and validate reset token - reset_record = await db.password_resets.find_one({ - "token": reset_token, - "used": False, - "expires_at": {"$gt": datetime.now(timezone.utc)} - }) - + reset_record = await db.password_resets.find_one( + { + "token": reset_token, + "used": False, + "expires_at": {"$gt": datetime.now(timezone.utc)}, + } + ) + if not reset_record: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid or expired reset token" + detail="Invalid or expired reset token", ) - + # Update user password new_hash = get_password_hash(new_password) await db.users.update_one( - {"_id": reset_record["user_id"]}, - {"$set": {"hashed_password": new_hash}} + {"_id": reset_record["user_id"]}, { + "$set": {"hashed_password": new_hash}} ) - + # Mark token as used await db.password_resets.update_one( - {"_id": reset_record["_id"]}, - {"$set": {"used": True}} + {"_id": reset_record["_id"]}, {"$set": {"used": True}} ) - + # Revoke all refresh tokens for this user (force re-login) await db.refresh_tokens.update_many( - {"user_id": reset_record["user_id"]}, - {"$set": {"revoked": True}} + {"user_id": reset_record["user_id"]}, {"$set": {"revoked": True}} ) - - return True + + return True + async def _create_refresh_token_record(self, user_id: str) -> str: """ Generates and stores a new refresh token for the specified user. - + Creates a refresh token with an expiration date and saves it in the database for token management and rotation. - + Args: user_id: The unique identifier of the user for whom the refresh token is created. - + Returns: The generated refresh token string. """ db = self.get_db() - + refresh_token = create_refresh_token() - expires_at = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days) - - await db.refresh_tokens.insert_one({ - "token": refresh_token, - "user_id": ObjectId(user_id) if isinstance(user_id, str) else user_id, - "expires_at": expires_at, - "revoked": False, - "created_at": datetime.now(timezone.utc) - }) - + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.refresh_token_expire_days + ) + + await db.refresh_tokens.insert_one( + { + "token": refresh_token, + "user_id": ObjectId(user_id) if isinstance(user_id, str) else user_id, + "expires_at": expires_at, + "revoked": False, + "created_at": datetime.now(timezone.utc), + } + ) + return refresh_token + # Create service instance auth_service = AuthService() diff --git a/backend/app/config.py b/backend/app/config.py index 9debb879..3ee6f809 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,23 +1,25 @@ -import os -from pydantic_settings import BaseSettings -from typing import Optional import logging +import os +import time from logging.config import dictConfig +from typing import Optional + +from pydantic_settings import BaseSettings from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request -import time + class Settings(BaseSettings): # Database mongodb_url: str = "mongodb://localhost:27017" database_name: str = "splitwiser" - + # JWT secret_key: str = "your-super-secret-jwt-key-change-this-in-production" algorithm: str = "HS256" access_token_expire_minutes: int = 15 refresh_token_expire_days: int = 30 - # Firebase + # Firebase firebase_project_id: Optional[str] = None firebase_service_account_path: str = "./firebase-service-account.json" # Firebase service account credentials as environment variables @@ -30,20 +32,23 @@ class Settings(BaseSettings): firebase_token_uri: Optional[str] = None firebase_auth_provider_x509_cert_url: Optional[str] = None firebase_client_x509_cert_url: Optional[str] = None - + # App debug: bool = False - + # CORS - Add your frontend domain here for production - allowed_origins: str = "http://localhost:3000,http://localhost:5173,http://127.0.0.1:3000,http://localhost:8081" + allowed_origins: str = ( + "http://localhost:3000,http://localhost:5173,http://127.0.0.1:3000,http://localhost:8081" + ) allow_all_origins: bool = False class Config: env_file = ".env" + settings = Settings() -#centralized logging config +# centralized logging config LOGGING_CONFIG = { "version": 1, "disable_existing_loggers": False, @@ -67,17 +72,20 @@ class Config: dictConfig(LOGGING_CONFIG) logger = logging.getLogger("splitwiser") + class RequestResponseLoggingMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): logger = logging.getLogger("splitwiser") - + logger.info(f"Incoming request: {request.method} {request.url}") - + start_time = time.time() response = await call_next(request) process_time = time.time() - start_time - - logger.info(f"Response status: {response.status_code} for {request.method} {request.url}") + + logger.info( + f"Response status: {response.status_code} for {request.method} {request.url}" + ) logger.info(f"Response time: {process_time:.2f} seconds") - - return response \ No newline at end of file + + return response diff --git a/backend/app/database.py b/backend/app/database.py index ceef4d93..f298c178 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -1,26 +1,30 @@ +from app.config import logger, settings from motor.motor_asyncio import AsyncIOMotorClient -from app.config import settings,logger + class MongoDB: client: AsyncIOMotorClient = None database = None + mongodb = MongoDB() + async def connect_to_mongo(): """ Initializes an asynchronous connection to MongoDB and sets the active database. - + Establishes a connection using the configured MongoDB URL and selects the database specified in the application settings. """ mongodb.client = AsyncIOMotorClient(settings.mongodb_url) mongodb.database = mongodb.client[settings.database_name] logger.info("Connected to MongoDB") + async def close_mongo_connection(): """ Closes the MongoDB client connection if it is currently open. - + This function safely terminates the connection to the MongoDB server by closing the existing client instance. """ @@ -28,10 +32,11 @@ async def close_mongo_connection(): mongodb.client.close() logger.info("Disconnected from MongoDB") + def get_database(): """ Returns the current MongoDB database instance. - + Use this function to access the active database connection managed by the module. """ return mongodb.database diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 6032515f..38877af6 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -1,18 +1,22 @@ -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from typing import Any, Dict + from app.auth.security import verify_token from app.database import get_database from bson import ObjectId -from typing import Dict, Any +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer security = HTTPBearer() -async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> Dict[str, Any]: + +async def get_current_user( + credentials: HTTPAuthorizationCredentials = Depends(security), +) -> Dict[str, Any]: """ Retrieves the currently authenticated user based on a JWT token from the HTTP Authorization header. - + Verifies the provided JWT token, extracts the user ID, and fetches the corresponding user document from the database. Raises an HTTP 401 Unauthorized error if the token is invalid, the user ID is missing, or the user does not exist. - + Returns: A dictionary representing the authenticated user, with the `_id` field as a string. """ @@ -20,31 +24,31 @@ async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(s # Verify token payload = verify_token(credentials.credentials) user_id = payload.get("sub") - - if not user_id: + + if not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) - + # Get user from database db = get_database() user = await db.users.find_one({"_id": ObjectId(user_id)}) - + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found", headers={"WWW-Authenticate": "Bearer"}, ) - + # Create a copy of the user document to avoid mutating the original user_copy = dict(user) # Convert ObjectId to string user_copy["_id"] = str(user_copy["_id"]) return user_copy - + except HTTPException: raise except Exception as e: diff --git a/backend/app/expenses/routes.py b/backend/app/expenses/routes.py index 226620b1..3277f43c 100644 --- a/backend/app/expenses/routes.py +++ b/backend/app/expenses/routes.py @@ -1,38 +1,67 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Query, UploadFile, File, Response -from fastapi.responses import StreamingResponse +import io +import uuid +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional + +from app.auth.security import get_current_user +from app.config import logger from app.expenses.schemas import ( - ExpenseCreateRequest, ExpenseCreateResponse, ExpenseListResponse, ExpenseResponse, - ExpenseUpdateRequest, SettlementCreateRequest, Settlement, SettlementUpdateRequest, - SettlementListResponse, OptimizedSettlementsResponse, FriendsBalanceResponse, - BalanceSummaryResponse, UserBalance, ExpenseAnalytics, AttachmentUploadResponse + AttachmentUploadResponse, + BalanceSummaryResponse, + ExpenseAnalytics, + ExpenseCreateRequest, + ExpenseCreateResponse, + ExpenseListResponse, + ExpenseResponse, + ExpenseUpdateRequest, + FriendsBalanceResponse, + OptimizedSettlementsResponse, + Settlement, + SettlementCreateRequest, + SettlementListResponse, + SettlementUpdateRequest, + UserBalance, ) from app.expenses.service import expense_service -from app.auth.security import get_current_user -from app.config import logger -from typing import Dict, Any, List, Optional -from datetime import datetime, timedelta -import io -import uuid +from fastapi import ( + APIRouter, + Depends, + File, + HTTPException, + Query, + Response, + UploadFile, + status, +) +from fastapi.responses import StreamingResponse router = APIRouter(prefix="/groups/{group_id}", tags=["Expenses"]) # Expense CRUD Operations -@router.post("/expenses", response_model=ExpenseCreateResponse, status_code=status.HTTP_201_CREATED) + +@router.post( + "/expenses", + response_model=ExpenseCreateResponse, + status_code=status.HTTP_201_CREATED, +) async def create_expense( group_id: str, expense_data: ExpenseCreateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Create a new expense within a group""" try: - result = await expense_service.create_expense(group_id, expense_data, current_user["_id"]) + result = await expense_service.create_expense( + group_id, expense_data, current_user["_id"] + ) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail="Failed to create expense") + @router.get("/expenses", response_model=ExpenseListResponse) async def list_group_expenses( group_id: str, @@ -41,7 +70,7 @@ async def list_group_expenses( from_date: Optional[datetime] = Query(None, alias="from"), to_date: Optional[datetime] = Query(None, alias="to"), tags: Optional[str] = Query(None), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """List all expenses for a group with pagination and filtering""" try: @@ -55,47 +84,58 @@ async def list_group_expenses( except Exception as e: raise HTTPException(status_code=500, detail="Failed to fetch expenses") + @router.get("/expenses/{expense_id}") async def get_single_expense( group_id: str, expense_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Retrieve details for a single expense""" try: - result = await expense_service.get_expense_by_id(group_id, expense_id, current_user["_id"]) + result = await expense_service.get_expense_by_id( + group_id, expense_id, current_user["_id"] + ) return result except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail="Failed to fetch expense") + @router.patch("/expenses/{expense_id}", response_model=ExpenseResponse) async def update_expense( group_id: str, expense_id: str, updates: ExpenseUpdateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Update an existing expense""" try: - result = await expense_service.update_expense(group_id, expense_id, updates, current_user["_id"]) + result = await expense_service.update_expense( + group_id, expense_id, updates, current_user["_id"] + ) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - logger.error(f"Error updating expense: {str(e)}",exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to update expense: {str(e)}") + logger.error(f"Error updating expense: {str(e)}", exc_info=True) + raise HTTPException( + status_code=500, detail=f"Failed to update expense: {str(e)}" + ) + @router.delete("/expenses/{expense_id}") async def delete_expense( group_id: str, expense_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Delete an expense""" try: - success = await expense_service.delete_expense(group_id, expense_id, current_user["_id"]) + success = await expense_service.delete_expense( + group_id, expense_id, current_user["_id"] + ) if not success: raise HTTPException(status_code=404, detail="Expense not found") return {"success": True} @@ -104,76 +144,96 @@ async def delete_expense( except Exception as e: raise HTTPException(status_code=500, detail="Failed to delete expense") + # Attachment Handling -@router.post("/expenses/{expense_id}/attachments", response_model=AttachmentUploadResponse, status_code=status.HTTP_201_CREATED) + +@router.post( + "/expenses/{expense_id}/attachments", + response_model=AttachmentUploadResponse, + status_code=status.HTTP_201_CREATED, +) async def upload_attachment_for_expense( group_id: str, expense_id: str, file: UploadFile = File(...), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Upload attachment for an expense""" try: # Verify user has access to the expense - await expense_service.get_expense_by_id(group_id, expense_id, current_user["_id"]) - + await expense_service.get_expense_by_id( + group_id, expense_id, current_user["_id"] + ) + # Generate unique key for the attachment - file_extension = file.filename.split(".")[-1] if "." in file.filename else "" + file_extension = file.filename.split( + ".")[-1] if "." in file.filename else "" attachment_key = f"{expense_id}_{uuid.uuid4().hex}.{file_extension}" - + # In a real implementation, you would upload to cloud storage (S3, etc.) # For now, we'll simulate this file_content = await file.read() - + # Store file metadata (in practice, store the actual file and return the URL) url = f"https://storage.example.com/attachments/{attachment_key}" - - return AttachmentUploadResponse( - attachment_key=attachment_key, - url=url - ) + + return AttachmentUploadResponse(attachment_key=attachment_key, url=url) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to upload attachment") + raise HTTPException( + status_code=500, detail="Failed to upload attachment") + @router.get("/expenses/{expense_id}/attachments/{key}") async def get_attachment( group_id: str, expense_id: str, key: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Get/download an attachment""" try: # Verify user has access to the expense - await expense_service.get_expense_by_id(group_id, expense_id, current_user["_id"]) - + await expense_service.get_expense_by_id( + group_id, expense_id, current_user["_id"] + ) + # In a real implementation, you would fetch from cloud storage # For now, we'll return a placeholder response - raise HTTPException(status_code=501, detail="Attachment download not implemented") + raise HTTPException( + status_code=501, detail="Attachment download not implemented" + ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException(status_code=500, detail="Failed to get attachment") + # Settlement Management -@router.post("/settlements", response_model=Settlement, status_code=status.HTTP_201_CREATED) + +@router.post( + "/settlements", response_model=Settlement, status_code=status.HTTP_201_CREATED +) async def manually_record_payment( group_id: str, settlement_data: SettlementCreateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Manually record a payment settlement between users in a group""" try: - result = await expense_service.create_manual_settlement(group_id, settlement_data, current_user["_id"]) + result = await expense_service.create_manual_settlement( + group_id, settlement_data, current_user["_id"] + ) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to record settlement") + raise HTTPException( + status_code=500, detail="Failed to record settlement") + @router.get("/settlements", response_model=SettlementListResponse) async def get_group_settlements( @@ -181,8 +241,10 @@ async def get_group_settlements( status_filter: Optional[str] = Query(None, alias="status"), page: int = Query(1, ge=1), limit: int = Query(50, ge=1, le=100), - algorithm: str = Query("advanced", description="Settlement algorithm: 'normal' or 'advanced'"), - current_user: Dict[str, Any] = Depends(get_current_user) + algorithm: str = Query( + "advanced", description="Settlement algorithm: 'normal' or 'advanced'" + ), + current_user: Dict[str, Any] = Depends(get_current_user), ): """Retrieve pending and optimized settlements for a group""" try: @@ -192,16 +254,23 @@ async def get_group_settlements( ) # Get optimized settlements - optimized_settlements = await expense_service.calculate_optimized_settlements(group_id, algorithm) + optimized_settlements = await expense_service.calculate_optimized_settlements( + group_id, algorithm + ) # Calculate summary from app.database import mongodb - total_pending_result = await mongodb.database.settlements.aggregate([ - {"$match": {"groupId": group_id, "status": "pending"}}, - {"$group": {"_id": None, "totalPending": {"$sum": "$amount"}}} - ]).to_list(None) - - total_pending = total_pending_result[0]["totalPending"] if total_pending_result else 0 + + total_pending_result = await mongodb.database.settlements.aggregate( + [ + {"$match": {"groupId": group_id, "status": "pending"}}, + {"$group": {"_id": None, "totalPending": {"$sum": "$amount"}}}, + ] + ).to_list(None) + + total_pending = ( + total_pending_result[0]["totalPending"] if total_pending_result else 0 + ) return SettlementListResponse( settlements=settlements_result["settlements"], @@ -209,41 +278,47 @@ async def get_group_settlements( summary={ "totalPending": total_pending, "transactionCount": len(settlements_result["settlements"]), - "optimizedCount": len(optimized_settlements) + "optimizedCount": len(optimized_settlements), }, pagination={ "currentPage": page, "totalPages": (settlements_result["total"] + limit - 1) // limit, "totalItems": settlements_result["total"], - "limit": limit - } + "limit": limit, + }, ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch settlements") + raise HTTPException( + status_code=500, detail="Failed to fetch settlements") + @router.get("/settlements/{settlement_id}", response_model=Settlement) async def get_single_settlement( group_id: str, settlement_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Retrieve details for a single settlement""" try: - settlement = await expense_service.get_settlement_by_id(group_id, settlement_id, current_user["_id"]) + settlement = await expense_service.get_settlement_by_id( + group_id, settlement_id, current_user["_id"] + ) return settlement except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch settlement") + raise HTTPException( + status_code=500, detail="Failed to fetch settlement") + @router.patch("/settlements/{settlement_id}", response_model=Settlement) async def mark_settlement_as_paid( group_id: str, settlement_id: str, updates: SettlementUpdateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Mark a settlement as paid""" try: @@ -254,65 +329,80 @@ async def mark_settlement_as_paid( except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to update settlement") + raise HTTPException( + status_code=500, detail="Failed to update settlement") + @router.delete("/settlements/{settlement_id}") async def delete_settlement( group_id: str, settlement_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Delete/undo a recorded settlement""" try: - success = await expense_service.delete_settlement(group_id, settlement_id, current_user["_id"]) + success = await expense_service.delete_settlement( + group_id, settlement_id, current_user["_id"] + ) if not success: raise HTTPException(status_code=404, detail="Settlement not found") - return { - "success": True, - "message": "Settlement record deleted successfully." - } + return {"success": True, "message": "Settlement record deleted successfully."} except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to delete settlement") + raise HTTPException( + status_code=500, detail="Failed to delete settlement") + @router.post("/settlements/optimize", response_model=OptimizedSettlementsResponse) async def calculate_optimized_settlements( group_id: str, - algorithm: str = Query("advanced", description="Settlement algorithm: 'normal' or 'advanced'"), - current_user: Dict[str, Any] = Depends(get_current_user) + algorithm: str = Query( + "advanced", description="Settlement algorithm: 'normal' or 'advanced'" + ), + current_user: Dict[str, Any] = Depends(get_current_user), ): """Calculate and return optimized (simplified) settlements for a group""" try: - optimized_settlements = await expense_service.calculate_optimized_settlements(group_id, algorithm) - + optimized_settlements = await expense_service.calculate_optimized_settlements( + group_id, algorithm + ) + # Calculate savings from app.database import mongodb - total_settlements = await mongodb.database.settlements.count_documents({ - "groupId": group_id, - "status": "pending" - }) - + + total_settlements = await mongodb.database.settlements.count_documents( + {"groupId": group_id, "status": "pending"} + ) + optimized_count = len(optimized_settlements) - reduction_percentage = ((total_settlements - optimized_count) / total_settlements * 100) if total_settlements > 0 else 0 + reduction_percentage = ( + ((total_settlements - optimized_count) / total_settlements * 100) + if total_settlements > 0 + else 0 + ) return OptimizedSettlementsResponse( optimizedSettlements=optimized_settlements, savings={ "originalTransactions": total_settlements, "optimizedTransactions": optimized_count, - "reductionPercentage": round(reduction_percentage, 1) - } + "reductionPercentage": round(reduction_percentage, 1), + }, ) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to calculate optimized settlements") + raise HTTPException( + status_code=500, detail="Failed to calculate optimized settlements" + ) + # User Balance Endpoints # These endpoints are defined at the root level in a separate router balance_router = APIRouter(prefix="/users/me", tags=["User Balance"]) + @balance_router.get("/friends-balance", response_model=FriendsBalanceResponse) async def get_cross_group_friend_balances( current_user: Dict[str, Any] = Depends(get_current_user) @@ -322,7 +412,9 @@ async def get_cross_group_friend_balances( result = await expense_service.get_friends_balance_summary(current_user["_id"]) return FriendsBalanceResponse(**result) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch friends balance") + raise HTTPException( + status_code=500, detail="Failed to fetch friends balance") + @balance_router.get("/balance-summary", response_model=BalanceSummaryResponse) async def get_overall_user_balance_summary( @@ -333,68 +425,81 @@ async def get_overall_user_balance_summary( result = await expense_service.get_overall_balance_summary(current_user["_id"]) return BalanceSummaryResponse(**result) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch balance summary") + raise HTTPException( + status_code=500, detail="Failed to fetch balance summary") + # Group-specific user balance @router.get("/users/{user_id}/balance", response_model=UserBalance) async def get_user_balance_in_specific_group( group_id: str, user_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Get a specific user's balance within a particular group""" try: - result = await expense_service.get_user_balance_in_group(group_id, user_id, current_user["_id"]) + result = await expense_service.get_user_balance_in_group( + group_id, user_id, current_user["_id"] + ) return UserBalance(**result) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch user balance") + raise HTTPException( + status_code=500, detail="Failed to fetch user balance") + # Analytics @router.get("/analytics", response_model=ExpenseAnalytics) async def group_expense_analytics( group_id: str, - period: str = Query("month", description="Analytics period: 'week', 'month', 'year'"), + period: str = Query( + "month", description="Analytics period: 'week', 'month', 'year'" + ), year: int = Query(...), month: Optional[int] = Query(None), - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Provide expense analytics for a group""" try: - result = await expense_service.get_group_analytics(group_id, current_user["_id"], period, year, month) + result = await expense_service.get_group_analytics( + group_id, current_user["_id"], period, year, month + ) return ExpenseAnalytics(**result) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: - raise HTTPException(status_code=500, detail="Failed to fetch analytics") + raise HTTPException( + status_code=500, detail="Failed to fetch analytics") + # Debug endpoint (remove in production) @router.get("/expenses/{expense_id}/debug") async def debug_expense( group_id: str, expense_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Debug endpoint to check expense details and user permissions""" try: from app.database import mongodb from bson import ObjectId - + # Check if expense exists - expense = await mongodb.database.expenses.find_one({"_id": ObjectId(expense_id)}) + expense = await mongodb.database.expenses.find_one( + {"_id": ObjectId(expense_id)} + ) if not expense: return {"error": "Expense not found", "expense_id": expense_id} - + # Check group membership - group = await mongodb.database.groups.find_one({ - "_id": ObjectId(group_id), - "members.userId": current_user["_id"] - }) - + group = await mongodb.database.groups.find_one( + {"_id": ObjectId(group_id), "members.userId": current_user["_id"]} + ) + # Check if user created the expense user_created = expense.get("createdBy") == current_user["_id"] - + return { "expense_exists": True, "expense_id": expense_id, @@ -410,8 +515,8 @@ async def debug_expense( "amount": expense.get("amount"), "splits_count": len(expense.get("splits", [])), "created_at": expense.get("createdAt"), - "updated_at": expense.get("updatedAt") - } + "updated_at": expense.get("updatedAt"), + }, } except Exception as e: return {"error": str(e), "type": type(e).__name__} diff --git a/backend/app/expenses/schemas.py b/backend/app/expenses/schemas.py index b4126ad4..734fa31e 100644 --- a/backend/app/expenses/schemas.py +++ b/backend/app/expenses/schemas.py @@ -1,23 +1,28 @@ -from pydantic import BaseModel, Field, validator,ConfigDict -from typing import Optional, List, Dict, Any from datetime import datetime from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, validator + class SplitType(str, Enum): EQUAL = "equal" UNEQUAL = "unequal" PERCENTAGE = "percentage" + class SettlementStatus(str, Enum): PENDING = "pending" COMPLETED = "completed" CANCELLED = "cancelled" + class ExpenseSplit(BaseModel): userId: str amount: float = Field(..., gt=0) type: SplitType = SplitType.EQUAL + class ExpenseCreateRequest(BaseModel): description: str = Field(..., min_length=1, max_length=500) amount: float = Field(..., gt=0) @@ -26,14 +31,18 @@ class ExpenseCreateRequest(BaseModel): tags: Optional[List[str]] = [] receiptUrls: Optional[List[str]] = [] - @validator('splits') + @validator("splits") def validate_splits_sum(cls, v, values): - if 'amount' in values: + if "amount" in values: total_split = sum(split.amount for split in v) - if abs(total_split - values['amount']) > 0.01: # Allow small floating point differences - raise ValueError('Split amounts must sum to total expense amount') + if ( + abs(total_split - values["amount"]) > 0.01 + ): # Allow small floating point differences + raise ValueError( + "Split amounts must sum to total expense amount") return v + class ExpenseUpdateRequest(BaseModel): description: Optional[str] = Field(None, min_length=1, max_length=500) amount: Optional[float] = Field(None, gt=0) @@ -41,19 +50,21 @@ class ExpenseUpdateRequest(BaseModel): tags: Optional[List[str]] = None receiptUrls: Optional[List[str]] = None - @validator('splits') + @validator("splits") def validate_splits_sum(cls, v, values): # Only validate if both splits and amount are provided in the update - if v is not None and 'amount' in values and values['amount'] is not None: + if v is not None and "amount" in values and values["amount"] is not None: total_split = sum(split.amount for split in v) - if abs(total_split - values['amount']) > 0.01: - raise ValueError('Split amounts must sum to total expense amount') + if abs(total_split - values["amount"]) > 0.01: + raise ValueError( + "Split amounts must sum to total expense amount") return v - + class Config: # Allow validation to work with partial updates validate_assignment = True + class ExpenseComment(BaseModel): id: str = Field(alias="_id") userId: str @@ -62,11 +73,10 @@ class ExpenseComment(BaseModel): createdAt: datetime model_config = ConfigDict( - populate_by_name=True, - str_strip_whitespace=True, - validate_assignment=True + populate_by_name=True, str_strip_whitespace=True, validate_assignment=True ) + class ExpenseHistoryEntry(BaseModel): id: str = Field(alias="_id") userId: str @@ -76,6 +86,7 @@ class ExpenseHistoryEntry(BaseModel): model_config = ConfigDict(populate_by_name=True) + class ExpenseResponse(BaseModel): id: str = Field(alias="_id") groupId: str @@ -93,6 +104,7 @@ class ExpenseResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) + class Settlement(BaseModel): id: str = Field(alias="_id") expenseId: Optional[str] = None # None for manual settlements @@ -109,6 +121,7 @@ class Settlement(BaseModel): model_config = ConfigDict(populate_by_name=True) + class OptimizedSettlement(BaseModel): fromUserId: str toUserId: str @@ -117,21 +130,25 @@ class OptimizedSettlement(BaseModel): amount: float consolidatedExpenses: Optional[List[str]] = [] + class GroupSummary(BaseModel): totalExpenses: float totalSettlements: int optimizedSettlements: List[OptimizedSettlement] + class ExpenseCreateResponse(BaseModel): expense: ExpenseResponse settlements: List[Settlement] groupSummary: GroupSummary + class ExpenseListResponse(BaseModel): expenses: List[ExpenseResponse] pagination: Dict[str, Any] summary: Dict[str, Any] + class SettlementCreateRequest(BaseModel): payer_id: str payee_id: str @@ -139,16 +156,19 @@ class SettlementCreateRequest(BaseModel): description: Optional[str] = None paidAt: Optional[datetime] = None + class SettlementUpdateRequest(BaseModel): status: SettlementStatus paidAt: Optional[datetime] = None + class SettlementListResponse(BaseModel): settlements: List[Settlement] optimizedSettlements: List[OptimizedSettlement] summary: Dict[str, Any] pagination: Dict[str, Any] + class UserBalance(BaseModel): userId: str userName: str @@ -159,12 +179,14 @@ class UserBalance(BaseModel): pendingSettlements: List[Settlement] = [] recentExpenses: List[Dict[str, Any]] = [] + class FriendBalanceBreakdown(BaseModel): groupId: str groupName: str balance: float owesYou: bool + class FriendBalance(BaseModel): userId: str userName: str @@ -174,10 +196,12 @@ class FriendBalance(BaseModel): breakdown: List[FriendBalanceBreakdown] lastActivity: datetime + class FriendsBalanceResponse(BaseModel): friendsBalance: List[FriendBalance] summary: Dict[str, Any] + class BalanceSummaryResponse(BaseModel): totalOwedToYou: float totalYouOwe: float @@ -185,6 +209,7 @@ class BalanceSummaryResponse(BaseModel): currency: str = "USD" groupsSummary: List[Dict[str, Any]] + class ExpenseAnalytics(BaseModel): period: str totalExpenses: float @@ -194,10 +219,12 @@ class ExpenseAnalytics(BaseModel): memberContributions: List[Dict[str, Any]] expenseTrends: List[Dict[str, Any]] + class AttachmentUploadResponse(BaseModel): attachment_key: str url: str + class OptimizedSettlementsResponse(BaseModel): optimizedSettlements: List[OptimizedSettlement] savings: Dict[str, Any] diff --git a/backend/app/expenses/service.py b/backend/app/expenses/service.py index 8f2b2162..609fb0f6 100644 --- a/backend/app/expenses/service.py +++ b/backend/app/expenses/service.py @@ -1,49 +1,58 @@ -from typing import List, Dict, Any, Optional, Tuple +import asyncio +from collections import defaultdict, deque from datetime import datetime, timedelta -from bson import ObjectId -from app.database import mongodb +from typing import Any, Dict, List, Optional, Tuple + from app.config import logger +from app.database import mongodb from app.expenses.schemas import ( - ExpenseCreateRequest, ExpenseUpdateRequest, ExpenseResponse, Settlement, - OptimizedSettlement, SettlementCreateRequest, SettlementStatus, SplitType + ExpenseCreateRequest, + ExpenseResponse, + ExpenseUpdateRequest, + OptimizedSettlement, + Settlement, + SettlementCreateRequest, + SettlementStatus, + SplitType, ) -import asyncio -from collections import defaultdict, deque +from bson import ObjectId + class ExpenseService: def __init__(self): pass - + @property def expenses_collection(self): return mongodb.database.expenses - + @property def settlements_collection(self): return mongodb.database.settlements - + @property def groups_collection(self): return mongodb.database.groups - + @property def users_collection(self): return mongodb.database.users - async def create_expense(self, group_id: str, expense_data: ExpenseCreateRequest, user_id: str) -> Dict[str, Any]: + async def create_expense( + self, group_id: str, expense_data: ExpenseCreateRequest, user_id: str + ) -> Dict[str, Any]: """Create a new expense and calculate settlements""" - + # Validate and convert group_id to ObjectId try: group_obj_id = ObjectId(group_id) except Exception: raise ValueError("Group not found or user not a member") - + # Verify user is member of the group - group = await self.groups_collection.find_one({ - "_id": group_obj_id, - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": group_obj_id, "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") @@ -61,7 +70,7 @@ async def create_expense(self, group_id: str, expense_data: ExpenseCreateRequest "comments": [], "history": [], "createdAt": datetime.utcnow(), - "updatedAt": datetime.utcnow() + "updatedAt": datetime.utcnow(), } # Insert expense @@ -82,19 +91,25 @@ async def create_expense(self, group_id: str, expense_data: ExpenseCreateRequest return { "expense": expense_response, "settlements": settlements, - "groupSummary": group_summary + "groupSummary": group_summary, } - async def _create_settlements_for_expense(self, expense_doc: Dict[str, Any], payer_id: str) -> List[Settlement]: + async def _create_settlements_for_expense( + self, expense_doc: Dict[str, Any], payer_id: str + ) -> List[Settlement]: """Create settlement records for an expense""" settlements = [] expense_id = str(expense_doc["_id"]) group_id = expense_doc["groupId"] # Get user names for the settlements - user_ids = [split["userId"] for split in expense_doc["splits"]] + [payer_id] - users = await self.users_collection.find({"_id": {"$in": [ObjectId(uid) for uid in user_ids]}}).to_list(None) - user_names = {str(user["_id"]): user.get("name", "Unknown") for user in users} + user_ids = [split["userId"] + for split in expense_doc["splits"]] + [payer_id] + users = await self.users_collection.find( + {"_id": {"$in": [ObjectId(uid) for uid in user_ids]}} + ).to_list(None) + user_names = {str(user["_id"]): user.get( + "name", "Unknown") for user in users} for split in expense_doc["splits"]: settlement_doc = { @@ -108,36 +123,41 @@ async def _create_settlements_for_expense(self, expense_doc: Dict[str, Any], pay "amount": split["amount"], "status": "completed" if split["userId"] == payer_id else "pending", "description": f"Share for {expense_doc['description']}", - "createdAt": datetime.utcnow() + "createdAt": datetime.utcnow(), } await self.settlements_collection.insert_one(settlement_doc) - + # Convert to Settlement model - settlement = Settlement(**{ - **settlement_doc, - "_id": str(settlement_doc["_id"]) - }) + settlement = Settlement( + **{**settlement_doc, "_id": str(settlement_doc["_id"])} + ) settlements.append(settlement) return settlements - async def list_group_expenses(self, group_id: str, user_id: str, page: int = 1, limit: int = 20, - from_date: Optional[datetime] = None, to_date: Optional[datetime] = None, - tags: Optional[List[str]] = None) -> Dict[str, Any]: + async def list_group_expenses( + self, + group_id: str, + user_id: str, + page: int = 1, + limit: int = 20, + from_date: Optional[datetime] = None, + to_date: Optional[datetime] = None, + tags: Optional[List[str]] = None, + ) -> Dict[str, Any]: """List expenses for a group with pagination and filtering""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") # Build query query = {"groupId": group_id} - + if from_date or to_date: date_filter = {} if from_date: @@ -154,7 +174,12 @@ async def list_group_expenses(self, group_id: str, user_id: str, page: int = 1, # Get expenses with pagination skip = (page - 1) * limit - expenses_cursor = self.expenses_collection.find(query).sort("createdAt", -1).skip(skip).limit(limit) + expenses_cursor = ( + self.expenses_collection.find(query) + .sort("createdAt", -1) + .skip(skip) + .limit(limit) + ) expenses_docs = await expenses_cursor.to_list(None) expenses = [] @@ -165,19 +190,23 @@ async def list_group_expenses(self, group_id: str, user_id: str, page: int = 1, # Calculate summary pipeline = [ {"$match": query}, - {"$group": { - "_id": None, - "totalAmount": {"$sum": "$amount"}, - "expenseCount": {"$sum": 1}, - "avgExpense": {"$avg": "$amount"} - }} + { + "$group": { + "_id": None, + "totalAmount": {"$sum": "$amount"}, + "expenseCount": {"$sum": 1}, + "avgExpense": {"$avg": "$amount"}, + } + }, ] - summary_result = await self.expenses_collection.aggregate(pipeline).to_list(None) - summary = summary_result[0] if summary_result else { - "totalAmount": 0, - "expenseCount": 0, - "avgExpense": 0 - } + summary_result = await self.expenses_collection.aggregate(pipeline).to_list( + None + ) + summary = ( + summary_result[0] + if summary_result + else {"totalAmount": 0, "expenseCount": 0, "avgExpense": 0} + ) summary.pop("_id", None) return { @@ -188,72 +217,70 @@ async def list_group_expenses(self, group_id: str, user_id: str, page: int = 1, "total": total, "totalPages": (total + limit - 1) // limit, "hasNext": page * limit < total, - "hasPrev": page > 1 + "hasPrev": page > 1, }, - "summary": summary + "summary": summary, } - async def get_expense_by_id(self, group_id: str, expense_id: str, user_id: str) -> Dict[str, Any]: + async def get_expense_by_id( + self, group_id: str, expense_id: str, user_id: str + ) -> Dict[str, Any]: """Get a single expense with details""" - + # Validate ObjectIds try: group_obj_id = ObjectId(group_id) expense_obj_id = ObjectId(expense_id) except Exception: raise ValueError("Group not found or user not a member") - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": group_obj_id, - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": group_obj_id, "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") - expense_doc = await self.expenses_collection.find_one({ - "_id": expense_obj_id, - "groupId": group_id - }) + expense_doc = await self.expenses_collection.find_one( + {"_id": expense_obj_id, "groupId": group_id} + ) if not expense_doc: raise ValueError("Expense not found") expense = await self._expense_doc_to_response(expense_doc) # Get related settlements - settlements_docs = await self.settlements_collection.find({ - "expenseId": expense_id - }).to_list(None) + settlements_docs = await self.settlements_collection.find( + {"expenseId": expense_id} + ).to_list(None) settlements = [] for doc in settlements_docs: - settlement = Settlement(**{ - **doc, - "_id": str(doc["_id"]) - }) + settlement = Settlement(**{**doc, "_id": str(doc["_id"])}) settlements.append(settlement) - return { - "expense": expense, - "relatedSettlements": settlements - } + return {"expense": expense, "relatedSettlements": settlements} - async def update_expense(self, group_id: str, expense_id: str, updates: ExpenseUpdateRequest, user_id: str) -> ExpenseResponse: + async def update_expense( + self, + group_id: str, + expense_id: str, + updates: ExpenseUpdateRequest, + user_id: str, + ) -> ExpenseResponse: """Update an expense""" - + try: # Validate ObjectId format try: expense_obj_id = ObjectId(expense_id) except Exception as e: raise ValueError(f"Invalid expense ID format: {expense_id}") - + # Verify user access and that they created the expense - expense_doc = await self.expenses_collection.find_one({ - "_id": expense_obj_id, - "groupId": group_id, - "createdBy": user_id - }) + expense_doc = await self.expenses_collection.find_one( + {"_id": expense_obj_id, "groupId": group_id, "createdBy": user_id} + ) if not expense_doc: raise ValueError("Expense not found or not authorized to edit") @@ -261,31 +288,34 @@ async def update_expense(self, group_id: str, expense_id: str, updates: ExpenseU if updates.splits is not None and updates.amount is not None: total_split = sum(split.amount for split in updates.splits) if abs(total_split - updates.amount) > 0.01: - raise ValueError('Split amounts must sum to total expense amount') - + raise ValueError( + "Split amounts must sum to total expense amount") + # If only splits are being updated, validate against current amount elif updates.splits is not None: current_amount = expense_doc["amount"] total_split = sum(split.amount for split in updates.splits) if abs(total_split - current_amount) > 0.01: - raise ValueError('Split amounts must sum to current expense amount') + raise ValueError( + "Split amounts must sum to current expense amount") # Store original data for history original_data = { "amount": expense_doc["amount"], "description": expense_doc["description"], - "splits": expense_doc["splits"] + "splits": expense_doc["splits"], } # Build update document update_doc = {"updatedAt": datetime.utcnow()} - + if updates.description is not None: update_doc["description"] = updates.description if updates.amount is not None: update_doc["amount"] = updates.amount if updates.splits is not None: - update_doc["splits"] = [split.model_dump() for split in updates.splits] + update_doc["splits"] = [split.model_dump() + for split in updates.splits] if updates.tags is not None: update_doc["tags"] = updates.tags if updates.receiptUrls is not None: @@ -295,37 +325,38 @@ async def update_expense(self, group_id: str, expense_id: str, updates: ExpenseU if len(update_doc) > 1: # More than just updatedAt # Get user name try: - user = await self.users_collection.find_one({"_id": ObjectId(user_id)}) - user_name = user.get("name", "Unknown User") if user else "Unknown User" + user = await self.users_collection.find_one( + {"_id": ObjectId(user_id)} + ) + user_name = ( + user.get( + "name", "Unknown User") if user else "Unknown User" + ) except: user_name = "Unknown User" - + history_entry = { "_id": ObjectId(), "userId": user_id, "userName": user_name, "beforeData": original_data, - "editedAt": datetime.utcnow() + "editedAt": datetime.utcnow(), } # Update expense with both $set and $push operations result = await self.expenses_collection.update_one( {"_id": expense_obj_id}, - { - "$set": update_doc, - "$push": {"history": history_entry} - } + {"$set": update_doc, "$push": {"history": history_entry}}, ) - + if result.matched_count == 0: raise ValueError("Expense not found during update") else: # No actual changes, just update the timestamp result = await self.expenses_collection.update_one( - {"_id": expense_obj_id}, - {"$set": update_doc} + {"_id": expense_obj_id}, {"$set": update_doc} ) - + if result.matched_count == 0: raise ValueError("Expense not found during update") @@ -333,40 +364,52 @@ async def update_expense(self, group_id: str, expense_id: str, updates: ExpenseU if updates.splits is not None or updates.amount is not None: try: # Delete old settlements for this expense - await self.settlements_collection.delete_many({"expenseId": expense_id}) - + await self.settlements_collection.delete_many( + {"expenseId": expense_id} + ) + # Get updated expense - updated_expense = await self.expenses_collection.find_one({"_id": expense_obj_id}) - + updated_expense = await self.expenses_collection.find_one( + {"_id": expense_obj_id} + ) + if updated_expense: # Create new settlements - await self._create_settlements_for_expense(updated_expense, user_id) + await self._create_settlements_for_expense( + updated_expense, user_id + ) except Exception as e: - logger.error(f"Warning: Failed to recalculate settlements: {e}",exc_info=True) + logger.error( + f"Warning: Failed to recalculate settlements: {e}", + exc_info=True, + ) # Continue anyway, as the expense update succeeded # Return updated expense - updated_expense = await self.expenses_collection.find_one({"_id": expense_obj_id}) + updated_expense = await self.expenses_collection.find_one( + {"_id": expense_obj_id} + ) if not updated_expense: raise ValueError("Failed to retrieve updated expense") - + return await self._expense_doc_to_response(updated_expense) - + except ValueError: raise except Exception as e: - logger.error(f"Error in update_expense: {str(e)}",exc_info=True) + logger.error(f"Error in update_expense: {str(e)}", exc_info=True) raise Exception(f"Database error during expense update: {str(e)}") - async def delete_expense(self, group_id: str, expense_id: str, user_id: str) -> bool: + async def delete_expense( + self, group_id: str, expense_id: str, user_id: str + ) -> bool: """Delete an expense""" - + # Verify user access and that they created the expense - expense_doc = await self.expenses_collection.find_one({ - "_id": ObjectId(expense_id), - "groupId": group_id, - "createdBy": user_id - }) + expense_doc = await self.expenses_collection.find_one( + {"_id": ObjectId(expense_id), "groupId": group_id, + "createdBy": user_id} + ) if not expense_doc: raise ValueError("Expense not found or not authorized to delete") @@ -374,25 +417,30 @@ async def delete_expense(self, group_id: str, expense_id: str, user_id: str) -> await self.settlements_collection.delete_many({"expenseId": expense_id}) # Delete the expense - result = await self.expenses_collection.delete_one({"_id": ObjectId(expense_id)}) + result = await self.expenses_collection.delete_one( + {"_id": ObjectId(expense_id)} + ) return result.deleted_count > 0 - async def calculate_optimized_settlements(self, group_id: str, algorithm: str = "advanced") -> List[OptimizedSettlement]: + async def calculate_optimized_settlements( + self, group_id: str, algorithm: str = "advanced" + ) -> List[OptimizedSettlement]: """Calculate optimized settlements using specified algorithm""" - + if algorithm == "normal": return await self._calculate_normal_settlements(group_id) else: return await self._calculate_advanced_settlements(group_id) - async def _calculate_normal_settlements(self, group_id: str) -> List[OptimizedSettlement]: + async def _calculate_normal_settlements( + self, group_id: str + ) -> List[OptimizedSettlement]: """Normal splitting algorithm - simplifies only direct relationships""" - + # Get all pending settlements for the group - settlements = await self.settlements_collection.find({ - "groupId": group_id, - "status": "pending" - }).to_list(None) + settlements = await self.settlements_collection.find( + {"groupId": group_id, "status": "pending"} + ).to_list(None) # Calculate net balances between each pair of users net_balances = defaultdict(lambda: defaultdict(float)) @@ -402,10 +450,10 @@ async def _calculate_normal_settlements(self, group_id: str) -> List[OptimizedSe payer = settlement["payerId"] payee = settlement["payeeId"] amount = settlement["amount"] - + user_names[payer] = settlement["payerName"] user_names[payee] = settlement["payeeName"] - + # Net amount that payer owes to payee net_balances[payer][payee] += amount @@ -415,36 +463,41 @@ async def _calculate_normal_settlements(self, group_id: str) -> List[OptimizedSe for payee in net_balances[payer]: payer_owes_payee = net_balances[payer][payee] payee_owes_payer = net_balances[payee][payer] - + net_amount = payer_owes_payee - payee_owes_payer - + if net_amount > 0.01: # Payer owes payee - optimized.append(OptimizedSettlement( - fromUserId=payer, - toUserId=payee, - fromUserName=user_names.get(payer, "Unknown"), - toUserName=user_names.get(payee, "Unknown"), - amount=round(net_amount, 2) - )) + optimized.append( + OptimizedSettlement( + fromUserId=payer, + toUserId=payee, + fromUserName=user_names.get(payer, "Unknown"), + toUserName=user_names.get(payee, "Unknown"), + amount=round(net_amount, 2), + ) + ) elif net_amount < -0.01: # Payee owes payer - optimized.append(OptimizedSettlement( - fromUserId=payee, - toUserId=payer, - fromUserName=user_names.get(payee, "Unknown"), - toUserName=user_names.get(payer, "Unknown"), - amount=round(-net_amount, 2) - )) + optimized.append( + OptimizedSettlement( + fromUserId=payee, + toUserId=payer, + fromUserName=user_names.get(payee, "Unknown"), + toUserName=user_names.get(payer, "Unknown"), + amount=round(-net_amount, 2), + ) + ) return optimized - async def _calculate_advanced_settlements(self, group_id: str) -> List[OptimizedSettlement]: + async def _calculate_advanced_settlements( + self, group_id: str + ) -> List[OptimizedSettlement]: """Advanced settlement algorithm using graph optimization""" - + # Get all pending settlements for the group - settlements = await self.settlements_collection.find({ - "groupId": group_id, - "status": "pending" - }).to_list(None) + settlements = await self.settlements_collection.find( + {"groupId": group_id, "status": "pending"} + ).to_list(None) # Calculate net balance for each user (what they owe - what they are owed) user_balances = defaultdict(float) @@ -454,10 +507,10 @@ async def _calculate_advanced_settlements(self, group_id: str) -> List[Optimized payer = settlement["payerId"] payee = settlement["payeeId"] amount = settlement["amount"] - + user_names[payer] = settlement["payerName"] user_names[payee] = settlement["payeeName"] - + # Payer paid for payee, so payee owes payer user_balances[payee] += amount # Positive means owes money user_balances[payer] -= amount # Negative means is owed money @@ -489,13 +542,15 @@ async def _calculate_advanced_settlements(self, group_id: str) -> List[Optimized settlement_amount = min(debt_amount, credit_amount) if settlement_amount > 0.01: - optimized.append(OptimizedSettlement( - fromUserId=debtor_id, - toUserId=creditor_id, - fromUserName=user_names.get(debtor_id, "Unknown"), - toUserName=user_names.get(creditor_id, "Unknown"), - amount=round(settlement_amount, 2) - )) + optimized.append( + OptimizedSettlement( + fromUserId=debtor_id, + toUserId=creditor_id, + fromUserName=user_names.get(debtor_id, "Unknown"), + toUserName=user_names.get(creditor_id, "Unknown"), + amount=round(settlement_amount, 2), + ) + ) # Update remaining amounts debtors[i][1] -= settlement_amount @@ -511,22 +566,31 @@ async def _calculate_advanced_settlements(self, group_id: str) -> List[Optimized return optimized - async def create_manual_settlement(self, group_id: str, settlement_data: SettlementCreateRequest, user_id: str) -> Settlement: + async def create_manual_settlement( + self, group_id: str, settlement_data: SettlementCreateRequest, user_id: str + ) -> Settlement: """Create a manual settlement record""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") # Get user names - users = await self.users_collection.find({ - "_id": {"$in": [ObjectId(settlement_data.payer_id), ObjectId(settlement_data.payee_id)]} - }).to_list(None) - user_names = {str(user["_id"]): user.get("name", "Unknown") for user in users} + users = await self.users_collection.find( + { + "_id": { + "$in": [ + ObjectId(settlement_data.payer_id), + ObjectId(settlement_data.payee_id), + ] + } + } + ).to_list(None) + user_names = {str(user["_id"]): user.get( + "name", "Unknown") for user in users} settlement_doc = { "_id": ObjectId(), @@ -540,56 +604,67 @@ async def create_manual_settlement(self, group_id: str, settlement_data: Settlem "status": "completed", "description": settlement_data.description or "Manual settlement", "paidAt": settlement_data.paidAt or datetime.utcnow(), - "createdAt": datetime.utcnow() + "createdAt": datetime.utcnow(), } await self.settlements_collection.insert_one(settlement_doc) - return Settlement(**{ - **settlement_doc, - "_id": str(settlement_doc["_id"]) - }) + return Settlement(**{**settlement_doc, "_id": str(settlement_doc["_id"])}) async def _expense_doc_to_response(self, doc: Dict[str, Any]) -> ExpenseResponse: """Convert expense document to response model""" - return ExpenseResponse(**{ - **doc, - "_id": str(doc["_id"]) - }) + return ExpenseResponse(**{**doc, "_id": str(doc["_id"])}) - async def _get_group_summary(self, group_id: str, optimized_settlements: List[OptimizedSettlement]) -> Dict[str, Any]: + async def _get_group_summary( + self, group_id: str, optimized_settlements: List[OptimizedSettlement] + ) -> Dict[str, Any]: """Get group summary statistics""" - + # Get total expenses pipeline = [ {"$match": {"groupId": group_id}}, - {"$group": { - "_id": None, - "totalExpenses": {"$sum": "$amount"}, - "expenseCount": {"$sum": 1} - }} + { + "$group": { + "_id": None, + "totalExpenses": {"$sum": "$amount"}, + "expenseCount": {"$sum": 1}, + } + }, ] - expense_result = await self.expenses_collection.aggregate(pipeline).to_list(None) - expense_stats = expense_result[0] if expense_result else {"totalExpenses": 0, "expenseCount": 0} + expense_result = await self.expenses_collection.aggregate(pipeline).to_list( + None + ) + expense_stats = ( + expense_result[0] + if expense_result + else {"totalExpenses": 0, "expenseCount": 0} + ) # Get total settlements count - settlement_count = await self.settlements_collection.count_documents({"groupId": group_id}) + settlement_count = await self.settlements_collection.count_documents( + {"groupId": group_id} + ) return { "totalExpenses": expense_stats["totalExpenses"], "totalSettlements": settlement_count, - "optimizedSettlements": optimized_settlements + "optimizedSettlements": optimized_settlements, } - async def get_group_settlements(self, group_id: str, user_id: str, status_filter: Optional[str] = None, - page: int = 1, limit: int = 50) -> Dict[str, Any]: + async def get_group_settlements( + self, + group_id: str, + user_id: str, + status_filter: Optional[str] = None, + page: int = 1, + limit: int = 50, + ) -> Dict[str, Any]: """Get settlements for a group with pagination""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") @@ -603,101 +678,104 @@ async def get_group_settlements(self, group_id: str, user_id: str, status_filter # Get settlements with pagination skip = (page - 1) * limit - settlements_docs = await self.settlements_collection.find(query).sort("createdAt", -1).skip(skip).limit(limit).to_list(None) + settlements_docs = ( + await self.settlements_collection.find(query) + .sort("createdAt", -1) + .skip(skip) + .limit(limit) + .to_list(None) + ) settlements = [] for doc in settlements_docs: - settlement = Settlement(**{ - **doc, - "_id": str(doc["_id"]) - }) + settlement = Settlement(**{**doc, "_id": str(doc["_id"])}) settlements.append(settlement) return { "settlements": settlements, "total": total, "page": page, - "limit": limit + "limit": limit, } - async def get_settlement_by_id(self, group_id: str, settlement_id: str, user_id: str) -> Settlement: + async def get_settlement_by_id( + self, group_id: str, settlement_id: str, user_id: str + ) -> Settlement: """Get a single settlement by ID""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") - settlement_doc = await self.settlements_collection.find_one({ - "_id": ObjectId(settlement_id), - "groupId": group_id - }) - + settlement_doc = await self.settlements_collection.find_one( + {"_id": ObjectId(settlement_id), "groupId": group_id} + ) + if not settlement_doc: raise ValueError("Settlement not found") - return Settlement(**{ - **settlement_doc, - "_id": str(settlement_doc["_id"]) - }) + return Settlement(**{**settlement_doc, "_id": str(settlement_doc["_id"])}) - async def update_settlement_status(self, group_id: str, settlement_id: str, status: SettlementStatus, - paid_at: Optional[datetime] = None, user_id: str = None) -> Settlement: + async def update_settlement_status( + self, + group_id: str, + settlement_id: str, + status: SettlementStatus, + paid_at: Optional[datetime] = None, + user_id: str = None, + ) -> Settlement: """Update settlement status""" - - update_doc = { - "status": status.value, - "updatedAt": datetime.utcnow() - } - + + update_doc = {"status": status.value, "updatedAt": datetime.utcnow()} + if paid_at: update_doc["paidAt"] = paid_at result = await self.settlements_collection.update_one( - {"_id": ObjectId(settlement_id), "groupId": group_id}, - {"$set": update_doc} + {"_id": ObjectId(settlement_id), "groupId": group_id}, { + "$set": update_doc} ) if result.matched_count == 0: raise ValueError("Settlement not found") # Get updated settlement - settlement_doc = await self.settlements_collection.find_one({"_id": ObjectId(settlement_id)}) - - return Settlement(**{ - **settlement_doc, - "_id": str(settlement_doc["_id"]) - }) - - async def delete_settlement(self, group_id: str, settlement_id: str, user_id: str) -> bool: + settlement_doc = await self.settlements_collection.find_one( + {"_id": ObjectId(settlement_id)} + ) + + return Settlement(**{**settlement_doc, "_id": str(settlement_doc["_id"])}) + + async def delete_settlement( + self, group_id: str, settlement_id: str, user_id: str + ) -> bool: """Delete a settlement""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") - result = await self.settlements_collection.delete_one({ - "_id": ObjectId(settlement_id), - "groupId": group_id - }) + result = await self.settlements_collection.delete_one( + {"_id": ObjectId(settlement_id), "groupId": group_id} + ) return result.deleted_count > 0 - async def get_user_balance_in_group(self, group_id: str, target_user_id: str, current_user_id: str) -> Dict[str, Any]: + async def get_user_balance_in_group( + self, group_id: str, target_user_id: str, current_user_id: str + ) -> Dict[str, Any]: """Get a user's balance within a specific group""" - + # Verify current user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": current_user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": current_user_id} + ) if not group: raise ValueError("Group not found or user not a member") @@ -707,66 +785,70 @@ async def get_user_balance_in_group(self, group_id: str, target_user_id: str, cu # Calculate totals from settlements pipeline = [ - {"$match": { - "groupId": group_id, - "$or": [ - {"payerId": target_user_id}, - {"payeeId": target_user_id} - ] - }}, - {"$group": { - "_id": None, - "totalPaid": { - "$sum": { - "$cond": [ - {"$eq": ["$payerId", target_user_id]}, - "$amount", - 0 - ] - } - }, - "totalOwed": { - "$sum": { - "$cond": [ - {"$eq": ["$payeeId", target_user_id]}, - "$amount", - 0 - ] - } + { + "$match": { + "groupId": group_id, + "$or": [{"payerId": target_user_id}, {"payeeId": target_user_id}], } - }} + }, + { + "$group": { + "_id": None, + "totalPaid": { + "$sum": { + "$cond": [ + {"$eq": ["$payerId", target_user_id]}, + "$amount", + 0, + ] + } + }, + "totalOwed": { + "$sum": { + "$cond": [ + {"$eq": ["$payeeId", target_user_id]}, + "$amount", + 0, + ] + } + }, + } + }, ] result = await self.settlements_collection.aggregate(pipeline).to_list(None) - balance_data = result[0] if result else {"totalPaid": 0, "totalOwed": 0} + balance_data = result[0] if result else { + "totalPaid": 0, "totalOwed": 0} total_paid = balance_data["totalPaid"] total_owed = balance_data["totalOwed"] net_balance = total_paid - total_owed # Get pending settlements - pending_settlements = await self.settlements_collection.find({ - "groupId": group_id, - "payeeId": target_user_id, - "status": "pending" - }).to_list(None) + pending_settlements = await self.settlements_collection.find( + {"groupId": group_id, "payeeId": target_user_id, "status": "pending"} + ).to_list(None) pending_settlement_objects = [] for doc in pending_settlements: - settlement = Settlement(**{ - **doc, - "_id": str(doc["_id"]) - }) + settlement = Settlement(**{**doc, "_id": str(doc["_id"])}) pending_settlement_objects.append(settlement) # Get recent expenses where user was involved - recent_expenses = await self.expenses_collection.find({ - "groupId": group_id, - "$or": [ - {"createdBy": target_user_id}, - {"splits.userId": target_user_id} - ] - }).sort("createdAt", -1).limit(5).to_list(None) + recent_expenses = ( + await self.expenses_collection.find( + { + "groupId": group_id, + "$or": [ + {"createdBy": target_user_id}, + {"splits.userId": target_user_id}, + ], + } + ) + .sort("createdAt", -1) + .limit(5) + .to_list(None) + ) recent_expense_data = [] for expense in recent_expenses: @@ -777,12 +859,14 @@ async def get_user_balance_in_group(self, group_id: str, target_user_id: str, cu user_share = split["amount"] break - recent_expense_data.append({ - "expenseId": str(expense["_id"]), - "description": expense["description"], - "userShare": user_share, - "createdAt": expense["createdAt"] - }) + recent_expense_data.append( + { + "expenseId": str(expense["_id"]), + "description": expense["description"], + "userShare": user_share, + "createdAt": expense["createdAt"], + } + ) return { "userId": target_user_id, @@ -792,16 +876,16 @@ async def get_user_balance_in_group(self, group_id: str, target_user_id: str, cu "netBalance": net_balance, "owesYou": net_balance > 0, "pendingSettlements": pending_settlement_objects, - "recentExpenses": recent_expense_data + "recentExpenses": recent_expense_data, } async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: """Get cross-group friend balances for a user""" - + # Get all groups user belongs to - groups = await self.groups_collection.find({ - "members.userId": user_id - }).to_list(None) + groups = await self.groups_collection.find({"members.userId": user_id}).to_list( + None + ) friends_balance = [] user_totals = {"totalOwedToYou": 0, "totalYouOwe": 0} @@ -814,10 +898,11 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: friend_ids.add(member["userId"]) # Get user names - users = await self.users_collection.find({ - "_id": {"$in": [ObjectId(uid) for uid in friend_ids]} - }).to_list(None) - user_names = {str(user["_id"]): user.get("name", "Unknown") for user in users} + users = await self.users_collection.find( + {"_id": {"$in": [ObjectId(uid) for uid in friend_ids]}} + ).to_list(None) + user_names = {str(user["_id"]): user.get( + "name", "Unknown") for user in users} for friend_id in friend_ids: friend_balance_data = { @@ -827,7 +912,7 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: "netBalance": 0, "owesYou": False, "breakdown": [], - "lastActivity": datetime.utcnow() + "lastActivity": datetime.utcnow(), } total_friend_balance = 0 @@ -835,68 +920,90 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: # Calculate balance for each group for group in groups: group_id = str(group["_id"]) - + # Check if friend is in this group - friend_in_group = any(member["userId"] == friend_id for member in group["members"]) + friend_in_group = any( + member["userId"] == friend_id for member in group["members"] + ) if not friend_in_group: continue # Calculate net balance between user and friend in this group pipeline = [ - {"$match": { - "groupId": group_id, - "$or": [ - {"payerId": user_id, "payeeId": friend_id}, - {"payerId": friend_id, "payeeId": user_id} - ] - }}, - {"$group": { - "_id": None, - "userOwes": { - "$sum": { - "$cond": [ - {"$and": [ - {"$eq": ["$payerId", friend_id]}, - {"$eq": ["$payeeId", user_id]} - ]}, - "$amount", - 0 - ] - } - }, - "friendOwes": { - "$sum": { - "$cond": [ - {"$and": [ - {"$eq": ["$payerId", user_id]}, - {"$eq": ["$payeeId", friend_id]} - ]}, - "$amount", - 0 - ] - } + { + "$match": { + "groupId": group_id, + "$or": [ + {"payerId": user_id, "payeeId": friend_id}, + {"payerId": friend_id, "payeeId": user_id}, + ], + } + }, + { + "$group": { + "_id": None, + "userOwes": { + "$sum": { + "$cond": [ + { + "$and": [ + {"$eq": [ + "$payerId", friend_id]}, + {"$eq": ["$payeeId", user_id]}, + ] + }, + "$amount", + 0, + ] + } + }, + "friendOwes": { + "$sum": { + "$cond": [ + { + "$and": [ + {"$eq": ["$payerId", user_id]}, + {"$eq": [ + "$payeeId", friend_id]}, + ] + }, + "$amount", + 0, + ] + } + }, } - }} + }, ] - result = await self.settlements_collection.aggregate(pipeline).to_list(None) - balance_data = result[0] if result else {"userOwes": 0, "friendOwes": 0} + result = await self.settlements_collection.aggregate(pipeline).to_list( + None + ) + balance_data = result[0] if result else { + "userOwes": 0, "friendOwes": 0} - group_balance = balance_data["friendOwes"] - balance_data["userOwes"] + group_balance = balance_data["friendOwes"] - \ + balance_data["userOwes"] total_friend_balance += group_balance - if abs(group_balance) > 0.01: # Only include if there's a significant balance - friend_balance_data["breakdown"].append({ - "groupId": group_id, - "groupName": group["name"], - "balance": group_balance, - "owesYou": group_balance > 0 - }) + if ( + abs(group_balance) > 0.01 + ): # Only include if there's a significant balance + friend_balance_data["breakdown"].append( + { + "groupId": group_id, + "groupName": group["name"], + "balance": group_balance, + "owesYou": group_balance > 0, + } + ) - if abs(total_friend_balance) > 0.01: # Only include friends with non-zero balance + if ( + abs(total_friend_balance) > 0.01 + ): # Only include friends with non-zero balance friend_balance_data["netBalance"] = total_friend_balance friend_balance_data["owesYou"] = total_friend_balance > 0 - + if total_friend_balance > 0: user_totals["totalOwedToYou"] += total_friend_balance else: @@ -909,19 +1016,20 @@ async def get_friends_balance_summary(self, user_id: str) -> Dict[str, Any]: "summary": { "totalOwedToYou": user_totals["totalOwedToYou"], "totalYouOwe": user_totals["totalYouOwe"], - "netBalance": user_totals["totalOwedToYou"] - user_totals["totalYouOwe"], + "netBalance": user_totals["totalOwedToYou"] + - user_totals["totalYouOwe"], "friendCount": len(friends_balance), - "activeGroups": len(groups) - } + "activeGroups": len(groups), + }, } async def get_overall_balance_summary(self, user_id: str) -> Dict[str, Any]: """Get overall balance summary for a user""" - + # Get all groups user belongs to - groups = await self.groups_collection.find({ - "members.userId": user_id - }).to_list(None) + groups = await self.groups_collection.find({"members.userId": user_id}).to_list( + None + ) total_owed_to_you = 0 total_you_owe = 0 @@ -929,50 +1037,49 @@ async def get_overall_balance_summary(self, user_id: str) -> Dict[str, Any]: for group in groups: group_id = str(group["_id"]) - + # Calculate user's balance in this group pipeline = [ - {"$match": { - "groupId": group_id, - "$or": [ - {"payerId": user_id}, - {"payeeId": user_id} - ] - }}, - {"$group": { - "_id": None, - "totalPaid": { - "$sum": { - "$cond": [ - {"$eq": ["$payerId", user_id]}, - "$amount", - 0 - ] - } - }, - "totalOwed": { - "$sum": { - "$cond": [ - {"$eq": ["$payeeId", user_id]}, - "$amount", - 0 - ] - } + { + "$match": { + "groupId": group_id, + "$or": [{"payerId": user_id}, {"payeeId": user_id}], } - }} + }, + { + "$group": { + "_id": None, + "totalPaid": { + "$sum": { + "$cond": [{"$eq": ["$payerId", user_id]}, "$amount", 0] + } + }, + "totalOwed": { + "$sum": { + "$cond": [{"$eq": ["$payeeId", user_id]}, "$amount", 0] + } + }, + } + }, ] result = await self.settlements_collection.aggregate(pipeline).to_list(None) - balance_data = result[0] if result else {"totalPaid": 0, "totalOwed": 0} + balance_data = result[0] if result else { + "totalPaid": 0, "totalOwed": 0} + + group_balance = balance_data["totalPaid"] - \ + balance_data["totalOwed"] - group_balance = balance_data["totalPaid"] - balance_data["totalOwed"] - - if abs(group_balance) > 0.01: # Only include groups with significant balance - groups_summary.append({ - "group_id": group_id, - "group_name": group["name"], - "yourBalanceInGroup": group_balance - }) + if ( + abs(group_balance) > 0.01 + ): # Only include groups with significant balance + groups_summary.append( + { + "group_id": group_id, + "group_name": group["name"], + "yourBalanceInGroup": group_balance, + } + ) if group_balance > 0: total_owed_to_you += group_balance @@ -984,18 +1091,23 @@ async def get_overall_balance_summary(self, user_id: str) -> Dict[str, Any]: "totalYouOwe": total_you_owe, "netBalance": total_owed_to_you - total_you_owe, "currency": "USD", - "groupsSummary": groups_summary + "groupsSummary": groups_summary, } - async def get_group_analytics(self, group_id: str, user_id: str, period: str = "month", - year: int = None, month: int = None) -> Dict[str, Any]: + async def get_group_analytics( + self, + group_id: str, + user_id: str, + period: str = "month", + year: int = None, + month: int = None, + ) -> Dict[str, Any]: """Get expense analytics for a group""" - + # Verify user access - group = await self.groups_collection.find_one({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + group = await self.groups_collection.find_one( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) if not group: raise ValueError("Group not found or user not a member") @@ -1022,10 +1134,10 @@ async def get_group_analytics(self, group_id: str, user_id: str, period: str = " period_str = f"{now.year}-{now.month:02d}" # Get expenses in the period - expenses = await self.expenses_collection.find({ - "groupId": group_id, - "createdAt": {"$gte": start_date, "$lt": end_date} - }).to_list(None) + expenses = await self.expenses_collection.find( + {"groupId": group_id, "createdAt": { + "$gte": start_date, "$lt": end_date}} + ).to_list(None) total_expenses = sum(expense["amount"] for expense in expenses) expense_count = len(expenses) @@ -1039,50 +1151,71 @@ async def get_group_analytics(self, group_id: str, user_id: str, period: str = " tag_stats[tag]["count"] += 1 top_categories = [] - for tag, stats in sorted(tag_stats.items(), key=lambda x: x[1]["amount"], reverse=True): - top_categories.append({ - "tag": tag, - "amount": stats["amount"], - "count": stats["count"], - "percentage": round((stats["amount"] / total_expenses * 100) if total_expenses > 0 else 0, 1) - }) + for tag, stats in sorted( + tag_stats.items(), key=lambda x: x[1]["amount"], reverse=True + ): + top_categories.append( + { + "tag": tag, + "amount": stats["amount"], + "count": stats["count"], + "percentage": round( + ( + (stats["amount"] / total_expenses * 100) + if total_expenses > 0 + else 0 + ), + 1, + ), + } + ) # Member contributions member_contributions = [] - group_members = {member["userId"]: member for member in group["members"]} - + group_members = {member["userId"] : member for member in group["members"]} + for member_id in group_members: # Get user info user = await self.users_collection.find_one({"_id": ObjectId(member_id)}) user_name = user.get("name", "Unknown") if user else "Unknown" - + # Calculate contributions - total_paid = sum(expense["amount"] for expense in expenses if expense["createdBy"] == member_id) - + total_paid = sum( + expense["amount"] + for expense in expenses + if expense["createdBy"] == member_id + ) + total_owed = 0 for expense in expenses: for split in expense["splits"]: if split["userId"] == member_id: total_owed += split["amount"] - member_contributions.append({ - "userId": member_id, - "userName": user_name, - "totalPaid": total_paid, - "totalOwed": total_owed, - "netContribution": total_paid - total_owed - }) + member_contributions.append( + { + "userId": member_id, + "userName": user_name, + "totalPaid": total_paid, + "totalOwed": total_owed, + "netContribution": total_paid - total_owed, + } + ) # Expense trends (daily) expense_trends = [] current_date = start_date while current_date < end_date: - day_expenses = [e for e in expenses if e["createdAt"].date() == current_date.date()] - expense_trends.append({ - "date": current_date.strftime("%Y-%m-%d"), - "amount": sum(e["amount"] for e in day_expenses), - "count": len(day_expenses) - }) + day_expenses = [ + e for e in expenses if e["createdAt"].date() == current_date.date() + ] + expense_trends.append( + { + "date": current_date.strftime("%Y-%m-%d"), + "amount": sum(e["amount"] for e in day_expenses), + "count": len(day_expenses), + } + ) current_date += timedelta(days=1) return { @@ -1092,7 +1225,9 @@ async def get_group_analytics(self, group_id: str, user_id: str, period: str = " "avgExpenseAmount": round(avg_expense, 2), "topCategories": top_categories[:10], # Top 10 categories "memberContributions": member_contributions, - "expenseTrends": expense_trends + "expenseTrends": expense_trends, } + + # Create service instance expense_service = ExpenseService() diff --git a/backend/app/groups/routes.py b/backend/app/groups/routes.py index 5d3a12d7..eccc8295 100644 --- a/backend/app/groups/routes.py +++ b/backend/app/groups/routes.py @@ -1,89 +1,108 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from typing import Any, Dict, List + +from app.auth.security import get_current_user from app.groups.schemas import ( - GroupCreateRequest, GroupResponse, GroupListResponse, GroupUpdateRequest, - JoinGroupRequest, JoinGroupResponse, MemberRoleUpdateRequest, - LeaveGroupResponse, DeleteGroupResponse, RemoveMemberResponse, - GroupMemberWithDetails + DeleteGroupResponse, + GroupCreateRequest, + GroupListResponse, + GroupMemberWithDetails, + GroupResponse, + GroupUpdateRequest, + JoinGroupRequest, + JoinGroupResponse, + LeaveGroupResponse, + MemberRoleUpdateRequest, + RemoveMemberResponse, ) from app.groups.service import group_service -from app.auth.security import get_current_user -from typing import Dict, Any, List +from fastapi import APIRouter, Depends, HTTPException, status router = APIRouter(prefix="/groups", tags=["Groups"]) + @router.post("", response_model=GroupResponse, status_code=status.HTTP_201_CREATED) async def create_group( group_data: GroupCreateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Create a new group""" group = await group_service.create_group( - group_data.model_dump(exclude_unset=True), - current_user["_id"] + group_data.model_dump(exclude_unset=True), current_user["_id"] ) if not group: raise HTTPException(status_code=500, detail="Failed to create group") return group + @router.get("", response_model=GroupListResponse) async def list_user_groups(current_user: Dict[str, Any] = Depends(get_current_user)): """List all groups the current user belongs to""" groups = await group_service.get_user_groups(current_user["_id"]) return {"groups": groups} + @router.get("/{group_id}", response_model=GroupResponse) async def get_group_details( - group_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) ): """Get group details including members""" group = await group_service.get_group_by_id(group_id, current_user["_id"]) if not group: - raise HTTPException(status_code=404, detail="Group not found or access denied") + raise HTTPException( + status_code=404, detail="Group not found or access denied") return group + @router.patch("/{group_id}", response_model=GroupResponse) async def update_group_metadata( group_id: str, updates: GroupUpdateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Update group metadata (admin only)""" update_data = updates.model_dump(exclude_unset=True) if not update_data: - raise HTTPException(status_code=400, detail="No update fields provided") - - updated_group = await group_service.update_group(group_id, update_data, current_user["_id"]) + raise HTTPException( + status_code=400, detail="No update fields provided") + + updated_group = await group_service.update_group( + group_id, update_data, current_user["_id"] + ) if not updated_group: - raise HTTPException(status_code=404, detail="Group not found or access denied") + raise HTTPException( + status_code=404, detail="Group not found or access denied") return updated_group + @router.delete("/{group_id}", response_model=DeleteGroupResponse) async def delete_group( - group_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) ): """Delete a group (admin only)""" deleted = await group_service.delete_group(group_id, current_user["_id"]) if not deleted: - raise HTTPException(status_code=404, detail="Group not found or access denied") + raise HTTPException( + status_code=404, detail="Group not found or access denied") return DeleteGroupResponse(success=True, message="Group deleted successfully") + @router.post("/join", response_model=JoinGroupResponse) async def join_group_by_code( join_data: JoinGroupRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Join a group using a join code""" - group = await group_service.join_group_by_code(join_data.joinCode, current_user["_id"]) + group = await group_service.join_group_by_code( + join_data.joinCode, current_user["_id"] + ) if not group: raise HTTPException(status_code=404, detail="Invalid join code") return {"group": group} + @router.post("/{group_id}/leave", response_model=LeaveGroupResponse) async def leave_group( - group_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) ): """Leave a group (only if no outstanding balances)""" left = await group_service.leave_group(group_id, current_user["_id"]) @@ -91,38 +110,43 @@ async def leave_group( raise HTTPException(status_code=400, detail="Failed to leave group") return LeaveGroupResponse(success=True, message="Successfully left the group") + @router.get("/{group_id}/members", response_model=List[GroupMemberWithDetails]) async def get_group_members( - group_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + group_id: str, current_user: Dict[str, Any] = Depends(get_current_user) ): """Get list of group members with detailed user information""" members = await group_service.get_group_members(group_id, current_user["_id"]) return members + @router.patch("/{group_id}/members/{member_id}", response_model=Dict[str, str]) async def update_member_role( group_id: str, member_id: str, role_update: MemberRoleUpdateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Change member role (admin only)""" updated = await group_service.update_member_role( group_id, member_id, role_update.role, current_user["_id"] ) if not updated: - raise HTTPException(status_code=400, detail="Failed to update member role") + raise HTTPException( + status_code=400, detail="Failed to update member role") return {"message": f"Member role updated to {role_update.role}"} + @router.delete("/{group_id}/members/{member_id}", response_model=RemoveMemberResponse) async def remove_group_member( group_id: str, member_id: str, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): """Remove a member from the group (admin only)""" - removed = await group_service.remove_member(group_id, member_id, current_user["_id"]) + removed = await group_service.remove_member( + group_id, member_id, current_user["_id"] + ) if not removed: raise HTTPException(status_code=400, detail="Failed to remove member") return RemoveMemberResponse(success=True, message="Member removed successfully") diff --git a/backend/app/groups/schemas.py b/backend/app/groups/schemas.py index 78ddd6d0..71647ba7 100644 --- a/backend/app/groups/schemas.py +++ b/backend/app/groups/schemas.py @@ -1,27 +1,33 @@ -from pydantic import BaseModel, Field,ConfigDict -from typing import Optional, List from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel, ConfigDict, Field + class GroupMember(BaseModel): userId: str role: str = "member" # "admin" or "member" joinedAt: datetime + class GroupMemberWithDetails(BaseModel): userId: str role: str = "member" # "admin" or "member" joinedAt: datetime user: Optional[dict] = None # Contains user details like name, email + class GroupCreateRequest(BaseModel): name: str = Field(..., min_length=1, max_length=100) currency: Optional[str] = "USD" imageUrl: Optional[str] = None + class GroupUpdateRequest(BaseModel): name: Optional[str] = Field(None, min_length=1, max_length=100) imageUrl: Optional[str] = None + class GroupResponse(BaseModel): id: str = Field(alias="_id") name: str @@ -34,26 +40,33 @@ class GroupResponse(BaseModel): model_config = ConfigDict(populate_by_name=True) + class GroupListResponse(BaseModel): groups: List[GroupResponse] + class JoinGroupRequest(BaseModel): joinCode: str = Field(..., min_length=1) + class JoinGroupResponse(BaseModel): group: GroupResponse + class MemberRoleUpdateRequest(BaseModel): role: str = Field(..., pattern="^(admin|member)$") + class LeaveGroupResponse(BaseModel): success: bool message: str + class DeleteGroupResponse(BaseModel): success: bool message: str + class RemoveMemberResponse(BaseModel): success: bool message: str diff --git a/backend/app/groups/service.py b/backend/app/groups/service.py index ad920fea..6d40a7bc 100644 --- a/backend/app/groups/service.py +++ b/backend/app/groups/service.py @@ -1,10 +1,12 @@ -from fastapi import HTTPException, status -from app.database import get_database -from bson import ObjectId -from datetime import datetime, timezone -from typing import Optional, Dict, Any, List import secrets import string +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from app.database import get_database +from bson import ObjectId +from fastapi import HTTPException, status + class GroupService: def __init__(self): @@ -16,13 +18,15 @@ def get_db(self): def generate_join_code(self, length: int = 6) -> str: """Generate a random alphanumeric join code""" characters = string.ascii_uppercase + string.digits - return ''.join(secrets.choice(characters) for _ in range(length)) + return "".join(secrets.choice(characters) for _ in range(length)) - async def _enrich_members_with_user_details(self, members: List[dict]) -> List[dict]: + async def _enrich_members_with_user_details( + self, members: List[dict] + ) -> List[dict]: """Private method to enrich member data with user details from users collection""" db = self.get_db() enriched_members = [] - + for member in members: member_user_id = member.get("userId") if member_user_id: @@ -30,39 +34,59 @@ async def _enrich_members_with_user_details(self, members: List[dict]) -> List[d # Fetch user details from users collection user_obj_id = ObjectId(member_user_id) user = await db.users.find_one({"_id": user_obj_id}) - + # Create enriched member object enriched_member = { "userId": member_user_id, "role": member.get("role", "member"), "joinedAt": member.get("joinedAt"), - "user": { - "name": user.get("name", f"User {member_user_id[-4:]}") if user else f"User {member_user_id[-4:]}", - "email": user.get("email", f"{member_user_id}@example.com") if user else f"{member_user_id}@example.com", - "avatar": user.get("imageUrl") or user.get("avatar") if user else None - } if user else { - "name": f"User {member_user_id[-4:]}", - "email": f"{member_user_id}@example.com", - "avatar": None - } + "user": ( + { + "name": ( + user.get( + "name", f"User {member_user_id[-4:]}") + if user + else f"User {member_user_id[-4:]}" + ), + "email": ( + user.get( + "email", f"{member_user_id}@example.com") + if user + else f"{member_user_id}@example.com" + ), + "avatar": ( + user.get("imageUrl") or user.get("avatar") + if user + else None + ), + } + if user + else { + "name": f"User {member_user_id[-4:]}", + "email": f"{member_user_id}@example.com", + "avatar": None, + } + ), } enriched_members.append(enriched_member) except Exception as e: # If user lookup fails, add member with basic info - enriched_members.append({ - "userId": member_user_id, - "role": member.get("role", "member"), - "joinedAt": member.get("joinedAt"), - "user": { - "name": f"User {member_user_id[-4:]}", - "email": f"{member_user_id}@example.com", - "avatar": None + enriched_members.append( + { + "userId": member_user_id, + "role": member.get("role", "member"), + "joinedAt": member.get("joinedAt"), + "user": { + "name": f"User {member_user_id[-4:]}", + "email": f"{member_user_id}@example.com", + "avatar": None, + }, } - }) + ) else: # Add member without user details if userId is missing enriched_members.append(member) - + return enriched_members def transform_group_document(self, group: dict) -> dict: @@ -73,7 +97,7 @@ def transform_group_document(self, group: dict) -> dict: group_id = str(group["_id"]) except Exception: return None - + return { "_id": group_id, "name": group.get("name"), @@ -82,13 +106,13 @@ def transform_group_document(self, group: dict) -> dict: "createdBy": group.get("createdBy"), "createdAt": group.get("createdAt"), "imageUrl": group.get("imageUrl"), - "members": group.get("members", []) + "members": group.get("members", []), } async def create_group(self, group_data: dict, user_id: str) -> dict: """Create a new group with the user as admin""" db = self.get_db() - + # Generate unique join code join_code = None for _ in range(10): # Try up to 10 times to generate unique code @@ -96,9 +120,11 @@ async def create_group(self, group_data: dict, user_id: str) -> dict: existing = await db.groups.find_one({"joinCode": join_code}) if not existing: break - + if not join_code: - raise HTTPException(status_code=500, detail="Failed to generate unique join code") + raise HTTPException( + status_code=500, detail="Failed to generate unique join code" + ) now = datetime.now(timezone.utc) group_doc = { @@ -108,13 +134,9 @@ async def create_group(self, group_data: dict, user_id: str) -> dict: "joinCode": join_code, "createdBy": user_id, "createdAt": now, - "members": [{ - "userId": user_id, - "role": "admin", - "joinedAt": now - }] + "members": [{"userId": user_id, "role": "admin", "joinedAt": now}], } - + result = await db.groups.insert_one(group_doc) created_group = await db.groups.find_one({"_id": result.inserted_id}) return self.transform_group_document(created_group) @@ -137,26 +159,27 @@ async def get_group_by_id(self, group_id: str, user_id: str) -> Optional[dict]: obj_id = ObjectId(group_id) except Exception: return None - - group = await db.groups.find_one({ - "_id": obj_id, - "members.userId": user_id - }) - + + group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) + if not group: return None - + # Transform the basic group document transformed_group = self.transform_group_document(group) - + if transformed_group and transformed_group.get("members"): # Enrich member details with user information - enriched_members = await self._enrich_members_with_user_details(transformed_group["members"]) + enriched_members = await self._enrich_members_with_user_details( + transformed_group["members"] + ) transformed_group["members"] = enriched_members - + return transformed_group - async def update_group(self, group_id: str, updates: dict, user_id: str) -> Optional[dict]: + async def update_group( + self, group_id: str, updates: dict, user_id: str + ) -> Optional[dict]: """Update group metadata (admin only)""" db = self.get_db() try: @@ -165,17 +188,19 @@ async def update_group(self, group_id: str, updates: dict, user_id: str) -> Opti return None # Check if user is admin - group = await db.groups.find_one({ - "_id": obj_id, - "members": {"$elemMatch": {"userId": user_id, "role": "admin"}} - }) + group = await db.groups.find_one( + { + "_id": obj_id, + "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, + } + ) if not group: - raise HTTPException(status_code=403, detail="Only group admins can update group details") + raise HTTPException( + status_code=403, detail="Only group admins can update group details" + ) result = await db.groups.find_one_and_update( - {"_id": obj_id}, - {"$set": updates}, - return_document=True + {"_id": obj_id}, {"$set": updates}, return_document=True ) return self.transform_group_document(result) @@ -188,12 +213,16 @@ async def delete_group(self, group_id: str, user_id: str) -> bool: return False # Check if user is admin - group = await db.groups.find_one({ - "_id": obj_id, - "members": {"$elemMatch": {"userId": user_id, "role": "admin"}} - }) + group = await db.groups.find_one( + { + "_id": obj_id, + "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, + } + ) if not group: - raise HTTPException(status_code=403, detail="Only group admins can delete groups") + raise HTTPException( + status_code=403, detail="Only group admins can delete groups" + ) result = await db.groups.delete_one({"_id": obj_id}) return result.deleted_count == 1 @@ -201,28 +230,32 @@ async def delete_group(self, group_id: str, user_id: str) -> bool: async def join_group_by_code(self, join_code: str, user_id: str) -> Optional[dict]: """Join a group using join code""" db = self.get_db() - + # Find group by join code group = await db.groups.find_one({"joinCode": join_code.upper()}) if not group: raise HTTPException(status_code=404, detail="Invalid join code") # Check if user is already a member - existing_member = next((m for m in group.get("members", []) if m["userId"] == user_id), None) + existing_member = next( + (m for m in group.get("members", []) if m["userId"] == user_id), None + ) if existing_member: - raise HTTPException(status_code=400, detail="You are already a member of this group") + raise HTTPException( + status_code=400, detail="You are already a member of this group" + ) # Add user as member new_member = { "userId": user_id, "role": "member", - "joinedAt": datetime.now(timezone.utc) + "joinedAt": datetime.now(timezone.utc), } result = await db.groups.find_one_and_update( {"_id": group["_id"]}, {"$push": {"members": new_member}}, - return_document=True + return_document=True, ) return self.transform_group_document(result) @@ -235,21 +268,24 @@ async def leave_group(self, group_id: str, user_id: str) -> bool: return False # Check if user is a member - group = await db.groups.find_one({ - "_id": obj_id, - "members.userId": user_id - }) + group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) if not group: - raise HTTPException(status_code=404, detail="Group not found or you are not a member") + raise HTTPException( + status_code=404, detail="Group not found or you are not a member" + ) # Check if user is the last admin - user_member = next((m for m in group.get("members", []) if m["userId"] == user_id), None) + user_member = next( + (m for m in group.get("members", []) if m["userId"] == user_id), None + ) if user_member and user_member["role"] == "admin": - admin_count = sum(1 for m in group.get("members", []) if m["role"] == "admin") + admin_count = sum( + 1 for m in group.get("members", []) if m["role"] == "admin" + ) if admin_count <= 1: raise HTTPException( - status_code=400, - detail="Cannot leave group when you are the only admin. Delete the group or promote another member to admin first." + status_code=400, + detail="Cannot leave group when you are the only admin. Delete the group or promote another member to admin first.", ) # TODO: Check for outstanding balances with expense service @@ -257,8 +293,7 @@ async def leave_group(self, group_id: str, user_id: str) -> bool: # This should be implemented when expense service is ready result = await db.groups.update_one( - {"_id": obj_id}, - {"$pull": {"members": {"userId": user_id}}} + {"_id": obj_id}, {"$pull": {"members": {"userId": user_id}}} ) return result.modified_count == 1 @@ -270,21 +305,20 @@ async def get_group_members(self, group_id: str, user_id: str) -> List[dict]: except Exception: return [] - group = await db.groups.find_one({ - "_id": obj_id, - "members.userId": user_id - }) + group = await db.groups.find_one({"_id": obj_id, "members.userId": user_id}) if not group: return [] members = group.get("members", []) - + # Fetch user details for each member enriched_members = await self._enrich_members_with_user_details(members) return enriched_members - async def update_member_role(self, group_id: str, member_id: str, new_role: str, user_id: str) -> bool: + async def update_member_role( + self, group_id: str, member_id: str, new_role: str, user_id: str + ) -> bool: """Update member role (admin only)""" db = self.get_db() try: @@ -293,30 +327,39 @@ async def update_member_role(self, group_id: str, member_id: str, new_role: str, return False # Check if user is admin - group = await db.groups.find_one({ - "_id": obj_id, - "members": {"$elemMatch": {"userId": user_id, "role": "admin"}} - }) + group = await db.groups.find_one( + { + "_id": obj_id, + "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, + } + ) if not group: - raise HTTPException(status_code=403, detail="Only group admins can update member roles") + raise HTTPException( + status_code=403, detail="Only group admins can update member roles" + ) # Check if target member exists - target_member = next((m for m in group.get("members", []) if m["userId"] == member_id), None) + target_member = next( + (m for m in group.get("members", []) if m["userId"] == member_id), None + ) if not target_member: - raise HTTPException(status_code=404, detail="Member not found in group") + raise HTTPException( + status_code=404, detail="Member not found in group") # Prevent admins from demoting themselves if they are the only admin if member_id == user_id and new_role != "admin": - admin_count = sum(1 for m in group.get("members", []) if m["role"] == "admin") + admin_count = sum( + 1 for m in group.get("members", []) if m["role"] == "admin" + ) if admin_count <= 1: raise HTTPException( - status_code=400, - detail="Cannot demote yourself when you are the only admin. Promote another member to admin first." + status_code=400, + detail="Cannot demote yourself when you are the only admin. Promote another member to admin first.", ) result = await db.groups.update_one( {"_id": obj_id, "members.userId": member_id}, - {"$set": {"members.$.role": new_role}} + {"$set": {"members.$.role": new_role}}, ) return result.modified_count == 1 @@ -329,33 +372,43 @@ async def remove_member(self, group_id: str, member_id: str, user_id: str) -> bo return False # Check if group exists and user is admin - group = await db.groups.find_one({ - "_id": obj_id, - "members": {"$elemMatch": {"userId": user_id, "role": "admin"}} - }) + group = await db.groups.find_one( + { + "_id": obj_id, + "members": {"$elemMatch": {"userId": user_id, "role": "admin"}}, + } + ) if not group: # Check if group exists at all group_exists = await db.groups.find_one({"_id": obj_id}) if not group_exists: raise HTTPException(status_code=404, detail="Group not found") else: - raise HTTPException(status_code=403, detail="Only group admins can remove members") + raise HTTPException( + status_code=403, detail="Only group admins can remove members" + ) # Check if target member exists and is not the requesting user - target_member = next((m for m in group.get("members", []) if m["userId"] == member_id), None) + target_member = next( + (m for m in group.get("members", []) if m["userId"] == member_id), None + ) if not target_member: - raise HTTPException(status_code=404, detail="Member not found in group") - + raise HTTPException( + status_code=404, detail="Member not found in group") + if member_id == user_id: - raise HTTPException(status_code=400, detail="Cannot remove yourself. Use leave group instead") + raise HTTPException( + status_code=400, + detail="Cannot remove yourself. Use leave group instead", + ) # TODO: Check for outstanding balances with expense service # For now, we'll allow removal without balance check result = await db.groups.update_one( - {"_id": obj_id}, - {"$pull": {"members": {"userId": member_id}}} + {"_id": obj_id}, {"$pull": {"members": {"userId": member_id}}} ) return result.modified_count == 1 + group_service = GroupService() diff --git a/backend/app/user/routes.py b/backend/app/user/routes.py index 8c4dca76..1c29e291 100644 --- a/backend/app/user/routes.py +++ b/backend/app/user/routes.py @@ -1,34 +1,58 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from app.user.schemas import UserProfileResponse, UserProfileUpdateRequest, DeleteUserResponse -from app.user.service import user_service +from typing import Any, Dict + from app.auth.security import get_current_user -from typing import Dict, Any +from app.user.schemas import ( + DeleteUserResponse, + UserProfileResponse, + UserProfileUpdateRequest, +) +from app.user.service import user_service +from fastapi import APIRouter, Depends, HTTPException, status router = APIRouter(prefix="/users", tags=["User"]) + @router.get("/me", response_model=UserProfileResponse) -async def get_current_user_profile(current_user: Dict[str, Any] = Depends(get_current_user)): +async def get_current_user_profile( + current_user: Dict[str, Any] = Depends(get_current_user) +): user = await user_service.get_user_by_id(current_user["_id"]) if not user: - raise HTTPException(status_code=404, detail={"error": "NotFound", "message": "User not found"}) + raise HTTPException( + status_code=404, detail={"error": "NotFound", "message": "User not found"} + ) return user + @router.patch("/me", response_model=Dict[str, Any]) async def update_user_profile( updates: UserProfileUpdateRequest, - current_user: Dict[str, Any] = Depends(get_current_user) + current_user: Dict[str, Any] = Depends(get_current_user), ): update_data = updates.model_dump(exclude_unset=True) if not update_data: - raise HTTPException(status_code=400, detail={"error": "InvalidInput", "message": "No update fields provided."}) - updated_user = await user_service.update_user_profile(current_user["_id"], update_data) + raise HTTPException( + status_code=400, + detail={"error": "InvalidInput", + "message": "No update fields provided."}, + ) + updated_user = await user_service.update_user_profile( + current_user["_id"], update_data + ) if not updated_user: - raise HTTPException(status_code=404, detail={"error": "NotFound", "message": "User not found"}) + raise HTTPException( + status_code=404, detail={"error": "NotFound", "message": "User not found"} + ) return {"user": updated_user} + @router.delete("/me", response_model=DeleteUserResponse) async def delete_user_account(current_user: Dict[str, Any] = Depends(get_current_user)): deleted = await user_service.delete_user(current_user["_id"]) if not deleted: - raise HTTPException(status_code=404, detail={"error": "NotFound", "message": "User not found"}) - return DeleteUserResponse(success=True, message="User account scheduled for deletion.") + raise HTTPException( + status_code=404, detail={"error": "NotFound", "message": "User not found"} + ) + return DeleteUserResponse( + success=True, message="User account scheduled for deletion." + ) diff --git a/backend/app/user/schemas.py b/backend/app/user/schemas.py index 1c584089..985d2710 100644 --- a/backend/app/user/schemas.py +++ b/backend/app/user/schemas.py @@ -1,6 +1,8 @@ -from pydantic import BaseModel, EmailStr -from typing import Optional from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, EmailStr + class UserProfileResponse(BaseModel): id: str @@ -11,11 +13,13 @@ class UserProfileResponse(BaseModel): createdAt: datetime updatedAt: datetime + class UserProfileUpdateRequest(BaseModel): name: Optional[str] = None imageUrl: Optional[str] = None currency: Optional[str] = None + class DeleteUserResponse(BaseModel): success: bool = True message: Optional[str] = None diff --git a/backend/app/user/service.py b/backend/app/user/service.py index 2b81fa4d..589a682a 100644 --- a/backend/app/user/service.py +++ b/backend/app/user/service.py @@ -1,8 +1,10 @@ -from fastapi import HTTPException, status, Depends +from datetime import datetime, timezone +from typing import Any, Dict, Optional + from app.database import get_database from bson import ObjectId -from datetime import datetime, timezone -from typing import Optional, Dict, Any +from fastapi import Depends, HTTPException, status + class UserService: def __init__(self): @@ -22,7 +24,11 @@ def iso(dt): return dt # Normalize to UTC and append 'Z' try: - dt_utc = dt.astimezone(timezone.utc) if getattr(dt, 'tzinfo', None) else dt.replace(tzinfo=timezone.utc) + dt_utc = ( + dt.astimezone(timezone.utc) + if getattr(dt, "tzinfo", None) + else dt.replace(tzinfo=timezone.utc) + ) return dt_utc.isoformat().replace("+00:00", "Z") except AttributeError: return str(dt) @@ -61,9 +67,7 @@ async def update_user_profile(self, user_id: str, updates: dict) -> Optional[dic updates = {k: v for k, v in updates.items() if k in allowed} updates["updated_at"] = datetime.now(timezone.utc) result = await db.users.find_one_and_update( - {"_id": obj_id}, - {"$set": updates}, - return_document=True + {"_id": obj_id}, {"$set": updates}, return_document=True ) return self.transform_user_document(result) @@ -76,4 +80,5 @@ async def delete_user(self, user_id: str) -> bool: result = await db.users.delete_one({"_id": obj_id}) return result.deleted_count > 0 + user_service = UserService() diff --git a/backend/generate_secret.py b/backend/generate_secret.py index 2623d347..96c1bf4f 100644 --- a/backend/generate_secret.py +++ b/backend/generate_secret.py @@ -1,20 +1,24 @@ import secrets import string + from app.config import logger + + def generate_jwt_secret(): """ Generates a cryptographically secure 64-character secret key for JWT authentication. - + The key consists of uppercase and lowercase letters, digits, and the special characters "!@#$%^&*". - + Returns: A randomly generated 64-character string suitable for use as a JWT secret key. """ # Generate a 64-character secret key alphabet = string.ascii_letters + string.digits + "!@#$%^&*" - secret_key = ''.join(secrets.choice(alphabet) for _ in range(64)) + secret_key = "".join(secrets.choice(alphabet) for _ in range(64)) return secret_key + if __name__ == "__main__": secret = generate_jwt_secret() logger.info("Generated JWT Secret Key:") diff --git a/backend/main.py b/backend/main.py index a4e2c212..3372ffb8 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,13 +1,16 @@ -from fastapi import FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import Response from contextlib import asynccontextmanager -from app.database import connect_to_mongo, close_mongo_connection + from app.auth.routes import router as auth_router -from app.user.routes import router as user_router +from app.config import RequestResponseLoggingMiddleware, logger, settings +from app.database import close_mongo_connection, connect_to_mongo +from app.expenses.routes import balance_router +from app.expenses.routes import router as expenses_router from app.groups.routes import router as groups_router -from app.expenses.routes import router as expenses_router, balance_router -from app.config import settings, logger,RequestResponseLoggingMiddleware +from app.user.routes import router as user_router +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import Response + @asynccontextmanager async def lifespan(app: FastAPI): @@ -21,13 +24,14 @@ async def lifespan(app: FastAPI): await close_mongo_connection() logger.info("Lifespan: MongoDB connection closed.") + app = FastAPI( title="Splitwiser API", description="Backend API for Splitwiser expense tracking application", version="1.0.0", docs_url="/docs", redoc_url="/redoc", - lifespan=lifespan + lifespan=lifespan, ) # CORS middleware - Enhanced configuration for production @@ -38,7 +42,11 @@ async def lifespan(app: FastAPI): logger.debug("Development mode: CORS configured to allow all origins") elif settings.allowed_origins: # Use specified origins in production mode - allowed_origins = [origin.strip() for origin in settings.allowed_origins.split(",") if origin.strip()] + allowed_origins = [ + origin.strip() + for origin in settings.allowed_origins.split(",") + if origin.strip() + ] else: # Fallback to allow all origins if not specified (not recommended for production) allowed_origins = ["*"] @@ -55,54 +63,65 @@ async def lifespan(app: FastAPI): allow_headers=[ "Accept", "Accept-Language", - "Content-Language", + "Content-Language", "Content-Type", "Authorization", "X-Requested-With", "Origin", "Cache-Control", "Pragma", - "X-CSRFToken" + "X-CSRFToken", ], expose_headers=["*"], max_age=3600, # Cache preflight responses for 1 hour ) + # Add a catch-all OPTIONS handler that should work for any path @app.options("/{path:path}") async def options_handler(request: Request, path: str): """Handle all OPTIONS requests""" logger.info(f"OPTIONS request received for path: /{path}") logger.info(f"Origin: {request.headers.get('origin', 'No origin header')}") - + response = Response(status_code=200) - + # Manually set CORS headers for debugging origin = request.headers.get("origin") if origin and (origin in allowed_origins or "*" in allowed_origins): response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" - response.headers["Access-Control-Allow-Headers"] = "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" + ) response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Max-Age"] = "3600" elif "*" in allowed_origins: response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" - response.headers["Access-Control-Allow-Headers"] = "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Accept, Accept-Language, Content-Language, Content-Type, Authorization, X-Requested-With, Origin, Cache-Control, Pragma, X-CSRFToken" + ) response.headers["Access-Control-Max-Age"] = "3600" - + return response + # Health check @app.get("/health") async def health_check(): """ Returns the health status of the Splitwiser API service. - + This endpoint can be used for health checks and monitoring. """ return {"status": "healthy", "service": "Splitwiser API"} + # Include routers app.include_router(auth_router) app.include_router(user_router) @@ -112,9 +131,5 @@ async def health_check(): if __name__ == "__main__": import uvicorn - uvicorn.run( - "main:app", - host="0.0.0.0", - port=8000, - reload=settings.debug - ) + + uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=settings.debug) diff --git a/backend/tests/auth/test_auth_routes.py b/backend/tests/auth/test_auth_routes.py index 72aa417e..0610a0fe 100644 --- a/backend/tests/auth/test_auth_routes.py +++ b/backend/tests/auth/test_auth_routes.py @@ -1,11 +1,17 @@ -import pytest -from httpx import AsyncClient, ASGITransport -from fastapi import FastAPI, status -from main import app # Assuming your FastAPI app instance is here -from app.config import settings # To potentially override settings if needed, or check values -from app.auth.security import verify_password, get_password_hash # For checking hashed password if necessary from datetime import datetime, timezone + +import pytest +from app.auth.security import ( # For checking hashed password if necessary + get_password_hash, + verify_password, +) +from app.config import ( # To potentially override settings if needed, or check values + settings, +) from bson import ObjectId +from fastapi import FastAPI, status +from httpx import ASGITransport, AsyncClient +from main import app # Assuming your FastAPI app instance is here # It's good practice to set a specific test secret key if not relying on external env vars # For now, we assume 'your-super-secret-jwt-key-change-this-in-production' from config.py is used, @@ -15,24 +21,32 @@ # Helper to get the mock_db if direct interaction is needed (though often not preferred) # from app.database import get_database + @pytest.mark.asyncio -async def test_signup_with_email_success(mock_db): # mock_db fixture is auto-used - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: +# mock_db fixture is auto-used +async def test_signup_with_email_success(mock_db): + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: signup_data = { "email": "testuser@example.com", "password": "securepassword123", - "name": "Test User" + "name": "Test User", } response = await ac.post("/auth/signup/email", json=signup_data) - print(f"Response text for test_signup_with_email_success: {response.text}") # Print response text - assert response.status_code == status.HTTP_200_OK # Or the actual success code used by the app + print( + f"Response text for test_signup_with_email_success: {response.text}" + ) # Print response text + assert ( + response.status_code == status.HTTP_200_OK + ) # Or the actual success code used by the app response_data = response.json() assert "access_token" in response_data assert "refresh_token" in response_data assert "user" in response_data assert response_data["user"]["email"] == signup_data["email"] assert response_data["user"]["name"] == signup_data["name"] - assert "_id" in response_data["user"] # Changed 'id' to '_id' + assert "_id" in response_data["user"] # Changed 'id' to '_id' # Verify user creation in the mock database # db = get_database() # This will be the mock_db instance due to the fixture @@ -40,30 +54,38 @@ async def test_signup_with_email_success(mock_db): # mock_db fixture is auto-use created_user = await mock_db.users.find_one({"email": signup_data["email"]}) assert created_user is not None assert created_user["name"] == signup_data["name"] - assert verify_password(signup_data["password"], created_user["hashed_password"]) + assert verify_password( + signup_data["password"], created_user["hashed_password"]) # Verify refresh token creation - refresh_token_record = await mock_db.refresh_tokens.find_one({"user_id": created_user["_id"]}) + refresh_token_record = await mock_db.refresh_tokens.find_one( + {"user_id": created_user["_id"]} + ) assert refresh_token_record is not None assert not refresh_token_record["revoked"] assert response_data["refresh_token"] == refresh_token_record["token"] + @pytest.mark.asyncio async def test_signup_with_existing_email(mock_db): # Pre-populate with a user existing_email = "existing@example.com" - await mock_db.users.insert_one({ - "email": existing_email, - "hashed_password": "hashedpassword", - "name": "Existing User", - "created_at": "sometime" # Simplified for mock - }) - - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + await mock_db.users.insert_one( + { + "email": existing_email, + "hashed_password": "hashedpassword", + "name": "Existing User", + "created_at": "sometime", # Simplified for mock + } + ) + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: signup_data = { "email": existing_email, "password": "newpassword123", - "name": "New User" + "name": "New User", } response = await ac.post("/auth/signup/email", json=signup_data) @@ -72,6 +94,7 @@ async def test_signup_with_existing_email(mock_db): assert "detail" in response_data assert "User with this email already exists" in response_data["detail"] + @pytest.mark.asyncio @pytest.mark.parametrize( "payload_modifier, affected_field, description", @@ -79,19 +102,26 @@ async def test_signup_with_existing_email(mock_db): (lambda p: p.pop("email"), "email", "missing_email"), (lambda p: p.pop("password"), "password", "missing_password"), (lambda p: p.pop("name"), "name", "missing_name"), - (lambda p: p.update({"password": "short"}), "password", "short_password"), - (lambda p: p.update({"email": "invalidemail"}), "email", "invalid_email"), - ] + (lambda p: p.update({"password": "short"}), + "password", "short_password"), + (lambda p: p.update({"email": "invalidemail"}), + "email", "invalid_email"), + ], ) -async def test_signup_invalid_input_refined(mock_db, payload_modifier, affected_field, description): +async def test_signup_invalid_input_refined( + mock_db, payload_modifier, affected_field, description +): base_payload = { "email": "testuser@example.com", "password": "securepassword123", - "name": "Test User" + "name": "Test User", } - payload_modifier(base_payload) # Modify the payload based on the current test case + # Modify the payload based on the current test case + payload_modifier(base_payload) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: response = await ac.post("/auth/signup/email", json=base_payload) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -107,13 +137,17 @@ async def test_signup_invalid_input_refined(mock_db, payload_modifier, affected_ if description == "short_password" and error_type == "string_too_short": error_found = True break - elif description == "invalid_email" and error_type == "value_error": # Simpler check, msg gives more detail + elif ( + description == "invalid_email" and error_type == "value_error" + ): # Simpler check, msg gives more detail error_found = True break elif "missing" in description and error_type == "missing": error_found = True break - assert error_found, f"Validation error for '{description}' (field: {affected_field}) not found or not specific enough in {response_data['detail']}" + assert ( + error_found + ), f"Validation error for '{description}' (field: {affected_field}) not found or not specific enough in {response_data['detail']}" @pytest.mark.asyncio @@ -128,23 +162,25 @@ async def test_login_with_email_success(mock_db): # For consistency with how AuthService creates user_id for refresh tokens (ObjectId(user_id)), # let's store _id as ObjectId here. user_obj_id = ObjectId() - await mock_db.users.insert_one({ - "_id": user_obj_id, - "email": user_email, - "hashed_password": hashed_password, - "name": "Login User", - "avatar": None, - "currency": "USD", - "created_at": datetime.now(timezone.utc), # Ensure datetime is used - "auth_provider": "email", - "firebase_uid": None - }) - - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - login_data = { + await mock_db.users.insert_one( + { + "_id": user_obj_id, "email": user_email, - "password": user_password + "hashed_password": hashed_password, + "name": "Login User", + "avatar": None, + "currency": "USD", + # Ensure datetime is used + "created_at": datetime.now(timezone.utc), + "auth_provider": "email", + "firebase_uid": None, } + ) + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + login_data = {"email": user_email, "password": user_password} response = await ac.post("/auth/login/email", json=login_data) assert response.status_code == status.HTTP_200_OK @@ -153,34 +189,39 @@ async def test_login_with_email_success(mock_db): assert "refresh_token" in response_data assert "user" in response_data assert response_data["user"]["email"] == user_email - assert response_data["user"]["_id"] == str(user_obj_id) # Changed 'id' to '_id' + assert response_data["user"]["_id"] == str( + user_obj_id) # Changed 'id' to '_id' # Verify refresh token creation for this user # Refresh token service stores user_id as ObjectId - refresh_token_record = await mock_db.refresh_tokens.find_one({"user_id": user_obj_id}) + refresh_token_record = await mock_db.refresh_tokens.find_one( + {"user_id": user_obj_id} + ) assert refresh_token_record is not None assert not refresh_token_record["revoked"] assert response_data["refresh_token"] == refresh_token_record["token"] + @pytest.mark.asyncio async def test_login_with_incorrect_password(mock_db): user_email = "wrongpass@example.com" correct_password = "correctpassword" incorrect_password = "incorrectpassword" - await mock_db.users.insert_one({ - "_id": ObjectId(), - "email": user_email, - "hashed_password": get_password_hash(correct_password), - "name": "Wrong Pass User", - "created_at": datetime.now(timezone.utc) - }) - - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - login_data = { + await mock_db.users.insert_one( + { + "_id": ObjectId(), "email": user_email, - "password": incorrect_password + "hashed_password": get_password_hash(correct_password), + "name": "Wrong Pass User", + "created_at": datetime.now(timezone.utc), } + ) + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + login_data = {"email": user_email, "password": incorrect_password} response = await ac.post("/auth/login/email", json=login_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -188,19 +229,23 @@ async def test_login_with_incorrect_password(mock_db): assert "detail" in response_data assert "Incorrect email or password" in response_data["detail"] + @pytest.mark.asyncio async def test_login_with_non_existent_email(mock_db): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: - login_data = { - "email": "nosuchuser@example.com", - "password": "anypassword" - } + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + login_data = {"email": "nosuchuser@example.com", + "password": "anypassword"} response = await ac.post("/auth/login/email", json=login_data) assert response.status_code == status.HTTP_401_UNAUTHORIZED response_data = response.json() assert "detail" in response_data - assert "Incorrect email or password" in response_data["detail"] # Same message for both cases + assert ( + "Incorrect email or password" in response_data["detail"] + ) # Same message for both cases + @pytest.mark.asyncio @pytest.mark.parametrize( @@ -208,19 +253,25 @@ async def test_login_with_non_existent_email(mock_db): [ (lambda p: p.pop("email"), "email", "missing_email"), (lambda p: p.pop("password"), "password", "missing_password"), - (lambda p: p.update({"email": "invalidemailformat"}), "email", "invalid_email_format"), - ] + ( + lambda p: p.update({"email": "invalidemailformat"}), + "email", + "invalid_email_format", + ), + ], ) -async def test_login_invalid_input(mock_db, payload_modifier, affected_field, description): - base_payload = { - "email": "validuser@example.com", - "password": "validpassword123" - } +async def test_login_invalid_input( + mock_db, payload_modifier, affected_field, description +): + base_payload = {"email": "validuser@example.com", + "password": "validpassword123"} # It doesn't matter if the user exists or not for input validation, # as validation happens before DB lookup for these kinds of errors. payload_modifier(base_payload) - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: response = await ac.post("/auth/login/email", json=base_payload) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @@ -231,10 +282,14 @@ async def test_login_invalid_input(mock_db, payload_modifier, affected_field, de for error_item in response_data["detail"]: if affected_field in error_item.get("loc", []): error_type = error_item.get("type", "") - if description == "invalid_email_format" and error_type == "value_error": # Simpler check + if ( + description == "invalid_email_format" and error_type == "value_error" + ): # Simpler check error_found = True break elif "missing" in description and error_type == "missing": error_found = True break - assert error_found, f"Validation error for '{description}' (field: {affected_field}) not found in {response_data['detail']}" + assert ( + error_found + ), f"Validation error for '{description}' (field: {affected_field}) not found in {response_data['detail']}" diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4e95af99..caf6c346 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,17 +1,18 @@ +import os # Added +import sys # Added +from pathlib import Path # Added +from unittest.mock import MagicMock, patch + +import firebase_admin # Added import pytest import pytest_asyncio -from unittest.mock import patch, MagicMock -import firebase_admin # Added -import os # Added -import sys # Added -from pathlib import Path # Added +from mongomock_motor import AsyncMongoMockClient # Add project root to sys.path to allow imports from app and main # This assumes conftest.py is in backend/tests/ project_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(project_root)) -from mongomock_motor import AsyncMongoMockClient @pytest.fixture(scope="session", autouse=True) def mock_firebase_admin(request): @@ -34,13 +35,15 @@ def mock_firebase_admin(request): "uid": "test_firebase_uid", "email": "firebaseuser@example.com", "name": "Firebase User", - "picture": None - } # Dummy decoded token + "picture": None, + } # Dummy decoded token patches = [ patch("firebase_admin.credentials.Certificate", mock_certificate), patch("firebase_admin.initialize_app", mock_initialize_app), - patch("firebase_admin.auth.verify_id_token", mock_firebase_auth.verify_id_token) # Mock specific function + patch( + "firebase_admin.auth.verify_id_token", mock_firebase_auth.verify_id_token + ), # Mock specific function ] for p in patches: @@ -57,19 +60,26 @@ def mock_firebase_admin(request): # If not using the os.environ patch, just yield: yield + @pytest_asyncio.fixture(scope="function", autouse=True) async def mock_db(): print("mock_db fixture: Creating AsyncMongoMockClient") mock_mongo_client = AsyncMongoMockClient() - print(f"mock_db fixture: mock_mongo_client type: {type(mock_mongo_client)}") + print( + f"mock_db fixture: mock_mongo_client type: {type(mock_mongo_client)}") mock_database_instance = mock_mongo_client["test_db"] - print(f"mock_db fixture: mock_database_instance type: {type(mock_database_instance)}, is None: {mock_database_instance is None}") + print( + f"mock_db fixture: mock_database_instance type: {type(mock_database_instance)}, is None: {mock_database_instance is None}" + ) # Patch get_database for all services that use it patches = [ - patch("app.auth.service.get_database", return_value=mock_database_instance), - patch("app.user.service.get_database", return_value=mock_database_instance), - patch("app.groups.service.get_database", return_value=mock_database_instance), + patch("app.auth.service.get_database", + return_value=mock_database_instance), + patch("app.user.service.get_database", + return_value=mock_database_instance), + patch("app.groups.service.get_database", + return_value=mock_database_instance), ] # Start all patches diff --git a/backend/tests/expenses/test_expense_routes.py b/backend/tests/expenses/test_expense_routes.py index 67610eae..f55ee2fd 100644 --- a/backend/tests/expenses/test_expense_routes.py +++ b/backend/tests/expenses/test_expense_routes.py @@ -1,19 +1,25 @@ -import pytest -from httpx import AsyncClient, ASGITransport -from fastapi import status from unittest.mock import AsyncMock, patch -from main import app # Adjusted import + +import pytest from app.expenses.schemas import ExpenseCreateRequest, ExpenseSplit +from fastapi import status +from httpx import ASGITransport, AsyncClient +from main import app # Adjusted import + @pytest.fixture async def async_client(): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: yield ac + @pytest.fixture def mock_current_user(): return {"_id": "test_user_123", "email": "test@example.com"} + @pytest.fixture def sample_expense_data(): return { @@ -21,19 +27,26 @@ def sample_expense_data(): "amount": 100.0, "splits": [ {"userId": "user_a", "amount": 50.0, "type": "equal"}, - {"userId": "user_b", "amount": 50.0, "type": "equal"} + {"userId": "user_b", "amount": 50.0, "type": "equal"}, ], "splitType": "equal", "tags": ["dinner", "test"], - "receiptUrls": [] + "receiptUrls": [], } + @pytest.mark.asyncio @patch("app.expenses.routes.get_current_user") @patch("app.expenses.service.expense_service.create_expense") -async def test_create_expense_endpoint(mock_create_expense, mock_get_current_user, sample_expense_data, mock_current_user, async_client: AsyncClient): +async def test_create_expense_endpoint( + mock_create_expense, + mock_get_current_user, + sample_expense_data, + mock_current_user, + async_client: AsyncClient, +): """Test create expense endpoint""" - + mock_get_current_user.return_value = mock_current_user mock_create_expense.return_value = { "expense": { @@ -49,32 +62,42 @@ async def test_create_expense_endpoint(mock_create_expense, mock_get_current_use "receiptUrls": [], "comments": [], "history": [], - "splitType": "equal" + "splitType": "equal", }, "settlements": [], "groupSummary": { "totalExpenses": 100.0, "totalSettlements": 2, - "optimizedSettlements": [] - } + "optimizedSettlements": [], + }, } - + response = await async_client.post( "/groups/group_123/expenses", json=sample_expense_data, - headers={"Authorization": "Bearer test_token"} + headers={"Authorization": "Bearer test_token"}, ) - + # This test would need proper authentication mocking to work # For now, it demonstrates the structure - assert response.status_code in [status.HTTP_201_CREATED, status.HTTP_401_UNAUTHORIZED, status.HTTP_422_UNPROCESSABLE_ENTITY] # Depending on auth setup + assert response.status_code in [ + status.HTTP_201_CREATED, + status.HTTP_401_UNAUTHORIZED, + status.HTTP_422_UNPROCESSABLE_ENTITY, + ] # Depending on auth setup + @pytest.mark.asyncio @patch("app.expenses.routes.get_current_user") @patch("app.expenses.service.expense_service.list_group_expenses") -async def test_list_expenses_endpoint(mock_list_expenses, mock_get_current_user, mock_current_user, async_client: AsyncClient): +async def test_list_expenses_endpoint( + mock_list_expenses, + mock_get_current_user, + mock_current_user, + async_client: AsyncClient, +): """Test list expenses endpoint""" - + mock_get_current_user.return_value = mock_current_user mock_list_expenses.return_value = { "expenses": [], @@ -84,29 +107,31 @@ async def test_list_expenses_endpoint(mock_list_expenses, mock_get_current_user, "total": 0, "totalPages": 0, "hasNext": False, - "hasPrev": False + "hasPrev": False, }, - "summary": { - "totalAmount": 0, - "expenseCount": 0, - "avgExpense": 0 - } + "summary": {"totalAmount": 0, "expenseCount": 0, "avgExpense": 0}, } - + response = await async_client.get( - "/groups/group_123/expenses", - headers={"Authorization": "Bearer test_token"} + "/groups/group_123/expenses", headers={"Authorization": "Bearer test_token"} ) - + # This test would need proper authentication mocking to work - assert response.status_code in [status.HTTP_200_OK, status.HTTP_401_UNAUTHORIZED] + assert response.status_code in [ + status.HTTP_200_OK, status.HTTP_401_UNAUTHORIZED] + @pytest.mark.asyncio @patch("app.expenses.routes.get_current_user") @patch("app.expenses.service.expense_service.calculate_optimized_settlements") -async def test_optimized_settlements_endpoint(mock_calculate_settlements, mock_get_current_user, mock_current_user, async_client: AsyncClient): +async def test_optimized_settlements_endpoint( + mock_calculate_settlements, + mock_get_current_user, + mock_current_user, + async_client: AsyncClient, +): """Test optimized settlements calculation endpoint""" - + mock_get_current_user.return_value = mock_current_user mock_calculate_settlements.return_value = [ { @@ -115,41 +140,47 @@ async def test_optimized_settlements_endpoint(mock_calculate_settlements, mock_g "fromUserName": "Alice", "toUserName": "Bob", "amount": 25.0, - "consolidatedExpenses": ["expense_1", "expense_2"] + "consolidatedExpenses": ["expense_1", "expense_2"], } ] - + response = await async_client.post( "/groups/group_123/settlements/optimize", - headers={"Authorization": "Bearer test_token"} + headers={"Authorization": "Bearer test_token"}, ) - + # This test would need proper authentication mocking to work - assert response.status_code in [status.HTTP_200_OK, status.HTTP_401_UNAUTHORIZED] + assert response.status_code in [ + status.HTTP_200_OK, status.HTTP_401_UNAUTHORIZED] + @pytest.mark.asyncio async def test_expense_validation(async_client: AsyncClient): """Test expense data validation""" - + # Invalid expense - splits don't sum to total invalid_data = { "description": "Test expense", "amount": 100.0, "splits": [ {"userId": "user_a", "amount": 40.0, "type": "equal"}, - {"userId": "user_b", "amount": 50.0, "type": "equal"} # Only 90 total + {"userId": "user_b", "amount": 50.0, "type": "equal"}, # Only 90 total ], - "splitType": "equal" + "splitType": "equal", } - + response = await async_client.post( "/groups/group_123/expenses", json=invalid_data, - headers={"Authorization": "Bearer test_token"} + headers={"Authorization": "Bearer test_token"}, ) - + # Should return validation error - assert response.status_code in [status.HTTP_422_UNPROCESSABLE_ENTITY, status.HTTP_401_UNAUTHORIZED] # 422 for validation error, 401 if auth fails first + assert response.status_code in [ + status.HTTP_422_UNPROCESSABLE_ENTITY, + status.HTTP_401_UNAUTHORIZED, + ] # 422 for validation error, 401 if auth fails first + if __name__ == "__main__": pytest.main([__file__]) diff --git a/backend/tests/expenses/test_expense_service.py b/backend/tests/expenses/test_expense_service.py index dc0733ce..a51cfdbf 100644 --- a/backend/tests/expenses/test_expense_service.py +++ b/backend/tests/expenses/test_expense_service.py @@ -1,10 +1,12 @@ -import pytest +import asyncio +from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch -from app.expenses.service import ExpenseService + +import pytest from app.expenses.schemas import ExpenseCreateRequest, ExpenseSplit, SplitType +from app.expenses.service import ExpenseService from bson import ObjectId -from datetime import datetime, timezone, timedelta -import asyncio + @pytest.fixture def expense_service(): @@ -12,6 +14,7 @@ def expense_service(): service = ExpenseService() return service + @pytest.fixture def mock_group_data(): """Mock group data for testing""" @@ -21,10 +24,11 @@ def mock_group_data(): "members": [ {"userId": "user_a", "role": "admin"}, {"userId": "user_b", "role": "member"}, - {"userId": "user_c", "role": "member"} - ] + {"userId": "user_c", "role": "member"}, + ], } + @pytest.fixture def mock_expense_data(): """Mock expense data for testing""" @@ -36,7 +40,7 @@ def mock_expense_data(): "amount": 100.0, "splits": [ {"userId": "user_a", "amount": 50.0, "type": "equal"}, - {"userId": "user_b", "amount": 50.0, "type": "equal"} + {"userId": "user_b", "amount": 50.0, "type": "equal"}, ], "splitType": "equal", "tags": ["dinner"], @@ -44,9 +48,10 @@ def mock_expense_data(): "comments": [], "history": [], "createdAt": datetime.now(timezone.utc), - "updatedAt": datetime.now(timezone.utc) + "updatedAt": datetime.now(timezone.utc), } + @pytest.mark.asyncio async def test_create_expense_success(expense_service, mock_group_data): """Test successful expense creation""" @@ -55,32 +60,43 @@ async def test_create_expense_success(expense_service, mock_group_data): amount=100.0, splits=[ ExpenseSplit(userId="user_a", amount=50.0), - ExpenseSplit(userId="user_b", amount=50.0) + ExpenseSplit(userId="user_b", amount=50.0), ], splitType=SplitType.EQUAL, - tags=["dinner"] + tags=["dinner"], ) - - with patch('app.expenses.service.mongodb') as mock_mongodb, \ - patch.object(expense_service, '_create_settlements_for_expense') as mock_settlements, \ - patch.object(expense_service, 'calculate_optimized_settlements') as mock_optimized, \ - patch.object(expense_service, '_get_group_summary') as mock_summary, \ - patch.object(expense_service, '_expense_doc_to_response') as mock_response: - + + with patch("app.expenses.service.mongodb") as mock_mongodb, patch.object( + expense_service, "_create_settlements_for_expense" + ) as mock_settlements, patch.object( + expense_service, "calculate_optimized_settlements" + ) as mock_optimized, patch.object( + expense_service, "_get_group_summary" + ) as mock_summary, patch.object( + expense_service, "_expense_doc_to_response" + ) as mock_response: + # Mock database collections mock_db = MagicMock() mock_mongodb.database = mock_db - + mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) mock_db.expenses.insert_one = AsyncMock() - + mock_settlements.return_value = [] mock_optimized.return_value = [] - mock_summary.return_value = {"totalExpenses": 100.0, "totalSettlements": 1, "optimizedSettlements": []} - mock_response.return_value = {"id": "test_id", "description": "Test Dinner"} - - result = await expense_service.create_expense("65f1a2b3c4d5e6f7a8b9c0d0", expense_request, "user_a") - + mock_summary.return_value = { + "totalExpenses": 100.0, + "totalSettlements": 1, + "optimizedSettlements": [], + } + mock_response.return_value = { + "id": "test_id", "description": "Test Dinner"} + + result = await expense_service.create_expense( + "65f1a2b3c4d5e6f7a8b9c0d0", expense_request, "user_a" + ) + # Assertions assert result is not None assert "expense" in result @@ -89,6 +105,7 @@ async def test_create_expense_success(expense_service, mock_group_data): mock_db.groups.find_one.assert_called_once() mock_db.expenses.insert_one.assert_called_once() + @pytest.mark.asyncio async def test_create_expense_invalid_group(expense_service): """Test expense creation with invalid group""" @@ -97,30 +114,35 @@ async def test_create_expense_invalid_group(expense_service): amount=100.0, splits=[ExpenseSplit(userId="user_a", amount=100.0)], ) - - with patch('app.expenses.service.mongodb') as mock_mongodb: + + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=None) - + # Test with invalid ObjectId format with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.create_expense("invalid_group", expense_request, "user_a") - + await expense_service.create_expense( + "invalid_group", expense_request, "user_a" + ) + # Test with valid ObjectId format but non-existent group with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.create_expense("65f1a2b3c4d5e6f7a8b9c0d0", expense_request, "user_a") + await expense_service.create_expense( + "65f1a2b3c4d5e6f7a8b9c0d0", expense_request, "user_a" + ) + @pytest.mark.asyncio async def test_calculate_optimized_settlements_advanced(expense_service): """Test advanced settlement algorithm with real optimization logic""" group_id = "test_group_123" - + # Create proper ObjectIds for users user_a_id = ObjectId() user_b_id = ObjectId() user_c_id = ObjectId() - + # Mock settlements representing: B owes A $100, C owes B $100 # Expected optimization: C should pay A $100 directly (instead of C->B and B->A) mock_settlements = [ @@ -128,49 +150,51 @@ async def test_calculate_optimized_settlements_advanced(expense_service): "_id": ObjectId(), "groupId": group_id, "payerId": str(user_b_id), - "payeeId": str(user_a_id), + "payeeId": str(user_a_id), "amount": 100.0, "status": "pending", "payerName": "Bob", - "payeeName": "Alice" + "payeeName": "Alice", }, { "_id": ObjectId(), "groupId": group_id, "payerId": str(user_c_id), "payeeId": str(user_b_id), - "amount": 100.0, + "amount": 100.0, "status": "pending", "payerName": "Charlie", - "payeeName": "Bob" - } + "payeeName": "Bob", + }, ] - + # Mock user data mock_users = { str(user_a_id): {"_id": user_a_id, "name": "Alice"}, - str(user_b_id): {"_id": user_b_id, "name": "Bob"}, - str(user_c_id): {"_id": user_c_id, "name": "Charlie"} + str(user_b_id): {"_id": user_b_id, "name": "Bob"}, + str(user_c_id): {"_id": user_c_id, "name": "Charlie"}, } - - with patch('app.expenses.service.mongodb') as mock_mongodb: + + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + # Setup async iterator for settlements mock_cursor = AsyncMock() mock_cursor.to_list.return_value = mock_settlements mock_db.settlements.find.return_value = mock_cursor - + # Setup user lookups async def mock_user_find_one(query): user_id = str(query["_id"]) return mock_users.get(user_id) - + mock_db.users.find_one = AsyncMock(side_effect=mock_user_find_one) - - result = await expense_service.calculate_optimized_settlements(group_id, "advanced") - + + result = await expense_service.calculate_optimized_settlements( + group_id, "advanced" + ) + # Verify optimization: should result in 1 transaction instead of 2 assert len(result) == 1 # The optimized result should be Alice paying Charlie $100 @@ -182,15 +206,16 @@ async def mock_user_find_one(query): assert settlement.fromUserId == str(user_a_id) assert settlement.toUserId == str(user_c_id) -@pytest.mark.asyncio + +@pytest.mark.asyncio async def test_calculate_optimized_settlements_normal(expense_service): """Test normal settlement algorithm - only simplifies direct relationships""" group_id = "test_group_123" - + # Create proper ObjectIds for users user_a_id = ObjectId() user_b_id = ObjectId() - + # Mock settlements: A owes B $100, B owes A $30 mock_settlements = [ { @@ -201,237 +226,269 @@ async def test_calculate_optimized_settlements_normal(expense_service): "amount": 100.0, "status": "pending", "payerName": "Bob", - "payeeName": "Alice" + "payeeName": "Alice", }, { - "_id": ObjectId(), + "_id": ObjectId(), "groupId": group_id, "payerId": str(user_a_id), "payeeId": str(user_b_id), "amount": 30.0, "status": "pending", "payerName": "Alice", - "payeeName": "Bob" - } + "payeeName": "Bob", + }, ] - + mock_users = { str(user_a_id): {"_id": user_a_id, "name": "Alice"}, - str(user_b_id): {"_id": user_b_id, "name": "Bob"} + str(user_b_id): {"_id": user_b_id, "name": "Bob"}, } - - with patch('app.expenses.service.mongodb') as mock_mongodb: + + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + mock_cursor = AsyncMock() - mock_cursor.to_list.return_value = mock_settlements + mock_cursor.to_list.return_value = mock_settlements mock_db.settlements.find.return_value = mock_cursor - + async def mock_user_find_one(query): user_id = str(query["_id"]) return mock_users.get(user_id) - + mock_db.users.find_one = AsyncMock(side_effect=mock_user_find_one) - - result = await expense_service.calculate_optimized_settlements(group_id, "normal") - + + result = await expense_service.calculate_optimized_settlements( + group_id, "normal" + ) + # Should result in optimized settlements. The normal algorithm may produce duplicates # but should calculate the correct net amount assert len(result) >= 1 - + # Find the settlement where Bob pays Alice - bob_to_alice_settlements = [s for s in result if s.fromUserName == "Bob" and s.toUserName == "Alice"] + bob_to_alice_settlements = [ + s for s in result if s.fromUserName == "Bob" and s.toUserName == "Alice" + ] assert len(bob_to_alice_settlements) >= 1 - + # Verify the amount is correct (100 - 30 = 70) settlement = bob_to_alice_settlements[0] assert settlement.amount == 70.0 assert settlement.fromUserId == str(user_b_id) assert settlement.toUserId == str(user_a_id) + @pytest.mark.asyncio async def test_update_expense_success(expense_service, mock_expense_data): """Test successful expense update""" from app.expenses.schemas import ExpenseUpdateRequest - + update_request = ExpenseUpdateRequest( - description="Updated Dinner", - amount=120.0 - ) - + description="Updated Dinner", amount=120.0) + updated_expense_data = mock_expense_data.copy() updated_expense_data["description"] = "Updated Dinner" updated_expense_data["amount"] = 120.0 - - with patch('app.expenses.service.mongodb') as mock_mongodb: + + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + # Mock finding the expense - mock_db.expenses.find_one = AsyncMock(side_effect=[mock_expense_data, updated_expense_data]) - + mock_db.expenses.find_one = AsyncMock( + side_effect=[mock_expense_data, updated_expense_data] + ) + # Mock user lookup - mock_db.users.find_one = AsyncMock(return_value={"_id": ObjectId("65f1a2b3c4d5e6f7a8b9c0d2"), "name": "Alice"}) - - # Mock update operation + mock_db.users.find_one = AsyncMock( + return_value={"_id": ObjectId( + "65f1a2b3c4d5e6f7a8b9c0d2"), "name": "Alice"} + ) + + # Mock update operation mock_update_result = MagicMock() mock_update_result.matched_count = 1 - mock_db.expenses.update_one = AsyncMock(return_value=mock_update_result) - - with patch.object(expense_service, '_expense_doc_to_response') as mock_response: - mock_response.return_value = {"id": "test_id", "description": "Updated Dinner"} - + mock_db.expenses.update_one = AsyncMock( + return_value=mock_update_result) + + with patch.object(expense_service, "_expense_doc_to_response") as mock_response: + mock_response.return_value = { + "id": "test_id", + "description": "Updated Dinner", + } + result = await expense_service.update_expense( "65f1a2b3c4d5e6f7a8b9c0d0", - "65f1a2b3c4d5e6f7a8b9c0d1", + "65f1a2b3c4d5e6f7a8b9c0d1", update_request, - "user_a" + "user_a", ) - + assert result is not None mock_db.expenses.update_one.assert_called_once() + @pytest.mark.asyncio async def test_update_expense_unauthorized(expense_service): """Test expense update by non-creator""" from app.expenses.schemas import ExpenseUpdateRequest - + update_request = ExpenseUpdateRequest(description="Unauthorized Update") - - with patch('app.expenses.service.mongodb') as mock_mongodb: + + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + # Mock finding no expense (user not creator) mock_db.expenses.find_one = AsyncMock(return_value=None) - - with pytest.raises(ValueError, match="Expense not found or not authorized to edit"): + + with pytest.raises( + ValueError, match="Expense not found or not authorized to edit" + ): await expense_service.update_expense( - "group_id", + "group_id", "65f1a2b3c4d5e6f7a8b9c0d1", - update_request, - "unauthorized_user" + update_request, + "unauthorized_user", ) + def test_expense_split_validation(): """Test expense split validation with proper assertions""" # Valid split - should not raise exception splits = [ ExpenseSplit(userId="user_a", amount=50.0), - ExpenseSplit(userId="user_b", amount=50.0) + ExpenseSplit(userId="user_b", amount=50.0), ] - + expense_request = ExpenseCreateRequest( - description="Test expense", - amount=100.0, - splits=splits + description="Test expense", amount=100.0, splits=splits ) - + # Verify the expense was created successfully assert expense_request.amount == 100.0 assert len(expense_request.splits) == 2 assert sum(split.amount for split in expense_request.splits) == 100.0 - + # Invalid split - should raise validation error - with pytest.raises(ValueError, match="Split amounts must sum to total expense amount"): + with pytest.raises( + ValueError, match="Split amounts must sum to total expense amount" + ): invalid_splits = [ ExpenseSplit(userId="user_a", amount=40.0), - ExpenseSplit(userId="user_b", amount=50.0) # Total 90, but expense is 100 + # Total 90, but expense is 100 + ExpenseSplit(userId="user_b", amount=50.0), ] - + ExpenseCreateRequest( - description="Test expense", - amount=100.0, - splits=invalid_splits + description="Test expense", amount=100.0, splits=invalid_splits ) + def test_split_types(): """Test different split types with proper validation""" # Equal split equal_splits = [ ExpenseSplit(userId="user_a", amount=33.33, type=SplitType.EQUAL), ExpenseSplit(userId="user_b", amount=33.33, type=SplitType.EQUAL), - ExpenseSplit(userId="user_c", amount=33.34, type=SplitType.EQUAL) + ExpenseSplit(userId="user_c", amount=33.34, type=SplitType.EQUAL), ] - + expense = ExpenseCreateRequest( description="Equal split expense", amount=100.0, splits=equal_splits, - splitType=SplitType.EQUAL + splitType=SplitType.EQUAL, ) - + assert expense.splitType == SplitType.EQUAL assert len(expense.splits) == 3 # Verify total with floating point tolerance total = sum(split.amount for split in expense.splits) assert abs(total - 100.0) < 0.01 - + # Unequal split unequal_splits = [ ExpenseSplit(userId="user_a", amount=60.0, type=SplitType.UNEQUAL), - ExpenseSplit(userId="user_b", amount=40.0, type=SplitType.UNEQUAL) + ExpenseSplit(userId="user_b", amount=40.0, type=SplitType.UNEQUAL), ] - + expense = ExpenseCreateRequest( - description="Unequal split expense", + description="Unequal split expense", amount=100.0, splits=unequal_splits, - splitType=SplitType.UNEQUAL + splitType=SplitType.UNEQUAL, ) - + assert expense.splitType == SplitType.UNEQUAL assert expense.splits[0].amount == 60.0 assert expense.splits[1].amount == 40.0 + @pytest.mark.asyncio async def test_get_expense_by_id_success(expense_service, mock_expense_data): """Test successful expense retrieval""" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + # Mock group membership check - mock_db.groups.find_one = AsyncMock(return_value={"_id": ObjectId("65f1a2b3c4d5e6f7a8b9c0d0")}) - + mock_db.groups.find_one = AsyncMock( + return_value={"_id": ObjectId("65f1a2b3c4d5e6f7a8b9c0d0")} + ) + # Mock expense lookup mock_db.expenses.find_one = AsyncMock(return_value=mock_expense_data) - + # Mock settlements lookup mock_cursor = AsyncMock() mock_cursor.to_list.return_value = [] mock_db.settlements.find.return_value = mock_cursor - - with patch.object(expense_service, '_expense_doc_to_response') as mock_response: - mock_response.return_value = {"id": "expense_id", "description": "Test Dinner"} - - result = await expense_service.get_expense_by_id("65f1a2b3c4d5e6f7a8b9c0d0", "65f1a2b3c4d5e6f7a8b9c0d1", "user_a") - + + with patch.object(expense_service, "_expense_doc_to_response") as mock_response: + mock_response.return_value = { + "id": "expense_id", + "description": "Test Dinner", + } + + result = await expense_service.get_expense_by_id( + "65f1a2b3c4d5e6f7a8b9c0d0", "65f1a2b3c4d5e6f7a8b9c0d1", "user_a" + ) + assert result is not None mock_db.groups.find_one.assert_called_once() mock_db.expenses.find_one.assert_called_once() + @pytest.mark.asyncio async def test_get_expense_by_id_not_found(expense_service): """Test expense retrieval when expense doesn't exist""" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - + # Mock group membership check - mock_db.groups.find_one = AsyncMock(return_value={"_id": ObjectId("65f1a2b3c4d5e6f7a8b9c0d0")}) - + mock_db.groups.find_one = AsyncMock( + return_value={"_id": ObjectId("65f1a2b3c4d5e6f7a8b9c0d0")} + ) + # Mock expense not found mock_db.expenses.find_one = AsyncMock(return_value=None) - + with pytest.raises(ValueError, match="Expense not found"): - await expense_service.get_expense_by_id("65f1a2b3c4d5e6f7a8b9c0d0", "65f1a2b3c4d5e6f7a8b9c0d1", "user_a") + await expense_service.get_expense_by_id( + "65f1a2b3c4d5e6f7a8b9c0d0", "65f1a2b3c4d5e6f7a8b9c0d1", "user_a" + ) + @pytest.mark.asyncio -async def test_list_group_expenses_success(expense_service, mock_group_data, mock_expense_data): +async def test_list_group_expenses_success( + expense_service, mock_group_data, mock_expense_data +): """Test successful listing of group expenses""" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -441,18 +498,29 @@ async def test_list_group_expenses_success(expense_service, mock_group_data, moc # Mock expense lookup mock_expense_cursor = AsyncMock() mock_expense_cursor.to_list.return_value = [mock_expense_data] - mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_expense_cursor + mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_expense_cursor + ) mock_db.expenses.count_documents = AsyncMock(return_value=1) # Mock aggregation for summary mock_aggregate_cursor = AsyncMock() - mock_aggregate_cursor.to_list.return_value = [{"totalAmount": 100.0, "expenseCount": 1, "avgExpense": 100.0}] + mock_aggregate_cursor.to_list.return_value = [ + {"totalAmount": 100.0, "expenseCount": 1, "avgExpense": 100.0} + ] mock_db.expenses.aggregate.return_value = mock_aggregate_cursor - with patch.object(expense_service, '_expense_doc_to_response', new_callable=AsyncMock) as mock_response: - mock_response.return_value = {"id": "expense_id", "description": "Test Dinner"} + with patch.object( + expense_service, "_expense_doc_to_response", new_callable=AsyncMock + ) as mock_response: + mock_response.return_value = { + "id": "expense_id", + "description": "Test Dinner", + } - result = await expense_service.list_group_expenses("65f1a2b3c4d5e6f7a8b9c0d0", "user_a") + result = await expense_service.list_group_expenses( + "65f1a2b3c4d5e6f7a8b9c0d0", "user_a" + ) assert result is not None assert "expenses" in result @@ -466,78 +534,107 @@ async def test_list_group_expenses_success(expense_service, mock_group_data, moc mock_db.expenses.count_documents.assert_called_once() mock_db.expenses.aggregate.assert_called_once() + @pytest.mark.asyncio async def test_list_group_expenses_empty(expense_service, mock_group_data): """Test listing group expenses when there are none""" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) mock_expense_cursor = AsyncMock() - mock_expense_cursor.to_list.return_value = [] # No expenses - mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_expense_cursor + mock_expense_cursor.to_list.return_value = [] # No expenses + mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_expense_cursor + ) mock_db.expenses.count_documents = AsyncMock(return_value=0) mock_aggregate_cursor = AsyncMock() - mock_aggregate_cursor.to_list.return_value = [] # No summary + mock_aggregate_cursor.to_list.return_value = [] # No summary mock_db.expenses.aggregate.return_value = mock_aggregate_cursor - result = await expense_service.list_group_expenses("65f1a2b3c4d5e6f7a8b9c0d0", "user_a") + result = await expense_service.list_group_expenses( + "65f1a2b3c4d5e6f7a8b9c0d0", "user_a" + ) assert result is not None assert len(result["expenses"]) == 0 assert result["pagination"]["total"] == 0 assert result["summary"]["totalAmount"] == 0 + @pytest.mark.asyncio -async def test_list_group_expenses_pagination(expense_service, mock_group_data, mock_expense_data): +async def test_list_group_expenses_pagination( + expense_service, mock_group_data, mock_expense_data +): """Test pagination for listing group expenses""" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) # Simulate 5 expenses, limit 2, page 2 - expenses_page_2 = [mock_expense_data, mock_expense_data] # Dummy data for page 2 + expenses_page_2 = [ + mock_expense_data, + mock_expense_data, + ] # Dummy data for page 2 mock_expense_cursor = AsyncMock() mock_expense_cursor.to_list.return_value = expenses_page_2 - mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_expense_cursor - mock_db.expenses.count_documents = AsyncMock(return_value=5) # Total 5 expenses + mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_expense_cursor + ) + mock_db.expenses.count_documents = AsyncMock( + return_value=5) # Total 5 expenses mock_aggregate_cursor = AsyncMock() - mock_aggregate_cursor.to_list.return_value = [{"totalAmount": 200.0, "expenseCount": 2, "avgExpense": 100.0}] + mock_aggregate_cursor.to_list.return_value = [ + {"totalAmount": 200.0, "expenseCount": 2, "avgExpense": 100.0} + ] mock_db.expenses.aggregate.return_value = mock_aggregate_cursor - with patch.object(expense_service, '_expense_doc_to_response', new_callable=AsyncMock) as mock_response: + with patch.object( + expense_service, "_expense_doc_to_response", new_callable=AsyncMock + ) as mock_response: # Each call to _expense_doc_to_response will return a unique dict to simulate different expenses - mock_response.side_effect = [{"id": "expense_1", "description": "Dinner 1"}, {"id": "expense_2", "description": "Dinner 2"}] + mock_response.side_effect = [ + {"id": "expense_1", "description": "Dinner 1"}, + {"id": "expense_2", "description": "Dinner 2"}, + ] - result = await expense_service.list_group_expenses("65f1a2b3c4d5e6f7a8b9c0d0", "user_a", page=2, limit=2) + result = await expense_service.list_group_expenses( + "65f1a2b3c4d5e6f7a8b9c0d0", "user_a", page=2, limit=2 + ) assert len(result["expenses"]) == 2 assert result["pagination"]["page"] == 2 assert result["pagination"]["limit"] == 2 assert result["pagination"]["total"] == 5 - assert result["pagination"]["totalPages"] == 3 # (5 + 2 - 1) // 2 + assert result["pagination"]["totalPages"] == 3 # (5 + 2 - 1) // 2 assert result["pagination"]["hasNext"] is True assert result["pagination"]["hasPrev"] is True # Check skip value: (page - 1) * limit = (2 - 1) * 2 = 2 - mock_db.expenses.find.return_value.sort.return_value.skip.assert_called_with(2) - mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with(2) + mock_db.expenses.find.return_value.sort.return_value.skip.assert_called_with( + 2 + ) + mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with( + 2 + ) @pytest.mark.asyncio -async def test_list_group_expenses_filters(expense_service, mock_group_data, mock_expense_data): +async def test_list_group_expenses_filters( + expense_service, mock_group_data, mock_expense_data +): """Test filters (date, tags) for listing group expenses""" from_date = datetime(2023, 1, 1, tzinfo=timezone.utc) to_date = datetime(2023, 1, 31, tzinfo=timezone.utc) tags = ["food", "urgent"] - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -545,19 +642,31 @@ async def test_list_group_expenses_filters(expense_service, mock_group_data, moc mock_expense_cursor = AsyncMock() mock_expense_cursor.to_list.return_value = [mock_expense_data] - mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_expense_cursor + mock_db.expenses.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_expense_cursor + ) mock_db.expenses.count_documents = AsyncMock(return_value=1) mock_aggregate_cursor = AsyncMock() - mock_aggregate_cursor.to_list.return_value = [{"totalAmount": 100.0, "expenseCount": 1, "avgExpense": 100.0}] + mock_aggregate_cursor.to_list.return_value = [ + {"totalAmount": 100.0, "expenseCount": 1, "avgExpense": 100.0} + ] mock_db.expenses.aggregate.return_value = mock_aggregate_cursor - with patch.object(expense_service, '_expense_doc_to_response', new_callable=AsyncMock) as mock_response: - mock_response.return_value = {"id": "expense_id", "description": "Filtered Dinner"} + with patch.object( + expense_service, "_expense_doc_to_response", new_callable=AsyncMock + ) as mock_response: + mock_response.return_value = { + "id": "expense_id", + "description": "Filtered Dinner", + } await expense_service.list_group_expenses( - "65f1a2b3c4d5e6f7a8b9c0d0", "user_a", - from_date=from_date, to_date=to_date, tags=tags + "65f1a2b3c4d5e6f7a8b9c0d0", + "user_a", + from_date=from_date, + to_date=to_date, + tags=tags, ) # Check if find query was called with correct filters @@ -583,13 +692,17 @@ async def test_list_group_expenses_filters(expense_service, mock_group_data, moc async def test_list_group_expenses_group_not_found(expense_service): """Test listing expenses when group is not found or user not member""" valid_but_non_existent_group_id = str(ObjectId()) - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # Group not found + mock_db.groups.find_one = AsyncMock( + return_value=None) # Group not found with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.list_group_expenses(valid_but_non_existent_group_id, "user_a") + await expense_service.list_group_expenses( + valid_but_non_existent_group_id, "user_a" + ) + @pytest.mark.asyncio async def test_delete_expense_success(expense_service, mock_expense_data): @@ -598,7 +711,7 @@ async def test_delete_expense_success(expense_service, mock_expense_data): expense_id = str(mock_expense_data["_id"]) user_id = mock_expense_data["createdBy"] - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -608,47 +721,59 @@ async def test_delete_expense_success(expense_service, mock_expense_data): # Mock successful deletion of expense mock_delete_expense_result = MagicMock() mock_delete_expense_result.deleted_count = 1 - mock_db.expenses.delete_one = AsyncMock(return_value=mock_delete_expense_result) + mock_db.expenses.delete_one = AsyncMock( + return_value=mock_delete_expense_result) # Mock successful deletion of related settlements mock_delete_settlements_result = MagicMock() - mock_delete_settlements_result.deleted_count = 2 # Assume 2 settlements deleted - mock_db.settlements.delete_many = AsyncMock(return_value=mock_delete_settlements_result) + mock_delete_settlements_result.deleted_count = 2 # Assume 2 settlements deleted + mock_db.settlements.delete_many = AsyncMock( + return_value=mock_delete_settlements_result + ) result = await expense_service.delete_expense(group_id, expense_id, user_id) assert result is True - mock_db.expenses.find_one.assert_called_once_with({ - "_id": ObjectId(expense_id), - "groupId": group_id, - "createdBy": user_id - }) - mock_db.settlements.delete_many.assert_called_once_with({"expenseId": expense_id}) - mock_db.expenses.delete_one.assert_called_once_with({"_id": ObjectId(expense_id)}) + mock_db.expenses.find_one.assert_called_once_with( + {"_id": ObjectId(expense_id), "groupId": group_id, + "createdBy": user_id} + ) + mock_db.settlements.delete_many.assert_called_once_with( + {"expenseId": expense_id} + ) + mock_db.expenses.delete_one.assert_called_once_with( + {"_id": ObjectId(expense_id)} + ) + @pytest.mark.asyncio async def test_delete_expense_not_found(expense_service): """Test deleting an expense that is not found or user not authorized""" - group_id = str(ObjectId()) # Valid format - expense_id = str(ObjectId()) # Valid format - user_id = "user_id_test" # This is used for matching createdBy, can be string + group_id = str(ObjectId()) # Valid format + expense_id = str(ObjectId()) # Valid format + user_id = "user_id_test" # This is used for matching createdBy, can be string - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db # Mock finding no expense mock_db.expenses.find_one = AsyncMock(return_value=None) - mock_db.settlements.delete_many = AsyncMock() # Should not be called if expense not found - mock_db.expenses.delete_one = AsyncMock() # Should not be called + mock_db.settlements.delete_many = ( + AsyncMock() + ) # Should not be called if expense not found + mock_db.expenses.delete_one = AsyncMock() # Should not be called - with pytest.raises(ValueError, match="Expense not found or not authorized to delete"): + with pytest.raises( + ValueError, match="Expense not found or not authorized to delete" + ): await expense_service.delete_expense(group_id, expense_id, user_id) mock_db.settlements.delete_many.assert_not_called() mock_db.expenses.delete_one.assert_not_called() + @pytest.mark.asyncio async def test_delete_expense_failed_deletion(expense_service, mock_expense_data): """Test scenario where expense deletion from DB fails""" @@ -656,31 +781,34 @@ async def test_delete_expense_failed_deletion(expense_service, mock_expense_data expense_id = str(mock_expense_data["_id"]) user_id = mock_expense_data["createdBy"] - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.expenses.find_one = AsyncMock(return_value=mock_expense_data) mock_delete_expense_result = MagicMock() - mock_delete_expense_result.deleted_count = 0 # Simulate DB deletion failure - mock_db.expenses.delete_one = AsyncMock(return_value=mock_delete_expense_result) + mock_delete_expense_result.deleted_count = 0 # Simulate DB deletion failure + mock_db.expenses.delete_one = AsyncMock( + return_value=mock_delete_expense_result) mock_db.settlements.delete_many = AsyncMock() result = await expense_service.delete_expense(group_id, expense_id, user_id) - assert result is False # Deletion failed - mock_db.settlements.delete_many.assert_called_once() # Settlements should still be attempted to be deleted + assert result is False # Deletion failed + # Settlements should still be attempted to be deleted + mock_db.settlements.delete_many.assert_called_once() mock_db.expenses.delete_one.assert_called_once() + @pytest.mark.asyncio async def test_create_manual_settlement_success(expense_service, mock_group_data): """Test successful creation of a manual settlement""" from app.expenses.schemas import SettlementCreateRequest group_id = str(mock_group_data["_id"]) - user_id = "user_a" # User creating the settlement + user_id = "user_a" # User creating the settlement payer_id_obj = ObjectId() payee_id_obj = ObjectId() payer_id_str = str(payer_id_obj) @@ -690,13 +818,13 @@ async def test_create_manual_settlement_success(expense_service, mock_group_data payer_id=payer_id_str, payee_id=payee_id_str, amount=50.0, - description="Manual payback" + description="Manual payback", ) mock_user_b_data = {"_id": payer_id_obj, "name": "User B"} mock_user_c_data = {"_id": payee_id_obj, "name": "User C"} - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -714,18 +842,23 @@ def sync_mock_user_find_cursor_factory(query, *args, **kwargs): if payee_id_obj in ids_in_query_objs: users_to_return.append(mock_user_c_data) - cursor_mock = AsyncMock() # This is the cursor mock - cursor_mock.to_list = AsyncMock(return_value=users_to_return) # .to_list() is an async method on the cursor - return cursor_mock # The factory returns the configured cursor mock + cursor_mock = AsyncMock() # This is the cursor mock + cursor_mock.to_list = AsyncMock( + return_value=users_to_return + ) # .to_list() is an async method on the cursor + return cursor_mock # The factory returns the configured cursor mock # mock_db.users.find is a MagicMock because .find() is a synchronous method. # Its side_effect (our factory) is called when mock_db.users.find() is invoked. - mock_db.users.find = MagicMock(side_effect=sync_mock_user_find_cursor_factory) + mock_db.users.find = MagicMock( + side_effect=sync_mock_user_find_cursor_factory) # Mock settlement insertion mock_db.settlements.insert_one = AsyncMock() - result = await expense_service.create_manual_settlement(group_id, settlement_request, user_id) + result = await expense_service.create_manual_settlement( + group_id, settlement_request, user_id + ) assert result is not None assert result.groupId == group_id @@ -733,42 +866,47 @@ def sync_mock_user_find_cursor_factory(query, *args, **kwargs): assert result.payeeId == payee_id_str assert result.amount == 50.0 assert result.description == "Manual payback" - assert result.status == "completed" # Manual settlements are marked completed + assert result.status == "completed" # Manual settlements are marked completed assert result.payerName == "User B" assert result.payeeName == "User C" - mock_db.groups.find_one.assert_called_once_with({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) + mock_db.groups.find_one.assert_called_once_with( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) mock_db.users.find.assert_called_once() mock_db.settlements.insert_one.assert_called_once() inserted_doc = mock_db.settlements.insert_one.call_args[0][0] - assert inserted_doc["expenseId"] is None # Manual settlements have no expenseId + # Manual settlements have no expenseId + assert inserted_doc["expenseId"] is None + @pytest.mark.asyncio async def test_create_manual_settlement_group_not_found(expense_service): """Test creating manual settlement when group is not found or user not member""" from app.expenses.schemas import SettlementCreateRequest - group_id = str(ObjectId()) # Valid format + group_id = str(ObjectId()) # Valid format user_id = "user_a" settlement_request = SettlementCreateRequest( - payer_id=str(ObjectId()), # Valid format - payee_id=str(ObjectId()), # Valid format - amount=50.0 + payer_id=str(ObjectId()), # Valid format + payee_id=str(ObjectId()), # Valid format + amount=50.0, ) - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # Group not found + mock_db.groups.find_one = AsyncMock( + return_value=None) # Group not found with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.create_manual_settlement(group_id, settlement_request, user_id) + await expense_service.create_manual_settlement( + group_id, settlement_request, user_id + ) mock_db.settlements.insert_one.assert_not_called() + @pytest.mark.asyncio async def test_get_group_settlements_success(expense_service, mock_group_data): """Test successful listing of group settlements""" @@ -776,12 +914,19 @@ async def test_get_group_settlements_success(expense_service, mock_group_data): user_id = "user_a" mock_settlement_doc = { - "_id": ObjectId(), "groupId": group_id, "payerId": "user_b", "payeeId": "user_c", - "amount": 50.0, "status": "pending", "description": "A settlement", - "createdAt": datetime.now(timezone.utc), "payerName": "User B", "payeeName": "User C" + "_id": ObjectId(), + "groupId": group_id, + "payerId": "user_b", + "payeeId": "user_c", + "amount": 50.0, + "status": "pending", + "description": "A settlement", + "createdAt": datetime.now(timezone.utc), + "payerName": "User B", + "payeeName": "User C", } - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -789,7 +934,9 @@ async def test_get_group_settlements_success(expense_service, mock_group_data): mock_settlements_cursor = AsyncMock() mock_settlements_cursor.to_list.return_value = [mock_settlement_doc] - mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_settlements_cursor + mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_settlements_cursor + ) mock_db.settlements.count_documents = AsyncMock(return_value=1) result = await expense_service.get_group_settlements(group_id, user_id) @@ -807,13 +954,20 @@ async def test_get_group_settlements_success(expense_service, mock_group_data): mock_db.settlements.find.assert_called_once() mock_db.settlements.count_documents.assert_called_once() # Check default sort, skip, limit - mock_db.settlements.find.return_value.sort.assert_called_with("createdAt", -1) - mock_db.settlements.find.return_value.sort.return_value.skip.assert_called_with(0) # (1-1)*50 - mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with(50) + mock_db.settlements.find.return_value.sort.assert_called_with( + "createdAt", -1) + mock_db.settlements.find.return_value.sort.return_value.skip.assert_called_with( + 0 + ) # (1-1)*50 + mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with( + 50 + ) @pytest.mark.asyncio -async def test_get_group_settlements_with_filters_and_pagination(expense_service, mock_group_data): +async def test_get_group_settlements_with_filters_and_pagination( + expense_service, mock_group_data +): """Test listing group settlements with status filter and pagination""" group_id = str(mock_group_data["_id"]) user_id = "user_a" @@ -822,23 +976,38 @@ async def test_get_group_settlements_with_filters_and_pagination(expense_service limit = 10 mock_settlement_doc = { - "_id": ObjectId(), "groupId": group_id, "payerId": "user_b", "payeeId": "user_c", - "amount": 50.0, "status": "completed", "description": "A settlement", - "createdAt": datetime.now(timezone.utc), "payerName": "User B", "payeeName": "User C" + "_id": ObjectId(), + "groupId": group_id, + "payerId": "user_b", + "payeeId": "user_c", + "amount": 50.0, + "status": "completed", + "description": "A settlement", + "createdAt": datetime.now(timezone.utc), + "payerName": "User B", + "payeeName": "User C", } - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) mock_settlements_cursor = AsyncMock() - mock_settlements_cursor.to_list.return_value = [mock_settlement_doc] * 5 # Simulate 5 settlements for this page - mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.return_value = mock_settlements_cursor - mock_db.settlements.count_documents = AsyncMock(return_value=15) # Total 15 settlements matching filter + mock_settlements_cursor.to_list.return_value = [ + mock_settlement_doc + ] * 5 # Simulate 5 settlements for this page + mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.return_value = ( + mock_settlements_cursor + ) + mock_db.settlements.count_documents = AsyncMock( + return_value=15 + ) # Total 15 settlements matching filter - result = await expense_service.get_group_settlements(group_id, user_id, status_filter=status_filter, page=page, limit=limit) + result = await expense_service.get_group_settlements( + group_id, user_id, status_filter=status_filter, page=page, limit=limit + ) assert len(result["settlements"]) == 5 assert result["total"] == 15 @@ -856,19 +1025,25 @@ async def test_get_group_settlements_with_filters_and_pagination(expense_service assert count_call_args["status"] == status_filter # Verify skip and limit - mock_db.settlements.find.return_value.sort.return_value.skip.assert_called_with((page - 1) * limit) - mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with(limit) + mock_db.settlements.find.return_value.sort.return_value.skip.assert_called_with( + (page - 1) * limit + ) + mock_db.settlements.find.return_value.sort.return_value.skip.return_value.limit.assert_called_with( + limit + ) + @pytest.mark.asyncio async def test_get_group_settlements_group_not_found(expense_service): """Test listing settlements when group not found or user not member""" - group_id = str(ObjectId()) # Valid format + group_id = str(ObjectId()) # Valid format user_id = "user_a" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # Group not found + mock_db.groups.find_one = AsyncMock( + return_value=None) # Group not found with pytest.raises(ValueError, match="Group not found or user not a member"): await expense_service.get_group_settlements(group_id, user_id) @@ -876,6 +1051,7 @@ async def test_get_group_settlements_group_not_found(expense_service): mock_db.settlements.find.assert_not_called() mock_db.settlements.count_documents.assert_not_called() + @pytest.mark.asyncio async def test_get_settlement_by_id_success(expense_service, mock_group_data): """Test successful retrieval of a settlement by ID""" @@ -885,51 +1061,64 @@ async def test_get_settlement_by_id_success(expense_service, mock_group_data): settlement_id_str = str(settlement_id_obj) mock_settlement_doc = { - "_id": settlement_id_obj, "groupId": group_id, "payerId": "user_b", - "payeeId": "user_c", "amount": 75.0, "status": "pending", - "description": "Specific settlement", "createdAt": datetime.now(timezone.utc), - "payerName": "User B", "payeeName": "User C" + "_id": settlement_id_obj, + "groupId": group_id, + "payerId": "user_b", + "payeeId": "user_c", + "amount": 75.0, + "status": "pending", + "description": "Specific settlement", + "createdAt": datetime.now(timezone.utc), + "payerName": "User B", + "payeeName": "User C", } - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) - mock_db.settlements.find_one = AsyncMock(return_value=mock_settlement_doc) + mock_db.settlements.find_one = AsyncMock( + return_value=mock_settlement_doc) - result = await expense_service.get_settlement_by_id(group_id, settlement_id_str, user_id) + result = await expense_service.get_settlement_by_id( + group_id, settlement_id_str, user_id + ) assert result is not None - assert result.id == settlement_id_str # Changed from _id to id + assert result.id == settlement_id_str # Changed from _id to id assert result.amount == 75.0 assert result.description == "Specific settlement" - mock_db.groups.find_one.assert_called_once_with({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) - mock_db.settlements.find_one.assert_called_once_with({ - "_id": ObjectId(settlement_id_str), - "groupId": group_id - }) + mock_db.groups.find_one.assert_called_once_with( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) + mock_db.settlements.find_one.assert_called_once_with( + {"_id": ObjectId(settlement_id_str), "groupId": group_id} + ) + @pytest.mark.asyncio async def test_get_settlement_by_id_not_found(expense_service, mock_group_data): """Test retrieving a settlement by ID when it's not found""" group_id = str(mock_group_data["_id"]) user_id = "user_a" - settlement_id_str = str(ObjectId()) # Non-existent ID + settlement_id_str = str(ObjectId()) # Non-existent ID - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) - mock_db.settlements.find_one = AsyncMock(return_value=None) # Settlement not found + mock_db.settlements.find_one = AsyncMock( + return_value=None + ) # Settlement not found with pytest.raises(ValueError, match="Settlement not found"): - await expense_service.get_settlement_by_id(group_id, settlement_id_str, user_id) + await expense_service.get_settlement_by_id( + group_id, settlement_id_str, user_id + ) + @pytest.mark.asyncio async def test_get_settlement_by_id_group_access_denied(expense_service): @@ -938,17 +1127,22 @@ async def test_get_settlement_by_id_group_access_denied(expense_service): user_id = "user_a" settlement_id_str = str(ObjectId()) - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # User not in group / group doesn't exist + mock_db.groups.find_one = AsyncMock( + return_value=None + ) # User not in group / group doesn't exist with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.get_settlement_by_id(group_id, settlement_id_str, user_id) + await expense_service.get_settlement_by_id( + group_id, settlement_id_str, user_id + ) mock_db.settlements.find_one.assert_not_called() + @pytest.mark.asyncio async def test_update_settlement_status_success(expense_service): """Test successful update of settlement status""" @@ -962,46 +1156,61 @@ async def test_update_settlement_status_success(expense_service): # Original settlement doc (before update) original_settlement_doc = { - "_id": settlement_id_obj, "groupId": group_id, "status": "pending", - "payerId": "p1", "payeeId": "p2", "amount": 10, "payerName": "P1", "payeeName": "P2", - "createdAt": datetime.now(timezone.utc) - timedelta(days=1) + "_id": settlement_id_obj, + "groupId": group_id, + "status": "pending", + "payerId": "p1", + "payeeId": "p2", + "amount": 10, + "payerName": "P1", + "payeeName": "P2", + "createdAt": datetime.now(timezone.utc) - timedelta(days=1), } # Settlement doc after update updated_settlement_doc = original_settlement_doc.copy() updated_settlement_doc["status"] = new_status.value updated_settlement_doc["paidAt"] = paid_at_time - updated_settlement_doc["updatedAt"] = datetime.now(timezone.utc) # Will be set by the method + updated_settlement_doc["updatedAt"] = datetime.now( + timezone.utc + ) # Will be set by the method - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_update_result = MagicMock() mock_update_result.matched_count = 1 - mock_db.settlements.update_one = AsyncMock(return_value=mock_update_result) + mock_db.settlements.update_one = AsyncMock( + return_value=mock_update_result) # find_one is called to retrieve the updated document - mock_db.settlements.find_one = AsyncMock(return_value=updated_settlement_doc) + mock_db.settlements.find_one = AsyncMock( + return_value=updated_settlement_doc) result = await expense_service.update_settlement_status( group_id, settlement_id_str, new_status, paid_at=paid_at_time ) assert result is not None - assert result.id == settlement_id_str # Changed from _id to id + assert result.id == settlement_id_str # Changed from _id to id assert result.status == new_status.value assert result.paidAt == paid_at_time mock_db.settlements.update_one.assert_called_once() update_call_args = mock_db.settlements.update_one.call_args[0] - assert update_call_args[0] == {"_id": settlement_id_obj, "groupId": group_id} # Filter query + assert update_call_args[0] == { + "_id": settlement_id_obj, + "groupId": group_id, + } # Filter query assert "$set" in update_call_args[1] set_doc = update_call_args[1]["$set"] assert set_doc["status"] == new_status.value assert set_doc["paidAt"] == paid_at_time assert "updatedAt" in set_doc - mock_db.settlements.find_one.assert_called_once_with({"_id": settlement_id_obj}) + mock_db.settlements.find_one.assert_called_once_with( + {"_id": settlement_id_obj}) + @pytest.mark.asyncio async def test_update_settlement_status_not_found(expense_service): @@ -1009,36 +1218,38 @@ async def test_update_settlement_status_not_found(expense_service): from app.expenses.schemas import SettlementStatus group_id = str(ObjectId()) - settlement_id_str = str(ObjectId()) # Non-existent ID + settlement_id_str = str(ObjectId()) # Non-existent ID new_status = SettlementStatus.CANCELLED - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_update_result = MagicMock() - mock_update_result.matched_count = 0 # Simulate settlement not found - mock_db.settlements.update_one = AsyncMock(return_value=mock_update_result) + mock_update_result.matched_count = 0 # Simulate settlement not found + mock_db.settlements.update_one = AsyncMock( + return_value=mock_update_result) mock_db.settlements.find_one = AsyncMock(return_value=None) - with pytest.raises(ValueError, match="Settlement not found"): await expense_service.update_settlement_status( group_id, settlement_id_str, new_status ) - mock_db.settlements.find_one.assert_not_called() # Should not be called if update fails + # Should not be called if update fails + mock_db.settlements.find_one.assert_not_called() + @pytest.mark.asyncio async def test_delete_settlement_success(expense_service, mock_group_data): """Test successful deletion of a settlement""" group_id = str(mock_group_data["_id"]) - user_id = "user_a" # User performing the deletion + user_id = "user_a" # User performing the deletion settlement_id_obj = ObjectId() settlement_id_str = str(settlement_id_obj) - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -1048,41 +1259,47 @@ async def test_delete_settlement_success(expense_service, mock_group_data): # Mock successful deletion mock_delete_result = MagicMock() mock_delete_result.deleted_count = 1 - mock_db.settlements.delete_one = AsyncMock(return_value=mock_delete_result) + mock_db.settlements.delete_one = AsyncMock( + return_value=mock_delete_result) - result = await expense_service.delete_settlement(group_id, settlement_id_str, user_id) + result = await expense_service.delete_settlement( + group_id, settlement_id_str, user_id + ) assert result is True - mock_db.groups.find_one.assert_called_once_with({ - "_id": ObjectId(group_id), - "members.userId": user_id - }) - mock_db.settlements.delete_one.assert_called_once_with({ - "_id": ObjectId(settlement_id_str), - "groupId": group_id - }) + mock_db.groups.find_one.assert_called_once_with( + {"_id": ObjectId(group_id), "members.userId": user_id} + ) + mock_db.settlements.delete_one.assert_called_once_with( + {"_id": ObjectId(settlement_id_str), "groupId": group_id} + ) + @pytest.mark.asyncio async def test_delete_settlement_not_found(expense_service, mock_group_data): """Test deleting a settlement that is not found""" group_id = str(mock_group_data["_id"]) user_id = "user_a" - settlement_id_str = str(ObjectId()) # Non-existent ID + settlement_id_str = str(ObjectId()) # Non-existent ID - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_db.groups.find_one = AsyncMock(return_value=mock_group_data) mock_delete_result = MagicMock() - mock_delete_result.deleted_count = 0 # Simulate not found - mock_db.settlements.delete_one = AsyncMock(return_value=mock_delete_result) + mock_delete_result.deleted_count = 0 # Simulate not found + mock_db.settlements.delete_one = AsyncMock( + return_value=mock_delete_result) - result = await expense_service.delete_settlement(group_id, settlement_id_str, user_id) + result = await expense_service.delete_settlement( + group_id, settlement_id_str, user_id + ) assert result is False + @pytest.mark.asyncio async def test_delete_settlement_group_access_denied(expense_service): """Test deleting settlement when user not member of the group""" @@ -1090,24 +1307,28 @@ async def test_delete_settlement_group_access_denied(expense_service): user_id = "user_a" settlement_id_str = str(ObjectId()) - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # User not in group + mock_db.groups.find_one = AsyncMock( + return_value=None) # User not in group with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.delete_settlement(group_id, settlement_id_str, user_id) + await expense_service.delete_settlement( + group_id, settlement_id_str, user_id + ) mock_db.settlements.delete_one.assert_not_called() + @pytest.mark.asyncio async def test_get_user_balance_in_group_success(expense_service, mock_group_data): """Test successful retrieval of a user's balance in a group""" group_id = str(mock_group_data["_id"]) target_user_id_obj = ObjectId() target_user_id_str = str(target_user_id_obj) - current_user_id = "user_a" # User making the request + current_user_id = "user_a" # User making the request mock_target_user_doc = {"_id": target_user_id_obj, "name": "User B Target"} @@ -1116,25 +1337,37 @@ async def test_get_user_balance_in_group_success(expense_service, mock_group_dat # User C paid 50 for User B (User B owes User C 50) # Net for User B: Paid 100, Owed 50. Net Balance = 50 (User B is owed 50 overall) mock_settlements_aggregate = [ - {"_id": None, "totalPaid": 100.0, "totalOwed": 50.0} - ] - mock_pending_settlements_docs = [ # User B is payee, i.e. is owed + {"_id": None, "totalPaid": 100.0, "totalOwed": 50.0}] + mock_pending_settlements_docs = [ # User B is payee, i.e. is owed { - "_id": ObjectId(), "groupId": group_id, "payerId": "user_a", "payeeId": target_user_id_str, - "amount": 100.0, "status": "pending", "description": "Owed to B", - "createdAt": datetime.now(timezone.utc), "payerName": "User A", "payeeName": "User B Target" + "_id": ObjectId(), + "groupId": group_id, + "payerId": "user_a", + "payeeId": target_user_id_str, + "amount": 100.0, + "status": "pending", + "description": "Owed to B", + "createdAt": datetime.now(timezone.utc), + "payerName": "User A", + "payeeName": "User B Target", } ] - mock_recent_expenses_docs = [ # Expense created by B, B also has a split + mock_recent_expenses_docs = [ # Expense created by B, B also has a split { - "_id": ObjectId(), "groupId": group_id, "createdBy": target_user_id_str, - "description": "Lunch by B", "amount": 150.0, - "splits": [{"userId": target_user_id_str, "amount": 75.0}, {"userId": "user_c", "amount": 75.0}], - "createdAt": datetime.now(timezone.utc) + "_id": ObjectId(), + "groupId": group_id, + "createdBy": target_user_id_str, + "description": "Lunch by B", + "amount": 150.0, + "splits": [ + {"userId": target_user_id_str, "amount": 75.0}, + {"userId": "user_c", "amount": 75.0}, + ], + "createdAt": datetime.now(timezone.utc), } ] - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -1151,24 +1384,33 @@ async def test_get_user_balance_in_group_success(expense_service, mock_group_dat # Mock pending settlements find mock_pending_cursor = AsyncMock() mock_pending_cursor.to_list.return_value = mock_pending_settlements_docs - mock_db.settlements.find.return_value = mock_pending_cursor # This is the first .find() call + mock_db.settlements.find.return_value = ( + mock_pending_cursor # This is the first .find() call + ) # Mock recent expenses find mock_expenses_cursor = AsyncMock() mock_expenses_cursor.to_list.return_value = mock_recent_expenses_docs # Ensure the second .find() call (for expenses) is correctly patched - mock_db.expenses.find.return_value.sort.return_value.limit.return_value = mock_expenses_cursor - + mock_db.expenses.find.return_value.sort.return_value.limit.return_value = ( + mock_expenses_cursor + ) - result = await expense_service.get_user_balance_in_group(group_id, target_user_id_str, current_user_id) + result = await expense_service.get_user_balance_in_group( + group_id, target_user_id_str, current_user_id + ) assert result is not None assert result["userId"] == target_user_id_str assert result["userName"] == "User B Target" assert result["totalPaid"] == 100.0 assert result["totalOwed"] == 50.0 - assert result["netBalance"] == 50.0 # 100 - 50 - assert result["owesYou"] is True # Net balance is positive, so target_user_id is owed money (by others in general) + assert result["netBalance"] == 50.0 # 100 - 50 + assert ( + result["owesYou"] + is True + # Net balance is positive, so target_user_id is owed money (by others in general) + ) assert len(result["pendingSettlements"]) == 1 assert result["pendingSettlements"][0].amount == 100.0 @@ -1177,40 +1419,49 @@ async def test_get_user_balance_in_group_success(expense_service, mock_group_dat assert result["recentExpenses"][0]["description"] == "Lunch by B" assert result["recentExpenses"][0]["userShare"] == 75.0 - mock_db.groups.find_one.assert_called_once_with({ - "_id": ObjectId(group_id), "members.userId": current_user_id - }) - mock_db.users.find_one.assert_called_once_with({"_id": target_user_id_obj}) + mock_db.groups.find_one.assert_called_once_with( + {"_id": ObjectId(group_id), "members.userId": current_user_id} + ) + mock_db.users.find_one.assert_called_once_with( + {"_id": target_user_id_obj}) mock_db.settlements.aggregate.assert_called_once() # Check the two find calls to settlements and expenses collections settlements_find_call_args = mock_db.settlements.find.call_args[0][0] - assert settlements_find_call_args["payeeId"] == target_user_id_str # For pending settlements + assert ( + settlements_find_call_args["payeeId"] == target_user_id_str + ) # For pending settlements expenses_find_call_args = mock_db.expenses.find.call_args[0][0] - assert "$or" in expenses_find_call_args # For recent expenses + assert "$or" in expenses_find_call_args # For recent expenses @pytest.mark.asyncio async def test_get_user_balance_in_group_access_denied(expense_service): """Test get user balance when current user not in group""" group_id = str(ObjectId()) - target_user_id_str = str(ObjectId()) # Use a valid ObjectId string for target - current_user_id = "user_x" # Not in group + # Use a valid ObjectId string for target + target_user_id_str = str(ObjectId()) + current_user_id = "user_x" # Not in group - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # Current user not member + mock_db.groups.find_one = AsyncMock( + return_value=None + ) # Current user not member with pytest.raises(ValueError, match="Group not found or user not a member"): - await expense_service.get_user_balance_in_group(group_id, target_user_id_str, current_user_id) + await expense_service.get_user_balance_in_group( + group_id, target_user_id_str, current_user_id + ) mock_db.users.find_one.assert_not_called() mock_db.settlements.aggregate.assert_not_called() mock_db.settlements.find.assert_not_called() mock_db.expenses.find.assert_not_called() + @pytest.mark.asyncio async def test_get_friends_balance_summary_success(expense_service): """Test successful retrieval of friends balance summary""" @@ -1221,7 +1472,8 @@ async def test_get_friends_balance_summary_success(expense_service): friend1_id_str = str(friend1_id_obj) friend2_id_str = str(friend2_id_obj) - group1_id = str(ObjectId()) # Remains as string, used for direct comparison in mock + # Remains as string, used for direct comparison in mock + group1_id = str(ObjectId()) group2_id = str(ObjectId()) mock_user_main_doc = {"_id": user_id_obj, "name": "Main User"} @@ -1230,13 +1482,19 @@ async def test_get_friends_balance_summary_success(expense_service): mock_groups_data = [ { - "_id": ObjectId(group1_id), "name": "Group Alpha", - "members": [{"userId": user_id_str}, {"userId": friend1_id_str}] + "_id": ObjectId(group1_id), + "name": "Group Alpha", + "members": [{"userId": user_id_str}, {"userId": friend1_id_str}], }, { - "_id": ObjectId(group2_id), "name": "Group Beta", - "members": [{"userId": user_id_str}, {"userId": friend1_id_str}, {"userId": friend2_id_str}] - } + "_id": ObjectId(group2_id), + "name": "Group Beta", + "members": [ + {"userId": user_id_str}, + {"userId": friend1_id_str}, + {"userId": friend2_id_str}, + ], + }, ] # Mocking settlement aggregations for each friend in each group @@ -1269,18 +1527,26 @@ def sync_mock_settlements_aggregate_cursor_factory(pipeline, *args, **kwargs): mock_agg_cursor = AsyncMock() if group_id_pipeline == group1_id and pipeline_friend_id == friend1_id_str: # Main owes Friend1 50 in Group Alpha - mock_agg_cursor.to_list.return_value = [{"_id": None, "userOwes": 50.0, "friendOwes": 0.0}] + mock_agg_cursor.to_list.return_value = [ + {"_id": None, "userOwes": 50.0, "friendOwes": 0.0} + ] elif group_id_pipeline == group2_id and pipeline_friend_id == friend1_id_str: # Friend1 owes Main 30 in Group Beta - mock_agg_cursor.to_list.return_value = [{"_id": None, "userOwes": 0.0, "friendOwes": 30.0}] + mock_agg_cursor.to_list.return_value = [ + {"_id": None, "userOwes": 0.0, "friendOwes": 30.0} + ] elif group_id_pipeline == group2_id and pipeline_friend_id == friend2_id_str: # Main owes Friend2 70 in Group Beta - mock_agg_cursor.to_list.return_value = [{"_id": None, "userOwes": 70.0, "friendOwes": 0.0}] + mock_agg_cursor.to_list.return_value = [ + {"_id": None, "userOwes": 70.0, "friendOwes": 0.0} + ] else: - mock_agg_cursor.to_list.return_value = [{"_id": None, "userOwes": 0.0, "friendOwes": 0.0}] # Default empty + mock_agg_cursor.to_list.return_value = [ + {"_id": None, "userOwes": 0.0, "friendOwes": 0.0} + ] # Default empty return mock_agg_cursor - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -1292,20 +1558,27 @@ def sync_mock_settlements_aggregate_cursor_factory(pipeline, *args, **kwargs): # Mock user name lookups # This side effect is for the users.find() call. It returns a cursor mock. def mock_user_find_cursor_side_effect(query, *args, **kwargs): - ids_in_query = query["_id"]["$in"] # These are already ObjectIds from the service + ids_in_query = query["_id"][ + "$in" + ] # These are already ObjectIds from the service users_to_return = [] - if friend1_id_obj in ids_in_query: users_to_return.append(mock_friend1_doc) - if friend2_id_obj in ids_in_query: users_to_return.append(mock_friend2_doc) + if friend1_id_obj in ids_in_query: + users_to_return.append(mock_friend1_doc) + if friend2_id_obj in ids_in_query: + users_to_return.append(mock_friend2_doc) cursor_mock = AsyncMock() cursor_mock.to_list = AsyncMock(return_value=users_to_return) return cursor_mock - mock_db.users.find = MagicMock(side_effect=mock_user_find_cursor_side_effect) + + mock_db.users.find = MagicMock( + side_effect=mock_user_find_cursor_side_effect) # Mock settlement aggregation logic # .aggregate() is sync, returns an async cursor. - mock_db.settlements.aggregate = MagicMock(side_effect=sync_mock_settlements_aggregate_cursor_factory) - + mock_db.settlements.aggregate = MagicMock( + side_effect=sync_mock_settlements_aggregate_cursor_factory + ) result = await expense_service.get_friends_balance_summary(user_id_str) @@ -1316,10 +1589,14 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): friends_balance = result["friendsBalance"] summary = result["summary"] - assert len(friends_balance) == 2 # Friend1 and Friend2 + assert len(friends_balance) == 2 # Friend1 and Friend2 - friend1_summary = next(f for f in friends_balance if f["userId"] == friend1_id_str) - friend2_summary = next(f for f in friends_balance if f["userId"] == friend2_id_str) + friend1_summary = next( + f for f in friends_balance if f["userId"] == friend1_id_str + ) + friend2_summary = next( + f for f in friends_balance if f["userId"] == friend2_id_str + ) # Friend1: owes Main 30 (Group Beta), Main owes Friend1 50 (Group Alpha) # Net for Friend1: Friend1 owes Main (30 - 50) = -20. So Main is owed 20 by Friend1. @@ -1329,7 +1606,9 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): # Group Beta: friendOwes (Friend1 to Main) = 30, userOwes (Main to Friend1) = 0. Balance = 30 - 0 = +30 (F1 owes Main 30) # Total for Friend1: Net Balance = -50 (from G1) + 30 (from G2) = -20. So Main User owes Friend1 20. assert friend1_summary["userName"] == "Friend One" - assert abs(friend1_summary["netBalance"] - (-20.0)) < 0.01 # Main owes Friend1 20 + assert ( + abs(friend1_summary["netBalance"] - (-20.0)) < 0.01 + ) # Main owes Friend1 20 assert friend1_summary["owesYou"] is False assert len(friend1_summary["breakdown"]) == 2 @@ -1337,13 +1616,14 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): # Group Beta: friendOwes (Friend2 to Main) = 0, userOwes (Main to Friend2) = 70. Balance = 0 - 70 = -70 # Total for Friend2: Net Balance = -70. So Main User owes Friend2 70. assert friend2_summary["userName"] == "Friend Two" - assert abs(friend2_summary["netBalance"] - (-70.0)) < 0.01 # Main owes Friend2 70 + assert ( + abs(friend2_summary["netBalance"] - (-70.0)) < 0.01 + ) # Main owes Friend2 70 assert friend2_summary["owesYou"] is False assert len(friend2_summary["breakdown"]) == 1 assert friend2_summary["breakdown"][0]["groupName"] == "Group Beta" assert abs(friend2_summary["breakdown"][0]["balance"] - (-70.0)) < 0.01 - # Summary: Main owes Friend1 20, Main owes Friend2 70. # totalOwedToYou = 0 # totalYouOwe = 20 (to F1) + 70 (to F2) = 90 @@ -1354,7 +1634,8 @@ def mock_user_find_cursor_side_effect(query, *args, **kwargs): assert summary["activeGroups"] == 2 # Verify mocks - mock_db.groups.find.assert_called_once_with({"members.userId": user_id_str}) + mock_db.groups.find.assert_called_once_with( + {"members.userId": user_id_str}) # settlements.aggregate is called for each friend in each group they share with user_id_str # Friend1 is in 2 groups with user_id_str, Friend2 is in 1 group with user_id_str. Total 3 calls. assert mock_db.settlements.aggregate.call_count == 3 @@ -1365,7 +1646,7 @@ async def test_get_friends_balance_summary_no_friends_or_groups(expense_service) """Test friends balance summary when user has no friends or no shared groups with balances""" user_id = "lonely_user" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -1378,9 +1659,13 @@ async def test_get_friends_balance_summary_no_friends_or_groups(expense_service) # However, if it were called, it should return a proper cursor. mock_user_find_cursor = AsyncMock() mock_user_find_cursor.to_list = AsyncMock(return_value=[]) - mock_db.users.find = MagicMock(return_value=mock_user_find_cursor) # find is sync, returns async cursor + mock_db.users.find = MagicMock( + return_value=mock_user_find_cursor + ) # find is sync, returns async cursor - mock_db.settlements.aggregate = AsyncMock() # Won't be called if no friends/groups + mock_db.settlements.aggregate = ( + AsyncMock() + ) # Won't be called if no friends/groups result = await expense_service.get_friends_balance_summary(user_id) @@ -1395,18 +1680,31 @@ async def test_get_friends_balance_summary_no_friends_or_groups(expense_service) # it would be mock_db.users.find.assert_called_once_with({'_id': {'$in': []}}) # For now, removing the assertion is fine as the main check is the summary. + @pytest.mark.asyncio async def test_get_overall_balance_summary_success(expense_service): """Test successful retrieval of overall balance summary for a user""" user_id = "user_test_overall" group1_id = str(ObjectId()) group2_id = str(ObjectId()) - group3_id = str(ObjectId()) # Group with zero balance for the user + group3_id = str(ObjectId()) # Group with zero balance for the user mock_groups_data = [ - {"_id": ObjectId(group1_id), "name": "Group One", "members": [{"userId": user_id}]}, - {"_id": ObjectId(group2_id), "name": "Group Two", "members": [{"userId": user_id}]}, - {"_id": ObjectId(group3_id), "name": "Group Three", "members": [{"userId": user_id}]} + { + "_id": ObjectId(group1_id), + "name": "Group One", + "members": [{"userId": user_id}], + }, + { + "_id": ObjectId(group2_id), + "name": "Group Two", + "members": [{"userId": user_id}], + }, + { + "_id": ObjectId(group3_id), + "name": "Group Three", + "members": [{"userId": user_id}], + }, ] # Mocking settlement aggregations for the user in each group @@ -1422,16 +1720,25 @@ def mock_aggregate_cursor_side_effect(pipeline, *args, **kwargs): cursor_mock = AsyncMock() if group_id_pipeline == group1_id: - cursor_mock.to_list = AsyncMock(return_value=[{"_id": None, "totalPaid": 100.0, "totalOwed": 20.0}]) + cursor_mock.to_list = AsyncMock( + return_value=[ + {"_id": None, "totalPaid": 100.0, "totalOwed": 20.0}] + ) elif group_id_pipeline == group2_id: - cursor_mock.to_list = AsyncMock(return_value=[{"_id": None, "totalPaid": 50.0, "totalOwed": 150.0}]) - elif group_id_pipeline == group3_id: # Zero balance - cursor_mock.to_list = AsyncMock(return_value=[{"_id": None, "totalPaid": 50.0, "totalOwed": 50.0}]) - else: # Should not happen in this test + cursor_mock.to_list = AsyncMock( + return_value=[ + {"_id": None, "totalPaid": 50.0, "totalOwed": 150.0}] + ) + elif group_id_pipeline == group3_id: # Zero balance + cursor_mock.to_list = AsyncMock( + return_value=[ + {"_id": None, "totalPaid": 50.0, "totalOwed": 50.0}] + ) + else: # Should not happen in this test cursor_mock.to_list = AsyncMock(return_value=[]) return cursor_mock - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db @@ -1442,7 +1749,9 @@ def mock_aggregate_cursor_side_effect(pipeline, *args, **kwargs): # Mock settlement aggregation # .aggregate() is a sync method returning an async cursor - mock_db.settlements.aggregate = MagicMock(side_effect=mock_aggregate_cursor_side_effect) + mock_db.settlements.aggregate = MagicMock( + side_effect=mock_aggregate_cursor_side_effect + ) result = await expense_service.get_overall_balance_summary(user_id) @@ -1460,8 +1769,12 @@ def mock_aggregate_cursor_side_effect(pipeline, *args, **kwargs): # Group three had zero balance, so it should not be in groupsSummary assert len(result["groupsSummary"]) == 2 - group1_summary = next(g for g in result["groupsSummary"] if g["group_id"] == group1_id) - group2_summary = next(g for g in result["groupsSummary"] if g["group_id"] == group2_id) + group1_summary = next( + g for g in result["groupsSummary"] if g["group_id"] == group1_id + ) + group2_summary = next( + g for g in result["groupsSummary"] if g["group_id"] == group2_id + ) assert group1_summary["group_name"] == "Group One" assert abs(group1_summary["yourBalanceInGroup"] - 80.0) < 0.01 @@ -1470,23 +1783,25 @@ def mock_aggregate_cursor_side_effect(pipeline, *args, **kwargs): assert abs(group2_summary["yourBalanceInGroup"] - (-100.0)) < 0.01 # Verify mocks - mock_db.groups.find.assert_called_once_with({"members.userId": user_id}) - assert mock_db.settlements.aggregate.call_count == 3 # Called for each group + mock_db.groups.find.assert_called_once_with( + {"members.userId": user_id}) + assert mock_db.settlements.aggregate.call_count == 3 # Called for each group + @pytest.mark.asyncio async def test_get_overall_balance_summary_no_groups(expense_service): """Test overall balance summary when user is in no groups""" user_id = "user_no_groups" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db mock_groups_cursor = AsyncMock() - mock_groups_cursor.to_list.return_value = [] # No groups + mock_groups_cursor.to_list.return_value = [] # No groups mock_db.groups.find.return_value = mock_groups_cursor - mock_db.settlements.aggregate = AsyncMock() # Should not be called + mock_db.settlements.aggregate = AsyncMock() # Should not be called result = await expense_service.get_overall_balance_summary(user_id) @@ -1496,13 +1811,15 @@ async def test_get_overall_balance_summary_no_groups(expense_service): assert len(result["groupsSummary"]) == 0 mock_db.settlements.aggregate.assert_not_called() + @pytest.mark.asyncio async def test_get_group_analytics_success(expense_service, mock_group_data): """Test successful retrieval of group analytics""" - group_id_str = str(mock_group_data["_id"]) # Changed variable name for clarity - user_a_obj = ObjectId() # This is the user making the request and also a member + group_id_str = str(mock_group_data["_id"] + ) # Changed variable name for clarity + user_a_obj = ObjectId() # This is the user making the request and also a member user_b_obj = ObjectId() - user_c_obj = ObjectId() # In group but no expenses + user_c_obj = ObjectId() # In group but no expenses user_a_str = str(user_a_obj) user_b_str = str(user_b_obj) user_c_str = str(user_c_obj) @@ -1520,17 +1837,31 @@ async def test_get_group_analytics_success(expense_service, mock_group_data): expense2_date = datetime(year, month, 15, tzinfo=timezone.utc) mock_expenses_in_period = [ { - "_id": ObjectId(), "groupId": group_id_str, "createdBy": user_a_str, - "description": "Groceries", "amount": 70.0, "tags": ["food", "household"], - "splits": [{"userId": user_a_str, "amount": 35.0}, {"userId": user_b_str, "amount": 35.0}], - "createdAt": expense1_date + "_id": ObjectId(), + "groupId": group_id_str, + "createdBy": user_a_str, + "description": "Groceries", + "amount": 70.0, + "tags": ["food", "household"], + "splits": [ + {"userId": user_a_str, "amount": 35.0}, + {"userId": user_b_str, "amount": 35.0}, + ], + "createdAt": expense1_date, }, { - "_id": ObjectId(), "groupId": group_id_str, "createdBy": user_b_str, - "description": "Movies", "amount": 30.0, "tags": ["entertainment", "food"], - "splits": [{"userId": user_a_str, "amount": 15.0}, {"userId": user_b_str, "amount": 15.0}], - "createdAt": expense2_date - } + "_id": ObjectId(), + "groupId": group_id_str, + "createdBy": user_b_str, + "description": "Movies", + "amount": 30.0, + "tags": ["entertainment", "food"], + "splits": [ + {"userId": user_a_str, "amount": 15.0}, + {"userId": user_b_str, "amount": 15.0}, + ], + "createdAt": expense2_date, + }, ] # Mock user data for member contributions @@ -1539,10 +1870,13 @@ async def test_get_group_analytics_success(expense_service, mock_group_data): mock_user_c_doc_db = {"_id": user_c_obj, "name": "User C"} async def mock_users_find_one_side_effect(query, *args, **kwargs): - user_id_query_obj = query["_id"] # This should be an ObjectId - if user_id_query_obj == user_a_obj: return mock_user_a_doc_db - if user_id_query_obj == user_b_obj: return mock_user_b_doc_db - if user_id_query_obj == user_c_obj: return mock_user_c_doc_db + user_id_query_obj = query["_id"] # This should be an ObjectId + if user_id_query_obj == user_a_obj: + return mock_user_a_doc_db + if user_id_query_obj == user_b_obj: + return mock_user_b_doc_db + if user_id_query_obj == user_c_obj: + return mock_user_c_doc_db return None # Adjust mock_group_data to ensure its members list matches what the service method expects @@ -1552,34 +1886,39 @@ async def mock_users_find_one_side_effect(query, *args, **kwargs): # Let's redefine mock_group_data for this specific test to ensure consistency. current_test_mock_group_data = { - "_id": ObjectId(group_id_str), # Use the same ObjectId as in the service call + # Use the same ObjectId as in the service call + "_id": ObjectId(group_id_str), "name": "Test Group Analytics", "members": [ {"userId": user_a_str, "role": "admin"}, {"userId": user_b_str, "role": "member"}, - {"userId": user_c_str, "role": "member"} - ] + {"userId": user_c_str, "role": "member"}, + ], } - - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db # Mock group membership check - mock_db.groups.find_one = AsyncMock(return_value=current_test_mock_group_data) # Use the adjusted mock + mock_db.groups.find_one = AsyncMock( + return_value=current_test_mock_group_data + ) # Use the adjusted mock # Mock expenses find for the period mock_expenses_cursor = AsyncMock() mock_expenses_cursor.to_list.return_value = mock_expenses_in_period mock_db.expenses.find.return_value = mock_expenses_cursor # Mock user lookups for member names - mock_db.users.find_one = AsyncMock(side_effect=mock_users_find_one_side_effect) + mock_db.users.find_one = AsyncMock( + side_effect=mock_users_find_one_side_effect) - result = await expense_service.get_group_analytics(group_id_str, user_a_str, period="month", year=year, month=month) + result = await expense_service.get_group_analytics( + group_id_str, user_a_str, period="month", year=year, month=month + ) assert result is not None assert result["period"] == f"{year}-{month:02d}" - assert abs(result["totalExpenses"] - 100.0) < 0.01 # 70 + 30 + assert abs(result["totalExpenses"] - 100.0) < 0.01 # 70 + 30 assert result["expenseCount"] == 2 assert abs(result["avgExpenseAmount"] - 50.0) < 0.01 @@ -1589,20 +1928,33 @@ async def mock_users_find_one_side_effect(query, *args, **kwargs): # household: 70 # entertainment: 30 food_cat = next(c for c in top_categories if c["tag"] == "food") - household_cat = next(c for c in top_categories if c["tag"] == "household") - entertainment_cat = next(c for c in top_categories if c["tag"] == "entertainment") + household_cat = next( + c for c in top_categories if c["tag"] == "household") + entertainment_cat = next( + c for c in top_categories if c["tag"] == "entertainment" + ) - assert abs(food_cat["amount"] - 100.0) < 0.01 and food_cat["count"] == 2 - assert abs(household_cat["amount"] - 70.0) < 0.01 and household_cat["count"] == 1 - assert abs(entertainment_cat["amount"] - 30.0) < 0.01 and entertainment_cat["count"] == 1 + assert abs(food_cat["amount"] - + 100.0) < 0.01 and food_cat["count"] == 2 + assert ( + abs(household_cat["amount"] - + 70.0) < 0.01 and household_cat["count"] == 1 + ) + assert ( + abs(entertainment_cat["amount"] - 30.0) < 0.01 + and entertainment_cat["count"] == 1 + ) assert "memberContributions" in result member_contribs = result["memberContributions"] - assert len(member_contribs) == 3 # user_a_str, user_b_str, user_c_str + assert len(member_contribs) == 3 # user_a_str, user_b_str, user_c_str - user_a_contrib = next(m for m in member_contribs if m["userId"] == user_a_str) - user_b_contrib = next(m for m in member_contribs if m["userId"] == user_b_str) - user_c_contrib = next(m for m in member_contribs if m["userId"] == user_c_str) + user_a_contrib = next( + m for m in member_contribs if m["userId"] == user_a_str) + user_b_contrib = next( + m for m in member_contribs if m["userId"] == user_b_str) + user_c_contrib = next( + m for m in member_contribs if m["userId"] == user_c_str) # User A: Paid 70 (Groceries). Owed 35 (Groceries) + 15 (Movies) = 50. Net = 70 - 50 = 20 assert user_a_contrib["userName"] == "User A" @@ -1624,31 +1976,42 @@ async def mock_users_find_one_side_effect(query, *args, **kwargs): assert "expenseTrends" in result # Should have entries for each day in the month. Check a couple. - assert len(result["expenseTrends"]) >= 28 # Days in Oct - day5_trend = next(d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-05") - assert abs(day5_trend["amount"] - 70.0) < 0.01 and day5_trend["count"] == 1 - day15_trend = next(d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-15") - assert abs(day15_trend["amount"] - 30.0) < 0.01 and day15_trend["count"] == 1 - day10_trend = next(d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-10") # No expense + assert len(result["expenseTrends"]) >= 28 # Days in Oct + day5_trend = next( + d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-05" + ) + assert abs(day5_trend["amount"] - + 70.0) < 0.01 and day5_trend["count"] == 1 + day15_trend = next( + d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-15" + ) + assert abs(day15_trend["amount"] - + 30.0) < 0.01 and day15_trend["count"] == 1 + day10_trend = next( + d for d in result["expenseTrends"] if d["date"] == f"{year}-{month:02d}-10" + ) # No expense assert day10_trend["amount"] == 0 and day10_trend["count"] == 0 # Verify mocks mock_db.groups.find_one.assert_called_once() mock_db.expenses.find.assert_called_once() # users.find_one called for each member in current_test_mock_group_data["members"] - assert mock_db.users.find_one.call_count == len(current_test_mock_group_data["members"]) + assert mock_db.users.find_one.call_count == len( + current_test_mock_group_data["members"] + ) @pytest.mark.asyncio async def test_get_group_analytics_group_not_found(expense_service): """Test get group analytics when group not found or user not member""" - group_id = str(ObjectId()) # Valid format + group_id = str(ObjectId()) # Valid format user_id = "user_a" - with patch('app.expenses.service.mongodb') as mock_mongodb: + with patch("app.expenses.service.mongodb") as mock_mongodb: mock_db = MagicMock() mock_mongodb.database = mock_db - mock_db.groups.find_one = AsyncMock(return_value=None) # Group not found + mock_db.groups.find_one = AsyncMock( + return_value=None) # Group not found with pytest.raises(ValueError, match="Group not found or user not a member"): await expense_service.get_group_analytics(group_id, user_id) @@ -1656,5 +2019,6 @@ async def test_get_group_analytics_group_not_found(expense_service): mock_db.expenses.find.assert_not_called() mock_db.users.find_one.assert_not_called() + if __name__ == "__main__": pytest.main([__file__]) diff --git a/backend/tests/groups/test_groups_routes.py b/backend/tests/groups/test_groups_routes.py index ef61a645..c629fb4e 100644 --- a/backend/tests/groups/test_groups_routes.py +++ b/backend/tests/groups/test_groups_routes.py @@ -1,26 +1,31 @@ +from datetime import datetime, timedelta +from unittest.mock import patch + import pytest -from httpx import AsyncClient, ASGITransport +from app.auth.security import create_access_token from fastapi import status +from httpx import ASGITransport, AsyncClient from main import app -from app.auth.security import create_access_token -from datetime import datetime, timedelta -from unittest.mock import patch # Sample user data for testing TEST_USER_ID = "60c72b2f9b1e8a3f9c8b4567" TEST_USER_EMAIL = "testuser@example.com" + @pytest.fixture def auth_headers(): token = create_access_token( data={"sub": TEST_USER_EMAIL, "_id": TEST_USER_ID}, - expires_delta=timedelta(minutes=15) + expires_delta=timedelta(minutes=15), ) return {"Authorization": f"Bearer {token}"} + @pytest.fixture async def async_client(): - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac: + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: yield ac @@ -28,13 +33,12 @@ class TestGroupsRoutes: """Test cases for Groups API endpoints""" @pytest.mark.asyncio - async def test_create_group_success(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_create_group_success( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test successful group creation""" - group_data = { - "name": "Test Group", - "currency": "USD" - } - + group_data = {"name": "Test Group", "currency": "USD"} + with patch("app.groups.service.group_service.create_group") as mock_create: mock_create.return_value = { "_id": "642f1e4a9b3c2d1f6a1b2c3d", @@ -44,15 +48,19 @@ async def test_create_group_success(self, async_client: AsyncClient, auth_header "createdBy": "user123", "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, - "members": [{ - "userId": "user123", - "role": "admin", - "joinedAt": "2023-01-01T00:00:00Z" - }] + "members": [ + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + } + ], } - - response = await async_client.post("/groups", json=group_data, headers=auth_headers) - + + response = await async_client.post( + "/groups", json=group_data, headers=auth_headers + ) + assert response.status_code == status.HTTP_201_CREATED data = response.json() assert data["name"] == "Test Group" @@ -60,43 +68,50 @@ async def test_create_group_success(self, async_client: AsyncClient, auth_header assert "joinCode" in data @pytest.mark.asyncio - async def test_create_group_empty_name(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_create_group_empty_name( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test group creation with empty name""" - group_data = { - "name": "", - "currency": "USD" - } - - response = await async_client.post("/groups", json=group_data, headers=auth_headers) + group_data = {"name": "", "currency": "USD"} + + response = await async_client.post( + "/groups", json=group_data, headers=auth_headers + ) assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY @pytest.mark.asyncio - async def test_list_user_groups(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_list_user_groups( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test listing user groups""" with patch("app.groups.service.group_service.get_user_groups") as mock_get: - mock_get.return_value = [{ - "_id": "642f1e4a9b3c2d1f6a1b2c3d", - "name": "Test Group", - "currency": "USD", - "joinCode": "ABC123", - "createdBy": "user123", - "createdAt": "2023-01-01T00:00:00Z", - "imageUrl": None, - "members": [] - }] - + mock_get.return_value = [ + { + "_id": "642f1e4a9b3c2d1f6a1b2c3d", + "name": "Test Group", + "currency": "USD", + "joinCode": "ABC123", + "createdBy": "user123", + "createdAt": "2023-01-01T00:00:00Z", + "imageUrl": None, + "members": [], + } + ] + response = await async_client.get("/groups", headers=auth_headers) - + assert response.status_code == status.HTTP_200_OK data = response.json() assert "groups" in data assert len(data["groups"]) == 1 @pytest.mark.asyncio - async def test_get_group_details(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_get_group_details( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test getting group details""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" - + with patch("app.groups.service.group_service.get_group_by_id") as mock_get: mock_get.return_value = { "_id": group_id, @@ -106,33 +121,41 @@ async def test_get_group_details(self, async_client: AsyncClient, auth_headers, "createdBy": "user123", "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, - "members": [] + "members": [], } - - response = await async_client.get(f"/groups/{group_id}", headers=auth_headers) - + + response = await async_client.get( + f"/groups/{group_id}", headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert data["_id"] == group_id @pytest.mark.asyncio - async def test_get_group_not_found(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_get_group_not_found( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test getting non-existent group""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" - + with patch("app.groups.service.group_service.get_group_by_id") as mock_get: mock_get.return_value = None - - response = await async_client.get(f"/groups/{group_id}", headers=auth_headers) - + + response = await async_client.get( + f"/groups/{group_id}", headers=auth_headers + ) + assert response.status_code == status.HTTP_404_NOT_FOUND @pytest.mark.asyncio - async def test_update_group_metadata(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_update_group_metadata( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test updating group metadata""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" update_data = {"name": "Updated Group Name"} - + with patch("app.groups.service.group_service.update_group") as mock_update: mock_update.return_value = { "_id": group_id, @@ -142,11 +165,13 @@ async def test_update_group_metadata(self, async_client: AsyncClient, auth_heade "createdBy": "user123", "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, - "members": [] + "members": [], } - - response = await async_client.patch(f"/groups/{group_id}", json=update_data, headers=auth_headers) - + + response = await async_client.patch( + f"/groups/{group_id}", json=update_data, headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert data["name"] == "Updated Group Name" @@ -155,21 +180,25 @@ async def test_update_group_metadata(self, async_client: AsyncClient, auth_heade async def test_delete_group(self, async_client: AsyncClient, auth_headers, mock_db): """Test deleting a group""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" - + with patch("app.groups.service.group_service.delete_group") as mock_delete: mock_delete.return_value = True - - response = await async_client.delete(f"/groups/{group_id}", headers=auth_headers) - + + response = await async_client.delete( + f"/groups/{group_id}", headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert data["success"] is True @pytest.mark.asyncio - async def test_join_group_by_code(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_join_group_by_code( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test joining a group by code""" join_data = {"joinCode": "ABC123"} - + with patch("app.groups.service.group_service.join_group_by_code") as mock_join: mock_join.return_value = { "_id": "642f1e4a9b3c2d1f6a1b2c3d", @@ -180,13 +209,23 @@ async def test_join_group_by_code(self, async_client: AsyncClient, auth_headers, "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - - response = await async_client.post("/groups/join", json=join_data, headers=auth_headers) - + + response = await async_client.post( + "/groups/join", json=join_data, headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert "group" in data @@ -195,64 +234,88 @@ async def test_join_group_by_code(self, async_client: AsyncClient, auth_headers, async def test_leave_group(self, async_client: AsyncClient, auth_headers, mock_db): """Test leaving a group""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" - + with patch("app.groups.service.group_service.leave_group") as mock_leave: mock_leave.return_value = True - - response = await async_client.post(f"/groups/{group_id}/leave", headers=auth_headers) - + + response = await async_client.post( + f"/groups/{group_id}/leave", headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert data["success"] is True @pytest.mark.asyncio - async def test_get_group_members(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_get_group_members( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test getting group members""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" - - with patch("app.groups.service.group_service.get_group_members") as mock_get_members: + + with patch( + "app.groups.service.group_service.get_group_members" + ) as mock_get_members: mock_get_members.return_value = [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, ] - - response = await async_client.get(f"/groups/{group_id}/members", headers=auth_headers) - + + response = await async_client.get( + f"/groups/{group_id}/members", headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert len(data) == 2 @pytest.mark.asyncio - async def test_update_member_role(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_update_member_role( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test updating member role""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" member_id = "user456" role_data = {"role": "admin"} - - with patch("app.groups.service.group_service.update_member_role") as mock_update_role: + + with patch( + "app.groups.service.group_service.update_member_role" + ) as mock_update_role: mock_update_role.return_value = True - + response = await async_client.patch( - f"/groups/{group_id}/members/{member_id}", + f"/groups/{group_id}/members/{member_id}", json=role_data, - headers=auth_headers + headers=auth_headers, ) - + assert response.status_code == status.HTTP_200_OK data = response.json() assert "message" in data @pytest.mark.asyncio - async def test_remove_member(self, async_client: AsyncClient, auth_headers, mock_db): + async def test_remove_member( + self, async_client: AsyncClient, auth_headers, mock_db + ): """Test removing a member from group""" group_id = "642f1e4a9b3c2d1f6a1b2c3d" member_id = "user456" - + with patch("app.groups.service.group_service.remove_member") as mock_remove: mock_remove.return_value = True - - response = await async_client.delete(f"/groups/{group_id}/members/{member_id}", headers=auth_headers) - + + response = await async_client.delete( + f"/groups/{group_id}/members/{member_id}", headers=auth_headers + ) + assert response.status_code == status.HTTP_200_OK data = response.json() assert data["success"] is True diff --git a/backend/tests/groups/test_groups_service.py b/backend/tests/groups/test_groups_service.py index 4d3d1178..0bc9daa3 100644 --- a/backend/tests/groups/test_groups_service.py +++ b/backend/tests/groups/test_groups_service.py @@ -1,8 +1,9 @@ +from unittest.mock import AsyncMock, MagicMock, patch + import pytest -from unittest.mock import AsyncMock, patch, MagicMock -from fastapi import HTTPException -from bson import ObjectId from app.groups.service import GroupService +from bson import ObjectId +from fastapi import HTTPException class TestGroupService: @@ -29,11 +30,11 @@ def test_transform_group_document(self): "createdBy": "user123", "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, - "members": [] + "members": [], } - + result = self.service.transform_group_document(group_doc) - + assert result["_id"] == "642f1e4a9b3c2d1f6a1b2c3d" assert result["name"] == "Test Group" assert result["currency"] == "USD" @@ -50,15 +51,15 @@ async def test_create_group_success(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + # Mock find_one to return None (no existing join code) mock_collection.find_one.return_value = None - + # Mock insert_one mock_result = MagicMock() mock_result.inserted_id = ObjectId("642f1e4a9b3c2d1f6a1b2c3d") mock_collection.insert_one.return_value = mock_result - + # Mock find_one for created group created_group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), @@ -68,20 +69,21 @@ async def test_create_group_success(self): "createdBy": "user123", "createdAt": "2023-01-01T00:00:00Z", "imageUrl": None, - "members": [{ - "userId": "user123", - "role": "admin", - "joinedAt": "2023-01-01T00:00:00Z" - }] + "members": [ + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + } + ], } mock_collection.find_one.side_effect = [None, created_group] - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): result = await self.service.create_group( - {"name": "Test Group", "currency": "USD"}, - "user123" + {"name": "Test Group", "currency": "USD"}, "user123" ) - + assert result["name"] == "Test Group" assert result["currency"] == "USD" assert "joinCode" in result @@ -92,32 +94,34 @@ async def test_get_user_groups(self): mock_db = MagicMock() # Use MagicMock instead of AsyncMock mock_collection = MagicMock() # Use MagicMock instead of AsyncMock mock_db.groups = mock_collection - + # Mock groups data - mock_groups = [{ - "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), - "name": "Test Group", - "currency": "USD", - "joinCode": "ABC123", - "createdBy": "user123", - "createdAt": "2023-01-01T00:00:00Z", - "imageUrl": None, - "members": [] - }] - + mock_groups = [ + { + "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), + "name": "Test Group", + "currency": "USD", + "joinCode": "ABC123", + "createdBy": "user123", + "createdAt": "2023-01-01T00:00:00Z", + "imageUrl": None, + "members": [], + } + ] + # Create a proper async iterator mock async def mock_async_iter(): for group in mock_groups: yield group - + # Mock cursor with proper __aiter__ method mock_cursor = MagicMock() mock_cursor.__aiter__ = lambda self: mock_async_iter() mock_collection.find.return_value = mock_cursor - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): result = await self.service.get_user_groups("user123") - + assert len(result) == 1 assert result[0]["name"] == "Test Group" @@ -127,34 +131,44 @@ async def test_join_group_by_code_success(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + existing_group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "joinCode": "ABC123", - "members": [{ - "userId": "user123", - "role": "admin", - "joinedAt": "2023-01-01T00:00:00Z" - }] + "members": [ + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + } + ], } - + updated_group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "joinCode": "ABC123", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + mock_collection.find_one.return_value = existing_group mock_collection.find_one_and_update.return_value = updated_group - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): result = await self.service.join_group_by_code("ABC123", "user456") - + assert result is not None assert len(result["members"]) == 2 @@ -164,13 +178,13 @@ async def test_join_group_invalid_code(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + mock_collection.find_one.return_value = None - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: await self.service.join_group_by_code("INVALID", "user456") - + assert exc_info.value.status_code == 404 assert "Invalid join code" in str(exc_info.value.detail) @@ -180,24 +194,26 @@ async def test_join_group_already_member(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + existing_group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "joinCode": "ABC123", - "members": [{ - "userId": "user456", - "role": "member", - "joinedAt": "2023-01-01T00:00:00Z" - }] + "members": [ + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + } + ], } - + mock_collection.find_one.return_value = existing_group - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: await self.service.join_group_by_code("ABC123", "user456") - + assert exc_info.value.status_code == 400 assert "already a member" in str(exc_info.value.detail) @@ -207,13 +223,15 @@ async def test_update_group_not_admin(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + mock_collection.find_one.return_value = None # User not admin - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: - await self.service.update_group("642f1e4a9b3c2d1f6a1b2c3d", {"name": "New Name"}, "user456") - + await self.service.update_group( + "642f1e4a9b3c2d1f6a1b2c3d", {"name": "New Name"}, "user456" + ) + assert exc_info.value.status_code == 403 assert "Only group admins" in str(exc_info.value.detail) @@ -223,25 +241,37 @@ async def test_update_member_role_prevent_last_admin_demotion(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + # Mock group with only one admin group_with_one_admin = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + mock_collection.find_one.return_value = group_with_one_admin - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: - await self.service.update_member_role("642f1e4a9b3c2d1f6a1b2c3d", "user123", "member", "user123") - + await self.service.update_member_role( + "642f1e4a9b3c2d1f6a1b2c3d", "user123", "member", "user123" + ) + assert exc_info.value.status_code == 400 - assert "Cannot demote yourself when you are the only admin" in str(exc_info.value.detail) + assert "Cannot demote yourself when you are the only admin" in str( + exc_info.value.detail + ) @pytest.mark.asyncio async def test_update_member_role_allow_admin_demotion_with_other_admins(self): @@ -249,26 +279,40 @@ async def test_update_member_role_allow_admin_demotion_with_other_admins(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + # Mock group with multiple admins group_with_multiple_admins = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user789", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user789", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + mock_collection.find_one.return_value = group_with_multiple_admins mock_result = MagicMock() mock_result.modified_count = 1 mock_collection.update_one.return_value = mock_result - - with patch.object(self.service, 'get_db', return_value=mock_db): - result = await self.service.update_member_role("642f1e4a9b3c2d1f6a1b2c3d", "user123", "member", "user123") - + + with patch.object(self.service, "get_db", return_value=mock_db): + result = await self.service.update_member_role( + "642f1e4a9b3c2d1f6a1b2c3d", "user123", "member", "user123" + ) + assert result is True @pytest.mark.asyncio @@ -277,14 +321,16 @@ async def test_remove_member_group_not_found(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + # Mock no group found for admin check and no group exists at all mock_collection.find_one.side_effect = [None, None] - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: - await self.service.remove_member("642f1e4a9b3c2d1f6a1b2c3d", "user456", "user123") - + await self.service.remove_member( + "642f1e4a9b3c2d1f6a1b2c3d", "user456", "user123" + ) + assert exc_info.value.status_code == 404 assert "Group not found" in str(exc_info.value.detail) @@ -294,25 +340,36 @@ async def test_remove_member_user_not_admin_but_group_exists(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + existing_group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + # First call returns None (user not admin), second call returns the group (group exists) mock_collection.find_one.side_effect = [None, existing_group] - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: - await self.service.remove_member("642f1e4a9b3c2d1f6a1b2c3d", "user456", "user789") # user789 is not admin - + await self.service.remove_member( + "642f1e4a9b3c2d1f6a1b2c3d", "user456", "user789" + ) # user789 is not admin + assert exc_info.value.status_code == 403 - assert "Only group admins can remove members" in str(exc_info.value.detail) + assert "Only group admins can remove members" in str( + exc_info.value.detail) @pytest.mark.asyncio async def test_leave_group_prevent_last_admin(self): @@ -320,25 +377,35 @@ async def test_leave_group_prevent_last_admin(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + # Mock group with only one admin group_with_one_admin = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + mock_collection.find_one.return_value = group_with_one_admin - - with patch.object(self.service, 'get_db', return_value=mock_db): + + with patch.object(self.service, "get_db", return_value=mock_db): with pytest.raises(HTTPException) as exc_info: await self.service.leave_group("642f1e4a9b3c2d1f6a1b2c3d", "user123") - + assert exc_info.value.status_code == 400 - assert "Cannot leave group when you are the only admin" in str(exc_info.value.detail) + assert "Cannot leave group when you are the only admin" in str( + exc_info.value.detail + ) @pytest.mark.asyncio async def test_leave_group_allow_member_to_leave(self): @@ -346,22 +413,32 @@ async def test_leave_group_allow_member_to_leave(self): mock_db = AsyncMock() mock_collection = AsyncMock() mock_db.groups = mock_collection - + group = { "_id": ObjectId("642f1e4a9b3c2d1f6a1b2c3d"), "name": "Test Group", "members": [ - {"userId": "user123", "role": "admin", "joinedAt": "2023-01-01T00:00:00Z"}, - {"userId": "user456", "role": "member", "joinedAt": "2023-01-01T00:00:00Z"} - ] + { + "userId": "user123", + "role": "admin", + "joinedAt": "2023-01-01T00:00:00Z", + }, + { + "userId": "user456", + "role": "member", + "joinedAt": "2023-01-01T00:00:00Z", + }, + ], } - + mock_collection.find_one.return_value = group mock_result = MagicMock() mock_result.modified_count = 1 mock_collection.update_one.return_value = mock_result - - with patch.object(self.service, 'get_db', return_value=mock_db): - result = await self.service.leave_group("642f1e4a9b3c2d1f6a1b2c3d", "user456") - + + with patch.object(self.service, "get_db", return_value=mock_db): + result = await self.service.leave_group( + "642f1e4a9b3c2d1f6a1b2c3d", "user456" + ) + assert result is True diff --git a/backend/tests/logger/test_logger.py b/backend/tests/logger/test_logger.py index 3a166cc1..3cdfe84a 100644 --- a/backend/tests/logger/test_logger.py +++ b/backend/tests/logger/test_logger.py @@ -1,13 +1,14 @@ -import pytest import logging -from app.config import logger,LOGGING_CONFIG from logging.config import dictConfig -from app.config import RequestResponseLoggingMiddleware -from fastapi.testclient import TestClient + +import pytest +from app.config import LOGGING_CONFIG, RequestResponseLoggingMiddleware, logger from fastapi import FastAPI +from fastapi.testclient import TestClient dictConfig(LOGGING_CONFIG) + def test_logger_init(): assert logger is not None assert isinstance(logger, logging.Logger) @@ -17,25 +18,30 @@ def test_logger_init(): assert logger.isEnabledFor(logging.ERROR) assert not logger.isEnabledFor(logging.DEBUG) + def test_logger_logs_info(caplog): with caplog.at_level(logging.INFO): logger.info("Test info message") assert "Test info message" in caplog.text + def test_logger_logs_debug(caplog): with caplog.at_level(logging.DEBUG): logging.debug("Test debug message") assert "Test debug message" in caplog.text + def test_logger_logs_error(caplog): with caplog.at_level(logging.ERROR): logger.error("Test error message") assert "Test error message" in caplog.text + def test_logger_logs_warning(caplog): with caplog.at_level(logging.WARNING): logging.warning("Test warning message") - assert "Test warning message" in caplog.text + assert "Test warning message" in caplog.text + @pytest.mark.asyncio async def test_request_response_logging_middleware_logs(caplog): @@ -45,13 +51,13 @@ async def test_request_response_logging_middleware_logs(caplog): @app.get("/test") async def test_endpoint(): - return {"message":"Test message"} + return {"message": "Test message"} client = TestClient(app) with caplog.at_level(logging.INFO): response = client.get("/test") - + assert response.status_code == 200 assert "Incoming request: GET http://testserver/test" in caplog.text - assert "Response status: 200 for GET http://testserver/test" in caplog.text \ No newline at end of file + assert "Response status: 200 for GET http://testserver/test" in caplog.text diff --git a/backend/tests/user/test_user_routes.py b/backend/tests/user/test_user_routes.py index be2d59a7..17e13c41 100644 --- a/backend/tests/user/test_user_routes.py +++ b/backend/tests/user/test_user_routes.py @@ -1,56 +1,71 @@ +from datetime import datetime, timedelta + import pytest -from fastapi.testclient import TestClient +from app.auth.security import create_access_token from fastapi import status +from fastapi.testclient import TestClient from main import app -from app.auth.security import create_access_token -from datetime import datetime, timedelta # Sample user data for testing TEST_USER_ID = "60c72b2f9b1e8a3f9c8b4567" TEST_USER_EMAIL = "testuser@example.com" + @pytest.fixture(scope="module") def client(): with TestClient(app) as c: yield c + @pytest.fixture(scope="module") def auth_headers(): token = create_access_token( data={"sub": TEST_USER_EMAIL, "_id": TEST_USER_ID}, - expires_delta=timedelta(minutes=15) + expires_delta=timedelta(minutes=15), ) return {"Authorization": f"Bearer {token}"} + @pytest.fixture(autouse=True, scope="function") async def setup_test_user(mocker): iso_date = "2023-01-01T00:00:00Z" iso_date2 = "2023-01-02T00:00:00Z" iso_date3 = "2023-01-03T00:00:00Z" - mocker.patch("app.user.service.user_service.get_user_by_id", return_value={ - "id": TEST_USER_ID, - "name": "Test User", - "email": TEST_USER_EMAIL, - "imageUrl": None, - "currency": "USD", - "createdAt": iso_date, - "updatedAt": iso_date - }) - mocker.patch("app.user.service.user_service.update_user_profile", return_value={ - "id": TEST_USER_ID, - "name": "Updated Test User", - "email": TEST_USER_EMAIL, - "imageUrl": "http://example.com/avatar.png", - "currency": "EUR", - "createdAt": iso_date, - "updatedAt": iso_date2 - }) - mocker.patch("app.user.service.user_service.delete_user", return_value=True) + mocker.patch( + "app.user.service.user_service.get_user_by_id", + return_value={ + "id": TEST_USER_ID, + "name": "Test User", + "email": TEST_USER_EMAIL, + "imageUrl": None, + "currency": "USD", + "createdAt": iso_date, + "updatedAt": iso_date, + }, + ) + mocker.patch( + "app.user.service.user_service.update_user_profile", + return_value={ + "id": TEST_USER_ID, + "name": "Updated Test User", + "email": TEST_USER_EMAIL, + "imageUrl": "http://example.com/avatar.png", + "currency": "EUR", + "createdAt": iso_date, + "updatedAt": iso_date2, + }, + ) + mocker.patch("app.user.service.user_service.delete_user", + return_value=True) yield + # --- Tests for GET /users/me --- -def test_get_current_user_profile_success(client: TestClient, auth_headers: dict, mocker): + +def test_get_current_user_profile_success( + client: TestClient, auth_headers: dict, mocker +): """Test successful retrieval of current user's profile.""" response = client.get("/users/me", headers=auth_headers) assert response.status_code == status.HTTP_200_OK @@ -62,23 +77,32 @@ def test_get_current_user_profile_success(client: TestClient, auth_headers: dict assert "createdAt" in data and data["createdAt"].endswith("Z") assert "updatedAt" in data and data["updatedAt"].endswith("Z") -def test_get_current_user_profile_not_found(client: TestClient, auth_headers: dict, mocker): + +def test_get_current_user_profile_not_found( + client: TestClient, auth_headers: dict, mocker +): """Test retrieval when user is not found in service layer.""" - mocker.patch("app.user.service.user_service.get_user_by_id", return_value=None) + mocker.patch("app.user.service.user_service.get_user_by_id", + return_value=None) response = client.get("/users/me", headers=auth_headers) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.json() == {"detail": {"error": "NotFound", "message": "User not found"}} + assert response.json() == { + "detail": {"error": "NotFound", "message": "User not found"} + } + # --- Tests for PATCH /users/me --- + def test_update_user_profile_success(client: TestClient, auth_headers: dict, mocker): """Test successful update of user profile.""" update_payload = { "name": "Updated Test User", "imageUrl": "http://example.com/avatar.png", - "currency": "EUR" + "currency": "EUR", } - response = client.patch("/users/me", headers=auth_headers, json=update_payload) + response = client.patch( + "/users/me", headers=auth_headers, json=update_payload) assert response.status_code == status.HTTP_200_OK data = response.json()["user"] assert data["name"] == "Updated Test User" @@ -88,17 +112,28 @@ def test_update_user_profile_success(client: TestClient, auth_headers: dict, moc assert "createdAt" in data and data["createdAt"].endswith("Z") assert "updatedAt" in data and data["updatedAt"].endswith("Z") -def test_update_user_profile_partial_update(client: TestClient, auth_headers: dict, mocker): + +def test_update_user_profile_partial_update( + client: TestClient, auth_headers: dict, mocker +): """Test updating only one field of the user profile.""" iso_date = "2023-01-01T00:00:00Z" iso_date3 = "2023-01-03T00:00:00Z" update_payload = {"name": "Only Name Updated"} - mocker.patch("app.user.service.user_service.update_user_profile", return_value={ - "id": TEST_USER_ID, "name": "Only Name Updated", "email": TEST_USER_EMAIL, - "imageUrl": None, "currency": "USD", - "createdAt": iso_date, "updatedAt": iso_date3 - }) - response = client.patch("/users/me", headers=auth_headers, json=update_payload) + mocker.patch( + "app.user.service.user_service.update_user_profile", + return_value={ + "id": TEST_USER_ID, + "name": "Only Name Updated", + "email": TEST_USER_EMAIL, + "imageUrl": None, + "currency": "USD", + "createdAt": iso_date, + "updatedAt": iso_date3, + }, + ) + response = client.patch( + "/users/me", headers=auth_headers, json=update_payload) assert response.status_code == status.HTTP_200_OK data = response.json()["user"] assert data["name"] == "Only Name Updated" @@ -107,22 +142,34 @@ def test_update_user_profile_partial_update(client: TestClient, auth_headers: di assert "createdAt" in data and data["createdAt"].endswith("Z") assert "updatedAt" in data and data["updatedAt"].endswith("Z") + def test_update_user_profile_no_fields(client: TestClient, auth_headers: dict): """Test updating profile with no fields, expecting a 400 error.""" response = client.patch("/users/me", headers=auth_headers, json={}) assert response.status_code == status.HTTP_400_BAD_REQUEST - assert response.json() == {"detail": {"error": "InvalidInput", "message": "No update fields provided."}} + assert response.json() == { + "detail": {"error": "InvalidInput", "message": "No update fields provided."} + } + -def test_update_user_profile_user_not_found(client: TestClient, auth_headers: dict, mocker): +def test_update_user_profile_user_not_found( + client: TestClient, auth_headers: dict, mocker +): """Test updating profile when user is not found by the service.""" - mocker.patch("app.user.service.user_service.update_user_profile", return_value=None) + mocker.patch( + "app.user.service.user_service.update_user_profile", return_value=None) update_payload = {"name": "Attempted Update"} - response = client.patch("/users/me", headers=auth_headers, json=update_payload) + response = client.patch( + "/users/me", headers=auth_headers, json=update_payload) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.json() == {"detail": {"error": "NotFound", "message": "User not found"}} + assert response.json() == { + "detail": {"error": "NotFound", "message": "User not found"} + } + # --- Tests for DELETE /users/me --- + def test_delete_user_account_success(client: TestClient, auth_headers: dict, mocker): """Test successful deletion of a user account.""" response = client.delete("/users/me", headers=auth_headers) @@ -131,12 +178,17 @@ def test_delete_user_account_success(client: TestClient, auth_headers: dict, moc assert data["success"] is True assert data["message"] == "User account scheduled for deletion." + def test_delete_user_account_not_found(client: TestClient, auth_headers: dict, mocker): """Test deleting a user account when the user is not found by the service.""" - mocker.patch("app.user.service.user_service.delete_user", return_value=False) + mocker.patch("app.user.service.user_service.delete_user", + return_value=False) response = client.delete("/users/me", headers=auth_headers) assert response.status_code == status.HTTP_404_NOT_FOUND - assert response.json() == {"detail": {"error": "NotFound", "message": "User not found"}} + assert response.json() == { + "detail": {"error": "NotFound", "message": "User not found"} + } + # All route tests are in place, removing the placeholder # def test_placeholder(): diff --git a/backend/tests/user/test_user_service.py b/backend/tests/user/test_user_service.py index ab6ebfb7..0a64d738 100644 --- a/backend/tests/user/test_user_service.py +++ b/backend/tests/user/test_user_service.py @@ -1,22 +1,25 @@ +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, MagicMock + import pytest -from app.user.service import UserService from app.database import get_database +from app.user.service import UserService from bson import ObjectId -from datetime import datetime, timezone, timedelta -from unittest.mock import AsyncMock, MagicMock # Initialize UserService instance for testing user_service = UserService() # --- Fixtures --- + @pytest.fixture def mock_db_client(): """Fixture to create a mock database client with an async users collection.""" db_client = MagicMock() - db_client.users = AsyncMock() # Mock the 'users' collection + db_client.users = AsyncMock() # Mock the 'users' collection return db_client + @pytest.fixture(autouse=True) def mock_get_database(mocker, mock_db_client): """Autouse fixture to mock get_database and return the mock_db_client.""" @@ -56,10 +59,12 @@ def mock_get_database(mocker, mock_db_client): # --- Tests for transform_user_document --- + def test_transform_user_document_all_fields(): transformed = user_service.transform_user_document(RAW_USER_FROM_DB) assert transformed == TRANSFORMED_USER_EXPECTED + def test_transform_user_document_missing_optional_fields(): raw_user_minimal = { "_id": TEST_OBJECT_ID, @@ -79,13 +84,14 @@ def test_transform_user_document_missing_optional_fields(): transformed = user_service.transform_user_document(raw_user_minimal) assert transformed == expected_transformed_minimal + def test_transform_user_document_with_updated_at_different_from_created_at(): raw_user_updated = { "_id": TEST_OBJECT_ID, "name": "Updated User", "email": "updated@example.com", "created_at": NOW, - "updated_at": LATER + "updated_at": LATER, } expected_transformed_updated = { "id": TEST_OBJECT_ID_STR, @@ -99,25 +105,30 @@ def test_transform_user_document_with_updated_at_different_from_created_at(): transformed = user_service.transform_user_document(raw_user_updated) assert transformed == expected_transformed_updated + def test_transform_user_document_none_input(): assert user_service.transform_user_document(None) is None + def test_transform_user_document_iso_none(): user = {"_id": "x", "created_at": None} result = user_service.transform_user_document(user) assert result["createdAt"] is None + def test_transform_user_document_iso_str(): user = {"_id": "x", "created_at": "2025-06-28T12:00:00Z"} result = user_service.transform_user_document(user) assert result["createdAt"] == "2025-06-28T12:00:00Z" + def test_transform_user_document_iso_naive_datetime(): dt = datetime(2025, 6, 28, 12, 0, 0) user = {"_id": "x", "created_at": dt} result = user_service.transform_user_document(user) assert result["createdAt"].endswith("Z") + def test_transform_user_document_iso_aware_datetime_utc(): dt = datetime(2025, 6, 28, 12, 0, 0, tzinfo=timezone.utc) user = {"_id": "x", "created_at": dt} @@ -125,6 +136,7 @@ def test_transform_user_document_iso_aware_datetime_utc(): assert result["createdAt"].endswith("Z") assert result["createdAt"].startswith("2025-06-28T12:00:00") + def test_transform_user_document_iso_aware_datetime_non_utc(): dt = datetime(2025, 6, 28, 14, 0, 0, tzinfo=timezone(timedelta(hours=2))) user = {"_id": "x", "created_at": dt} @@ -132,35 +144,45 @@ def test_transform_user_document_iso_aware_datetime_non_utc(): assert result["createdAt"].endswith("Z") assert result["createdAt"].startswith("2025-06-28T12:00:00") + def test_transform_user_document_iso_unexpected_type(): - class Dummy: pass + class Dummy: + pass + dummy = Dummy() user = {"_id": "x", "created_at": dummy} result = user_service.transform_user_document(user) assert result["createdAt"] == str(dummy) + # --- Tests for get_user_by_id --- + @pytest.mark.asyncio async def test_get_user_by_id_found(mock_db_client, mock_get_database): mock_db_client.users.find_one.return_value = RAW_USER_FROM_DB user = await user_service.get_user_by_id(TEST_OBJECT_ID_STR) - mock_db_client.users.find_one.assert_called_once_with({"_id": TEST_OBJECT_ID}) + mock_db_client.users.find_one.assert_called_once_with( + {"_id": TEST_OBJECT_ID}) assert user == TRANSFORMED_USER_EXPECTED + @pytest.mark.asyncio async def test_get_user_by_id_not_found(mock_db_client, mock_get_database): mock_db_client.users.find_one.return_value = None user = await user_service.get_user_by_id(TEST_OBJECT_ID_STR) - mock_db_client.users.find_one.assert_called_once_with({"_id": TEST_OBJECT_ID}) + mock_db_client.users.find_one.assert_called_once_with( + {"_id": TEST_OBJECT_ID}) assert user is None + # --- Tests for update_user_profile --- + @pytest.mark.asyncio async def test_update_user_profile_success(mock_db_client, mock_get_database): update_data = {"name": "New Name", "currency": "CAD"} @@ -172,17 +194,23 @@ async def test_update_user_profile_success(mock_db_client, mock_get_database): mock_db_client.users.find_one_and_update.return_value = updated_user_doc_from_db # Expected transformed output - expected_transformed = user_service.transform_user_document(updated_user_doc_from_db) + expected_transformed = user_service.transform_user_document( + updated_user_doc_from_db + ) - updated_user = await user_service.update_user_profile(TEST_OBJECT_ID_STR, update_data) + updated_user = await user_service.update_user_profile( + TEST_OBJECT_ID_STR, update_data + ) args, kwargs = mock_db_client.users.find_one_and_update.call_args assert args[0] == {"_id": TEST_OBJECT_ID} assert "$set" in args[1] assert args[1]["$set"]["name"] == "New Name" assert args[1]["$set"]["currency"] == "CAD" - assert "updated_at" in args[1]["$set"] # Check that updated_at was added - assert kwargs["return_document"] is True # from pymongo import ReturnDocument (True means ReturnDocument.AFTER) + assert "updated_at" in args[1]["$set"] # Check that updated_at was added + assert ( + kwargs["return_document"] is True + ) # from pymongo import ReturnDocument (True means ReturnDocument.AFTER) assert updated_user is not None assert updated_user["name"] == "New Name" @@ -193,11 +221,15 @@ async def test_update_user_profile_success(mock_db_client, mock_get_database): @pytest.mark.asyncio async def test_update_user_profile_user_not_found(mock_db_client, mock_get_database): - mock_db_client.users.find_one_and_update.return_value = None # Simulate user not found + mock_db_client.users.find_one_and_update.return_value = ( + None # Simulate user not found + ) update_data = {"name": "New Name"} NON_EXISTENT_VALID_OID = "123456789012345678901234" - updated_user = await user_service.update_user_profile(NON_EXISTENT_VALID_OID, update_data) + updated_user = await user_service.update_user_profile( + NON_EXISTENT_VALID_OID, update_data + ) args, kwargs = mock_db_client.users.find_one_and_update.call_args assert args[0] == {"_id": ObjectId(NON_EXISTENT_VALID_OID)} @@ -207,8 +239,10 @@ async def test_update_user_profile_user_not_found(mock_db_client, mock_get_datab assert kwargs["return_document"] is True assert updated_user is None + # --- Tests for delete_user --- + @pytest.mark.asyncio async def test_delete_user_success(mock_db_client, mock_get_database): mock_delete_result = MagicMock() @@ -217,9 +251,11 @@ async def test_delete_user_success(mock_db_client, mock_get_database): result = await user_service.delete_user(TEST_OBJECT_ID_STR) - mock_db_client.users.delete_one.assert_called_once_with({"_id": TEST_OBJECT_ID}) + mock_db_client.users.delete_one.assert_called_once_with( + {"_id": TEST_OBJECT_ID}) assert result is True + @pytest.mark.asyncio async def test_delete_user_not_found(mock_db_client, mock_get_database): mock_delete_result = MagicMock() @@ -228,5 +264,6 @@ async def test_delete_user_not_found(mock_db_client, mock_get_database): result = await user_service.delete_user(TEST_OBJECT_ID_STR) - mock_db_client.users.delete_one.assert_called_once_with({"_id": TEST_OBJECT_ID}) + mock_db_client.users.delete_one.assert_called_once_with( + {"_id": TEST_OBJECT_ID}) assert result is False