Skip to content

Commit

Permalink
Merge with updated documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Aghili committed Nov 12, 2023
2 parents 937364c + 4ca6d96 commit 44142b7
Show file tree
Hide file tree
Showing 20 changed files with 784 additions and 86 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Pylint

on: [push]

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint
- name: Analysing the code with pylint
run: |
pylint $(git ls-files '*.py')
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ input_doc/index.faiss
__pycache__
./lantern
data
google_sheets_credentials.json
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ Feel free to contribute to the development of StructHunt by submitting issues, f

## License

This project is licensed under the [MIT License](LICENSE).
This project is licensed under the [MIT License](LICENSE).
65 changes: 61 additions & 4 deletions VectorDatabase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,38 @@
import psycopg2
from fragment import Fragment
from publication import Publication


# Class to represent a publication with attributes id, title, pmc, pubmed, and doi
class Publication:
id = ""
title = ""
pmc = ""
pubmed = ""
doi = ""

def __init__(self, id, title, pmc, pubmed, doi):
self.id = id # (DOI) Unique identifier for the publication
self.title = title # Title of the publication
self.pmc = pmc # PubMed Central (PMC) Link
self.pubmed = pubmed # PubMed Link
self.doi = doi # Digital Object Identifier (DOI) Link for the publication


# Class to represent a fragment of a publication with attributes id, header, content, and vector
class Fragment:
# Class variables to store default values for attributes
id = ""
header = ""
content = ""
vector = ""

def __init__(self, id, header, content, vector):
# Constructor to initialize the attributes of the Fragment object

# Set the attributes of the object with the values provided during instantiation
self.id = id # (DOI) Unique identifier for the fragment
self.header = header # Header or title of the fragment
self.content = content # Content or text of the fragment
self.vector = vector # Vector representation of the fragment


# Lantern class that exposes functionality of database to application
Expand Down Expand Up @@ -231,14 +263,16 @@ def getAllFragmentsOfPublication(self, id):

"""
Retrieves unread publications from the 'publications' table.
Parameters:
- delete_unread_entries: bool, decides if entries are deleted from the "unread" table
Returns:
- List[Publication], a list of Publication objects representing the unread publications.
Notes:
- Performs a left join between 'publications' and 'unread' tables to retrieve unread publications.
- Clears the 'unread' table after retrieving the unread publications.
"""

def getUnreadPublications(self):
def getUnreadPublications(self, delete_unread_entries=True):
conn = self.conn
cursor = conn.cursor()

Expand All @@ -247,7 +281,9 @@ def getUnreadPublications(self):

publications = cursor.fetchall()

cursor.execute('DELETE FROM unread;')
if delete_unread_entries:
cursor.execute('DELETE FROM unread;')

conn.commit()
cursor.close()

Expand Down Expand Up @@ -281,3 +317,24 @@ def publicationExists(self, id):
cursor.close()

return count[0] == 1

"""
Fetches the content and embeddings of a publication by id
Parameters:
- id: Text, the unique identifier of the publication.
Returns:
- [(text, embedding)] content of a publication's embeddings
Notes:
"""

def get_embeddings_for_pub(self, id):
texts = []
embeddings = []
if not self.publicationExists(id):
return
fragments = self.getAllFragmentsOfPublication(id)
for fragment in fragments:
texts.append(fragment.content)
embeddings.append(fragment.vector)
text_embeddings = list(zip(texts, embeddings))
return text_embeddings
4 changes: 4 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"emails": ["aozalevsky@gmail.com", "steveurkel99@gmail.com"],
"DEBUG": false
}
190 changes: 190 additions & 0 deletions document_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import json
import re

from VectorDatabase import Lantern, Publication
from google_sheets import SheetsApiClient
from prompts import get_qbi_hackathon_prompt, METHODS_KEYWORDS

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain import PromptTemplate
from datetime import date
from langchain.vectorstores import FAISS


class DocumentAnalyzer:
"""Takes in a list of publications to analyze, then prompts the chatbot, processes the response,
aggregates the results, and reports the results to the spreadsheet
"""

CONFIG_PATH = "./config.json"

def __init__(self):
self.lantern = Lantern()
self.sheets = SheetsApiClient()
self.llm = LlmHandler()

self.email_addresses, self.notification_via_email = self.parse_config()

@staticmethod
def parse_config():
try:
with open(DocumentAnalyzer.CONFIG_PATH, 'r') as config_file:
config_data = json.load(config_file)

