diff --git a/macrostrat_db_insertion/database.py b/macrostrat_db_insertion/database.py new file mode 100644 index 0000000..9c57dab --- /dev/null +++ b/macrostrat_db_insertion/database.py @@ -0,0 +1,42 @@ + +from sqlalchemy import create_engine, MetaData, Engine +from sqlalchemy.orm import sessionmaker, declarative_base, Session + +engine: Engine | None = None +base: declarative_base = None +session: Session | None = None + + +def get_engine() -> Engine: + return engine + + +def get_base() -> declarative_base: + return base + + +def connect_engine(uri: str, schema: str): + global engine + global session + global base + + engine = create_engine(uri) + session = session + + base = declarative_base() + base.metadata.reflect(get_engine()) + base.metadata.reflect(get_engine(), schema=schema, views=True) + + +def dispose_engine(): + global engine + engine.dispose() + + +def get_session_maker() -> sessionmaker: + return sessionmaker(autocommit=False, autoflush=False, bind=get_engine()) + + +def get_session() -> Session: + with get_session_maker()() as s: + yield s diff --git a/macrostrat_db_insertion/environment.yml b/macrostrat_db_insertion/environment.yml index d564100..2161fbe 100644 --- a/macrostrat_db_insertion/environment.yml +++ b/macrostrat_db_insertion/environment.yml @@ -47,4 +47,6 @@ dependencies: - urllib3==2.2.1 - werkzeug==3.0.2 - zipp==3.18.1 + - fuzzysearch + - uvicorn prefix: /conda/envs/db_insert_env diff --git a/macrostrat_db_insertion/security-v1.py b/macrostrat_db_insertion/security-v1.py new file mode 100644 index 0000000..3481601 --- /dev/null +++ b/macrostrat_db_insertion/security-v1.py @@ -0,0 +1,114 @@ +import os +from datetime import datetime +from typing import Annotated, Optional + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2AuthorizationCodeBearer, +) +from fastapi.security.utils import get_authorization_scheme_param +from jose import JWTError, jwt +from pydantic import BaseModel +from sqlalchemy import select, update +from starlette.status import HTTP_401_UNAUTHORIZED + +from macrostrat_db_insertion.security.db import get_access_token +from macrostrat_db_insertion.security.model import TokenData + +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours +GROUP_TOKEN_LENGTH = 32 +GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent + + +class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer): + """Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header""" + + async def __call__(self, request: Request) -> Optional[str]: + authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={ + "WWW-Authenticate": "Bearer" + }, + ) + else: + return None # pragma: nocover + return param + + +oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie( + authorizationUrl='/security/login', + tokenUrl="/security/callback", + auto_error=False +) + +http_bearer = HTTPBearer(auto_error=False) + + +def get_groups_from_header_token( + header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None: + """Get the groups from the bearer token in the header""" + + if header_token is None: + return None + + token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT) + token_hash_string = token_hash.decode('utf-8') + + token = get_access_token(token=token_hash_string) + + if token is None: + return None + + return token.group + + +def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]): + """Get the current user from the JWT token in the cookies""" + + # If there wasn't a token include in the request + if token is None: + return None + + try: + payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']]) + sub: str = payload.get("sub") + groups = payload.get("groups", []) + token_data = TokenData(sub=sub, groups=groups) + except JWTError as e: + return None + + return token_data + + +def get_groups( + user_token_data: TokenData | None = Depends(get_user_token_from_cookie), + header_token: int | None = Depends(get_groups_from_header_token) +) -> list[int]: + """Get the groups from both the cookies and header""" + + groups = [] + if user_token_data is not None: + groups = user_token_data.groups + + if header_token is not None: + groups.append(header_token) + + return groups + + +async def has_access(groups: list[int] = Depends(get_groups)) -> bool: + """Check if the user has access to the group""" + + if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development': + return True + + return 1 in groups diff --git a/macrostrat_db_insertion/security/__init__.py b/macrostrat_db_insertion/security/__init__.py new file mode 100644 index 0000000..31f3b93 --- /dev/null +++ b/macrostrat_db_insertion/security/__init__.py @@ -0,0 +1,6 @@ +from macrostrat_db_insertion.security.main import ( + get_groups_from_header_token, + get_user_token_from_cookie, + get_groups, + has_access +) diff --git a/macrostrat_db_insertion/security/db.py b/macrostrat_db_insertion/security/db.py new file mode 100644 index 0000000..b64493b --- /dev/null +++ b/macrostrat_db_insertion/security/db.py @@ -0,0 +1,30 @@ +import datetime + +from sqlalchemy import select, update + +from macrostrat_db_insertion.database import get_session_maker, get_engine +from macrostrat_db_insertion.security.schema import Token + + +def get_access_token(token: str): + """The sole database call """ + + session_maker = get_session_maker() + with session_maker() as session: + + select_stmt = select(Token).where(Token.token == token) + + # Check that the token exists + result = (session.scalars(select_stmt)).first() + + # Check if it has expired + if result.expires_on < datetime.datetime.now(datetime.timezone.utc): + return None + + # Update the used_on column + if result is not None: + stmt = update(Token).where(Token.token == token).values(used_on=datetime.datetime.utcnow()) + session.execute(stmt) + session.commit() + + return (session.scalars(select_stmt)).first() diff --git a/macrostrat_db_insertion/security/main.py b/macrostrat_db_insertion/security/main.py new file mode 100644 index 0000000..2037b31 --- /dev/null +++ b/macrostrat_db_insertion/security/main.py @@ -0,0 +1,114 @@ +import os +from datetime import datetime +from typing import Annotated, Optional + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2AuthorizationCodeBearer, +) +from fastapi.security.utils import get_authorization_scheme_param +from jose import JWTError, jwt +from pydantic import BaseModel +from sqlalchemy import select, update +from starlette.status import HTTP_401_UNAUTHORIZED + +from macrostrat_db_insertion.security.db import get_access_token +from macrostrat_db_insertion.security.model import TokenData + +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours +GROUP_TOKEN_LENGTH = 32 +GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent + + +class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer): + """Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header""" + + async def __call__(self, request: Request) -> Optional[str]: + authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={ + "WWW-Authenticate": "Bearer" + }, + ) + else: + return None # pragma: nocover + return param + + +oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie( + authorizationUrl='/security/login', + tokenUrl="/security/callback", + auto_error=False +) + +http_bearer = HTTPBearer(auto_error=False) + + +def get_groups_from_header_token( + header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None: + """Get the groups from the bearer token in the header""" + + if header_token is None: + return None + + token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT) + token_hash_string = token_hash.decode('utf-8') + + token = get_access_token(token=token_hash_string) + + if token is None: + return None + + return token.group + + +def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]): + """Get the current user from the JWT token in the cookies""" + + # If there wasn't a token include in the request + if token is None: + return None + + try: + payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']]) + sub: str = payload.get("sub") + groups = payload.get("groups", []) + token_data = TokenData(sub=sub, groups=groups) + except JWTError as e: + return None + + return token_data + + +def get_groups( + user_token_data: TokenData | None = Depends(get_user_token_from_cookie), + header_token: int | None = Depends(get_groups_from_header_token) +) -> list[int]: + """Get the groups from both the cookies and header""" + + groups = [] + if user_token_data is not None: + groups = user_token_data.groups + + if header_token is not None: + groups.append(header_token) + + return groups + + +def has_access(groups: list[int] = Depends(get_groups)) -> bool: + """Check if the user has access to the group""" + + if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development': + return True + + return 1 in groups diff --git a/macrostrat_db_insertion/security/model.py b/macrostrat_db_insertion/security/model.py new file mode 100644 index 0000000..d610881 --- /dev/null +++ b/macrostrat_db_insertion/security/model.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class TokenData(BaseModel): + sub: str + groups: list[int] = [] + + +class User(BaseModel): + username: str + email: str | None = None + full_name: str | None = None + disabled: bool | None = None + + +class AccessToken(BaseModel): + group: int + token: str + + +class GroupTokenRequest(BaseModel): + expiration: int + group_id: int diff --git a/macrostrat_db_insertion/security/schema.py b/macrostrat_db_insertion/security/schema.py new file mode 100644 index 0000000..f1e015d --- /dev/null +++ b/macrostrat_db_insertion/security/schema.py @@ -0,0 +1,71 @@ +import enum +from typing import List +import datetime +from sqlalchemy import ForeignKey, func, DateTime, Enum, PrimaryKeyConstraint, UniqueConstraint +from sqlalchemy.dialects.postgresql import VARCHAR, TEXT, INTEGER, ARRAY, BOOLEAN, JSON, JSONB +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +class Base(DeclarativeBase): + pass + + +class GroupMembers(Base): + __tablename__ = "group_members" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + group_id: Mapped[int] = mapped_column(ForeignKey("macrostrat_auth.group.id")) + user_id: Mapped[int] = mapped_column(ForeignKey("macrostrat_auth.user.id")) + + +class Group(Base): + __tablename__ = "group" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(VARCHAR(255)) + users: Mapped[List["User"]] = relationship(secondary="macrostrat_auth.group_members", lazy="joined", + back_populates="groups") + + +class User(Base): + __tablename__ = "user" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + sub: Mapped[str] = mapped_column(VARCHAR(255)) + name: Mapped[str] = mapped_column(VARCHAR(255)) + email: Mapped[str] = mapped_column(VARCHAR(255)) + groups: Mapped[List[Group]] = relationship(secondary="macrostrat_auth.group_members", lazy="joined", + back_populates="users") + created_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class Token(Base): + __tablename__ = "token" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + token: Mapped[str] = mapped_column(VARCHAR(255), unique=True) + group: Mapped[Group] = mapped_column(ForeignKey("macrostrat_auth.group.id")) + used_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + expires_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True) + ) + created_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index 4eb8e3f..bfd169e 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -1,58 +1,61 @@ -from flask import Flask, jsonify, request -from flask_cors import CORS -from flask_sqlalchemy import SQLAlchemy -import sqlalchemy -from sqlalchemy import inspect + +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, HTTPException, Depends +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic_settings import BaseSettings from sqlalchemy.dialects.postgresql import insert as INSERT_STATEMENT -from sqlalchemy import select as SELECT_STATEMENT -from sqlalchemy.orm import declarative_base -from flask import Flask +from sqlalchemy import select as SELECT_STATEMENT, text +from sqlalchemy.orm import Session from fuzzysearch import find_near_matches -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import inspect, MetaData -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import declarative_base -import json import traceback import hashlib import requests -from datetime import datetime, timezone -import os - -def load_flask_app(schema_name): - # Create the app - app = Flask(__name__) - uri = os.environ['uri'] - app.config["SQLALCHEMY_DATABASE_URI"] = uri - print("Loading schema", schema_name, "from uri", uri) - - # Create the db - Base = declarative_base(metadata = sqlalchemy.MetaData(schema = schema_name)) - db = SQLAlchemy(model_class=Base) - db.init_app(app) - with app.app_context(): - db.metadata.reflect(bind = db.engine, schema = schema_name, views = True) - print("Finished loading schema", schema_name) - - return app, db - -# Connect to the database -SCHEMA_NAME = os.environ['macrostrat_xdd_schema_name'] -MAX_TRIES = 5 -app, db = load_flask_app(SCHEMA_NAME) -CORS(app) + +from macrostrat_db_insertion.database import connect_engine, dispose_engine, get_base, get_session +from macrostrat_db_insertion.security import has_access + + +class Settings(BaseSettings): + uri: str + SCHEMA: str + max_tries: int = 5 + + +settings = Settings() + + +@asynccontextmanager +async def setup_engine(a: FastAPI): + """Return database client instance.""" + connect_engine(settings.uri, settings.SCHEMA) + yield + dispose_engine() + +app = FastAPI( + lifespan=setup_engine +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) def get_complete_table_name(table_name): - return SCHEMA_NAME + "." + table_name + return settings.SCHEMA + "." + table_name def verify_key_presents(input, required_keys): for key in required_keys: if key not in input: return "Request missing field " + key - + return None -def get_model_metadata(request_data, additional_data): +def get_model_metadata(request_data, additional_data, session: Session): # Verify that we have the required metada verify_result = verify_key_presents(request_data, ["model_name", "model_version"]) if verify_result is not None: @@ -60,13 +63,13 @@ def get_model_metadata(request_data, additional_data): # Verify that the model already exists and gets its internal model id models_table_name = get_complete_table_name("model") - models_table = db.metadata.tables[models_table_name] + models_table = get_base().metadata.tables[models_table_name] model_name = request_data["model_name"] try: # Execute the select query models_select_statement = SELECT_STATEMENT(models_table) models_select_statement = models_select_statement.where(models_table.c.name == model_name) - models_result = db.session.execute(models_select_statement).all() + models_result = session.execute(models_select_statement).all() # Ensure we got a result and if so get the model id if len(models_result) > 0: @@ -75,7 +78,7 @@ def get_model_metadata(request_data, additional_data): except: error_msg = "Failed to get id for model " + model_name + " from table " + models_table_name + " due to error: " + traceback.format_exc() return False, error_msg - + # If not then insert the model and get its id if "internal_model_id" not in additional_data: try: @@ -84,8 +87,8 @@ def get_model_metadata(request_data, additional_data): } model_insert_statement = INSERT_STATEMENT(models_table).values(**model_insert_values) model_insert_statement = model_insert_statement.on_conflict_do_nothing(index_elements = list(model_insert_values.keys())) - result = db.session.execute(model_insert_statement) - db.session.commit() + result = session.execute(model_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -95,11 +98,11 @@ def get_model_metadata(request_data, additional_data): additional_data["internal_model_id"] = result_insert_keys[0] except: return False, "Failed to insert model " + model_name + " into table " + models_table_name + " due to error: " + traceback.format_exc() - + # Try to insert the model version into the the table model_version = str(request_data["model_version"]) versions_table_name = get_complete_table_name("model_version") - versions_table = db.metadata.tables[versions_table_name] + versions_table = get_base().metadata.tables[versions_table_name] try: # Try to insert the model version insert_request_values = { @@ -108,19 +111,19 @@ def get_model_metadata(request_data, additional_data): } version_insert_statement = INSERT_STATEMENT(versions_table).values(**insert_request_values) version_insert_statement = version_insert_statement.on_conflict_do_nothing(index_elements = list(insert_request_values.keys())) - db.session.execute(version_insert_statement) - db.session.commit() + session.execute(version_insert_statement) + session.commit() except: error_msg = "Failed to insert version " + str(model_version) + " for model " + model_name + " into table " + versions_table_name + " due to error: " + traceback.format_exc() return False, error_msg - + # Get the version id for this model and version try: # Execute the select query version_select_statement = SELECT_STATEMENT(versions_table) version_select_statement = version_select_statement.where(versions_table.c.model_id == additional_data["internal_model_id"]) version_select_statement = version_select_statement.where(versions_table.c.name == model_version) - version_result = db.session.execute(version_select_statement).all() + version_result = session.execute(version_select_statement).all() # Ensure we got a result if len(version_result) == 0: @@ -165,20 +168,20 @@ def find_link(bibjson): return None PUBLICATIONS_URL = "https://xdd.wisc.edu/api/articles" -def record_publication(source_text, request_additional_data): +def record_publication(source_text, request_additional_data, session: Session): # See if already have a result for this publication paper_id = source_text["paper_id"] publication_table_name = get_complete_table_name("publication") - publications_table = db.metadata.tables[publication_table_name] + publications_table = get_base().metadata.tables[publication_table_name] found_existing_publication = False try: publication_select_statement = SELECT_STATEMENT(publications_table) publication_select_statement = publication_select_statement.where(publications_table.c.paper_id == paper_id) - publication_result = db.session.execute(publication_select_statement).all() + publication_result = session.execute(publication_select_statement).all() found_existing_publication = len(publication_result) > 0 except: return False, "Failed to check for paper " + paper_id + " in table " + publication_table_name + " due to error: " + traceback.format_exc() - + if found_existing_publication: return True, None @@ -203,25 +206,25 @@ def record_publication(source_text, request_additional_data): url = find_link(citation_json) if url is not None: insert_request_values["url"] = url - + # Make the insert request publication_insert_request = INSERT_STATEMENT(publications_table).values(**insert_request_values) - result = db.session.execute(publication_insert_request) - db.session.commit() + result = session.execute(publication_insert_request) + session.commit() except: - return False, "Failed to insert publication for paper " + paper_id + " into table " + publication_table_name + " due to error: " + traceback.format_exc() - + return False, "Failed to insert publication for paper " + paper_id + " into table " + publication_table_name + " due to error: " + traceback.format_exc() + return True, None -def get_weaviate_text_id(source_text, request_additional_data): +def get_weaviate_text_id(source_text, request_additional_data, session: Session): # Verify that we have the required fields required_source_fields = ["preprocessor_id", "paper_id", "hashed_text", "weaviate_id", "paragraph_text"] source_verify_result = verify_key_presents(source_text, required_source_fields) if source_verify_result is not None: return False, source_verify_result - + # First record the publication - sucess, error_msg = record_publication(source_text, request_additional_data) + sucess, error_msg = record_publication(source_text, request_additional_data, session) if not sucess: return sucess, error_msg @@ -229,18 +232,18 @@ def get_weaviate_text_id(source_text, request_additional_data): curr_text_type = source_text["text_type"] paragraph_hash = source_text["hashed_text"] sources_table_name = get_complete_table_name("source_text") - sources_table = db.metadata.tables[sources_table_name] + sources_table = get_base().metadata.tables[sources_table_name] try: # Get the sources values sources_values = {} for key_name in required_source_fields: sources_values[key_name] = source_text[key_name] sources_values["source_text_type"] = curr_text_type - + sources_insert_statement = INSERT_STATEMENT(sources_table).values(**sources_values) sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "paragraph_text"]) - db.session.execute(sources_insert_statement) - db.session.commit() + session.execute(sources_insert_statement) + session.commit() except: return False, "Failed to insert paragraph with weaviate id " + source_text["weaviate_id"] + " into table " + sources_table_name + " due to error: " + traceback.format_exc() @@ -249,7 +252,7 @@ def get_weaviate_text_id(source_text, request_additional_data): source_id_select_statement = SELECT_STATEMENT(sources_table.c.id) source_id_select_statement = source_id_select_statement.where(sources_table.c.source_text_type == curr_text_type) source_id_select_statement = source_id_select_statement.where(sources_table.c.paragraph_text == source_text["paragraph_text"]) - source_id_result = db.session.execute(source_id_select_statement).all() + source_id_result = session.execute(source_id_select_statement).all() # Ensure we got a result if len(source_id_result) == 0: @@ -259,7 +262,7 @@ def get_weaviate_text_id(source_text, request_additional_data): except: return False, "Failed to find internal source id for weaviate paragraph having hash " + paragraph_hash + " due to error: " + traceback.format_exc() -def get_map_description_id(source_text, request_additional_data): +def get_map_description_id(source_text, request_additional_data, session: Session): # Verify that we have the required fields required_source_fields = ["paragraph_text", "legend_id"] source_verify_result = verify_key_presents(source_text, required_source_fields) @@ -271,7 +274,7 @@ def get_map_description_id(source_text, request_additional_data): text_hash = str(hashlib.sha256(paragraph_text.encode("ascii")).hexdigest()) curr_text_type = source_text["text_type"] sources_table_name = get_complete_table_name("source_text") - sources_table = db.metadata.tables[sources_table_name] + sources_table = get_base().metadata.tables[sources_table_name] legend_id = source_text["legend_id"] try: # Get the sources values @@ -281,11 +284,11 @@ def get_map_description_id(source_text, request_additional_data): "map_legend_id" : legend_id, "hashed_text" : text_hash } - + sources_insert_statement = INSERT_STATEMENT(sources_table).values(**sources_values) sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "paragraph_text"]) - db.session.execute(sources_insert_statement) - db.session.commit() + session.execute(sources_insert_statement) + session.commit() except: return False, "Failed to insert paragraph with legend id " + str(legend_id) + " into table " + sources_table_name + " due to error: " + traceback.format_exc() @@ -294,7 +297,7 @@ def get_map_description_id(source_text, request_additional_data): source_id_select_statement = SELECT_STATEMENT(sources_table.c.id) source_id_select_statement = source_id_select_statement.where(sources_table.c.source_text_type == curr_text_type) source_id_select_statement = source_id_select_statement.where(sources_table.c.paragraph_text == source_text["paragraph_text"]) - source_id_result = db.session.execute(source_id_select_statement).all() + source_id_result = session.execute(source_id_select_statement).all() # Ensure we got a result if len(source_id_result) == 0: @@ -308,7 +311,7 @@ def get_map_description_id(source_text, request_additional_data): "weaviate_text" : get_weaviate_text_id, "map_descriptions" : get_map_description_id } -def get_source_text_id(source_text, request_additional_data): +def get_source_text_id(source_text, request_additional_data, session: Session): # Ensure that we have the required metadata fields text_required_text = verify_key_presents(source_text, ["text_type", "paragraph_text"]) if text_required_text is not None: @@ -319,8 +322,8 @@ def get_source_text_id(source_text, request_additional_data): text_type = source_text["text_type"] if text_type not in METHOD_TO_PROCESS_TEXT: return False, "Server currently doesn't support text of type " + text_type - - return METHOD_TO_PROCESS_TEXT[text_type](source_text, request_additional_data) + + return METHOD_TO_PROCESS_TEXT[text_type](source_text, request_additional_data, session) def get_lith_id(lithology): try: @@ -373,15 +376,15 @@ def get_strat_id(strat_name): "strat_name" : ("strat_name_id", get_strat_id) } -def get_entity_type_id(entity_type): +def get_entity_type_id(entity_type, session: Session): entity_type_table_name = get_complete_table_name("entity_type") - entity_type_table = db.metadata.tables[entity_type_table_name] + entity_type_table = get_base().metadata.tables[entity_type_table_name] # First try to get the entity type try: entity_type_id_select_statement = SELECT_STATEMENT(entity_type_table) entity_type_id_select_statement = entity_type_id_select_statement.where(entity_type_table.c.name == entity_type) - entity_type_result = db.session.execute(entity_type_id_select_statement).all() + entity_type_result = session.execute(entity_type_id_select_statement).all() # Ensure we got a result if len(entity_type_result) > 0: @@ -397,8 +400,8 @@ def get_entity_type_id(entity_type): } entity_type_insert_statement = INSERT_STATEMENT(entity_type_table).values(**entity_type_insert_values) entity_type_insert_statement = entity_type_insert_statement.on_conflict_do_nothing(index_elements = list(entity_type_insert_values.keys())) - result = db.session.execute(entity_type_insert_statement) - db.session.commit() + result = session.execute(entity_type_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -409,12 +412,12 @@ def get_entity_type_id(entity_type): except: return False, "Failed to insert entity type " + entity_type + " into table " + entity_type_table_name + " due to error: " + traceback.format_exc() -def get_entity_id(entity_name, entity_type, request_additional_data, provided_start_idx = None, provided_end_idx = None): +def get_entity_id(entity_name, entity_type, request_additional_data, session: Session, provided_start_idx = None, provided_end_idx = None): # Record the entity type - success, entity_type_id = get_entity_type_id(entity_type) + success, entity_type_id = get_entity_type_id(entity_type, session) if not success: return success, entity_type_id - + # Determine the values to write to the entities table entity_insert_request_values = { "name" : entity_name, @@ -442,12 +445,12 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st for idx in range(1, len(matches)): if matches[idx].dist < matches[best_match_idx].dist: best_match_idx = idx - + # Record the idx start_idx, end_idx = matches[best_match_idx].start, matches[best_match_idx].end curr_max_l_dist *= 2 - + # Record the results entity_start_idx, entity_end_idx, str_match_type = start_idx, end_idx, "fuzzy" @@ -467,19 +470,19 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st # Else ensure we can record the value if id_val is not None: entity_insert_request_values[key_name] = id_val - + # Insert in the result into the table entity_table_name = get_complete_table_name("entity") - entity_table = db.metadata.tables[entity_table_name] + entity_table = get_base().metadata.tables[entity_table_name] try: # Execute the request entity_insert_request = INSERT_STATEMENT(entity_table).values(**entity_insert_request_values) entity_insert_request = entity_insert_request.on_conflict_do_nothing(index_elements = ["name", "model_run_id", "entity_type_id", "start_index", "end_index"]) - result = db.session.execute(entity_insert_request) - db.session.commit() + result = session.execute(entity_insert_request) + session.commit() except: return False, "Failed to entity " + entity_name + " of type " + entity_type + " into table " + entity_table_name + " due to error:" + traceback.format_exc() - + # Get the entity id for the inserted value try: entity_select_statement = SELECT_STATEMENT(entity_table.c.id) @@ -488,7 +491,7 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st entity_select_statement = entity_select_statement.where(entity_table.c.entity_type_id == entity_insert_request_values["entity_type_id"]) entity_select_statement = entity_select_statement.where(entity_table.c.start_index == entity_insert_request_values["start_index"]) entity_select_statement = entity_select_statement.where(entity_table.c.end_index == entity_insert_request_values["end_index"]) - entity_id_result = db.session.execute(entity_select_statement).all() + entity_id_result = session.execute(entity_select_statement).all() # Ensure we got a result if len(entity_id_result) == 0: @@ -497,7 +500,7 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st return True, first_row["id"] except: return False, "Failed to find internal entity id for entity " + entity_name + " due to error: " + traceback.format_exc() - + def extract_indicies(request_values, expected_prefix): provided_start_idx, provided_end_idx = None, None start_search_term, end_search_term = expected_prefix + "start_idx", expected_prefix + "end_idx" @@ -519,32 +522,32 @@ def extract_indicies(request_values, expected_prefix): return False, "Failed to parse " + request_values[end_search_term] + " as an integer due an error " + traceback.format_exc() else: return False, f'Provided {start_search_term} but not {end_search_term} for entity ' + request_values["entity"] - + return True, (provided_start_idx, provided_end_idx) -def record_single_entity(entity, request_additional_data): +def record_single_entity(entity, request_additional_data, session: Session): # Ensure that we have the required metadata fields entity_required_fields = verify_key_presents(entity, ["entity", "entity_type"]) if entity_required_fields is not None: return False, entity_required_fields - + # See if the range is provided success, indicies_results = extract_indicies(entity, "") if not success: return success, indicies_results provided_start_idx, provided_end_idx = indicies_results - return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, provided_start_idx, provided_end_idx) + return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, session, provided_start_idx, provided_end_idx) -def get_relationship_type_id(relationship_type): +def get_relationship_type_id(relationship_type, session: Session): relationship_type_table_name = get_complete_table_name("relationship_type") - relationship_type_table = db.metadata.tables[relationship_type_table_name] + relationship_type_table = get_base().metadata.tables[relationship_type_table_name] - # First try to get the relationship type + # First try to get the relationship type try: relationship_type_id_select_statement = SELECT_STATEMENT(relationship_type_table) relationship_type_id_select_statement = relationship_type_id_select_statement.where(relationship_type_table.c.name == relationship_type) - relationship_type_result = db.session.execute(relationship_type_id_select_statement).all() + relationship_type_result = session.execute(relationship_type_id_select_statement).all() # Ensure we got a result if len(relationship_type_result) > 0: @@ -560,8 +563,8 @@ def get_relationship_type_id(relationship_type): } relationship_type_insert_statement = INSERT_STATEMENT(relationship_type_table).values(**relationship_type_insert_values) relationship_type_insert_statement = relationship_type_insert_statement.on_conflict_do_nothing(index_elements = list(relationship_type_insert_values.keys())) - result = db.session.execute(relationship_type_insert_statement) - db.session.commit() + result = session.execute(relationship_type_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -577,12 +580,12 @@ def get_relationship_type_id(relationship_type): "strat" : ("strat_to_lith", "strat_name", "lith"), "att" : ("lith_to_attribute", "lith", "lith_att") } -def record_relationship(relationship, request_additional_data): +def record_relationship(relationship, request_additional_data, session: Session): # Ensure that we have the required metadata fields relationship_required_fields = verify_key_presents(relationship, ["src", "relationship_type", "dst"]) if relationship_required_fields is not None: return False, relationship_required_fields - + # Extract the types provided_relationship_type = relationship["relationship_type"] db_relationship_type, src_entity_type, dst_entity_type = provided_relationship_type, UNKNOWN_ENTITY_TYPE, UNKNOWN_ENTITY_TYPE @@ -590,12 +593,12 @@ def record_relationship(relationship, request_additional_data): if provided_relationship_type.startswith(key_name): db_relationship_type, src_entity_type, dst_entity_type = RELATIONSHIP_DETAILS[key_name] break - + # Record the relationship type - success, relationship_type_id = get_relationship_type_id(db_relationship_type) + success, relationship_type_id = get_relationship_type_id(db_relationship_type, session) if not success: return success, relationship_type_id - + # Extract the source indicies success, indicies_results = extract_indicies(relationship, "src_") if not success: @@ -603,23 +606,23 @@ def record_relationship(relationship, request_additional_data): src_provided_start_idx, src_provided_end_idx = indicies_results # Get the src entity ids - success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, src_provided_start_idx, src_provided_end_idx) + success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, session, src_provided_start_idx, src_provided_end_idx) if not success: return success, src_entity_id - + # Extract the dest indicies success, indicies_results = extract_indicies(relationship, "dst_") if not success: return success, indicies_results dst_provided_start_idx, dst_provided_end_idx = indicies_results - success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, dst_provided_start_idx, dst_provided_end_idx) + success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, session, dst_provided_start_idx, dst_provided_end_idx) if not success: return success, dst_entity_id # Now record the relationship relationship_table_name = get_complete_table_name("relationship") - relationship_table = db.metadata.tables[relationship_table_name] + relationship_table = get_base().metadata.tables[relationship_table_name] try: # Get the sources values relationship_insert_values = { @@ -631,27 +634,27 @@ def record_relationship(relationship, request_additional_data): if "reasoning" in relationship: relationship_insert_values["reasoning"] = relationship["reasoning"] - + relationship_insert_statement = INSERT_STATEMENT(relationship_table).values(**relationship_insert_values) relationship_insert_statement = relationship_insert_statement.on_conflict_do_nothing(index_elements = ["model_run_id", "src_entity_id", "dst_entity_id", "relationship_type_id"]) - db.session.execute(relationship_insert_statement) - db.session.commit() + session.execute(relationship_insert_statement) + session.commit() except: return False, "Failed to insert relationship with src " + relationship["src"] + " and dst " + relationship["dst"] + " into table " + relationship_table_name + " due to error: " + traceback.format_exc() return True, None -def get_previous_run(source_text_id): +def get_previous_run(source_text_id, session: Session): # Load the latest run table latest_run_table_name = get_complete_table_name("latest_run_per_text") - latest_run_table = db.metadata.tables[latest_run_table_name] + latest_run_table = get_base().metadata.tables[latest_run_table_name] # Get the latest for the current source text prev_run_id = None try: previous_id_for_source_select_statement = SELECT_STATEMENT(latest_run_table) previous_id_for_source_select_statement = previous_id_for_source_select_statement.where(latest_run_table.c.source_text_id == source_text_id) - previous_id_result = db.session.execute(previous_id_for_source_select_statement).all() + previous_id_result = session.execute(previous_id_for_source_select_statement).all() # Ensure we got a result if len(previous_id_result) > 0: @@ -662,7 +665,7 @@ def get_previous_run(source_text_id): return True, prev_run_id -def process_input_request(request_data): +def process_input_request(request_data, session): # Ensure that we have the required metadata fields run_verify_result = verify_key_presents(request_data, ["run_id", "results"]) if run_verify_result is not None: @@ -670,7 +673,7 @@ def process_input_request(request_data): # Get the model metadata for this model request_additional_data = {} - sucess, err_msg = get_model_metadata(request_data, request_additional_data) + sucess, err_msg = get_model_metadata(request_data, request_additional_data, session) if not sucess: return sucess, err_msg @@ -679,21 +682,21 @@ def process_input_request(request_data): all_results = request_data["results"] base_model_run_id = request_data["run_id"] model_run_table_name = get_complete_table_name("model_run") - model_run_table = db.metadata.tables[model_run_table_name] + model_run_table = get_base().metadata.tables[model_run_table_name] for idx, current_result in enumerate(all_results): # Ensure that we have the required metadata fields result_required_fields = verify_key_presents(current_result, ["text"]) if result_required_fields is not None: return False, result_required_fields - + # First get the source text id for this result - success, source_text_id = get_source_text_id(current_result["text"], request_additional_data) + success, source_text_id = get_source_text_id(current_result["text"], request_additional_data, session) if not success: return success, source_text_id # Then get the previous run for this result - success, previous_run_id = get_previous_run(source_text_id) + success, previous_run_id = get_previous_run(source_text_id, session) if not success: return success, previous_run_id @@ -712,8 +715,8 @@ def process_input_request(request_data): model_run_insert_request = INSERT_STATEMENT(model_run_table).values(**model_run_insert_values) model_run_insert_request = model_run_insert_request.on_conflict_do_update(constraint = "no_duplicate_runs", set_ = model_run_insert_values) - result = db.session.execute(model_run_insert_request) - db.session.commit() + result = session.execute(model_run_insert_request) + session.commit() # Now get the interal run id for this new run result_insert_keys = result.inserted_primary_key @@ -726,37 +729,72 @@ def process_input_request(request_data): # Now actually record the graph if "relationships" in current_result: for relationship in current_result["relationships"]: - sucessful, message = record_relationship(relationship, request_additional_data) + sucessful, message = record_relationship(relationship, request_additional_data, session) if not sucessful: return sucessful, message - + # Record just the entities if "just_entities" in current_result: for entity in current_result["just_entities"]: - sucessful, err_msg = record_single_entity(entity, request_additional_data) + sucessful, err_msg = record_single_entity(entity, request_additional_data, session) if not sucessful: return sucessful, err_msg - + return True, None # Opentially take in user id -@app.route("/record_run", methods=["POST"]) -def record_run(): +@app.post("/record_run") +async def record_run( + request: Request, + user_has_access: bool = Depends(has_access), + session: Session = Depends(get_session) +): + + if not user_has_access: + raise HTTPException(status_code=403, detail="User does not have access to record run") + # Record the run - request_data = request.get_json() - sucessful, error_msg = process_input_request(request_data) - if not sucessful: - print("Returning error of", error_msg) - return jsonify({"error" : error_msg}), 400 - return jsonify({"sucess" : "Sucessfully processed the run"}), 200 + request_data = await request.json() + + successful, error_msg = process_input_request(request_data, session) + if not successful: + raise HTTPException(status_code=400, detail=error_msg) + + return JSONResponse(content={"success": "Successfully processed the run"}) + + +@app.get("/health") +async def health( + session = Depends(get_session) +): + health_checks = {} -@app.route("/health", methods=["GET"]) -def health(): - """Health check endpoint""" + try: + session.execute(text("SELECT 1")) + except: + health_checks["database"] = False + else: + health_checks["database"] = True + + # Test that we can get metadata + try: + models_table_name = get_complete_table_name("model") + models_table = get_base().metadata.tables[models_table_name] + models_select_statement = SELECT_STATEMENT(models_table) + models_select_statement = models_select_statement.limit(1) + session.execute(models_select_statement) + except: + health_checks["metadata"] = False + else: + health_checks["metadata"] = True - return jsonify({"status": "Server Running"}), 200 + return JSONResponse(content={ + "webserver": "ok", + "healthy": health_checks + }) if __name__ == "__main__": - app.run(host = "0.0.0.0", port = 9543, debug = True) \ No newline at end of file + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=9543) diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py new file mode 100644 index 0000000..d560ca7 --- /dev/null +++ b/macrostrat_db_insertion/test.py @@ -0,0 +1,46 @@ +# Test the insertion of data into the macrostrat database + +import json +import glob +from types import SimpleNamespace + +import pytest +from fastapi.testclient import TestClient + +from macrostrat_db_insertion.server import app +from macrostrat_db_insertion.security import get_groups_from_header_token + +TEST_GROUP_TOKEN = "vFWCCodpP8hFF6LxFrpYQTqcJjCGOWyn" +TEST_GROUP_ID = 2 + + +@pytest.fixture +def api_client() -> TestClient: + with TestClient(app) as api_client: + yield api_client + + +class TestAPI: + + def test_insert(self, api_client: TestClient): + + for file in glob.glob("example_requests/**/*.json"): + data = json.loads(open(file, "r").read()) + + response = api_client.post( + "/record_run", + json=data + ) + + response.status_code == 200 + + +class TestSecurity: + + def test_get_groups_from_header_token(self, api_client: TestClient): + + mock_header = SimpleNamespace(**{ + 'credentials': TEST_GROUP_TOKEN + }) + + assert get_groups_from_header_token(mock_header) == TEST_GROUP_ID