Skip to content

Commit

Permalink
sqlalchemy 1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
amandine-sahl committed Jun 6, 2024
1 parent 5b66457 commit 15f08ae
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 65 deletions.
26 changes: 12 additions & 14 deletions backend/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,43 @@
import json

from flask import current_app
from flask_sqlalchemy import BaseQuery
from sqlalchemy import func, or_
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import aliased

from .env import db


class GTEventsQuery(BaseQuery):
def filter_properties(self, filters):
class GTEventsQuery:
def filter_properties(self, query, filters):
if filters.get("search_name", None):
search_name = filters.get("search_name", None)
self = self.filter(
query = query.filter(
func.unaccent(GTEvents.name).ilike(func.unaccent(f"%{search_name}%"))
)
filters.pop("search_name")

if "begin_date" in filters:
self = self.filter(GTEvents.begin_date >= filters.pop("begin_date"))
query = query.filter(GTEvents.begin_date >= filters.pop("begin_date"))

if "end_date" in filters:
# set the end_date at 23h59 because a hour can be set in timestamp
end_date = datetime.datetime.strptime(
filters.pop("end_date")[:10], "%Y-%m-%d"
)
end_date = end_date.replace(hour=23, minute=59, second=59)
self = self.filter(GTEvents.end_date <= end_date)
query = query.filter(GTEvents.end_date <= end_date)

if "bilan.annulation" in filters:
canceled = json.loads(filters.pop("bilan.annulation"))
tbilan = aliased(getattr(GTEvents, "bilan"))
self = self.outerjoin(tbilan)
query = query.outerjoin(tbilan)
if canceled:
self = self.filter(
query = query.filter(
tbilan.annulation == True,
)
else:
self = self.filter(
query = query.filter(
or_(tbilan.annulation == None, tbilan.annulation == False)
)

Expand All @@ -48,21 +47,20 @@ def filter_properties(self, filters):
if hasattr(GTEvents, param) and filters.get(param):
# Split multi choice
if len(filters.get(param).split(",")) > 1:
self = self.filter(
query = query.filter(
getattr(GTEvents, param).in_(filters.get(param).split(","))
)
else:
self = self.filter(getattr(GTEvents, param) == filters.get(param))
query = query.filter(getattr(GTEvents, param) == filters.get(param))

# Filter not deleted
self = self.filter(GTEvents.deleted != True)
return self
query = query.filter(GTEvents.deleted != True)
return query


class GTEvents(db.Model):
__tablename__ = "tourism_touristicevent"
__table_args__ = {"schema": "public"}
query_class = GTEventsQuery

id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.Unicode, nullable=False)
Expand Down
13 changes: 7 additions & 6 deletions backend/core/repository.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from datetime import datetime

from sqlalchemy import func
from sqlalchemy import func, select

from core.models import db, GTEvents, TAnimationsBilans


def query_stats_bilan(params):
query = GTEvents.query.filter(GTEvents.deleted != True)
query = select(GTEvents).filter(GTEvents.deleted != True)
if "year" in params:
query = query.filter(
func.date_part("year", GTEvents.begin_date) == params["year"]
)
nb_events = query.count()
events = query.all()
nb_events = db.session.scalar(select(func.count()).select_from(query.subquery()))
events = db.session.scalars(query).unique().all()

# events with capacity
events_capacity = [e for e in events if e.capacity and e.capacity > 0]
Expand Down Expand Up @@ -41,12 +41,13 @@ def query_stats_bilan(params):
taux_remplissage_passe = (
round(taux_remplissage_passe, 3) if taux_remplissage_passe else 0
)
query = db.session.query(func.count(GTEvents.id)).filter(GTEvents.cancelled == True)

query = select(func.count(GTEvents.id)).filter(GTEvents.cancelled == True)
if "year" in params:
query = query.filter(
func.date_part("year", GTEvents.begin_date) == params["year"]
)
nb_annulation = query.scalar()
nb_annulation = db.session.execute(query).scalar_one()

return {
"nb_animations": nb_events,
Expand Down
93 changes: 70 additions & 23 deletions backend/core/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from email_validator import validate_email, EmailNotValidError, EmailSyntaxError
from flask import jsonify, request, Blueprint, render_template, session, current_app

from core.models import db, GTEvents, TReservations, VExportBilan, TTokens, TEventInfo
from sqlalchemy import select

from core.models import (
db,
GTEvents,
TReservations,
VExportBilan,
TTokens,
TEventInfo,
GTEventsQuery,
)
from core.repository import query_stats_bilan
from core.schemas import (
GTEventsSchema,
Expand Down Expand Up @@ -166,15 +176,16 @@ def get_events():
except KeyError:
sort_order = "desc"

events = GTEvents.query.filter_properties(query_params)
query = select(GTEvents)
query = GTEventsQuery().filter_properties(query, query_params)
if hasattr(GTEvents, sort_col):
model_sort_col = getattr(GTEvents, sort_col)
else:
model_sort_col = GTEvents.begin_date

events = events.order_by(getattr(model_sort_col, sort_order)(), GTEvents.id.asc())
query = query.order_by(getattr(model_sort_col, sort_order)(), GTEvents.id.asc())

events = events.paginate(page=page, per_page=limit)
events = db.paginate(query, page=page, per_page=limit, error_out=False)

results = GTEventsSchema(many=True, only=fields).dump(events.items)

Expand All @@ -193,7 +204,7 @@ def get_events():
@app_routes.route("/events/<int:event_id>")
def get_one_event(event_id):
"""Retourne un événement de Geotrek par son identifiant."""
event = GTEvents.query.get(event_id)
event = db.session.get(GTEvents, event_id)
if not event:
return jsonify({"error": f"Event #{event_id} not found"}), 404
return GTEventsSchema().dumps(event)
Expand Down Expand Up @@ -226,12 +237,12 @@ def get_reservations():
email = session["user"]
is_admin = is_user_admin()

query = db.session.query(TReservations)
query = select(TReservations)
if event_id:
query = query.filter_by(id_event=event_id)
if not is_admin:
query = query.filter_by(email=email)
query = query.paginate(page=page, per_page=limit)
query = db.paginate(query, page=page, per_page=limit)
results = TReservationsSchema(many=True).dump(query.items)

return jsonify(
Expand All @@ -257,7 +268,8 @@ class BodyParamValidationError(Exception):
def _post_reservations_by_user(post_data):
reservation = TReservationsSchema().load(post_data, session=db.session)

event = GTEvents.query.get(reservation.id_event)
event = db.session.get(GTEvents, reservation.id_event)

if not event:
raise BodyParamValidationError(
f"Event with ID {reservation.id_event} not found"
Expand Down Expand Up @@ -289,7 +301,8 @@ def _post_reservations_by_user(post_data):
def _post_reservations_by_admin(post_data):
reservation = TReservationsCreateByAdminSchema().load(post_data, session=db.session)

event = GTEvents.query.get(reservation.id_event)
event = db.session.get(GTEvents, reservation.id_event)

if not event:
raise BodyParamValidationError(
f"Event with ID {reservation.id_event} not found"
Expand Down Expand Up @@ -363,7 +376,13 @@ def confirm_reservation():
reservation exists it is confirmed and a confirmation mail is sent."""
token = request.get_json()["resa_token"]

resa = TReservations.query.filter_by(token=token).first()
resa = (
db.session.execute(
select(TReservations).where(TReservations.token == token).limit(1)
)
.scalars()
.first()
)
if not resa:
return jsonify({"error": "The token is invalid"}), 404

Expand Down Expand Up @@ -394,7 +413,7 @@ def confirm_reservation():
@login_admin_required
def update_reservation(reservation_id):
# Check : la réservation existe
reservation = TReservations.query.get(reservation_id)
reservation = db.session.get(TReservations, reservation_id)
if not reservation:
return jsonify({"error": f"Reservation #{reservation_id} not found"}), 404

Expand All @@ -420,7 +439,7 @@ def cancel_reservation(reservation_id):
is_admin = is_user_admin()

# Check : la réservation existe
reservation = TReservations.query.get(reservation_id)
reservation = db.session.get(TReservations, reservation_id)
if not reservation:
return jsonify({"error": f"Reservation #{reservation_id} not found"}), 404

Expand Down Expand Up @@ -517,11 +536,17 @@ def login():

login_token_lifespan = current_app.config["LOGIN_TOKEN_LIFETIME"]
limit = datetime.now() - login_token_lifespan

token = (
TTokens.query.filter_by(used=False)
.filter_by(token=login_token)
.filter(TTokens.created_at > limit)
db.session.execute(
select(TTokens)
.where(
TTokens.used == False,
TTokens.token == login_token,
TTokens.created_at > limit,
)
.limit(1)
)
.scalars()
.first()
)
if not token:
Expand Down Expand Up @@ -573,7 +598,15 @@ def logout():
@app_routes.route("/export_reservation/<id>", methods=["GET"])
@login_admin_required
def export_reservation(id):
resa = TReservations.query.filter_by(id_event=id).filter_by(cancelled=False).all()
resa = (
db.session.scalars(
select(TReservations).where(
TReservations.id_event == id, TReservations.cancelled == False
)
)
.unique()
.all()
)
export_fields = [
"id_reservation",
"id_event",
Expand Down Expand Up @@ -653,7 +686,7 @@ def get_stats_global():
@app_routes.route("/export/events")
@login_admin_required
def get_export_events():
events = VExportBilan.query.all()
events = db.session.scalars(select(VExportBilan)).all()
results = VExportBilanSchema(many=True).dump(events)
fields = VExportBilan.__table__.columns.keys()
return to_csv_resp("export_bilan", results, fields, ";")
Expand All @@ -666,10 +699,16 @@ def get_event_info(event_id):
S'il n'y a pas d'infos enregistrées un TEventInfo vide est créé et enregistré.
"""
event_info = TEventInfo.query.filter_by(id_event=event_id).first()
event_info = (
db.session.execute(
select(TEventInfo).where(TEventInfo.id_event == event_id).limit(1)
)
.scalars()
.first()
)

if not event_info:
event = GTEvents.query.get(event_id)
event = db.session.get(GTEvents, event_id)
if not event:
return jsonify({"error": f"Event #{event_id} not found"}), 404
event_info = TEventInfo(id_event=event_id)
Expand All @@ -683,12 +722,19 @@ def get_event_info(event_id):
@app_routes.route("/events/<int:event_id>/info", methods=["PUT"])
@login_admin_required
def set_event_info(event_id):
# TODO ADD TEST
"""Met à jour les infos liées à l'événement indiqué."""
post_data = request.get_json()
event_info = TEventInfo.query.filter_by(id_event=event_id).first()
event_info = (
db.session.execute(
select(TEventInfo).where(TEventInfo.id_event == event_id).limit(1)
)
.scalars()
.first()
)

if not event_info:
event = GTEvents.query.get(event_id)
event = db.session.get(GTEvents, event_id)
if not event:
return jsonify({"error": f"Event #{event_id} not found"}), 400
event_info = TEventInfo(id_event=event_id)
Expand All @@ -706,7 +752,8 @@ def set_event_info(event_id):
@app_routes.route("/events/<int:event_id>/cancel-reservations", methods=["POST"])
@login_admin_required
def send_event_cancellation_emails(event_id):
event = GTEvents.query.get(event_id)
# TODO ADD TEST
event = db.session.get(GTEvents, event_id)
if not event:
return jsonify({"error": f"Event #{event_id} not found"}), 404

Expand Down
11 changes: 8 additions & 3 deletions backend/test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from datetime import date
from flask import url_for
from sqlalchemy import select
from sqlalchemy.sql import text
from app import create_app

Expand Down Expand Up @@ -90,9 +91,13 @@ def get_token(client):

# Get token manually
token = (
TTokens.query.filter_by(used=False)
.filter_by(email=ADMIN_EMAIL)
.order_by(TTokens.created_at.desc())
db.session.execute(
select(TTokens)
.where(TTokens.used == False, TTokens.email == ADMIN_EMAIL)
.order_by(TTokens.created_at.desc())
.limit(1)
)
.scalars()
.first()
)

Expand Down
Loading

0 comments on commit 15f08ae

Please sign in to comment.