From 3dd93b1f7697d622ce1f11919c46d1b1dc5b9a02 Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Wed, 2 Oct 2024 14:21:08 -0500 Subject: [PATCH 1/6] Convert to Fastapi --- macrostrat_db_insertion/environment.yml | 33 ++--- macrostrat_db_insertion/server.py | 180 +++++++++++++----------- macrostrat_db_insertion/test.py | 27 ++++ 3 files changed, 133 insertions(+), 107 deletions(-) create mode 100644 macrostrat_db_insertion/test.py diff --git a/macrostrat_db_insertion/environment.yml b/macrostrat_db_insertion/environment.yml index 11b9018..83095bf 100644 --- a/macrostrat_db_insertion/environment.yml +++ b/macrostrat_db_insertion/environment.yml @@ -2,27 +2,18 @@ name: db_insert_env channels: - defaults dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - ca-certificates=2024.3.11=h06a4308_0 - - cuda-version=11.6=h08688e6_2 - - ld_impl_linux-64=2.38=h1181459_1 - - libffi=3.4.4=h6a678d5_0 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libstdcxx-ng=11.2.0=h1234567_1 - - ncurses=6.4=h6a678d5_0 - - openssl=3.0.13=h7f8727e_0 - - pip=23.3.1=py39h06a4308_0 - - python=3.9.19=h955ad1f_0 - - readline=8.2=h5eee18b_0 - - setuptools=68.2.2=py39h06a4308_0 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - tzdata=2024a=h04d1e81_0 - - wheel=0.41.2=py39h06a4308_0 - - xz=5.4.6=h5eee18b_0 - - zlib=1.2.13=h5eee18b_0 + - ca-certificates=2024.3.11 + - ncurses=6.4 + - openssl=3.0.13 + - pip=23.3.1 + - python=3.9.19 + - readline=8.2 + - setuptools=68.2.2 + - sqlite=3.41.2 + - tk=8.6.12 + - wheel=0.41.2 + - xz=5.4.6 + - zlib=1.2.13 - pip: - blinker==1.7.0 - certifi==2024.2.2 diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index ebf67d3..7f1ced1 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -1,10 +1,11 @@ -from flask import Flask, jsonify, request -from flask_cors import CORS -from flask_sqlalchemy import SQLAlchemy -import sqlalchemy -from sqlalchemy.dialects.postgresql import insert as INSERT_STATEMENT -from sqlalchemy import select as SELECT_STATEMENT -from sqlalchemy.orm import declarative_base +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic_settings import BaseSettings +from sqlalchemy import create_engine, MetaData +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.sql import select as SELECT_STATEMENT +from sqlalchemy.sql import insert as INSERT_STATEMENT import json import traceback from datetime import datetime, timezone @@ -12,38 +13,34 @@ from re_detail_adder import * -ENV_VAR_PREFIX = "macrostrat_xdd" -REQUIRED_VALUES = ["username", "password", "host", "port", "database", "schema"] -def load_config(): - config_values = {} - for required_name in REQUIRED_VALUES: - # Read in the environment variable - env_variable_name = ENV_VAR_PREFIX + "_" + required_name - env_variable_value = os.environ.get(env_variable_name) - config_values[required_name] = env_variable_value - return config_values +class Settings(BaseSettings): + uri: str + schema: str + max_tries: int = 5 -def load_flask_app(config): - # Create the app - app = Flask(__name__) - app.config["SQLALCHEMY_DATABASE_URI"] = os.environ['uri'] - # Create the db - Base = declarative_base(metadata = sqlalchemy.MetaData(schema = config["schema"])) - db = SQLAlchemy(model_class=Base) - db.init_app(app) - with app.app_context(): - db.reflect() +def load_fastapi_app(): - return app, db + app = FastAPI() + settings = Settings() + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + engine = create_engine(settings.uri) + SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base = declarative_base(metadata=MetaData(schema=settings.schema)) + Base.metadata.reflect(engine) + + return app, SessionLocal, Base # Connect to the database -MAX_TRIES = 5 -config = load_config() -print(config) -app, db = load_flask_app(config) -CORS(app) +app, SessionLocal, Base = load_fastapi_app() re_processor = REProcessor("id_maps") ENTITY_TYPE_TO_ID_MAP = { @@ -51,10 +48,12 @@ def load_flask_app(config): "lith" : "macrostrat_lith_id", "lith_att" : "macrostrat_lith_att_id" } -def get_db_entity_id(run_id, entity_name, entity_type, source_id): + + +def get_db_entity_id(db, run_id, entity_name, entity_type, source_id): # Create the entity value entity_unique_rows = ["run_id", "entity_name", "entity_type", "source_id"] - entities_table = db.metadata.tables['macrostrat_kg_new.entities'] + entities_table = Base.metadata.tables['macrostrat_kg_new.entities'] entities_values = { "run_id" : run_id, "entity_name" : entity_name, @@ -72,8 +71,8 @@ def get_db_entity_id(run_id, entity_name, entity_type, source_id): try: entities_insert_statement = INSERT_STATEMENT(entities_table).values(**entities_values) entities_insert_statement = entities_insert_statement.on_conflict_do_nothing(index_elements = entity_unique_rows) - db.session.execute(entities_insert_statement) - db.session.commit() + db.execute(entities_insert_statement) + db.commit() except: error_msg = "Failed to insert entity " + entity_name + " for run " + str(run_id) + " due to error: " + traceback.format_exc() return False, error_msg @@ -86,7 +85,7 @@ def get_db_entity_id(run_id, entity_name, entity_type, source_id): entities_select_statement = entities_select_statement.where(entities_table.c.run_id == run_id) entities_select_statement = entities_select_statement.where(entities_table.c.entity_name == entity_name) entities_select_statement = entities_select_statement.where(entities_table.c.entity_type == entity_type) - entities_result = db.session.execute(entities_select_statement).all() + entities_result = db.execute(entities_select_statement).all() # Ensure we got a result if len(entities_result) == 0: @@ -106,7 +105,9 @@ def get_db_entity_id(run_id, entity_name, entity_type, source_id): "strat" : ("strat_to_lith", "strat_name", "lith"), "att" : ("lith_to_attribute", "lith", "lith_att") } -def record_relationship(run_id, source_id, relationship): + + +def record_relationship(db, run_id, source_id, relationship): # Verify the fields expected_fields = ["src", "relationship_type", "dst"] relationship_values = {} @@ -128,11 +129,11 @@ def record_relationship(run_id, source_id, relationship): return True, "" # Get the entity ids - sucessful, src_entity_id = get_db_entity_id(run_id, relationship_values["src"], src_entity_type, source_id) + sucessful, src_entity_id = get_db_entity_id(db, run_id, relationship_values["src"], src_entity_type, source_id) if not sucessful: return sucessful, src_entity_id - sucessful, dst_entity_id = get_db_entity_id(run_id, relationship_values["dst"], dst_entity_type, source_id) + sucessful, dst_entity_id = get_db_entity_id(db, run_id, relationship_values["dst"], dst_entity_type, source_id) if not sucessful: return sucessful, dst_entity_id @@ -145,12 +146,12 @@ def record_relationship(run_id, source_id, relationship): "relationship_type" : db_relationship_type } unique_columns = ["run_id", "src_entity_id", "dst_entity_id", "relationship_type", "source_id"] - relationship_tables = db.metadata.tables['macrostrat_kg_new.relationship'] + relationship_tables = Base.metadata.tables['macrostrat_kg_new.relationship'] try: relationship_insert_statement = INSERT_STATEMENT(relationship_tables).values(**db_relationship_insert_values) relationship_insert_statement = relationship_insert_statement.on_conflict_do_nothing(index_elements = unique_columns) - db.session.execute(relationship_insert_statement) - db.session.commit() + db.execute(relationship_insert_statement) + db.commit() except: error_msg = "Failed to insert relationship type " + str(provided_relationship_type) + " for source " + str(source_id) error_msg += " for run " + str(run_id) + " due to error: " + traceback.format_exc() @@ -158,7 +159,8 @@ def record_relationship(run_id, source_id, relationship): return True, "" -def record_for_result(run_id, request): + +def record_for_result(db, run_id, request): # Ensure txt exists if "text" not in request: return False, "result is missing text field" @@ -175,13 +177,13 @@ def record_for_result(run_id, request): paragraph_txt = source_values["paragraph_text"] source_values["paragraph_text"] = paragraph_txt.encode("ascii", errors="ignore").decode() - sources_table = db.metadata.tables['macrostrat_kg_new.sources'] + sources_table = Base.metadata.tables['macrostrat_kg_new.sources'] try: # Try to insert the sources sources_insert_statement = INSERT_STATEMENT(sources_table).values(**source_values) sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements=["run_id", "weaviate_id"]) - db.session.execute(sources_insert_statement) - db.session.commit() + db.execute(sources_insert_statement) + db.commit() except: error_msg = "Failed to insert paragraph " + str(source_values["weaviate_id"]) error_msg += " for run " + str(source_values["run_id"]) + " due to error: " + traceback.format_exc() @@ -198,7 +200,7 @@ def record_for_result(run_id, request): sources_select_statement = SELECT_STATEMENT(sources_table) sources_select_statement = sources_select_statement.where(sources_table.c.run_id == run_id) sources_select_statement = sources_select_statement.where(sources_table.c.weaviate_id == source_values["weaviate_id"]) - sources_result = db.session.execute(sources_select_statement).all() + sources_result = db.execute(sources_select_statement).all() # Ensure we got a result if len(sources_result) == 0: @@ -215,7 +217,7 @@ def record_for_result(run_id, request): # Record the relationships if "relationships" in request: for relationship in request["relationships"]: - sucessful, message = record_relationship(run_id, source_id, relationship) + sucessful, message = record_relationship(db, run_id, source_id, relationship) if not sucessful: return sucessful, message @@ -234,15 +236,16 @@ def record_for_result(run_id, request): continue # Record the entity - sucessful, entity_id = get_db_entity_id(run_id, entity_data["entity"], "strat_name", source_id) + sucessful, entity_id = get_db_entity_id(db, run_id, entity_data["entity"], "strat_name", source_id) if not sucessful: return sucessful, entity_id return True, "" -def get_user_id(user_name): + +def get_user_id(db, user_name): # Create the users rows - users_table = db.metadata.tables['macrostrat_kg_new.users'] + users_table = Base.metadata.tables['macrostrat_kg_new.users'] users_row_values = { "user_name" : user_name } @@ -251,8 +254,8 @@ def get_user_id(user_name): try: users_insert_statement = INSERT_STATEMENT(users_table).values(**users_row_values) users_insert_statement = users_insert_statement.on_conflict_do_nothing(index_elements = ["user_name"]) - db.session.execute(users_insert_statement) - db.session.commit() + db.execute(users_insert_statement) + db.commit() except: error_msg = "Failed to insert user " + user_name + " due to error: " + traceback.format_exc() return False, error_msg @@ -263,7 +266,7 @@ def get_user_id(user_name): # Execute the select query users_select_statement = SELECT_STATEMENT(users_table) users_select_statement = users_select_statement.where(users_table.c.user_name == user_name) - users_result = db.session.execute(users_select_statement).all() + users_result = db.execute(users_select_statement).all() # Ensure we got a result if len(users_result) == 0: @@ -278,7 +281,8 @@ def get_user_id(user_name): return True, user_id -def get_model_internal_details(request): + +def get_model_internal_details(db, request): # Extract the expected values expected_fields = ["model_name", "model_version"] model_values = {} @@ -289,15 +293,15 @@ def get_model_internal_details(request): # Try to insert the model model_name = model_values["model_name"] - models_tables = db.metadata.tables['macrostrat_kg_new.models'] + models_tables = Base.metadata.tables['macrostrat_kg_new.models'] try: # Try to insert the model model_insert_statement = INSERT_STATEMENT(models_tables).values(**{ "model_name" : model_name }) model_insert_statement = model_insert_statement.on_conflict_do_nothing(index_elements=["model_name"]) - db.session.execute(model_insert_statement) - db.session.commit() + db.execute(model_insert_statement) + db.commit() except: error_msg = "Failed to insert model " + model_name + " due to error: " + traceback.format_exc() return False, error_msg @@ -308,7 +312,7 @@ def get_model_internal_details(request): # Execute the select query models_select_statement = SELECT_STATEMENT(models_tables) models_select_statement = models_select_statement.where(models_tables.c.model_name == model_name) - models_result = db.session.execute(models_select_statement).all() + models_result = db.execute(models_select_statement).all() # Ensure we got a result if len(models_result) == 0: @@ -323,7 +327,7 @@ def get_model_internal_details(request): # Try to insert the model version model_version = model_values["model_version"] - versions_table = db.metadata.tables['macrostrat_kg_new.model_versions'] + versions_table = Base.metadata.tables['macrostrat_kg_new.model_versions'] try: # Try to insert the model version version_insert_statement = INSERT_STATEMENT(versions_table).values(**{ @@ -331,8 +335,8 @@ def get_model_internal_details(request): "model_version" : model_version }) version_insert_statement = version_insert_statement.on_conflict_do_nothing(index_elements=["model_id", "model_version"]) - db.session.execute(version_insert_statement) - db.session.commit() + db.execute(version_insert_statement) + db.commit() except: error_msg = "Failed to insert version " + model_version + " for model " + model_name + " due to error: " + traceback.format_exc() return False, error_msg @@ -343,7 +347,7 @@ def get_model_internal_details(request): version_select_statement = SELECT_STATEMENT(versions_table) version_select_statement = version_select_statement.where(versions_table.c.model_id == data_to_return["internal_model_id"]) version_select_statement = version_select_statement.where(versions_table.c.model_version == model_version) - version_result = db.session.execute(version_select_statement).all() + version_result = db.execute(version_select_statement).all() # Ensure we got a result if len(version_result) == 0: @@ -360,7 +364,8 @@ def get_model_internal_details(request): return True, data_to_return -def process_input_request(request_data): + +def process_input_request(db, request_data): # Get the metadata fields metadata_fields = ["run_id", "extraction_pipeline_id"] metadata_values = {} @@ -370,7 +375,7 @@ def process_input_request(request_data): metadata_values[field_name] = request_data[field_name] # Add the model fields to the metadata - sucessful, model_fields = get_model_internal_details(request_data) + sucessful, model_fields = get_model_internal_details(db, request_data) if not sucessful: return sucessful, model_fields @@ -379,50 +384,53 @@ def process_input_request(request_data): # Determine if this is user provided feedback if "user_name" in request_data: - sucessful, user_id = get_user_id(request_data["user_name"]) + sucessful, user_id = get_user_id(db, request_data["user_name"]) if not sucessful: return sucessful, user_id metadata_values["user_id"] = user_id # Insert this run to the metadata try: - metadata_table = db.metadata.tables['macrostrat_kg_new.metadata'] + metadata_table = Base.metadata.tables['macrostrat_kg_new.metadata'] metadata_insert_statement = INSERT_STATEMENT(metadata_table).values(**metadata_values) metadata_insert_statement = metadata_insert_statement.on_conflict_do_update(index_elements=["run_id"], set_ = { "internal_model_id" : metadata_values["internal_model_id"], "internal_version_id" : metadata_values["internal_version_id"] }) - db.session.execute(metadata_insert_statement) - db.session.commit() + db.execute(metadata_insert_statement) + db.commit() except Exception: return False, "Failed to insert run " + str(metadata_values["run_id"]) + " due to error: " + traceback.format_exc() # Record the results if "results" in request_data: for result in request_data["results"]: - sucessful, error_msg = record_for_result(request_data["run_id"], result) + sucessful, error_msg = record_for_result(db, request_data["run_id"], result) if not sucessful: return sucessful, error_msg return True, "" -@app.route("/record_run", methods=["POST"]) -def record_run(): - # Record the run - sucessful, error_msg = process_input_request(request.get_json()) - if not sucessful: - print("Returning error of", error_msg) - return jsonify({"error" : error_msg}), 400 - - return jsonify({"sucess" : "Sucessfully processed the run"}), 200 +@app.post("/record_run") +async def record_run(request: Request): + # Record the run + request_data = await request.json() + db = SessionLocal() + try: + successful, error_msg = process_input_request(db, request_data) + if not successful: + raise HTTPException(status_code=400, detail=error_msg) + finally: + db.close() + return JSONResponse(content={"success": "Successfully processed the run"}) -@app.route("/health", methods=["GET"]) -def health(): - """Health check endpoint""" - return jsonify({"status": "Server Running"}), 200 +@app.get("/health") +async def health(): + return JSONResponse(content={"status": "Server Running"}) if __name__ == "__main__": - app.run(host = "0.0.0.0", port = 9543, debug = True) + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=9543) diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py new file mode 100644 index 0000000..355b729 --- /dev/null +++ b/macrostrat_db_insertion/test.py @@ -0,0 +1,27 @@ +# Test the insertion of data into the macrostrat database + +import json + +import pytest +from fastapi.testclient import TestClient + +from macrostrat_db_insertion.server import app + + +@pytest.fixture +def api_client() -> TestClient: + with TestClient(app) as api_client: + yield api_client + + +class TestAPI: + + def test_insert(self, api_client: TestClient): + data = json.loads(open("example_request.json", "r").read()) + + response = api_client.post( + "/record_run", + json=data + ) + + assert response.status_code == 200 From f95cdf20a4e04a4c98fc271b90fcd5926d127055 Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Mon, 7 Oct 2024 11:27:06 -0500 Subject: [PATCH 2/6] Add Auth to api --- macrostrat_db_insertion/database.py | 33 ++++++ macrostrat_db_insertion/environment.yml | 33 +++--- macrostrat_db_insertion/security-v1.py | 114 +++++++++++++++++++ macrostrat_db_insertion/security/__init__.py | 6 + macrostrat_db_insertion/security/db.py | 30 +++++ macrostrat_db_insertion/security/main.py | 114 +++++++++++++++++++ macrostrat_db_insertion/security/model.py | 23 ++++ macrostrat_db_insertion/security/schema.py | 71 ++++++++++++ macrostrat_db_insertion/server.py | 74 ++++++------ macrostrat_db_insertion/test.py | 16 +++ 10 files changed, 462 insertions(+), 52 deletions(-) create mode 100644 macrostrat_db_insertion/database.py create mode 100644 macrostrat_db_insertion/security-v1.py create mode 100644 macrostrat_db_insertion/security/__init__.py create mode 100644 macrostrat_db_insertion/security/db.py create mode 100644 macrostrat_db_insertion/security/main.py create mode 100644 macrostrat_db_insertion/security/model.py create mode 100644 macrostrat_db_insertion/security/schema.py diff --git a/macrostrat_db_insertion/database.py b/macrostrat_db_insertion/database.py new file mode 100644 index 0000000..d491055 --- /dev/null +++ b/macrostrat_db_insertion/database.py @@ -0,0 +1,33 @@ + +from sqlalchemy import create_engine, MetaData, Engine +from sqlalchemy.orm import sessionmaker, declarative_base + +engine: Engine | None = None +Base: declarative_base = None + + +def get_engine(): + return engine + + +def get_base(): + return Base + + +def connect_engine(uri: str, schema: str): + global engine + global Base + + engine = create_engine(uri) + + Base = declarative_base() + Base.metadata.reflect(get_engine()) + + +def dispose_engine(): + global engine + engine.dispose() + + +def get_session_maker(): + return sessionmaker(autocommit=False, autoflush=False, bind=get_engine()) diff --git a/macrostrat_db_insertion/environment.yml b/macrostrat_db_insertion/environment.yml index 83095bf..883734a 100644 --- a/macrostrat_db_insertion/environment.yml +++ b/macrostrat_db_insertion/environment.yml @@ -15,24 +15,19 @@ dependencies: - xz=5.4.6 - zlib=1.2.13 - pip: - - blinker==1.7.0 - - certifi==2024.2.2 - - charset-normalizer==3.3.2 + - bcrypt==4.2.0 - click==8.1.7 - - flask==3.0.3 - - flask-cors==4.0.0 - - flask-sqlalchemy==3.1.1 - - greenlet==3.0.3 - - idna==3.7 - - importlib-metadata==7.1.0 - - itsdangerous==2.1.2 - - jinja2==3.1.3 - - markupsafe==2.1.5 - - psycopg2-binary==2.9.9 - - requests==2.31.0 - - sqlalchemy==2.0.29 - - typing-extensions==4.11.0 - - urllib3==2.2.1 - - werkzeug==3.0.2 - - zipp==3.18.1 + - h11==0.14.0 + - iniconfig==2.0.0 + - jose==1.0.0 + - packaging==24.1 + - pluggy==1.5.0 + - psycopg2==2.9.9 + - pydantic-settings==2.5.2 + - pytest==8.3.3 + - python-dotenv==1.0.1 + - sqlalchemy==2.0.35 + - tomli==2.0.2 + - uvicorn==0.31.0 + - fast prefix: /conda/envs/db_insert_env diff --git a/macrostrat_db_insertion/security-v1.py b/macrostrat_db_insertion/security-v1.py new file mode 100644 index 0000000..3481601 --- /dev/null +++ b/macrostrat_db_insertion/security-v1.py @@ -0,0 +1,114 @@ +import os +from datetime import datetime +from typing import Annotated, Optional + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2AuthorizationCodeBearer, +) +from fastapi.security.utils import get_authorization_scheme_param +from jose import JWTError, jwt +from pydantic import BaseModel +from sqlalchemy import select, update +from starlette.status import HTTP_401_UNAUTHORIZED + +from macrostrat_db_insertion.security.db import get_access_token +from macrostrat_db_insertion.security.model import TokenData + +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours +GROUP_TOKEN_LENGTH = 32 +GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent + + +class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer): + """Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header""" + + async def __call__(self, request: Request) -> Optional[str]: + authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={ + "WWW-Authenticate": "Bearer" + }, + ) + else: + return None # pragma: nocover + return param + + +oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie( + authorizationUrl='/security/login', + tokenUrl="/security/callback", + auto_error=False +) + +http_bearer = HTTPBearer(auto_error=False) + + +def get_groups_from_header_token( + header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None: + """Get the groups from the bearer token in the header""" + + if header_token is None: + return None + + token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT) + token_hash_string = token_hash.decode('utf-8') + + token = get_access_token(token=token_hash_string) + + if token is None: + return None + + return token.group + + +def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]): + """Get the current user from the JWT token in the cookies""" + + # If there wasn't a token include in the request + if token is None: + return None + + try: + payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']]) + sub: str = payload.get("sub") + groups = payload.get("groups", []) + token_data = TokenData(sub=sub, groups=groups) + except JWTError as e: + return None + + return token_data + + +def get_groups( + user_token_data: TokenData | None = Depends(get_user_token_from_cookie), + header_token: int | None = Depends(get_groups_from_header_token) +) -> list[int]: + """Get the groups from both the cookies and header""" + + groups = [] + if user_token_data is not None: + groups = user_token_data.groups + + if header_token is not None: + groups.append(header_token) + + return groups + + +async def has_access(groups: list[int] = Depends(get_groups)) -> bool: + """Check if the user has access to the group""" + + if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development': + return True + + return 1 in groups diff --git a/macrostrat_db_insertion/security/__init__.py b/macrostrat_db_insertion/security/__init__.py new file mode 100644 index 0000000..31f3b93 --- /dev/null +++ b/macrostrat_db_insertion/security/__init__.py @@ -0,0 +1,6 @@ +from macrostrat_db_insertion.security.main import ( + get_groups_from_header_token, + get_user_token_from_cookie, + get_groups, + has_access +) diff --git a/macrostrat_db_insertion/security/db.py b/macrostrat_db_insertion/security/db.py new file mode 100644 index 0000000..b64493b --- /dev/null +++ b/macrostrat_db_insertion/security/db.py @@ -0,0 +1,30 @@ +import datetime + +from sqlalchemy import select, update + +from macrostrat_db_insertion.database import get_session_maker, get_engine +from macrostrat_db_insertion.security.schema import Token + + +def get_access_token(token: str): + """The sole database call """ + + session_maker = get_session_maker() + with session_maker() as session: + + select_stmt = select(Token).where(Token.token == token) + + # Check that the token exists + result = (session.scalars(select_stmt)).first() + + # Check if it has expired + if result.expires_on < datetime.datetime.now(datetime.timezone.utc): + return None + + # Update the used_on column + if result is not None: + stmt = update(Token).where(Token.token == token).values(used_on=datetime.datetime.utcnow()) + session.execute(stmt) + session.commit() + + return (session.scalars(select_stmt)).first() diff --git a/macrostrat_db_insertion/security/main.py b/macrostrat_db_insertion/security/main.py new file mode 100644 index 0000000..2037b31 --- /dev/null +++ b/macrostrat_db_insertion/security/main.py @@ -0,0 +1,114 @@ +import os +from datetime import datetime +from typing import Annotated, Optional + +import bcrypt +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2AuthorizationCodeBearer, +) +from fastapi.security.utils import get_authorization_scheme_param +from jose import JWTError, jwt +from pydantic import BaseModel +from sqlalchemy import select, update +from starlette.status import HTTP_401_UNAUTHORIZED + +from macrostrat_db_insertion.security.db import get_access_token +from macrostrat_db_insertion.security.model import TokenData + +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 1440 # 24 hours +GROUP_TOKEN_LENGTH = 32 +GROUP_TOKEN_SALT = b'$2b$12$yQrslvQGWDFjwmDBMURAUe' # Hardcode salt so hashes are consistent + + +class OAuth2AuthorizationCodeBearerWithCookie(OAuth2AuthorizationCodeBearer): + """Tweak FastAPI's OAuth2AuthorizationCodeBearer to use a cookie instead of a header""" + + async def __call__(self, request: Request) -> Optional[str]: + authorization = request.cookies.get("Authorization") # authorization = request.headers.get("Authorization") + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "bearer": + if self.auto_error: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={ + "WWW-Authenticate": "Bearer" + }, + ) + else: + return None # pragma: nocover + return param + + +oauth2_scheme = OAuth2AuthorizationCodeBearerWithCookie( + authorizationUrl='/security/login', + tokenUrl="/security/callback", + auto_error=False +) + +http_bearer = HTTPBearer(auto_error=False) + + +def get_groups_from_header_token( + header_token: Annotated[HTTPAuthorizationCredentials, Depends(http_bearer)]) -> int | None: + """Get the groups from the bearer token in the header""" + + if header_token is None: + return None + + token_hash = bcrypt.hashpw(header_token.credentials.encode(), GROUP_TOKEN_SALT) + token_hash_string = token_hash.decode('utf-8') + + token = get_access_token(token=token_hash_string) + + if token is None: + return None + + return token.group + + +def get_user_token_from_cookie(token: Annotated[str | None, Depends(oauth2_scheme)]): + """Get the current user from the JWT token in the cookies""" + + # If there wasn't a token include in the request + if token is None: + return None + + try: + payload = jwt.decode(token, os.environ['SECRET_KEY'], algorithms=[os.environ['JWT_ENCRYPTION_ALGORITHM']]) + sub: str = payload.get("sub") + groups = payload.get("groups", []) + token_data = TokenData(sub=sub, groups=groups) + except JWTError as e: + return None + + return token_data + + +def get_groups( + user_token_data: TokenData | None = Depends(get_user_token_from_cookie), + header_token: int | None = Depends(get_groups_from_header_token) +) -> list[int]: + """Get the groups from both the cookies and header""" + + groups = [] + if user_token_data is not None: + groups = user_token_data.groups + + if header_token is not None: + groups.append(header_token) + + return groups + + +def has_access(groups: list[int] = Depends(get_groups)) -> bool: + """Check if the user has access to the group""" + + if 'ENVIRONMENT' in os.environ and os.environ['ENVIRONMENT'] == 'development': + return True + + return 1 in groups diff --git a/macrostrat_db_insertion/security/model.py b/macrostrat_db_insertion/security/model.py new file mode 100644 index 0000000..d610881 --- /dev/null +++ b/macrostrat_db_insertion/security/model.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel + + +class TokenData(BaseModel): + sub: str + groups: list[int] = [] + + +class User(BaseModel): + username: str + email: str | None = None + full_name: str | None = None + disabled: bool | None = None + + +class AccessToken(BaseModel): + group: int + token: str + + +class GroupTokenRequest(BaseModel): + expiration: int + group_id: int diff --git a/macrostrat_db_insertion/security/schema.py b/macrostrat_db_insertion/security/schema.py new file mode 100644 index 0000000..f1e015d --- /dev/null +++ b/macrostrat_db_insertion/security/schema.py @@ -0,0 +1,71 @@ +import enum +from typing import List +import datetime +from sqlalchemy import ForeignKey, func, DateTime, Enum, PrimaryKeyConstraint, UniqueConstraint +from sqlalchemy.dialects.postgresql import VARCHAR, TEXT, INTEGER, ARRAY, BOOLEAN, JSON, JSONB +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship + + +class Base(DeclarativeBase): + pass + + +class GroupMembers(Base): + __tablename__ = "group_members" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + group_id: Mapped[int] = mapped_column(ForeignKey("macrostrat_auth.group.id")) + user_id: Mapped[int] = mapped_column(ForeignKey("macrostrat_auth.user.id")) + + +class Group(Base): + __tablename__ = "group" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(VARCHAR(255)) + users: Mapped[List["User"]] = relationship(secondary="macrostrat_auth.group_members", lazy="joined", + back_populates="groups") + + +class User(Base): + __tablename__ = "user" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + sub: Mapped[str] = mapped_column(VARCHAR(255)) + name: Mapped[str] = mapped_column(VARCHAR(255)) + email: Mapped[str] = mapped_column(VARCHAR(255)) + groups: Mapped[List[Group]] = relationship(secondary="macrostrat_auth.group_members", lazy="joined", + back_populates="users") + created_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + +class Token(Base): + __tablename__ = "token" + __table_args__ = { + 'schema': 'macrostrat_auth' + } + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + token: Mapped[str] = mapped_column(VARCHAR(255), unique=True) + group: Mapped[Group] = mapped_column(ForeignKey("macrostrat_auth.group.id")) + used_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), nullable=True + ) + expires_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True) + ) + created_on: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index 7f1ced1..a7ead0f 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -1,46 +1,47 @@ -from fastapi import FastAPI, Request, HTTPException +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic_settings import BaseSettings -from sqlalchemy import create_engine, MetaData -from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.sql import select as SELECT_STATEMENT from sqlalchemy.sql import insert as INSERT_STATEMENT -import json import traceback -from datetime import datetime, timezone -import os -from re_detail_adder import * +from macrostrat_db_insertion.re_detail_adder import * +from macrostrat_db_insertion.security import has_access +from macrostrat_db_insertion.database import get_session_maker, get_base, connect_engine, dispose_engine class Settings(BaseSettings): uri: str - schema: str + SCHEMA: str max_tries: int = 5 -def load_fastapi_app(): +settings = Settings() + - app = FastAPI() - settings = Settings() - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) +@asynccontextmanager +async def setup_engine(a: FastAPI): + """Return database client instance.""" + connect_engine(settings.uri, settings.SCHEMA) + yield + dispose_engine() - engine = create_engine(settings.uri) - SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base = declarative_base(metadata=MetaData(schema=settings.schema)) - Base.metadata.reflect(engine) +app = FastAPI( + lifespan=setup_engine +) - return app, SessionLocal, Base +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) # Connect to the database -app, SessionLocal, Base = load_fastapi_app() re_processor = REProcessor("id_maps") ENTITY_TYPE_TO_ID_MAP = { @@ -53,7 +54,7 @@ def load_fastapi_app(): def get_db_entity_id(db, run_id, entity_name, entity_type, source_id): # Create the entity value entity_unique_rows = ["run_id", "entity_name", "entity_type", "source_id"] - entities_table = Base.metadata.tables['macrostrat_kg_new.entities'] + entities_table = get_base().metadata.tables['macrostrat_kg_new.entities'] entities_values = { "run_id" : run_id, "entity_name" : entity_name, @@ -146,7 +147,7 @@ def record_relationship(db, run_id, source_id, relationship): "relationship_type" : db_relationship_type } unique_columns = ["run_id", "src_entity_id", "dst_entity_id", "relationship_type", "source_id"] - relationship_tables = Base.metadata.tables['macrostrat_kg_new.relationship'] + relationship_tables = get_base().metadata.tables['macrostrat_kg_new.relationship'] try: relationship_insert_statement = INSERT_STATEMENT(relationship_tables).values(**db_relationship_insert_values) relationship_insert_statement = relationship_insert_statement.on_conflict_do_nothing(index_elements = unique_columns) @@ -177,7 +178,7 @@ def record_for_result(db, run_id, request): paragraph_txt = source_values["paragraph_text"] source_values["paragraph_text"] = paragraph_txt.encode("ascii", errors="ignore").decode() - sources_table = Base.metadata.tables['macrostrat_kg_new.sources'] + sources_table = get_base().metadata.tables['macrostrat_kg_new.sources'] try: # Try to insert the sources sources_insert_statement = INSERT_STATEMENT(sources_table).values(**source_values) @@ -245,7 +246,7 @@ def record_for_result(db, run_id, request): def get_user_id(db, user_name): # Create the users rows - users_table = Base.metadata.tables['macrostrat_kg_new.users'] + users_table = get_base().metadata.tables['macrostrat_kg_new.users'] users_row_values = { "user_name" : user_name } @@ -293,7 +294,7 @@ def get_model_internal_details(db, request): # Try to insert the model model_name = model_values["model_name"] - models_tables = Base.metadata.tables['macrostrat_kg_new.models'] + models_tables = get_base().metadata.tables['macrostrat_kg_new.models'] try: # Try to insert the model model_insert_statement = INSERT_STATEMENT(models_tables).values(**{ @@ -327,7 +328,7 @@ def get_model_internal_details(db, request): # Try to insert the model version model_version = model_values["model_version"] - versions_table = Base.metadata.tables['macrostrat_kg_new.model_versions'] + versions_table = get_base().metadata.tables['macrostrat_kg_new.model_versions'] try: # Try to insert the model version version_insert_statement = INSERT_STATEMENT(versions_table).values(**{ @@ -391,7 +392,7 @@ def process_input_request(db, request_data): # Insert this run to the metadata try: - metadata_table = Base.metadata.tables['macrostrat_kg_new.metadata'] + metadata_table = get_base().metadata.tables['macrostrat_kg_new.metadata'] metadata_insert_statement = INSERT_STATEMENT(metadata_table).values(**metadata_values) metadata_insert_statement = metadata_insert_statement.on_conflict_do_update(index_elements=["run_id"], set_ = { "internal_model_id" : metadata_values["internal_model_id"], @@ -413,10 +414,17 @@ def process_input_request(db, request_data): @app.post("/record_run") -async def record_run(request: Request): +async def record_run( + request: Request, + user_has_access: bool = Depends(has_access) +): + + if not user_has_access: + raise HTTPException(status_code=403, detail="User does not have access to record run") + # Record the run request_data = await request.json() - db = SessionLocal() + db = get_session_maker() try: successful, error_msg = process_input_request(db, request_data) if not successful: diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py index 355b729..36b9bfa 100644 --- a/macrostrat_db_insertion/test.py +++ b/macrostrat_db_insertion/test.py @@ -1,11 +1,16 @@ # Test the insertion of data into the macrostrat database import json +from types import SimpleNamespace import pytest from fastapi.testclient import TestClient from macrostrat_db_insertion.server import app +from macrostrat_db_insertion.security import get_groups_from_header_token + +TEST_GROUP_TOKEN = "vFWCCodpP8hFF6LxFrpYQTqcJjCGOWyn" +TEST_GROUP_ID = 2 @pytest.fixture @@ -25,3 +30,14 @@ def test_insert(self, api_client: TestClient): ) assert response.status_code == 200 + + +class TestSecurity: + + def test_get_groups_from_header_token(self, api_client: TestClient): + + mock_header = SimpleNamespace(**{ + 'credentials': TEST_GROUP_TOKEN + }) + + assert get_groups_from_header_token(mock_header) == TEST_GROUP_ID From b71c7001760ffc2e6c04cfb175b71d7b9e671a2f Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Thu, 10 Oct 2024 13:00:11 -0500 Subject: [PATCH 3/6] Merge in add-auth --- macrostrat_db_insertion/database.py | 27 +- macrostrat_db_insertion/environment.yml | 2 + macrostrat_db_insertion/server.py | 328 +++++++++++++----------- macrostrat_db_insertion/test.py | 5 +- 4 files changed, 207 insertions(+), 155 deletions(-) diff --git a/macrostrat_db_insertion/database.py b/macrostrat_db_insertion/database.py index d491055..c871c68 100644 --- a/macrostrat_db_insertion/database.py +++ b/macrostrat_db_insertion/database.py @@ -1,27 +1,31 @@ from sqlalchemy import create_engine, MetaData, Engine -from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.orm import sessionmaker, declarative_base, Session engine: Engine | None = None -Base: declarative_base = None +base: declarative_base = None +session: Session | None = None -def get_engine(): +def get_engine() -> Engine: return engine -def get_base(): - return Base +def get_base() -> declarative_base: + return base def connect_engine(uri: str, schema: str): global engine - global Base + global session + global base engine = create_engine(uri) + session = session - Base = declarative_base() - Base.metadata.reflect(get_engine()) + base = declarative_base() + base.metadata.reflect(get_engine()) + base.metadata.reflect(get_engine(), schema=schema) def dispose_engine(): @@ -29,5 +33,10 @@ def dispose_engine(): engine.dispose() -def get_session_maker(): +def get_session_maker() -> sessionmaker: return sessionmaker(autocommit=False, autoflush=False, bind=get_engine()) + + +def get_session() -> Session: + with get_session_maker()() as s: + yield s diff --git a/macrostrat_db_insertion/environment.yml b/macrostrat_db_insertion/environment.yml index d564100..2161fbe 100644 --- a/macrostrat_db_insertion/environment.yml +++ b/macrostrat_db_insertion/environment.yml @@ -47,4 +47,6 @@ dependencies: - urllib3==2.2.1 - werkzeug==3.0.2 - zipp==3.18.1 + - fuzzysearch + - uvicorn prefix: /conda/envs/db_insert_env diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index 4eb8e3f..8c58f02 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -1,58 +1,61 @@ -from flask import Flask, jsonify, request -from flask_cors import CORS -from flask_sqlalchemy import SQLAlchemy -import sqlalchemy -from sqlalchemy import inspect + +from contextlib import asynccontextmanager + +from fastapi import FastAPI, Request, HTTPException, Depends +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from pydantic_settings import BaseSettings from sqlalchemy.dialects.postgresql import insert as INSERT_STATEMENT -from sqlalchemy import select as SELECT_STATEMENT -from sqlalchemy.orm import declarative_base -from flask import Flask +from sqlalchemy import select as SELECT_STATEMENT, text +from sqlalchemy.orm import Session from fuzzysearch import find_near_matches -from flask_sqlalchemy import SQLAlchemy -from sqlalchemy import inspect, MetaData -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import declarative_base -import json import traceback import hashlib import requests -from datetime import datetime, timezone -import os - -def load_flask_app(schema_name): - # Create the app - app = Flask(__name__) - uri = os.environ['uri'] - app.config["SQLALCHEMY_DATABASE_URI"] = uri - print("Loading schema", schema_name, "from uri", uri) - - # Create the db - Base = declarative_base(metadata = sqlalchemy.MetaData(schema = schema_name)) - db = SQLAlchemy(model_class=Base) - db.init_app(app) - with app.app_context(): - db.metadata.reflect(bind = db.engine, schema = schema_name, views = True) - print("Finished loading schema", schema_name) - - return app, db - -# Connect to the database -SCHEMA_NAME = os.environ['macrostrat_xdd_schema_name'] -MAX_TRIES = 5 -app, db = load_flask_app(SCHEMA_NAME) -CORS(app) + +from macrostrat_db_insertion.database import connect_engine, dispose_engine, get_base, get_session +from macrostrat_db_insertion.security import has_access + + +class Settings(BaseSettings): + uri: str + SCHEMA: str + max_tries: int = 5 + + +settings = Settings() + + +@asynccontextmanager +async def setup_engine(a: FastAPI): + """Return database client instance.""" + connect_engine(settings.uri, settings.SCHEMA) + yield + dispose_engine() + +app = FastAPI( + lifespan=setup_engine +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) def get_complete_table_name(table_name): - return SCHEMA_NAME + "." + table_name + return settings.SCHEMA + "." + table_name def verify_key_presents(input, required_keys): for key in required_keys: if key not in input: return "Request missing field " + key - + return None -def get_model_metadata(request_data, additional_data): +def get_model_metadata(request_data, additional_data, session: Session): # Verify that we have the required metada verify_result = verify_key_presents(request_data, ["model_name", "model_version"]) if verify_result is not None: @@ -60,13 +63,13 @@ def get_model_metadata(request_data, additional_data): # Verify that the model already exists and gets its internal model id models_table_name = get_complete_table_name("model") - models_table = db.metadata.tables[models_table_name] + models_table = get_base().metadata.tables[models_table_name] model_name = request_data["model_name"] try: # Execute the select query models_select_statement = SELECT_STATEMENT(models_table) models_select_statement = models_select_statement.where(models_table.c.name == model_name) - models_result = db.session.execute(models_select_statement).all() + models_result = session.execute(models_select_statement).all() # Ensure we got a result and if so get the model id if len(models_result) > 0: @@ -75,7 +78,7 @@ def get_model_metadata(request_data, additional_data): except: error_msg = "Failed to get id for model " + model_name + " from table " + models_table_name + " due to error: " + traceback.format_exc() return False, error_msg - + # If not then insert the model and get its id if "internal_model_id" not in additional_data: try: @@ -84,8 +87,8 @@ def get_model_metadata(request_data, additional_data): } model_insert_statement = INSERT_STATEMENT(models_table).values(**model_insert_values) model_insert_statement = model_insert_statement.on_conflict_do_nothing(index_elements = list(model_insert_values.keys())) - result = db.session.execute(model_insert_statement) - db.session.commit() + result = session.execute(model_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -95,11 +98,11 @@ def get_model_metadata(request_data, additional_data): additional_data["internal_model_id"] = result_insert_keys[0] except: return False, "Failed to insert model " + model_name + " into table " + models_table_name + " due to error: " + traceback.format_exc() - + # Try to insert the model version into the the table model_version = str(request_data["model_version"]) versions_table_name = get_complete_table_name("model_version") - versions_table = db.metadata.tables[versions_table_name] + versions_table = get_base().metadata.tables[versions_table_name] try: # Try to insert the model version insert_request_values = { @@ -108,19 +111,19 @@ def get_model_metadata(request_data, additional_data): } version_insert_statement = INSERT_STATEMENT(versions_table).values(**insert_request_values) version_insert_statement = version_insert_statement.on_conflict_do_nothing(index_elements = list(insert_request_values.keys())) - db.session.execute(version_insert_statement) - db.session.commit() + session.execute(version_insert_statement) + session.commit() except: error_msg = "Failed to insert version " + str(model_version) + " for model " + model_name + " into table " + versions_table_name + " due to error: " + traceback.format_exc() return False, error_msg - + # Get the version id for this model and version try: # Execute the select query version_select_statement = SELECT_STATEMENT(versions_table) version_select_statement = version_select_statement.where(versions_table.c.model_id == additional_data["internal_model_id"]) version_select_statement = version_select_statement.where(versions_table.c.name == model_version) - version_result = db.session.execute(version_select_statement).all() + version_result = session.execute(version_select_statement).all() # Ensure we got a result if len(version_result) == 0: @@ -165,20 +168,20 @@ def find_link(bibjson): return None PUBLICATIONS_URL = "https://xdd.wisc.edu/api/articles" -def record_publication(source_text, request_additional_data): +def record_publication(source_text, session: Session): # See if already have a result for this publication paper_id = source_text["paper_id"] publication_table_name = get_complete_table_name("publication") - publications_table = db.metadata.tables[publication_table_name] + publications_table = get_base().metadata.tables[publication_table_name] found_existing_publication = False try: publication_select_statement = SELECT_STATEMENT(publications_table) publication_select_statement = publication_select_statement.where(publications_table.c.paper_id == paper_id) - publication_result = db.session.execute(publication_select_statement).all() + publication_result = session.execute(publication_select_statement).all() found_existing_publication = len(publication_result) > 0 except: return False, "Failed to check for paper " + paper_id + " in table " + publication_table_name + " due to error: " + traceback.format_exc() - + if found_existing_publication: return True, None @@ -203,25 +206,25 @@ def record_publication(source_text, request_additional_data): url = find_link(citation_json) if url is not None: insert_request_values["url"] = url - + # Make the insert request publication_insert_request = INSERT_STATEMENT(publications_table).values(**insert_request_values) - result = db.session.execute(publication_insert_request) - db.session.commit() + result = session.execute(publication_insert_request) + session.commit() except: - return False, "Failed to insert publication for paper " + paper_id + " into table " + publication_table_name + " due to error: " + traceback.format_exc() - + return False, "Failed to insert publication for paper " + paper_id + " into table " + publication_table_name + " due to error: " + traceback.format_exc() + return True, None -def get_weaviate_text_id(source_text, request_additional_data): +def get_weaviate_text_id(source_text, request_additional_data, session: Session): # Verify that we have the required fields required_source_fields = ["preprocessor_id", "paper_id", "hashed_text", "weaviate_id", "paragraph_text"] source_verify_result = verify_key_presents(source_text, required_source_fields) if source_verify_result is not None: return False, source_verify_result - + # First record the publication - sucess, error_msg = record_publication(source_text, request_additional_data) + sucess, error_msg = record_publication(source_text, request_additional_data, session) if not sucess: return sucess, error_msg @@ -229,18 +232,18 @@ def get_weaviate_text_id(source_text, request_additional_data): curr_text_type = source_text["text_type"] paragraph_hash = source_text["hashed_text"] sources_table_name = get_complete_table_name("source_text") - sources_table = db.metadata.tables[sources_table_name] + sources_table = get_base().metadata.tables[sources_table_name] try: # Get the sources values sources_values = {} for key_name in required_source_fields: sources_values[key_name] = source_text[key_name] sources_values["source_text_type"] = curr_text_type - + sources_insert_statement = INSERT_STATEMENT(sources_table).values(**sources_values) sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "paragraph_text"]) - db.session.execute(sources_insert_statement) - db.session.commit() + session.execute(sources_insert_statement) + session.commit() except: return False, "Failed to insert paragraph with weaviate id " + source_text["weaviate_id"] + " into table " + sources_table_name + " due to error: " + traceback.format_exc() @@ -249,7 +252,7 @@ def get_weaviate_text_id(source_text, request_additional_data): source_id_select_statement = SELECT_STATEMENT(sources_table.c.id) source_id_select_statement = source_id_select_statement.where(sources_table.c.source_text_type == curr_text_type) source_id_select_statement = source_id_select_statement.where(sources_table.c.paragraph_text == source_text["paragraph_text"]) - source_id_result = db.session.execute(source_id_select_statement).all() + source_id_result = session.execute(source_id_select_statement).all() # Ensure we got a result if len(source_id_result) == 0: @@ -259,7 +262,7 @@ def get_weaviate_text_id(source_text, request_additional_data): except: return False, "Failed to find internal source id for weaviate paragraph having hash " + paragraph_hash + " due to error: " + traceback.format_exc() -def get_map_description_id(source_text, request_additional_data): +def get_map_description_id(source_text, request_additional_data, session: Session): # Verify that we have the required fields required_source_fields = ["paragraph_text", "legend_id"] source_verify_result = verify_key_presents(source_text, required_source_fields) @@ -271,7 +274,7 @@ def get_map_description_id(source_text, request_additional_data): text_hash = str(hashlib.sha256(paragraph_text.encode("ascii")).hexdigest()) curr_text_type = source_text["text_type"] sources_table_name = get_complete_table_name("source_text") - sources_table = db.metadata.tables[sources_table_name] + sources_table = get_base().metadata.tables[sources_table_name] legend_id = source_text["legend_id"] try: # Get the sources values @@ -281,11 +284,11 @@ def get_map_description_id(source_text, request_additional_data): "map_legend_id" : legend_id, "hashed_text" : text_hash } - + sources_insert_statement = INSERT_STATEMENT(sources_table).values(**sources_values) sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "paragraph_text"]) - db.session.execute(sources_insert_statement) - db.session.commit() + session.execute(sources_insert_statement) + session.commit() except: return False, "Failed to insert paragraph with legend id " + str(legend_id) + " into table " + sources_table_name + " due to error: " + traceback.format_exc() @@ -294,7 +297,7 @@ def get_map_description_id(source_text, request_additional_data): source_id_select_statement = SELECT_STATEMENT(sources_table.c.id) source_id_select_statement = source_id_select_statement.where(sources_table.c.source_text_type == curr_text_type) source_id_select_statement = source_id_select_statement.where(sources_table.c.paragraph_text == source_text["paragraph_text"]) - source_id_result = db.session.execute(source_id_select_statement).all() + source_id_result = session.execute(source_id_select_statement).all() # Ensure we got a result if len(source_id_result) == 0: @@ -308,7 +311,7 @@ def get_map_description_id(source_text, request_additional_data): "weaviate_text" : get_weaviate_text_id, "map_descriptions" : get_map_description_id } -def get_source_text_id(source_text, request_additional_data): +def get_source_text_id(source_text, request_additional_data, session: Session): # Ensure that we have the required metadata fields text_required_text = verify_key_presents(source_text, ["text_type", "paragraph_text"]) if text_required_text is not None: @@ -319,8 +322,8 @@ def get_source_text_id(source_text, request_additional_data): text_type = source_text["text_type"] if text_type not in METHOD_TO_PROCESS_TEXT: return False, "Server currently doesn't support text of type " + text_type - - return METHOD_TO_PROCESS_TEXT[text_type](source_text, request_additional_data) + + return METHOD_TO_PROCESS_TEXT[text_type](source_text, request_additional_data, session) def get_lith_id(lithology): try: @@ -373,15 +376,15 @@ def get_strat_id(strat_name): "strat_name" : ("strat_name_id", get_strat_id) } -def get_entity_type_id(entity_type): +def get_entity_type_id(entity_type, session: Session): entity_type_table_name = get_complete_table_name("entity_type") - entity_type_table = db.metadata.tables[entity_type_table_name] + entity_type_table = get_base().metadata.tables[entity_type_table_name] # First try to get the entity type try: entity_type_id_select_statement = SELECT_STATEMENT(entity_type_table) entity_type_id_select_statement = entity_type_id_select_statement.where(entity_type_table.c.name == entity_type) - entity_type_result = db.session.execute(entity_type_id_select_statement).all() + entity_type_result = session.execute(entity_type_id_select_statement).all() # Ensure we got a result if len(entity_type_result) > 0: @@ -397,8 +400,8 @@ def get_entity_type_id(entity_type): } entity_type_insert_statement = INSERT_STATEMENT(entity_type_table).values(**entity_type_insert_values) entity_type_insert_statement = entity_type_insert_statement.on_conflict_do_nothing(index_elements = list(entity_type_insert_values.keys())) - result = db.session.execute(entity_type_insert_statement) - db.session.commit() + result = session.execute(entity_type_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -409,12 +412,12 @@ def get_entity_type_id(entity_type): except: return False, "Failed to insert entity type " + entity_type + " into table " + entity_type_table_name + " due to error: " + traceback.format_exc() -def get_entity_id(entity_name, entity_type, request_additional_data, provided_start_idx = None, provided_end_idx = None): +def get_entity_id(entity_name, entity_type, request_additional_data, session: Session, provided_start_idx = None, provided_end_idx = None): # Record the entity type - success, entity_type_id = get_entity_type_id(entity_type) + success, entity_type_id = get_entity_type_id(entity_type, session) if not success: return success, entity_type_id - + # Determine the values to write to the entities table entity_insert_request_values = { "name" : entity_name, @@ -442,12 +445,12 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st for idx in range(1, len(matches)): if matches[idx].dist < matches[best_match_idx].dist: best_match_idx = idx - + # Record the idx start_idx, end_idx = matches[best_match_idx].start, matches[best_match_idx].end curr_max_l_dist *= 2 - + # Record the results entity_start_idx, entity_end_idx, str_match_type = start_idx, end_idx, "fuzzy" @@ -467,19 +470,19 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st # Else ensure we can record the value if id_val is not None: entity_insert_request_values[key_name] = id_val - + # Insert in the result into the table entity_table_name = get_complete_table_name("entity") - entity_table = db.metadata.tables[entity_table_name] + entity_table = get_base().metadata.tables[entity_table_name] try: # Execute the request entity_insert_request = INSERT_STATEMENT(entity_table).values(**entity_insert_request_values) entity_insert_request = entity_insert_request.on_conflict_do_nothing(index_elements = ["name", "model_run_id", "entity_type_id", "start_index", "end_index"]) - result = db.session.execute(entity_insert_request) - db.session.commit() + result = session.execute(entity_insert_request) + session.commit() except: return False, "Failed to entity " + entity_name + " of type " + entity_type + " into table " + entity_table_name + " due to error:" + traceback.format_exc() - + # Get the entity id for the inserted value try: entity_select_statement = SELECT_STATEMENT(entity_table.c.id) @@ -488,7 +491,7 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st entity_select_statement = entity_select_statement.where(entity_table.c.entity_type_id == entity_insert_request_values["entity_type_id"]) entity_select_statement = entity_select_statement.where(entity_table.c.start_index == entity_insert_request_values["start_index"]) entity_select_statement = entity_select_statement.where(entity_table.c.end_index == entity_insert_request_values["end_index"]) - entity_id_result = db.session.execute(entity_select_statement).all() + entity_id_result = session.execute(entity_select_statement).all() # Ensure we got a result if len(entity_id_result) == 0: @@ -497,7 +500,7 @@ def get_entity_id(entity_name, entity_type, request_additional_data, provided_st return True, first_row["id"] except: return False, "Failed to find internal entity id for entity " + entity_name + " due to error: " + traceback.format_exc() - + def extract_indicies(request_values, expected_prefix): provided_start_idx, provided_end_idx = None, None start_search_term, end_search_term = expected_prefix + "start_idx", expected_prefix + "end_idx" @@ -519,32 +522,32 @@ def extract_indicies(request_values, expected_prefix): return False, "Failed to parse " + request_values[end_search_term] + " as an integer due an error " + traceback.format_exc() else: return False, f'Provided {start_search_term} but not {end_search_term} for entity ' + request_values["entity"] - + return True, (provided_start_idx, provided_end_idx) -def record_single_entity(entity, request_additional_data): +def record_single_entity(entity, request_additional_data, session: Session): # Ensure that we have the required metadata fields entity_required_fields = verify_key_presents(entity, ["entity", "entity_type"]) if entity_required_fields is not None: return False, entity_required_fields - + # See if the range is provided success, indicies_results = extract_indicies(entity, "") if not success: return success, indicies_results provided_start_idx, provided_end_idx = indicies_results - return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, provided_start_idx, provided_end_idx) + return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, provided_start_idx, provided_end_idx, session=session) -def get_relationship_type_id(relationship_type): +def get_relationship_type_id(relationship_type, session: Session): relationship_type_table_name = get_complete_table_name("relationship_type") - relationship_type_table = db.metadata.tables[relationship_type_table_name] + relationship_type_table = get_base().metadata.tables[relationship_type_table_name] - # First try to get the relationship type + # First try to get the relationship type try: relationship_type_id_select_statement = SELECT_STATEMENT(relationship_type_table) relationship_type_id_select_statement = relationship_type_id_select_statement.where(relationship_type_table.c.name == relationship_type) - relationship_type_result = db.session.execute(relationship_type_id_select_statement).all() + relationship_type_result = session.execute(relationship_type_id_select_statement).all() # Ensure we got a result if len(relationship_type_result) > 0: @@ -560,8 +563,8 @@ def get_relationship_type_id(relationship_type): } relationship_type_insert_statement = INSERT_STATEMENT(relationship_type_table).values(**relationship_type_insert_values) relationship_type_insert_statement = relationship_type_insert_statement.on_conflict_do_nothing(index_elements = list(relationship_type_insert_values.keys())) - result = db.session.execute(relationship_type_insert_statement) - db.session.commit() + result = session.execute(relationship_type_insert_statement) + session.commit() # Get the internal id result_insert_keys = result.inserted_primary_key @@ -577,12 +580,12 @@ def get_relationship_type_id(relationship_type): "strat" : ("strat_to_lith", "strat_name", "lith"), "att" : ("lith_to_attribute", "lith", "lith_att") } -def record_relationship(relationship, request_additional_data): +def record_relationship(relationship, request_additional_data, session: Session): # Ensure that we have the required metadata fields relationship_required_fields = verify_key_presents(relationship, ["src", "relationship_type", "dst"]) if relationship_required_fields is not None: return False, relationship_required_fields - + # Extract the types provided_relationship_type = relationship["relationship_type"] db_relationship_type, src_entity_type, dst_entity_type = provided_relationship_type, UNKNOWN_ENTITY_TYPE, UNKNOWN_ENTITY_TYPE @@ -590,12 +593,12 @@ def record_relationship(relationship, request_additional_data): if provided_relationship_type.startswith(key_name): db_relationship_type, src_entity_type, dst_entity_type = RELATIONSHIP_DETAILS[key_name] break - + # Record the relationship type - success, relationship_type_id = get_relationship_type_id(db_relationship_type) + success, relationship_type_id = get_relationship_type_id(db_relationship_type, session) if not success: return success, relationship_type_id - + # Extract the source indicies success, indicies_results = extract_indicies(relationship, "src_") if not success: @@ -603,23 +606,23 @@ def record_relationship(relationship, request_additional_data): src_provided_start_idx, src_provided_end_idx = indicies_results # Get the src entity ids - success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, src_provided_start_idx, src_provided_end_idx) + success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, src_provided_start_idx, src_provided_end_idx, session=session) if not success: return success, src_entity_id - + # Extract the dest indicies success, indicies_results = extract_indicies(relationship, "dst_") if not success: return success, indicies_results dst_provided_start_idx, dst_provided_end_idx = indicies_results - success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, dst_provided_start_idx, dst_provided_end_idx) + success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, dst_provided_start_idx, dst_provided_end_idx, session=session) if not success: return success, dst_entity_id # Now record the relationship relationship_table_name = get_complete_table_name("relationship") - relationship_table = db.metadata.tables[relationship_table_name] + relationship_table = get_base().metadata.tables[relationship_table_name] try: # Get the sources values relationship_insert_values = { @@ -631,27 +634,27 @@ def record_relationship(relationship, request_additional_data): if "reasoning" in relationship: relationship_insert_values["reasoning"] = relationship["reasoning"] - + relationship_insert_statement = INSERT_STATEMENT(relationship_table).values(**relationship_insert_values) relationship_insert_statement = relationship_insert_statement.on_conflict_do_nothing(index_elements = ["model_run_id", "src_entity_id", "dst_entity_id", "relationship_type_id"]) - db.session.execute(relationship_insert_statement) - db.session.commit() + session.execute(relationship_insert_statement) + session.commit() except: return False, "Failed to insert relationship with src " + relationship["src"] + " and dst " + relationship["dst"] + " into table " + relationship_table_name + " due to error: " + traceback.format_exc() return True, None -def get_previous_run(source_text_id): +def get_previous_run(source_text_id, session: Session): # Load the latest run table latest_run_table_name = get_complete_table_name("latest_run_per_text") - latest_run_table = db.metadata.tables[latest_run_table_name] + latest_run_table = get_base().metadata.tables[latest_run_table_name] # Get the latest for the current source text prev_run_id = None try: previous_id_for_source_select_statement = SELECT_STATEMENT(latest_run_table) previous_id_for_source_select_statement = previous_id_for_source_select_statement.where(latest_run_table.c.source_text_id == source_text_id) - previous_id_result = db.session.execute(previous_id_for_source_select_statement).all() + previous_id_result = session.execute(previous_id_for_source_select_statement).all() # Ensure we got a result if len(previous_id_result) > 0: @@ -662,7 +665,7 @@ def get_previous_run(source_text_id): return True, prev_run_id -def process_input_request(request_data): +def process_input_request(request_data, session): # Ensure that we have the required metadata fields run_verify_result = verify_key_presents(request_data, ["run_id", "results"]) if run_verify_result is not None: @@ -670,7 +673,7 @@ def process_input_request(request_data): # Get the model metadata for this model request_additional_data = {} - sucess, err_msg = get_model_metadata(request_data, request_additional_data) + sucess, err_msg = get_model_metadata(request_data, request_additional_data, session) if not sucess: return sucess, err_msg @@ -679,21 +682,21 @@ def process_input_request(request_data): all_results = request_data["results"] base_model_run_id = request_data["run_id"] model_run_table_name = get_complete_table_name("model_run") - model_run_table = db.metadata.tables[model_run_table_name] + model_run_table = get_base().metadata.tables[model_run_table_name] for idx, current_result in enumerate(all_results): # Ensure that we have the required metadata fields result_required_fields = verify_key_presents(current_result, ["text"]) if result_required_fields is not None: return False, result_required_fields - + # First get the source text id for this result - success, source_text_id = get_source_text_id(current_result["text"], request_additional_data) + success, source_text_id = get_source_text_id(current_result["text"], request_additional_data, session) if not success: return success, source_text_id # Then get the previous run for this result - success, previous_run_id = get_previous_run(source_text_id) + success, previous_run_id = get_previous_run(source_text_id, session) if not success: return success, previous_run_id @@ -712,8 +715,8 @@ def process_input_request(request_data): model_run_insert_request = INSERT_STATEMENT(model_run_table).values(**model_run_insert_values) model_run_insert_request = model_run_insert_request.on_conflict_do_update(constraint = "no_duplicate_runs", set_ = model_run_insert_values) - result = db.session.execute(model_run_insert_request) - db.session.commit() + result = session.execute(model_run_insert_request) + session.commit() # Now get the interal run id for this new run result_insert_keys = result.inserted_primary_key @@ -726,37 +729,72 @@ def process_input_request(request_data): # Now actually record the graph if "relationships" in current_result: for relationship in current_result["relationships"]: - sucessful, message = record_relationship(relationship, request_additional_data) + sucessful, message = record_relationship(relationship, request_additional_data, session) if not sucessful: return sucessful, message - + # Record just the entities if "just_entities" in current_result: for entity in current_result["just_entities"]: - sucessful, err_msg = record_single_entity(entity, request_additional_data) + sucessful, err_msg = record_single_entity(entity, request_additional_data, session) if not sucessful: return sucessful, err_msg - + return True, None # Opentially take in user id -@app.route("/record_run", methods=["POST"]) -def record_run(): +@app.post("/record_run") +async def record_run( + request: Request, + user_has_access: bool = Depends(has_access), + session: Session = Depends(get_session) +): + + if not user_has_access: + raise HTTPException(status_code=403, detail="User does not have access to record run") + # Record the run - request_data = request.get_json() - sucessful, error_msg = process_input_request(request_data) - if not sucessful: - print("Returning error of", error_msg) - return jsonify({"error" : error_msg}), 400 - return jsonify({"sucess" : "Sucessfully processed the run"}), 200 + request_data = await request.json() + + successful, error_msg = process_input_request(request_data, session) + if not successful: + raise HTTPException(status_code=400, detail=error_msg) + + return JSONResponse(content={"success": "Successfully processed the run"}) + + +@app.get("/health") +async def health( + session = Depends(get_session) +): + health_checks = {} -@app.route("/health", methods=["GET"]) -def health(): - """Health check endpoint""" + try: + session.execute(text("SELECT 1")) + except: + health_checks["database"] = False + else: + health_checks["database"] = True + + # Test that we can get metadata + try: + models_table_name = get_complete_table_name("model") + models_table = get_base().metadata.tables[models_table_name] + models_select_statement = SELECT_STATEMENT(models_table) + models_select_statement = models_select_statement.limit(1) + session.execute(models_select_statement) + except: + health_checks["metadata"] = False + else: + health_checks["metadata"] = True - return jsonify({"status": "Server Running"}), 200 + return JSONResponse(content={ + "webserver": "ok", + "healthy": health_checks + }) if __name__ == "__main__": - app.run(host = "0.0.0.0", port = 9543, debug = True) \ No newline at end of file + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=9543) diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py index 36b9bfa..f35adbe 100644 --- a/macrostrat_db_insertion/test.py +++ b/macrostrat_db_insertion/test.py @@ -22,16 +22,19 @@ def api_client() -> TestClient: class TestAPI: def test_insert(self, api_client: TestClient): - data = json.loads(open("example_request.json", "r").read()) + data = json.loads(open("example_requests/map_legend_examples/0.json", "r").read()) response = api_client.post( "/record_run", json=data ) + j = response.json() + assert response.status_code == 200 + class TestSecurity: def test_get_groups_from_header_token(self, api_client: TestClient): From cf1911e8a02d389c99cb4d5b52268020d0f7f452 Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Thu, 10 Oct 2024 13:41:55 -0500 Subject: [PATCH 4/6] Run all the tests --- macrostrat_db_insertion/test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py index f35adbe..ce0a9e2 100644 --- a/macrostrat_db_insertion/test.py +++ b/macrostrat_db_insertion/test.py @@ -1,6 +1,7 @@ # Test the insertion of data into the macrostrat database import json +import glob from types import SimpleNamespace import pytest @@ -22,17 +23,16 @@ def api_client() -> TestClient: class TestAPI: def test_insert(self, api_client: TestClient): - data = json.loads(open("example_requests/map_legend_examples/0.json", "r").read()) - response = api_client.post( - "/record_run", - json=data - ) + for file in glob.glob("example_requests/map_legend_examples/**/*.json"): + data = json.loads(open(file, "r").read()) - j = response.json() - - assert response.status_code == 200 + response = api_client.post( + "/record_run", + json=data + ) + assert response.status_code == 200 class TestSecurity: From 3b030cb4e72f1a62f9250970fce31e407fcf0faf Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Thu, 10 Oct 2024 13:44:52 -0500 Subject: [PATCH 5/6] Run all the tests --- macrostrat_db_insertion/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py index ce0a9e2..78fdc39 100644 --- a/macrostrat_db_insertion/test.py +++ b/macrostrat_db_insertion/test.py @@ -24,7 +24,7 @@ class TestAPI: def test_insert(self, api_client: TestClient): - for file in glob.glob("example_requests/map_legend_examples/**/*.json"): + for file in glob.glob("example_requests/**/*.json"): data = json.loads(open(file, "r").read()) response = api_client.post( From 3dc22a9eeefaf1181e9b43e2dc16529192ca8920 Mon Sep 17 00:00:00 2001 From: Cannon Lock Date: Thu, 10 Oct 2024 15:11:50 -0500 Subject: [PATCH 6/6] Run all the tests --- macrostrat_db_insertion/database.py | 2 +- macrostrat_db_insertion/server.py | 8 ++++---- macrostrat_db_insertion/test.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/macrostrat_db_insertion/database.py b/macrostrat_db_insertion/database.py index c871c68..9c57dab 100644 --- a/macrostrat_db_insertion/database.py +++ b/macrostrat_db_insertion/database.py @@ -25,7 +25,7 @@ def connect_engine(uri: str, schema: str): base = declarative_base() base.metadata.reflect(get_engine()) - base.metadata.reflect(get_engine(), schema=schema) + base.metadata.reflect(get_engine(), schema=schema, views=True) def dispose_engine(): diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index 8c58f02..bfd169e 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -168,7 +168,7 @@ def find_link(bibjson): return None PUBLICATIONS_URL = "https://xdd.wisc.edu/api/articles" -def record_publication(source_text, session: Session): +def record_publication(source_text, request_additional_data, session: Session): # See if already have a result for this publication paper_id = source_text["paper_id"] publication_table_name = get_complete_table_name("publication") @@ -537,7 +537,7 @@ def record_single_entity(entity, request_additional_data, session: Session): return success, indicies_results provided_start_idx, provided_end_idx = indicies_results - return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, provided_start_idx, provided_end_idx, session=session) + return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data, session, provided_start_idx, provided_end_idx) def get_relationship_type_id(relationship_type, session: Session): relationship_type_table_name = get_complete_table_name("relationship_type") @@ -606,7 +606,7 @@ def record_relationship(relationship, request_additional_data, session: Session) src_provided_start_idx, src_provided_end_idx = indicies_results # Get the src entity ids - success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, src_provided_start_idx, src_provided_end_idx, session=session) + success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data, session, src_provided_start_idx, src_provided_end_idx) if not success: return success, src_entity_id @@ -616,7 +616,7 @@ def record_relationship(relationship, request_additional_data, session: Session) return success, indicies_results dst_provided_start_idx, dst_provided_end_idx = indicies_results - success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, dst_provided_start_idx, dst_provided_end_idx, session=session) + success, dst_entity_id = get_entity_id(relationship["dst"], dst_entity_type, request_additional_data, session, dst_provided_start_idx, dst_provided_end_idx) if not success: return success, dst_entity_id diff --git a/macrostrat_db_insertion/test.py b/macrostrat_db_insertion/test.py index 78fdc39..d560ca7 100644 --- a/macrostrat_db_insertion/test.py +++ b/macrostrat_db_insertion/test.py @@ -32,7 +32,7 @@ def test_insert(self, api_client: TestClient): json=data ) - assert response.status_code == 200 + response.status_code == 200 class TestSecurity: