Skip to content

Commit

Permalink
Update how database env variables are consumed
Browse files Browse the repository at this point in the history
  • Loading branch information
CannonLock committed Sep 13, 2024
1 parent b5f31d1 commit d0c28d9
Showing 1 changed file with 25 additions and 32 deletions.
57 changes: 25 additions & 32 deletions macrostrat_db_insertion/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@ def load_config():
if env_variable_value is None:
raise Exception("Environment variable " + env_variable_name + " is not set")
config_values[required_name] = env_variable_value

return config_values

def load_flask_app(config):
# Create the app
# Create the app
app = Flask(__name__)
app.config["SQLALCHEMY_DATABASE_URI"] = sqlalchemy.URL.create(
"postgresql",
username = config["username"],
password = config["password"],
host = config["host"],
port = config["port"],
database = config["database"],
)
app.config["SQLALCHEMY_DATABASE_URI"] = os.environ['uri']

# Create the db
Base = declarative_base(metadata = sqlalchemy.MetaData(schema = config["schema"]))
Expand Down Expand Up @@ -86,7 +79,7 @@ def get_db_entity_id(run_id, entity_name, entity_type, source_id):
except:
error_msg = "Failed to insert entity " + entity_name + " for run " + str(run_id) + " due to error: " + traceback.format_exc()
return False, error_msg

# Get this entity id
entity_id = ""
try:
Expand All @@ -100,7 +93,7 @@ def get_db_entity_id(run_id, entity_name, entity_type, source_id):
# Ensure we got a result
if len(entities_result) == 0:
raise Exception("Got zero rows matching query " + str(entities_select_statement))

# Extract the sources id
first_row = entities_result[0]._mapping
entity_id = str(first_row["entity_id"])
Expand Down Expand Up @@ -131,16 +124,16 @@ def record_relationship(run_id, source_id, relationship):
if provided_relationship_type.startswith(key_name):
db_relationship_type, src_entity_type, dst_entity_type = RELATIONSHIP_DETAILS[key_name]
break

# Ignore this type
if len(db_relationship_type) == 0 or len(src_entity_type) == 0 or len(dst_entity_type) == 0:
return True, ""

# Get the entity ids
sucessful, src_entity_id = get_db_entity_id(run_id, relationship_values["src"], src_entity_type, source_id)
if not sucessful:
return sucessful, src_entity_id

sucessful, dst_entity_id = get_db_entity_id(run_id, relationship_values["dst"], dst_entity_type, source_id)
if not sucessful:
return sucessful, dst_entity_id
Expand All @@ -164,22 +157,22 @@ def record_relationship(run_id, source_id, relationship):
error_msg = "Failed to insert relationship type " + str(provided_relationship_type) + " for source " + str(source_id)
error_msg += " for run " + str(run_id) + " due to error: " + traceback.format_exc()
return False, error_msg

return True, ""

def record_for_result(run_id, request):
# Ensure txt exists
if "text" not in request:
return False, "result is missing text field"

source_fields = ["preprocessor_id", "paper_id", "hashed_text", "weaviate_id", "paragraph_text"]
source_values = {"run_id" : run_id}
text_data = request["text"]
for field_name in source_fields:
if field_name not in text_data:
return False, "Request text is missing field " + str(field_name)
source_values[field_name] = text_data[field_name]

# Remove non ascii data from text
paragraph_txt = source_values["paragraph_text"]
source_values["paragraph_text"] = paragraph_txt.encode("ascii", errors="ignore").decode()
Expand All @@ -195,11 +188,11 @@ def record_for_result(run_id, request):
error_msg = "Failed to insert paragraph " + str(source_values["weaviate_id"])
error_msg += " for run " + str(source_values["run_id"]) + " due to error: " + traceback.format_exc()
return False, error_msg

# Deal with case if we have no relationships
if "relationships" not in request:
return True, ""

# Get the sources id
source_id = ""
try:
Expand All @@ -212,15 +205,15 @@ def record_for_result(run_id, request):
# Ensure we got a result
if len(sources_result) == 0:
raise Exception("Got zero rows matching query " + str(sources_select_statement))

# Extract the sources id
first_row = sources_result[0]._mapping
source_id = str(first_row["source_id"])
except:
error_msg = "Failed to get sources id for paragraph " + str(source_values["weaviate_id"])
error_msg += " for run " + str(source_values["run_id"]) + " due to error: " + traceback.format_exc()
return False, error_msg

# Record the relationships
if "relationships" in request:
for relationship in request["relationships"]:
Expand All @@ -235,13 +228,13 @@ def record_for_result(run_id, request):
# Ensure that it has all the required keys
for key in required_entity_keys:
if key not in entity_data:
return False, "Provided just entities missing key " + str(key)
return False, "Provided just entities missing key " + str(key)

# Only record strats
entity_type = entity_data["entity_type"]
if not entity_type.startswith("strat"):
continue

# Record the entity
sucessful, entity_id = get_db_entity_id(run_id, entity_data["entity"], "strat_name", source_id)
if not sucessful:
Expand All @@ -265,7 +258,7 @@ def get_user_id(user_name):
except:
error_msg = "Failed to insert user " + user_name + " due to error: " + traceback.format_exc()
return False, error_msg

# Get this entity id
user_id = ""
try:
Expand All @@ -277,7 +270,7 @@ def get_user_id(user_name):
# Ensure we got a result
if len(users_result) == 0:
raise Exception("Got zero rows matching query " + str(users_select_statement))

# Extract the sources id
first_row = users_result[0]._mapping
user_id = str(first_row["user_id"])
Expand All @@ -295,7 +288,7 @@ def get_model_internal_details(request):
if field not in request:
return False, "Request missing field " + field
model_values[field] = str(request[field])

# Try to insert the model
model_name = model_values["model_name"]
models_tables = db.metadata.tables['macrostrat_kg_new.models']
Expand All @@ -322,14 +315,14 @@ def get_model_internal_details(request):
# Ensure we got a result
if len(models_result) == 0:
raise Exception("Got zero rows matching query " + str(models_select_statement))

# Extract the sources id
first_row = models_result[0]._mapping
data_to_return["internal_model_id"] = str(first_row["model_id"])
except:
error_msg = "Failed to get id for model " + model_name + " due to error: " + traceback.format_exc()
return False, error_msg

# Try to insert the model version
model_version = model_values["model_version"]
versions_table = db.metadata.tables['macrostrat_kg_new.model_versions']
Expand Down Expand Up @@ -357,7 +350,7 @@ def get_model_internal_details(request):
# Ensure we got a result
if len(version_result) == 0:
raise Exception("Got zero rows matching query " + str(version_select_statement))

# Extract the sources id
first_row = version_result[0]._mapping
data_to_return["internal_version_id"] = str(first_row["version_id"])
Expand Down Expand Up @@ -422,8 +415,8 @@ def record_run():
if not sucessful:
print("Returning error of", error_msg)
return jsonify({"error" : error_msg}), 400
return jsonify({"sucess" : "Sucessfully processed the run"}), 200

return jsonify({"sucess" : "Sucessfully processed the run"}), 200

if __name__ == "__main__":
app.run(host = "0.0.0.0", port = 9543, debug = True)

0 comments on commit d0c28d9

Please sign in to comment.