# Extracting fields from the config_data
my_list = config_data.get('emails', []) # Default to an empty list if 'my_list' is not present
my_bool = config_data.get('DEBUG', False) # Default to False if 'my_bool' is not present

return my_list, my_bool

except FileNotFoundError:
print(f"Config file '{DocumentAnalyzer.CONFIG_PATH}' not found. Using defaults (no email addresses)")
return [], False
except json.JSONDecodeError as e:
print(f"Error decoding JSON in '{DocumentAnalyzer.CONFIG_PATH}': {e}")
return None, None

def analyze_all_unread(self):
"""pulls all new files from Lantern database, evaluates them, and publishes results to google sheets
"""
publications = self.lantern.getUnreadPublications()
self.process_publications(publications)

def process_publications(self, publications: [Publication]):
"""takes a list of publications, applies retrievalQA and processes responses
NOTE: completely untested, just refactored code from hackathon
Args:
publications ([]): list of publications
"""

rows = []
hits = 0
for pub in publications:
text_embeddings = self.lantern.get_embeddings_for_pub(pub.id)
classification, response = 0, ''
if self.paper_about_cryoem(text_embeddings):
classification, response = self.analyze_publication(text_embeddings)
hits += classification
else:
# print('paper not about cryo-em')
pass
# add date if it's added
rows.append([pub.doi, pub.title, "", str(date.today()), "", int(classification), response, ""])

self.update_spreadsheet(rows, hits)

def update_spreadsheet(self, rows: [], hits: int):
"""pushes a list of rows to the spreadsheet and notifies via email
Args:
rows ([]): rows of data to be uploaded to sheet
hits (int): number of positive classifications in the rows
"""
if hits > len(rows):
raise ValueError(f"Number of hits ({hits}) is greater than the number of entries ({len(rows)}), sus")

self.sheets.append_rows(rows)

if self.notification_via_email:
self.email(hits, len(rows))

def email(self, hits: int, total: int):
msg = f"""
This batch of paper analysis has concluded.
{total} papers were analyzed in total over the date range 11/2 - 11/3
{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data"""

self.sheets.email(msg, self.email_addresses)

def analyze_publication(self, text_embeddings: []):
"""poses a question about the document, processes the result and returns it
NOTE: for now, only uses the hackathon question, might add more later
Args:
text_embeddings ([]): list of (embedding, text) pairs from document to be analyzed
Returns:
bool: classification of response to query as positive (True) or negative (False)
str: response from chatGPT
"""
# NOTE: These very likely need to change
open_ai_emb = OpenAIEmbeddings()
query = get_qbi_hackathon_prompt(METHODS_KEYWORDS)
faiss_index = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb)
response = self.llm.evaluate_queries(faiss_index, query)[0]
return self.classify_response(response), response

@staticmethod
def classify_response(response: str):
"""converting text response from GPT into boolean
Args:
response (str): response from ChatGPT to the query
Returns:
bool: True if answer to question is "yes"
"""
if response is None:
return False
# this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used...",
# which is wrong because we asked it about
# inclusion of non-cryo-em stuff
#
# if "cryo" in response.lower():
# return (False, None)
return response.lower().startswith('yes')

@staticmethod
def paper_about_cryoem(text_embeddings: []):
"""checks if the string "cryoem" or "cryo-em" is present in the text
Args:
text_embeddings [(text, embedding)]: text and embeddings of a publication
Returns:
bool: True if the text mentions cryo-em
"""
return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in text_embeddings)


class LlmHandler:
"""Handles creation of langchain and evaluation of queries
"""

def __init__(self):
self.llm = ChatOpenAI(
temperature=0, model_name="gpt-4", max_tokens=300, request_timeout=30, max_retries=3
)

def evaluate_queries(self, embedding, queries):
chatbot = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k": 3})
)

template = """ {query}? """
responses = []
for q in queries:
prompt = PromptTemplate(
input_variables=["query"],
template=template,
)

responses.append(chatbot.run(
prompt.format(query=q)
))
return responses


def main():
document_analyzer = DocumentAnalyzer()
#document_analyzer.analyze_all_unread() # analyzes all new files in lantern db


if __name__ == '__main__':
main()
18 changes: 0 additions & 18 deletions fragment.py

This file was deleted.

3 changes: 0 additions & 3 deletions get_database.sh

This file was deleted.

Loading

0 comments on commit 44142b7

Please sign in to comment.