Skip to content

Commit

Permalink
Refactor verification service and fix type issues
Browse files Browse the repository at this point in the history
  • Loading branch information
zobweyt committed Aug 15, 2024
1 parent c700d5f commit 0b439c0
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 31 deletions.
4 changes: 2 additions & 2 deletions backend/src/api/v1/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from src.api.v1.auth.schemas import Token
from src.api.v1.users.schemas import UserCreate, UserPasswordReset
from src.api.v1.users.service import create_user, get_user_by_email, is_email_registered, update_password
from src.api.v1.verification.service import expire_code_if_valid
from src.api.v1.verification.service import expire_code_if_correct
from src.security import create_access_token, is_valid_password

router = APIRouter(prefix="/auth", tags=["Authentication"])
Expand Down Expand Up @@ -41,7 +41,7 @@ def reset_password(schema: UserPasswordReset, session: Session) -> Token:

if not user:
raise HTTPException(status.HTTP_404_NOT_FOUND, "User not found")
if not expire_code_if_valid(schema.email, schema.code):
if not expire_code_if_correct(schema.email, schema.code):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, "The code is invalid")

update_password(session, user, schema.password)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/api/v1/users/me/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
verify_user,
)
from src.api.v1.verification.schemas import Code
from src.api.v1.verification.service import expire_code_if_valid
from src.api.v1.verification.service import expire_code_if_correct
from src.config import settings
from src.storage import fs, mimetype

Expand All @@ -31,7 +31,7 @@ def get_current_user(current_user: CurrentUser) -> User:
def verify_current_user(current_user: CurrentUser, schema: Code, session: Session) -> User:
if current_user.is_verified:
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "User has already been verified")
if not expire_code_if_valid(current_user.email, schema.code):
if not expire_code_if_correct(current_user.email, schema.code):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, "The code is invalid")

verify_user(session, current_user)
Expand Down
4 changes: 2 additions & 2 deletions backend/src/api/v1/verification/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from src.api.v1.users.schemas import UserEmail
from src.api.v1.verification.mailings import send_verification_message
from src.api.v1.verification.schemas import CodeResponse, CodeVerify
from src.api.v1.verification.service import is_valid_code
from src.api.v1.verification.service import is_code_correct

router = APIRouter(prefix="/verification", tags=["Verification"])


@router.post("/verify")
def verify_code(schema: CodeVerify) -> Response:
if not is_valid_code(schema.email, schema.code):
if not is_code_correct(schema.email, schema.code):
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, "The code is invalid")
return Response(status_code=status.HTTP_202_ACCEPTED)

Expand Down
55 changes: 30 additions & 25 deletions backend/src/api/v1/verification/service.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
from datetime import timedelta
from random import randint
from typing import Any

from redis.typing import KeyT, ResponseT

from src.cache import redis, separate
from src.config import settings

__CACHE_KEY_PREFIX = "codes"


def set_code(email: str) -> int:
code = generate_code()
name = generate_code_key(email)
redis.set(name, code)
redis.expire(name, timedelta(minutes=settings.otp.expire_minutes))
def set_code(subject: KeyT) -> int:
name = generate_code_cache_name(subject)
code = generate_random_code()
expires_at = timedelta(minutes=settings.otp.expire_minutes)

redis.set(name, code, expires_at)

return code


def get_code(email: str) -> Any: # TODO: Use proper type conversion.
return redis.get(generate_code_key(email))
def get_code(subject: KeyT) -> ResponseT:
name = generate_code_cache_name(subject)
code = redis.get(name)

return code


def expire_code_if_valid(email: str, code: int) -> bool:
is_valid = is_valid_code(email, code)
if is_valid:
expire_code(email)
return is_valid
def expire_code(subject: KeyT) -> ResponseT:
name = generate_code_cache_name(subject)
response = redis.delete(name)

return response

def is_valid_code(email: str, code: int) -> bool:
cached_code = get_code(email)
if cached_code:
# TODO: Use proper type conversion.
return int(cached_code) == code # type: ignore
return False

def expire_code_if_correct(subject: KeyT, code: int) -> bool:
is_correct = is_code_correct(subject, code)
if is_correct:
expire_code(subject)
return is_correct

def generate_code() -> int:
return randint(settings.otp.min, settings.otp.max)

def is_code_correct(subject: KeyT, code: int) -> bool:
return get_code(subject) == str(code).encode()

def generate_code_key(email: str) -> str:
return separate(__CACHE_KEY_PREFIX, email)

def generate_random_code() -> int:
return randint(settings.otp.min, settings.otp.max)


def expire_code(email: str) -> None:
redis.delete(generate_code_key(email))
def generate_code_cache_name(subject: KeyT) -> KeyT:
return separate(__CACHE_KEY_PREFIX, subject)

0 comments on commit 0b439c0

Please sign in to comment.