diff --git a/macrostrat_db_insertion/server.py b/macrostrat_db_insertion/server.py index c2092d8..efd0270 100644 --- a/macrostrat_db_insertion/server.py +++ b/macrostrat_db_insertion/server.py @@ -285,34 +285,52 @@ def get_strat_id(strat_name): "strat_name" : ("strat_name_id", get_strat_id) } -def record_entity_type(entity_type): - # Record this is an entity type +def get_entity_type_id(entity_type): entity_type_table_name = get_complete_table_name("entity_type") entity_type_table = db.metadata.tables[entity_type_table_name] + + # First try to get the entity type + try: + entity_type_id_select_statement = SELECT_STATEMENT(entity_type_table) + entity_type_id_select_statement = entity_type_id_select_statement.where(entity_type_table.c.name == entity_type) + entity_type_result = db.session.execute(entity_type_id_select_statement).all() + + # Ensure we got a result + if len(entity_type_result) > 0: + first_row = entity_type_result[0]._mapping + return True, first_row["id"] + except: + return False, "Failed to find entity type " + entity_type + " in table " + entity_type_table_name + " due to error: " + traceback.format_exc() + + # Try to insert the entity type try: - # Try to insert the model version entity_type_insert_values = { "name" : entity_type, } entity_type_insert_statement = INSERT_STATEMENT(entity_type_table).values(**entity_type_insert_values) entity_type_insert_statement = entity_type_insert_statement.on_conflict_do_nothing(index_elements = list(entity_type_insert_values.keys())) - db.session.execute(entity_type_insert_statement) + result = db.session.execute(entity_type_insert_statement) db.session.commit() + + # Get the internal id + result_insert_keys = result.inserted_primary_key + if len(result_insert_keys) == 0: + return False, "Insert statement " + str(entity_type_insert_statement) + " returned zero primary keys" + + return True, result_insert_keys[0] except: return False, "Failed to insert entity type " + entity_type + " into table " + entity_type_table_name + " due to error: " + traceback.format_exc() - return True, None - def get_entity_id(entity_name, entity_type, request_additional_data): # Record the entity type - success, err_msg = record_entity_type(entity_type) + success, entity_type_id = get_entity_type_id(entity_type) if not success: - return success, err_msg + return success, entity_type_id # Determine the values to write to the entities table entity_insert_request_values = { "name" : entity_name, - "type" : entity_type, + "entity_type_id" : entity_type_id, "model_run_id" : request_additional_data["internal_run_id"], "source_id" : request_additional_data["internal_source_id"] } @@ -365,24 +383,42 @@ def record_single_entity(entity, request_additional_data): return get_entity_id(entity["entity"], entity["entity_type"], request_additional_data) -def record_relationship_type(relationship_type): - # Record this is an entity type +def get_relationship_type_id(relationship_type): relationship_type_table_name = get_complete_table_name("relationship_type") relationship_type_table = db.metadata.tables[relationship_type_table_name] + + # First try to get the relationship type + try: + relationship_type_id_select_statement = SELECT_STATEMENT(relationship_type_table) + relationship_type_id_select_statement = relationship_type_id_select_statement.where(relationship_type_table.c.name == relationship_type) + relationship_type_result = db.session.execute(relationship_type_id_select_statement).all() + + # Ensure we got a result + if len(relationship_type_result) > 0: + first_row = relationship_type_result[0]._mapping + return True, first_row["id"] + except: + return False, "Failed to find entity type " + relationship_type + " in table " + relationship_type_table_name + " due to error: " + traceback.format_exc() + + # If not try to insert the relationship type into the table try: - # Try to insert the model version relationship_type_insert_values = { "name" : relationship_type, } relationship_type_insert_statement = INSERT_STATEMENT(relationship_type_table).values(**relationship_type_insert_values) relationship_type_insert_statement = relationship_type_insert_statement.on_conflict_do_nothing(index_elements = list(relationship_type_insert_values.keys())) - db.session.execute(relationship_type_insert_statement) + result = db.session.execute(relationship_type_insert_statement) db.session.commit() + + # Get the internal id + result_insert_keys = result.inserted_primary_key + if len(result_insert_keys) == 0: + return False, "Insert statement " + str(relationship_type_insert_statement) + " returned zero primary keys" + + return True, result_insert_keys[0] except: return False, "Failed to insert entity type " + relationship_type + " into table " + relationship_type_table_name + " due to error: " + traceback.format_exc() - return True, None - RELATIONSHIP_DETAILS = { "strat" : ("strat_to_lith", "strat_name", "lith"), "att" : ("lith_to_attribute", "lith", "lith_att") @@ -406,9 +442,9 @@ def record_relationship(relationship, request_additional_data): return True, "" # Record the relationship type - success, err_msg = record_relationship_type(db_relationship_type) + success, relationship_type_id = get_relationship_type_id(db_relationship_type) if not success: - return success, err_msg + return success, relationship_type_id # Get the entity ids success, src_entity_id = get_entity_id(relationship["src"], src_entity_type, request_additional_data) @@ -425,7 +461,7 @@ def record_relationship(relationship, request_additional_data): try: # Get the sources values relationship_insert_values = { - "type" : db_relationship_type, + "relationship_type_id" : relationship_type_id, "model_run_id" : request_additional_data["internal_run_id"], "source_id" : request_additional_data["internal_source_id"], "src_entity_id" : src_entity_id, @@ -532,7 +568,7 @@ def process_input_request(request_data): def record_run(): # Record the run request_data = request.get_json() - print("Got request of", request_data) + print("Got request", request_data) sucessful, error_msg = process_input_request(request_data) if not sucessful: print("Returning error of", error_msg)