diff --git a/Procfile b/Procfile index 82be1677..8ad3fb0d 100644 --- a/Procfile +++ b/Procfile @@ -1 +1 @@ -web: gunicorn -b :$PORT training.main:app --workers $NUM_WORKERS --worker-class uvicorn.workers.UvicornWorker \ No newline at end of file +web: gunicorn -b :$PORT training.main:app --workers $NUM_WORKERS --worker-class uvicorn.workers.UvicornWorker --timeout 1200 \ No newline at end of file diff --git a/alembic/versions/db581ea0a1c3_gspc_updates.py b/alembic/versions/db581ea0a1c3_gspc_updates.py new file mode 100644 index 00000000..fe742555 --- /dev/null +++ b/alembic/versions/db581ea0a1c3_gspc_updates.py @@ -0,0 +1,59 @@ +"""gspc_updates + +Revision ID: db581ea0a1c3 +Revises: 12049328fd0a +Create Date: 2025-02-12 09:56:53.397539 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'db581ea0a1c3' +down_revision = '12049328fd0a' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Add new columns + op.add_column('gspc_invite', sa.Column('gspc_invite_id', sa.UUID(as_uuid=True), nullable=True)) + op.add_column('gspc_invite', sa.Column('second_invite_date', sa.DateTime(timezone=True), nullable=True)) + op.add_column('gspc_invite', sa.Column('final_invite_date', sa.DateTime(timezone=True), nullable=True)) + op.add_column('gspc_invite', sa.Column('completed_date', sa.DateTime(timezone=True), nullable=True)) + + # Create unique index for gspc_invite_id + op.create_index( + 'ix_gspc_invite_gspc_invite_id', + 'gspc_invite', + ['gspc_invite_id'], + unique=True + ) + + # Add foreign key column to gspc_completions + op.add_column('gspc_completions', sa.Column('gspc_invite_id', sa.UUID(as_uuid=True), nullable=True)) + + # Create foreign key constraint + op.create_foreign_key( + 'gspc_completions_x_gspc_invite', + 'gspc_completions', + 'gspc_invite', + ['gspc_invite_id'], + ['gspc_invite_id'] + ) + + +def downgrade() -> None: + # Drop foreign key constraint and column from gspc_completions + op.drop_constraint(None, 'gspc_completions', type_='foreignkey') + op.drop_column('gspc_completions', 'gspc_invite_id') + + # Drop index + op.drop_index('ix_gspc_invite_gspc_invite_id', table_name='gspc_invite') + + # Drop columns + op.drop_column('gspc_invite', 'completed_date') + op.drop_column('gspc_invite', 'final_invite_date') + op.drop_column('gspc_invite', 'second_invite_date') + op.drop_column('gspc_invite', 'gspc_invite_id') diff --git a/dev/uaa/uaa.yml b/dev/uaa/uaa.yml index 58d32ae4..644bf07e 100644 --- a/dev/uaa/uaa.yml +++ b/dev/uaa/uaa.yml @@ -37,7 +37,7 @@ scim: users: - paul|wombat|paul@uaa.test|Paul|Smith|openid - stefan|wallaby|stefan@uaa.test|Stefan|Schmidt|openid - + - mark|wombat|mark.meyer@gsa.gov|Mark|openid oauth: user: authorities: diff --git a/package.json b/package.json index 77bb691a..65010631 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "scripts": { "build:frontend": "cd training-front-end && npm install && npm run build && cd ..", "federalist": "npm run build:frontend", - "dev": "(trap 'kill 0' SIGINT; npm run dev:frontend & npm run dev:backend)", + "dev": "(npm run dev:frontend & npm run dev:backend)", "dev:frontend": "cd training-front-end && npm run dev", "dev:backend": "uvicorn training.main:app --reload", "dev:db-start": "docker-compose up -d", diff --git a/training-front-end/src/components/AdminGSPC.vue b/training-front-end/src/components/AdminGSPC.vue index f4005324..e197f3e2 100644 --- a/training-front-end/src/components/AdminGSPC.vue +++ b/training-front-end/src/components/AdminGSPC.vue @@ -173,7 +173,7 @@ class="usa-alert--slim" :has-heading="false" > - Emails successfully sent to {{ successCount }} people. + Emails sending to {{ successCount }} people.
diff --git a/training-front-end/src/components/GspcRegistration.vue b/training-front-end/src/components/GspcRegistration.vue index 7d8d5cdf..e7fe7f1f 100644 --- a/training-front-end/src/components/GspcRegistration.vue +++ b/training-front-end/src/components/GspcRegistration.vue @@ -45,8 +45,8 @@ const quizStarted = ref(false) const quizSubmitted = ref(false) const error = ref(props.error) - let redirectExpirationDateString = "" - let expirationDate = "" + let gspcInviteId = "" + let redirectGspcInviteIdString = "" const certTypeGspc = 2 const questions = @@ -63,8 +63,8 @@ onBeforeMount(async () => { const urlParams = new URLSearchParams(window.location.search); - expirationDate = urlParams.get('expirationDate') - redirectExpirationDateString = 'expirationDate=' + expirationDate + gspcInviteId = urlParams.get('gspcInviteId') + redirectGspcInviteIdString = 'gspcInviteId=' + gspcInviteId }) function startLoading() { @@ -91,7 +91,7 @@ 'Content-Type': 'application/json', 'Authorization': `Bearer ${user.value.jwt}` }, - body: JSON.stringify({'responses':{'responses': user_answers}, 'expiration_date': expirationDate}) + body: JSON.stringify({'responses':{'responses': user_answers}, 'gspc_invite_id': gspcInviteId}) }) } catch { const err = new Error("There was a problem connecting with the server") @@ -145,7 +145,7 @@ title="gspc_registration" :header="header" link-destination-text="the GSA SmartPay Program Certification (GSPC)" - :parameters="redirectExpirationDateString" + :parameters="redirectGspcInviteIdString" @start-loading="startLoading" @error="setError" > diff --git a/training/api/api_v1/auth.py b/training/api/api_v1/auth.py index c8517dce..1f6fd2a2 100644 --- a/training/api/api_v1/auth.py +++ b/training/api/api_v1/auth.py @@ -26,9 +26,7 @@ def auth_exchange( ): db_user = user_repo.find_by_email(uaa_user.get("email")) if not db_user: - logging.info( - f"UAA authenticated, but not found in database: {uaa_user['email']}" - ) + logging.info("UAA authenticated, but not found in database", extra={'user': uaa_user['email']}) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Invalid user." @@ -36,7 +34,7 @@ def auth_exchange( user = User.model_validate(db_user) if not user.is_admin(): - logging.info(f"UAA authenticated, but not an admin: {uaa_user['email']}") + logging.info("UAA authenticated, but not an admin", extra={'user': uaa_user['email']}) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to login." @@ -44,5 +42,5 @@ def auth_exchange( jwt_user = UserJWT.model_validate(db_user) encoded_jwt = jwt.encode(jwt_user.model_dump(), settings.JWT_SECRET, algorithm="HS256") - logging.info(f"Token exchange success for {db_user.email}") + logging.info("Token exchange success", extra={'user': db_user.email}) return {'user': jwt_user, 'jwt': encoded_jwt} diff --git a/training/api/api_v1/gspc.py b/training/api/api_v1/gspc.py index 81f8a85a..a60743e9 100644 --- a/training/api/api_v1/gspc.py +++ b/training/api/api_v1/gspc.py @@ -1,5 +1,4 @@ from typing import Any -import logging import csv from io import StringIO from fastapi import APIRouter, status, HTTPException, Response, Depends @@ -7,10 +6,11 @@ from training.services import GspcService from training.repositories import GspcInviteRepository, GspcCompletionRepository from training.api.deps import gspc_invite_repository, gspc_completion_repository, gspc_service -from training.api.email import send_gspc_invite_email +from training.api.email import InviteTuple, send_gspc_invite_emails from training.api.auth import RequireRole -from training.config import settings from training.api.auth import JWTUser +from fastapi import BackgroundTasks +from training.config import settings router = APIRouter() @@ -19,27 +19,26 @@ @router.post("/gspc-invite") async def gspc_admin_invite( gspcInvite: GspcInvite, + background_tasks: BackgroundTasks, repo: GspcInviteRepository = Depends(gspc_invite_repository), user=Depends(RequireRole(["Admin"])) ): ''' Given a list of emails we parse them into two list (valid and invalid). - Then we log each of the valid emails to the db and shoot of an email to each. + Then we log each of the valid emails to the db and shoot off an email to each. ''' try: # Parse emails string into valid and invalid email list gspcInvite.parse() - for email in gspcInvite.valid_emails: - repo.create(email=email, certification_expiration_date=gspcInvite.certification_expiration_date) - # If performance becomes an issue use multithreading to send the emails - try: - params = gspcInvite.certification_expiration_date.strftime('%Y-%m-%d') - link = f"{settings.BASE_URL}/gspc_registration/?expirationDate={params}" - send_gspc_invite_email(to_email=email, link=link) - logging.info(f"Sent gspc invite email to {email}") - except Exception as e: - logging.error("Error sending gspc invite email", e) + entities = repo.bulk_create(emails=gspcInvite.valid_emails, certification_expiration_date=gspcInvite.certification_expiration_date) + + # Explicitly load needed props into memory before passing to the background task + entities_data = [InviteTuple(entity.gspc_invite_id, entity.email) for entity in entities] + + # Add email sending to background tasks + # note: passing in settings as the task looses the current app context once triggered + background_tasks.add_task(send_gspc_invite_emails, invites=entities_data, app_settings=settings) # Return object with both list for success and failure messages return gspcInvite diff --git a/training/api/api_v1/loginless_flow.py b/training/api/api_v1/loginless_flow.py index defd79b7..b908d4fe 100644 --- a/training/api/api_v1/loginless_flow.py +++ b/training/api/api_v1/loginless_flow.py @@ -89,8 +89,10 @@ def send_link( # and try the link. if not all(role in role_names for role in required_roles): logging.info( - f"{user.email} does not have the required role to access {page_id_lookup[dest.page_id]['path']}" + "unauthorized access attempt", + extra={'user': user.email, 'path': page_id_lookup[dest.page_id]['path']} ) + raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized" @@ -114,7 +116,7 @@ def send_link( url = f"{settings.BASE_URL}{path}?{parameters}" try: send_email(to_email=user.email, name=user.name, link=url, training_title=dest.title) - logging.info(f"Sent confirmation email to {user.email} for {path}") + logging.info("Sent confirmation email", extra={'user': user.email, 'path': path}) except Exception as e: logging.error("Error sending mail", e) raise HTTPException( @@ -151,6 +153,6 @@ async def get_user( if not db_user: db_user = repo.create(user) user_return = UserJWT.model_validate(db_user) - logging.info(f"Confirmed email token for {user.email}") + logging.info("Confirmed email token", extra={'user': user.email}) encoded_jwt = jwt.encode(user_return.model_dump(), settings.JWT_SECRET, algorithm="HS256") return {'user': user_return, 'jwt': encoded_jwt} diff --git a/training/api/auth.py b/training/api/auth.py index 6768a191..b7021fa9 100644 --- a/training/api/auth.py +++ b/training/api/auth.py @@ -1,4 +1,5 @@ import json +import logging from typing import Annotated from urllib.request import urlopen @@ -29,6 +30,7 @@ async def __call__(self, request: Request): user = self.decode_jwt(credentials.credentials) if user is None: + JWTUser.log_invalid_jwt(credentials.credentials, request.url.path) raise HTTPException(status_code=403, detail="Invalid or expired token.") return user @@ -38,27 +40,28 @@ def decode_jwt(self, token: str): except InvalidTokenError: return + @staticmethod + def log_invalid_jwt(token: str, path: str): + try: + invalid_claim = jwt.decode(token, options={"verify_signature": False}) + logging.info("Invalid token", extra={'decoded': invalid_claim, 'path': path}) + except InvalidTokenError: + logging.warning("Unprocessable token", extra={'token': token, 'path': path}) + -class UAAJWTUser(HTTPBearer): +class UAAJWTUser(JWTUser): ''' Represents a JWT issued by an OAuth server. Used as part of the Admin SecureAuth flow ''' - async def __call__(self, request: Request): - - credentials: HTTPAuthorizationCredentials | None = await super().__call__(request) - user = self.decode_jwt(credentials.credentials) - if user is None: - raise HTTPException(status_code=403, detail="Invalid or expired token.") - return user - def decode_jwt(self, token: str): token_header = jwt.get_unverified_header(token) key_id = token_header.get("kid") jwk = self.get_jwks().get(key_id) if jwk is None: + logging.info("Unknown jwk", extra={'key_id': key_id}) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Unrecognized token." @@ -88,6 +91,7 @@ def get_jwks(self): jwks = json.load(res) if jwks.get("keys") is None: + logging.warning("Unable to get JSON Web keys from server") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to get required data from authentication server (public keys)." @@ -118,6 +122,7 @@ def discover_jwks_endpoint(self) -> str: jwks_endpoint_uri = data.get("jwks_uri") if jwks_endpoint_uri is None: + logging.warning("Unable to get jwks endpoint from server") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Unable to get required data from authentication server (JWKS URI)." @@ -144,15 +149,17 @@ def __call__(self, user=Depends(JWTUser())): try: user_roles = user['roles'] except KeyError: + logging.info("Unauthorized Access", extra={'user': user, 'required_roles': self.required_roles}) raise HTTPException(status_code=401, detail="Not Authorized") if all(role in user_roles for role in self.required_roles): return user else: + logging.info("Unauthorized Access", extra={'user': user, 'required_roles': self.required_roles}) raise HTTPException(status_code=401, detail="Not Authorized") -def user_from_form(jwtToken: Annotated[str, Form()]): +def user_from_form(jwtToken: Annotated[str, Form()], request: Request): ''' This allows POST requests to send a token as part of form-encoded request. There are places in the front-end where we want to download a file, but we also @@ -163,5 +170,6 @@ def user_from_form(jwtToken: Annotated[str, Form()]): ''' try: return jwt.decode(jwtToken, settings.JWT_SECRET, algorithms=["HS256"]) - except jwt.exceptions.InvalidTokenError: + except InvalidTokenError: + JWTUser.log_invalid_jwt(jwtToken, request.url.path) raise HTTPException(status_code=401, detail="Not Authorized") diff --git a/training/api/email.py b/training/api/email.py index a6b7cc4e..8a3de6e6 100644 --- a/training/api/email.py +++ b/training/api/email.py @@ -1,10 +1,15 @@ +from itertools import islice +import logging from string import Template +from typing import Iterator, List, NamedTuple +import uuid from pydantic import EmailStr from smtplib import SMTP from email.message import EmailMessage -from training.config import settings +from training.config import settings, Settings from training.errors import SendEmailError +import time # We also use jinja template. # See: https://sabuhish.github.io/fastapi-mail/example/#using-jinja2-html-templates @@ -105,3 +110,78 @@ def send_gspc_invite_email(to_email: EmailStr, link: str) -> None: raise SendEmailError from e finally: smtp.quit() + + +class InviteTuple(NamedTuple): + gspc_invite_id: uuid.UUID + email: str + + +def send_gspc_invite_emails(invites: list[InviteTuple], app_settings: Settings) -> None: + """Background task designed to do a bulk send of GSPC invite emails""" + # Start timer to track how long the job takes + start_time = time.time() + logging.info(f"Starting gspc invite job, number of invites:{len(invites)}") + + email_messages = [create_email_message(invite, app_settings) for invite in invites] + send_emails_in_batches(email_messages=email_messages, batch_size=25, app_settings=app_settings) + + end_time = time.time() + execution_time = end_time - start_time + # Format time as minutes and seconds for better readability + minutes = int(execution_time // 60) + seconds = int(execution_time % 60) + logging.info(f"Finished gspc invite job. Total execution time: {minutes} minutes and {seconds} seconds for {len(invites)} emails") + + +def create_email_message(invite: InviteTuple, app_settings: Settings) -> EmailMessage: + """Create an EmailMessage object for a given invite.""" + + link = f"{app_settings.BASE_URL}/gspc_registration/?gspcInviteId={invite.gspc_invite_id}" + body = GSPC_INVITE_EMAIL_TEMPLATE.substitute({"link": link}) + + message = EmailMessage() + message.set_content(body, subtype="html") + message["Subject"] = "Verify your GSA SmartPay Program Certification (GSPC) Coursework and Experience" + message["From"] = f"{app_settings.EMAIL_FROM_NAME} <{app_settings.EMAIL_FROM}>" + message["To"] = invite.email + + return message + + +def batch_iterator(items: List, batch_size: int) -> Iterator: + """Create an iterator that yields batches of the specified size.""" + iterator = iter(items) + batch = list(islice(iterator, batch_size)) + while batch: + yield batch + batch = list(islice(iterator, batch_size)) + + +def send_emails_in_batches(email_messages: List[EmailMessage], batch_size: int, app_settings: Settings) -> None: + """Chunks the list into batches and attempts to send each back of emails.""" + for batch in batch_iterator(email_messages, batch_size): + max_retries = 3 + # Attempt to trigger email batch, retries on failure + for attempt in range(max_retries): + try: + with SMTP(app_settings.SMTP_SERVER, port=app_settings.SMTP_PORT, timeout=30) as smtp: + smtp.starttls() + if app_settings.SMTP_USER and app_settings.SMTP_PASSWORD: + smtp.login(user=app_settings.SMTP_USER, password=app_settings.SMTP_PASSWORD) + + # Send messages in current batch + for message in batch: + smtp.send_message(message) + smtp.quit() + break # Exit retry loop if successful + except Exception as e: + if attempt < max_retries - 1: + # Wait with backoff: 1, 2, 4 seconds... + sleep_time = 1 * (attempt + 1) + time.sleep(sleep_time) + # Log the error after all retries failed + else: + # Extract all email addresses from the batch and join them with commas + addresses_list = ", ".join([message['To'] for message in batch]) + logging.error(f"Failed to send batch after {max_retries} attempts: {str(e)}. Addresses: {addresses_list}") diff --git a/training/models/gspc_completion.py b/training/models/gspc_completion.py index 12bed4ab..11e5e3c9 100644 --- a/training/models/gspc_completion.py +++ b/training/models/gspc_completion.py @@ -1,8 +1,8 @@ from typing import Any -from datetime import datetime +from datetime import date, datetime from training.models import Base from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy import Column, Date, ForeignKey, func +from sqlalchemy import ForeignKey, func class GspcCompletion(Base): @@ -11,6 +11,6 @@ class GspcCompletion(Base): id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[int] = mapped_column(ForeignKey("users.id")) passed: Mapped[bool] = mapped_column() - certification_expiration_date = Column(Date(), nullable=False) + certification_expiration_date: Mapped[date] = mapped_column(nullable=False) submit_ts: Mapped[datetime] = mapped_column(server_default=func.now()) responses: Mapped[dict[str, Any]] = mapped_column() diff --git a/training/models/gspc_invite.py b/training/models/gspc_invite.py index f124c318..04746a99 100644 --- a/training/models/gspc_invite.py +++ b/training/models/gspc_invite.py @@ -1,5 +1,7 @@ +from datetime import datetime, date +import uuid from training.models import Base -from sqlalchemy import Column, DateTime, Date, func +from sqlalchemy import UUID, DateTime, Date, func from sqlalchemy.orm import Mapped, mapped_column @@ -7,6 +9,22 @@ class GspcInvite(Base): __tablename__ = "gspc_invite" id: Mapped[int] = mapped_column(primary_key=True) + gspc_invite_id: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), unique=True, nullable=True) email: Mapped[str] = mapped_column(unique=False) - created_date = Column(DateTime(timezone=True), default=func.now(), nullable=False) - certification_expiration_date = Column(Date(), nullable=False) + created_date: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=func.now(), nullable=False) + certification_expiration_date: Mapped[date] = mapped_column(Date(), nullable=False) + second_invite_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + final_invite_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + completed_date: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + + def to_dict(self): + return { + "id": self.id, + "gspc_invite_id": str(self.gspc_invite_id) if self.gspc_invite_id else None, + "email": self.email, + "created_date": self.created_date.isoformat() if self.created_date else None, + "certification_expiration_date": self.certification_expiration_date.isoformat() if self.certification_expiration_date else None, + "second_invite_date": self.second_invite_date.isoformat() if self.second_invite_date else None, + "final_invite_date": self.final_invite_date.isoformat() if self.final_invite_date else None, + "completed_date": self.completed_date.isoformat() if self.completed_date else None + } diff --git a/training/repositories/base.py b/training/repositories/base.py index 33753f10..61d8a50f 100644 --- a/training/repositories/base.py +++ b/training/repositories/base.py @@ -1,4 +1,4 @@ -from typing import Any, Generic, Type, TypeVar +from typing import Any, Generic, List, Type, TypeVar from sqlalchemy.orm import Session from training import models @@ -18,6 +18,16 @@ def save(self, item: T) -> T: self._session.refresh(item) return item + def bulk_save(self, items: List[T]) -> List[T]: + """Bulk saves multiple items efficiently.""" + self._session.add_all(items) + self._session.commit() + + for item in items: + self._session.refresh(item) + + return items + def find_by_id(self, id: int) -> T | None: return self._session.query(self._model).filter_by(id=id).first() diff --git a/training/repositories/gspc_invite.py b/training/repositories/gspc_invite.py index 43a46ff2..62e0c4a0 100644 --- a/training/repositories/gspc_invite.py +++ b/training/repositories/gspc_invite.py @@ -1,7 +1,12 @@ +from itertools import islice +import logging +from typing import Iterator, List +import uuid from sqlalchemy.orm import Session from training import models -from datetime import datetime +from datetime import datetime, date from .base import BaseRepository +import time class GspcInviteRepository(BaseRepository[models.GspcInvite]): @@ -9,8 +14,90 @@ class GspcInviteRepository(BaseRepository[models.GspcInvite]): def __init__(self, session: Session): super().__init__(session, models.GspcInvite) - def create(self, email: str, certification_expiration_date: datetime) -> models.GspcInvite: + def create(self, email: str, certification_expiration_date: date) -> models.GspcInvite: return self.save(models.GspcInvite( email=email, - certification_expiration_date=certification_expiration_date + certification_expiration_date=certification_expiration_date, + gspc_invite_id=uuid.uuid4() )) + + def bulk_create(self, emails: list[str], certification_expiration_date: date) -> list[models.GspcInvite]: + """Bulk insert GspcInvite records for multiple emails.""" + logging.info(f"Starting gspc bulk create, number of invites:{len(emails)}") + invites = [ + models.GspcInvite( + email=email, + certification_expiration_date=certification_expiration_date, + gspc_invite_id=uuid.uuid4() + ) + for email in emails + ] + + try: + # Insert 50 at a time + for batch in GspcInviteRepository.batch_iterator(invites, 50): + self.bulk_save(batch) + time.sleep(1) + + return invites + + except Exception as e: + raise Exception(f"Batch insert failed: {str(e)}") from e + + def batch_iterator(items: List, batch_size: int) -> Iterator: + """Create an iterator that yields batches of the specified size.""" + iterator = iter(items) + batch = list(islice(iterator, batch_size)) + while batch: + yield batch + batch = list(islice(iterator, batch_size)) + + def get_by_gspc_invite_id(self, gspc_invite_id: uuid.UUID) -> models.GspcInvite: + """ + retrieves a GspcInvite by its gspc_invite_id + param gspc_invite_id: (uuid.UUID): The UUID tied to each invite + returns:models.GspcInvite: The matching invite record + """ + result = self._session.query(models.GspcInvite).filter( + models.GspcInvite.gspc_invite_id == gspc_invite_id + ).first() + + if result is None: + raise ValueError("No invite found with the given gspc_invite_id") + return result + + def set_second_invite_date(self, id: int) -> None: + """ + Sets second_invite_date to now + :param id: gspc_invite ID to update + :return: None + """ + gspc_invite = self.find_by_id(id) + if gspc_invite is None: + raise ValueError("invalid gspc invite id") + gspc_invite.second_invite_date = datetime.now() + self._session.commit() + + def set_final_invite_date(self, id: int) -> None: + """ + Sets final_invite_date to now + :param id: gspc_invite ID to update + :return: None + """ + gspc_invite = self.find_by_id(id) + if gspc_invite is None: + raise ValueError("invalid gspc invite id") + gspc_invite.final_invite_date = datetime.now() + self._session.commit() + + def set_completion_date(self, id: int) -> None: + """ + Sets completed_date to now + :param id: gspc_invite ID to update + :return: None + """ + gspc_invite = self.find_by_id(id) + if gspc_invite is None: + raise ValueError("invalid gspc invite id") + gspc_invite.completed_date = datetime.now() + self._session.commit() diff --git a/training/schemas/gspc_certificate.py b/training/schemas/gspc_certificate.py index 3a99c66a..84ea242d 100644 --- a/training/schemas/gspc_certificate.py +++ b/training/schemas/gspc_certificate.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, date from pydantic import ConfigDict, BaseModel @@ -6,6 +6,6 @@ class GspcCertificate(BaseModel): user_id: int user_name: str agency: str - certification_expiration_date: str + certification_expiration_date: date completion_date: datetime model_config = ConfigDict(from_attributes=True) diff --git a/training/schemas/gspc_completion.py b/training/schemas/gspc_completion.py index 10ab39a4..ef80bfe6 100644 --- a/training/schemas/gspc_completion.py +++ b/training/schemas/gspc_completion.py @@ -1,9 +1,13 @@ + +from datetime import date from typing import Any +import uuid from pydantic import BaseModel class GspcCompletion(BaseModel): user_id: int passed: bool - certification_expiration_date: str + gspc_invite_id: uuid.UUID + certification_expiration_date: date responses: dict[str, Any] diff --git a/training/schemas/gspc_submission.py b/training/schemas/gspc_submission.py index db9c4663..77b4898c 100644 --- a/training/schemas/gspc_submission.py +++ b/training/schemas/gspc_submission.py @@ -1,3 +1,4 @@ +import uuid from pydantic import BaseModel, ConfigDict @@ -15,6 +16,6 @@ class GspcSubmissionQuestions(BaseModel): class GspcSubmission(BaseModel): - expiration_date: str + gspc_invite_id: uuid.UUID responses: GspcSubmissionQuestions model_config = ConfigDict(from_attributes=True) diff --git a/training/services/gspc.py b/training/services/gspc.py index c7d4baab..f587071f 100644 --- a/training/services/gspc.py +++ b/training/services/gspc.py @@ -1,5 +1,5 @@ import logging -from training.repositories import GspcCompletionRepository, UserRepository +from training.repositories import GspcCompletionRepository, UserRepository, GspcInviteRepository from training.schemas import GspcSubmission, GspcResult, GspcCompletion from sqlalchemy.orm import Session from training.services import Certificate @@ -28,6 +28,7 @@ class GspcService(): def __init__(self, db: Session): self.gspc_completion_repo = GspcCompletionRepository(db) + self.gspc_invite_repo = GspcInviteRepository(db) self.user_repo = UserRepository(db) self.certificate_service = Certificate() @@ -39,14 +40,17 @@ def grade(self, user_id: int, submission: GspcSubmission) -> GspcResult: :return: GspcResult model which includes the final result """ + gspc_invite = self.gspc_invite_repo.get_by_gspc_invite_id(submission.gspc_invite_id) + passed = all(question.correct for question in submission.responses.responses) responses_dict = submission.responses.model_dump() result = self.gspc_completion_repo.create(GspcCompletion( user_id=user_id, passed=passed, - certification_expiration_date=submission.expiration_date, - responses=responses_dict + certification_expiration_date=gspc_invite.certification_expiration_date, + responses=responses_dict, + gspc_invite_id=gspc_invite.gspc_invite_id )) if (passed): @@ -61,6 +65,8 @@ def grade(self, user_id: int, submission: GspcSubmission) -> GspcResult: self.email_certificate(user.name, user.email, pdf_bytes) logging.info(f"Sent confirmation email to {user.email} for passing training quiz") + + self.gspc_invite_repo.set_completion_date(gspc_invite.id) except Exception as e: logging.error("Error sending quiz confirmation mail", e) raise