Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion backend/app/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,43 @@
UserResponse, ErrorResponse
)
from app.auth.service import auth_service
from app.auth.security import create_access_token
from app.auth.security import create_access_token, oauth2_scheme # Import oauth2_scheme
from fastapi.security import OAuth2PasswordRequestForm # Import OAuth2PasswordRequestForm
from datetime import timedelta
from app.config import settings

router = APIRouter(prefix="/auth", tags=["Authentication"])

@router.post("/token", response_model=TokenResponse, include_in_schema=False) # include_in_schema=False to hide from docs if desired, or True to show
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
"""
OAuth2 compatible token login, get an access token for future requests.
This endpoint is used by Swagger UI for authorization.
It expects username (email) and password in form-data.
"""
try:

Check warning on line 23 in backend/app/auth/routes.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/routes.py#L23

Added line #L23 was not covered by tests
# Note: OAuth2PasswordRequestForm uses 'username' field for the user identifier.
# We'll treat it as email here.
result = await auth_service.authenticate_user_with_email(

Check warning on line 26 in backend/app/auth/routes.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/routes.py#L26

Added line #L26 was not covered by tests
email=form_data.username, # form_data.username is the email
password=form_data.password
)

access_token = create_access_token(

Check warning on line 31 in backend/app/auth/routes.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/routes.py#L31

Added line #L31 was not covered by tests
data={"sub": str(result["user"]["_id"])},
expires_delta=timedelta(minutes=settings.access_token_expire_minutes)
)

return TokenResponse(access_token=access_token, token_type="bearer")
except HTTPException:
raise
except Exception as e:

Check warning on line 39 in backend/app/auth/routes.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/routes.py#L36-L39

Added lines #L36 - L39 were not covered by tests
# It's good practice to log the exception here
raise HTTPException(

Check warning on line 41 in backend/app/auth/routes.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/routes.py#L41

Added line #L41 was not covered by tests
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Authentication failed: {str(e)}"
)

@router.post("/signup/email", response_model=AuthResponse)
async def signup_with_email(request: EmailSignupRequest):
"""
Expand Down
30 changes: 26 additions & 4 deletions backend/app/auth/security.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
from jose import JWTError, jwt
from passlib.context import CryptContext
from fastapi import HTTPException, status
from fastapi import HTTPException, status, Depends
from fastapi.security import OAuth2PasswordBearer
from app.config import settings
import secrets

Expand All @@ -13,6 +14,8 @@
# Fallback for bcrypt version compatibility issues
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=12)

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/token") # Updated tokenUrl

def verify_password(plain_password: str, hashed_password: str) -> bool:
"""
Verifies whether a plaintext password matches a given hashed password.
Expand Down Expand Up @@ -53,9 +56,9 @@
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=settings.access_token_expire_minutes)
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.access_token_expire_minutes)

Check warning on line 61 in backend/app/auth/security.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/security.py#L61

Added line #L61 was not covered by tests

to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.secret_key, algorithm=settings.algorithm)
Expand Down Expand Up @@ -95,3 +98,22 @@
A random 32-byte URL-safe string suitable for use as a password reset token.
"""
return secrets.token_urlsafe(32)

def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
"""
Retrieves the current user based on the provided JWT token using centralized verification.

Args:
token: The JWT token from which to extract the user information.

Returns:
A dictionary containing the current user's information.

