Skip to content

Commit

Permalink
ruff fixes everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Apr 14, 2024
1 parent 3eca242 commit ef02179
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 70 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ select = [


[tool.ruff.lint.per-file-ignores]

# D415: API endpoint docstrings shouldn't end with a period
"server/controllers/*" = ["D415"]
# not migrated yet
"server/models/*" = ["E", "W", "F", "UP", "B", "SIM", "I", "N", "D"]
"server/nlp/*" = ["E", "W", "F", "UP", "B", "SIM", "I", "N", "D"]

[tool.ruff.lint.pydocstyle]
convention = "google"
23 changes: 23 additions & 0 deletions scripts/ci_cd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

#!/bin/bash

# Usage: bash scripts/ci_cd.sh
# Description: Run all ci/cd tests locally.
# These are the same tests that run on GitHub when you push code.

cd "/workspaces/ballot"

echo "Running ci/cd checks..."
echo "(1) pre-commit checks..."

pre-commit run --all-files

echo "(2) Run Pyright..."

npx pyright

echo "(3) Run Pytests..."

pytest -v -s

echo "Tests complete."
35 changes: 27 additions & 8 deletions server/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
"""Initialize the Flask app."""

import numpy
from apiflask import APIFlask
from flask import redirect, render_template
from flask_cors import CORS
from flask_sqlalchemy import SQLAlchemy
from flask import redirect, render_template

# https://stackoverflow.com/questions/50626058/psycopg2-cant-adapt-type-numpy-int64
import numpy
from psycopg2.extensions import register_adapter, AsIs
from psycopg2.extensions import AsIs, register_adapter


def addapt_numpy_float64(numpy_float64):
"""Adapt numpy.float64 to SQL syntax.
See here:
https://stackoverflow.com/questions/50626058/psycopg2-cant-adapt-type-numpy-int64
"""
return AsIs(numpy_float64)


def addapt_numpy_int64(numpy_int64):
"""Adapt numpy.int64 to SQL syntax.
See here:
https://stackoverflow.com/questions/50626058/psycopg2-cant-adapt-type-numpy-int64
"""
return AsIs(numpy_int64)


Expand All @@ -31,14 +41,16 @@ def addapt_numpy_int64(numpy_int64):
# # initialize with some default documents to create the index
# document = Document(
# "what is hackmit?",
# "HackMIT is a weekend-long event where thousands of students from around the world come together to work on cool new software and/or hardware projects.",
# "HackMIT is a weekend-long event where thousands of students from around the
# world come together to work on cool new software and/or hardware projects.",
# "https://hackmit.org",
# "what is hackmit?",
# )
# db.session.add(document)
# document = Document(
# "what is blueprint?",
# "Blueprint is a weekend-long learnathon and hackathon for high school students hosted at MIT",
# "Blueprint is a weekend-long learnathon and hackathon for high school
# students hosted at MIT",
# "https://blueprint.hackmit.org",
# "what is blueprint?",
# )
Expand All @@ -47,6 +59,7 @@ def addapt_numpy_int64(numpy_int64):


def create_app():
"""Create the Flask app."""
app = APIFlask(
__name__,
docs_path="/api/docs",
Expand All @@ -62,8 +75,14 @@ def create_app():

with app.app_context():
db.init_app(app)

allowed_domains = app.config.get("ALLOWED_DOMAINS")
assert type(allowed_domains) is list[str]

cors.init_app(
app, origins=app.config.get("ALLOWED_DOMAINS"), supports_credentials=True
app,
origins=allowed_domains,
supports_credentials=True,
)

from server.controllers import api
Expand Down
17 changes: 11 additions & 6 deletions server/cli.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
"""Flask CLI commands."""

import datetime

from flask import Blueprint

from server import db
from server.models.email import Email
from server.models.thread import Thread
from server.models.response import Response
from server.nlp.responses import generate_response
import datetime
from server.controllers.emails import (
thread_emails_to_openai_messages,
document_data,
increment_response_count,
thread_emails_to_openai_messages,
)
from server.models.email import Email
from server.models.response import Response
from server.models.thread import Thread
from server.nlp.responses import generate_response

seed = Blueprint("seed", __name__)


@seed.cli.command()
def email():
"""Seed the database with a test email."""
subject = "Test Email Subject"
body = "Hello! What is blueprint?"
body = "Dear Blueprint Team,\n\n" + body + "\n\nBest regards,\nAndrew\n\n"
Expand Down
7 changes: 5 additions & 2 deletions server/controllers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Provides controllers module."""

from flask import Blueprint
from server.controllers.emails import emails

from server.controllers.admin import admin
from server.controllers.faq import faq
from server.controllers.auth import auth
from server.controllers.emails import emails
from server.controllers.faq import faq

api = Blueprint("api", __name__, url_prefix="/api")
api.register_blueprint(emails)
Expand Down
33 changes: 25 additions & 8 deletions server/controllers/admin.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from server import db
from flask import request
from apiflask import APIBlueprint
from server.models.document import Document
from server.nlp.embeddings import embed_corpus
from ast import literal_eval
"""The admin controller handles admin-related routes."""

import json
from ast import literal_eval

import pandas as pd
from apiflask import APIBlueprint
from flask import request

from server import db
from server.models.document import Document
from server.nlp.embeddings import embed_corpus

admin = APIBlueprint("admin", __name__, url_prefix="/admin", tag="Admin")


@admin.route("/upload_document", methods=["POST"])
def upload_text():
"""POST /admin/upload_document"""
data = request.form
document = Document(
data["question"], data["content"], data["source"], data["label"]
Expand All @@ -24,8 +28,11 @@ def upload_text():

@admin.route("/delete_document", methods=["POST"])
def delete_text():
"""POST /admin/delete_document"""
data = request.form
document = Document.query.get(data["id"])
if document is None:
return {"error": "Document not found"}, 404
if document.response_count > 0:
document.to_delete = True
db.session.commit()
Expand All @@ -38,8 +45,11 @@ def delete_text():

@admin.route("/edit_document", methods=["POST"])
def update_text():
"""POST /admin/edit_document"""
data = request.form
document = Document.query.get(data["id"])
if document is None:
return {"error": "Document not found"}, 404
document.question = data["question"]
document.content = data["content"]
document.source = data["source"]
Expand All @@ -50,12 +60,14 @@ def update_text():

@admin.route("/get_documents", methods=["GET"])
def get_all():
"""GET /admin/get_documents"""
documents = Document.query.order_by(Document.id.desc()).all()
return [document.map() for document in documents]


@admin.route("/update_embeddings", methods=["GET"])
def update_embeddings():
"""GET /admin/update_embeddings"""
documents = Document.query.order_by(Document.id.desc()).all()

docs = [document.map() for document in documents if not document.to_delete]
Expand All @@ -74,15 +86,16 @@ def update_embeddings():

@admin.route("/import_json", methods=["POST"])
def upload_json():
"""POST /admin/import_json"""
try:
file = request.data
json_data = literal_eval(file.decode("utf8"))
for doc in json_data:
document = Document(
doc["question"] if "question" in doc else "",
doc.get("question", ""),
doc["content"],
doc["source"],
doc["label"] if "label" in doc else "",
doc.get("label", ""),
)
db.session.add(document)
db.session.commit()
Expand All @@ -93,6 +106,7 @@ def upload_json():

@admin.route("/export_json", methods=["GET"])
def export_json():
"""GET /admin/export_json"""
documents = Document.query.order_by(Document.id.desc()).all()
return json.dumps(
[
Expand All @@ -111,6 +125,7 @@ def export_json():

@admin.route("/import_csv", methods=["POST"])
def import_csv():
"""POST /admin/import_csv"""
try:
file = request.files["file"]
df = pd.read_csv(file.stream)
Expand All @@ -131,6 +146,7 @@ def import_csv():

@admin.route("/export_csv", methods=["GET"])
def export_csv():
"""GET /admin/export_csv"""
documents = Document.query.order_by(Document.id.desc()).all()
df = pd.DataFrame(
[
Expand All @@ -150,6 +166,7 @@ def export_csv():

@admin.route("/clear_documents", methods=["POST"])
def clear_documents():
"""POST /admin/clear_documents"""
try:
documents = Document.query.order_by(Document.id.desc()).all()
for document in documents:
Expand Down
32 changes: 20 additions & 12 deletions server/controllers/auth.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from flask import current_app as app, request, session, redirect, url_for
"""The auth controller handles authentication-related routes."""

from apiflask import APIBlueprint, abort
from authlib.integrations.flask_client import OAuth
from flask import current_app as app
from flask import redirect, request, session, url_for

auth = APIBlueprint("auth", __name__, url_prefix="/auth", tag="Auth")

app.secret_key = app.config["SESSION_SECRET"]

oauth = OAuth(app)
oauth: OAuth = OAuth(app)
google = oauth.register(
name="google",
client_id=app.config["AUTH_CLIENT_ID"],
Expand All @@ -23,15 +26,14 @@


def auth_required_decorator(roles):
"""
middleware for protected routes
"""
"""Middleware for protected routes."""

def auth_required(func):
def wrapper(*args, **kwargs):
if not dict(session).get("user", 0):
return abort(401)
elif dict(session).get("user").get("role") not in roles:
if (
not dict(session).get("user", 0)
or dict(session).get("user").get("role") not in roles # type: ignore
):
return abort(401)
return func(*args, **kwargs)

Expand All @@ -46,6 +48,7 @@ def wrapper(*args, **kwargs):
@auth.route("/whoami")
def whoami():
"""GET /whoami
Returns user if they are logged in, otherwise returns nothing.
"""
if dict(session).get("user", 0):
Expand All @@ -56,22 +59,25 @@ def whoami():
@auth.route("/login")
def login():
"""GET /login
launches google authentication.
"""
google = oauth.create_client("google")
scheme = "https" if app.config["ENV"] == "production" else "http"
redirect_uri = url_for("api.auth.authorize", _external=True, _scheme=scheme)
return google.authorize_redirect(redirect_uri)
return google.authorize_redirect(redirect_uri) # type: ignore


@auth.route("/authorize")
def authorize():
"""GET /authorize
callback function after google authentication. verifies user token, then returns user data if it is in the database.
callback function after google authentication. verifies user token, then returns
user data if it is in the database.
"""
google = oauth.create_client("google")
token = google.authorize_access_token()
user_info = oauth.google.userinfo(token=token)
token = google.authorize_access_token() # type: ignore
user_info = oauth.google.userinfo(token=token) # type: ignore
for admin in app.config["AUTH_ADMINS"]:
if admin["email"] == user_info["email"]:
session["user"] = {"role": "Admin"}
Expand All @@ -83,6 +89,7 @@ def authorize():
@auth.doc(tags=["Auth"])
def login_admin():
"""POST /login_admin
log in with admin credentials
"""
data = request.get_json()
Expand All @@ -100,6 +107,7 @@ def login_admin():
@auth.post("/logout")
def logout():
"""POST /logout
clears current user session
"""
session.clear()
Expand Down
Loading

0 comments on commit ef02179

Please sign in to comment.