Raises:
HTTPException: If the token is invalid or user information cannot be extracted.
"""
payload = verify_token(token) # Centralized JWT validation and error handling
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")

Check warning on line 118 in backend/app/auth/security.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/security.py#L118

Added line #L118 was not covered by tests
return {"_id": user_id}
16 changes: 8 additions & 8 deletions backend/app/auth/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
from pymongo.errors import DuplicateKeyError
from bson import ObjectId
Expand Down Expand Up @@ -95,7 +95,7 @@
"name": name,
"avatar": None,
"currency": "USD",
"created_at": datetime.utcnow(),
"created_at": datetime.now(timezone.utc),
"auth_provider": "email",
"firebase_uid": None
}
Expand Down Expand Up @@ -198,7 +198,7 @@
"name": name,
"avatar": picture,
"currency": "USD",
"created_at": datetime.utcnow(),
"created_at": datetime.now(timezone.utc),
"auth_provider": "google",
"firebase_uid": firebase_uid,
"hashed_password": None
Expand Down Expand Up @@ -245,7 +245,7 @@
token_record = await db.refresh_tokens.find_one({
"token": refresh_token,
"revoked": False,
"expires_at": {"$gt": datetime.utcnow()}
"expires_at": {"$gt": datetime.now(timezone.utc)}
})

if not token_record:
Expand Down Expand Up @@ -322,7 +322,7 @@

# Generate reset token
reset_token = generate_reset_token()
reset_expires = datetime.utcnow() + timedelta(hours=1) # 1 hour expiry
reset_expires = datetime.now(timezone.utc) + timedelta(hours=1) # 1 hour expiry

Check warning on line 325 in backend/app/auth/service.py

View check run for this annotation

Codecov / codecov/patch

backend/app/auth/service.py#L325

Added line #L325 was not covered by tests

# Store reset token
await db.password_resets.insert_one({
Expand Down Expand Up @@ -362,7 +362,7 @@
reset_record = await db.password_resets.find_one({
"token": reset_token,
"used": False,
"expires_at": {"$gt": datetime.utcnow()}
"expires_at": {"$gt": datetime.now(timezone.utc)}
})

if not reset_record:
Expand Down Expand Up @@ -406,14 +406,14 @@
db = self.get_db()

refresh_token = create_refresh_token()
expires_at = datetime.utcnow() + timedelta(days=settings.refresh_token_expire_days)
expires_at = datetime.now(timezone.utc) + timedelta(days=settings.refresh_token_expire_days)

await db.refresh_tokens.insert_one({
"token": refresh_token,
"user_id": ObjectId(user_id) if isinstance(user_id, str) else user_id,
"expires_at": expires_at,
"revoked": False,
"created_at": datetime.utcnow()
"created_at": datetime.now(timezone.utc)
})

return refresh_token
Expand Down
34 changes: 34 additions & 0 deletions backend/app/user/routes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.user.schemas import UserProfileResponse, UserProfileUpdateRequest, DeleteUserResponse
from app.user.service import user_service
from app.auth.security import get_current_user
from typing import Dict, Any

router = APIRouter(prefix="/users", tags=["User"])

@router.get("/me", response_model=UserProfileResponse)
async def get_current_user_profile(current_user: Dict[str, Any] = Depends(get_current_user)):
user = await user_service.get_user_by_id(current_user["_id"])
if not user:
raise HTTPException(status_code=404, detail="User not found")
return user

@router.patch("/me", response_model=Dict[str, Any])
async def update_user_profile(
updates: UserProfileUpdateRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
):
update_data = updates.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="No update fields provided.")
updated_user = await user_service.update_user_profile(current_user["_id"], update_data)
if not updated_user:
raise HTTPException(status_code=404, detail="User not found")
return {"user": updated_user}

@router.delete("/me", response_model=DeleteUserResponse)
async def delete_user_account(current_user: Dict[str, Any] = Depends(get_current_user)):
deleted = await user_service.delete_user(current_user["_id"])
if not deleted:
raise HTTPException(status_code=404, detail="User not found")
return DeleteUserResponse(success=True, message="User account scheduled for deletion.")
23 changes: 23 additions & 0 deletions backend/app/user/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pydantic import BaseModel, EmailStr, Field
from typing import Optional
from datetime import datetime

class UserProfileResponse(BaseModel):
id: str = Field(alias="_id")
name: str
email: EmailStr
imageUrl: Optional[str] = Field(default=None, alias="avatar")
currency: str = "USD"
createdAt: datetime
updatedAt: datetime

model_config = {"populate_by_name": True}

class UserProfileUpdateRequest(BaseModel):
name: Optional[str] = None
imageUrl: Optional[str] = None
currency: Optional[str] = None

class DeleteUserResponse(BaseModel):
success: bool = True
message: Optional[str] = None
63 changes: 63 additions & 0 deletions backend/app/user/service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from fastapi import HTTPException, status, Depends
from app.database import get_database
from bson import ObjectId
from datetime import datetime, timezone
from typing import Optional, Dict, Any

class UserService:
def __init__(self):
pass

def get_db(self):
return get_database()

def transform_user_document(self, user: dict) -> dict:
if not user:
return None
try:
user_id = str(user["_id"])
except Exception:
return None # Handle invalid ObjectId gracefully

Check warning on line 20 in backend/app/user/service.py

View check run for this annotation

Codecov / codecov/patch

backend/app/user/service.py#L19-L20

Added lines #L19 - L20 were not covered by tests
return {
"_id": user_id,
"name": user.get("name"),
"email": user.get("email"),
"avatar": user.get("imageUrl") or user.get("avatar"),
"currency": user.get("currency", "USD"),
"createdAt": user.get("created_at"),
"updatedAt": user.get("updated_at") or user.get("created_at"),
}

async def get_user_by_id(self, user_id: str) -> Optional[dict]:
db = self.get_db()
try:
obj_id = ObjectId(user_id)
except Exception:
return None # Handle invalid ObjectId gracefully

Check warning on line 36 in backend/app/user/service.py

View check run for this annotation

Codecov / codecov/patch

backend/app/user/service.py#L35-L36

Added lines #L35 - L36 were not covered by tests
user = await db.users.find_one({"_id": obj_id})
return self.transform_user_document(user)

async def update_user_profile(self, user_id: str, updates: dict) -> Optional[dict]:
db = self.get_db()
try:
obj_id = ObjectId(user_id)
except Exception:
return None # Handle invalid ObjectId gracefully

Check warning on line 45 in backend/app/user/service.py

View check run for this annotation

Codecov / codecov/patch

backend/app/user/service.py#L44-L45

Added lines #L44 - L45 were not covered by tests
updates["updated_at"] = datetime.now(timezone.utc)
result = await db.users.find_one_and_update(
{"_id": obj_id},
{"$set": updates},
return_document=True
)
return self.transform_user_document(result)

async def delete_user(self, user_id: str) -> bool:
db = self.get_db()
try:
obj_id = ObjectId(user_id)
except Exception:
return False # Handle invalid ObjectId gracefully

Check warning on line 59 in backend/app/user/service.py

View check run for this annotation

Codecov / codecov/patch

backend/app/user/service.py#L58-L59

Added lines #L58 - L59 were not covered by tests
result = await db.users.delete_one({"_id": obj_id})
return result.deleted_count == 1

user_service = UserService()
33 changes: 17 additions & 16 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from contextlib import asynccontextmanager
from app.database import connect_to_mongo, close_mongo_connection
from app.auth.routes import router as auth_router
from app.user.routes import router as user_router
from app.config import settings

@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
print("Lifespan: Connecting to MongoDB...")
await connect_to_mongo()
print("Lifespan: MongoDB connected.")
yield
# Shutdown
print("Lifespan: Closing MongoDB connection...")
await close_mongo_connection()
print("Lifespan: MongoDB connection closed.")

app = FastAPI(
title="Splitwiser API",
description="Backend API for Splitwiser expense tracking application",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
redoc_url="/redoc",
lifespan=lifespan
)

# CORS middleware - Enhanced configuration for production
Expand Down Expand Up @@ -74,21 +89,6 @@ async def options_handler(request: Request, path: str):

return response

# Database events
@app.on_event("startup")
async def startup_event():
"""
Initializes the MongoDB connection when the application starts.
"""
await connect_to_mongo()

@app.on_event("shutdown")
async def shutdown_event():
"""
Closes the MongoDB connection when the application shuts down.
"""
await close_mongo_connection()

# Health check
@app.get("/health")
async def health_check():
Expand All @@ -101,6 +101,7 @@ async def health_check():

# Include routers
app.include_router(auth_router)
app.include_router(user_router)

if __name__ == "__main__":
import uvicorn
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ httpx
mongomock-motor
pytest-env
pytest-cov
pytest-mock
6 changes: 3 additions & 3 deletions backend/tests/auth/test_auth_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from main import app # Assuming your FastAPI app instance is here
from app.config import settings # To potentially override settings if needed, or check values
from app.auth.security import verify_password, get_password_hash # For checking hashed password if necessary
from datetime import datetime
from datetime import datetime, timezone
from bson import ObjectId

# It's good practice to set a specific test secret key if not relying on external env vars
Expand Down Expand Up @@ -135,7 +135,7 @@ async def test_login_with_email_success(mock_db):
"name": "Login User",
"avatar": None,
"currency": "USD",
"created_at": datetime.utcnow(), # Ensure datetime is used
"created_at": datetime.now(timezone.utc), # Ensure datetime is used
"auth_provider": "email",
"firebase_uid": None
})
Expand Down Expand Up @@ -173,7 +173,7 @@ async def test_login_with_incorrect_password(mock_db):
"email": user_email,
"hashed_password": get_password_hash(correct_password),
"name": "Wrong Pass User",
"created_at": datetime.utcnow()
"created_at": datetime.now(timezone.utc)
})

async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
Expand Down
Loading