diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000..383e65c --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Pylint + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pylint + - name: Analysing the code with pylint + run: | + pylint $(git ls-files '*.py') diff --git a/.gitignore b/.gitignore index 662ecf4..10b4283 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ input_doc/index.faiss __pycache__ ./lantern data +google_sheets_credentials.json diff --git a/README.md b/README.md index edc75ba..64fac56 100644 --- a/README.md +++ b/README.md @@ -53,4 +53,4 @@ Feel free to contribute to the development of StructHunt by submitting issues, f ## License -This project is licensed under the [MIT License](LICENSE). \ No newline at end of file +This project is licensed under the [MIT License](LICENSE). diff --git a/VectorDatabase.py b/VectorDatabase.py index 395bdc6..f924c0d 100644 --- a/VectorDatabase.py +++ b/VectorDatabase.py @@ -1,6 +1,38 @@ import psycopg2 -from fragment import Fragment -from publication import Publication + + +# Class to represent a publication with attributes id, title, pmc, pubmed, and doi +class Publication: + id = "" + title = "" + pmc = "" + pubmed = "" + doi = "" + + def __init__(self, id, title, pmc, pubmed, doi): + self.id = id # (DOI) Unique identifier for the publication + self.title = title # Title of the publication + self.pmc = pmc # PubMed Central (PMC) Link + self.pubmed = pubmed # PubMed Link + self.doi = doi # Digital Object Identifier (DOI) Link for the publication + + +# Class to represent a fragment of a publication with attributes id, header, content, and vector +class Fragment: + # Class variables to store default values for attributes + id = "" + header = "" + content = "" + vector = "" + + def __init__(self, id, header, content, vector): + # Constructor to initialize the attributes of the Fragment object + + # Set the attributes of the object with the values provided during instantiation + self.id = id # (DOI) Unique identifier for the fragment + self.header = header # Header or title of the fragment + self.content = content # Content or text of the fragment + self.vector = vector # Vector representation of the fragment # Lantern class that exposes functionality of database to application @@ -231,6 +263,8 @@ def getAllFragmentsOfPublication(self, id): """ Retrieves unread publications from the 'publications' table. + Parameters: + - delete_unread_entries: bool, decides if entries are deleted from the "unread" table Returns: - List[Publication], a list of Publication objects representing the unread publications. Notes: @@ -238,7 +272,7 @@ def getAllFragmentsOfPublication(self, id): - Clears the 'unread' table after retrieving the unread publications. """ - def getUnreadPublications(self): + def getUnreadPublications(self, delete_unread_entries=True): conn = self.conn cursor = conn.cursor() @@ -247,7 +281,9 @@ def getUnreadPublications(self): publications = cursor.fetchall() - cursor.execute('DELETE FROM unread;') + if delete_unread_entries: + cursor.execute('DELETE FROM unread;') + conn.commit() cursor.close() @@ -281,3 +317,24 @@ def publicationExists(self, id): cursor.close() return count[0] == 1 + + """ + Fetches the content and embeddings of a publication by id + Parameters: + - id: Text, the unique identifier of the publication. + Returns: + - [(text, embedding)] content of a publication's embeddings + Notes: + """ + + def get_embeddings_for_pub(self, id): + texts = [] + embeddings = [] + if not self.publicationExists(id): + return + fragments = self.getAllFragmentsOfPublication(id) + for fragment in fragments: + texts.append(fragment.content) + embeddings.append(fragment.vector) + text_embeddings = list(zip(texts, embeddings)) + return text_embeddings diff --git a/config.json b/config.json new file mode 100644 index 0000000..f40fb96 --- /dev/null +++ b/config.json @@ -0,0 +1,4 @@ +{ + "emails": ["aozalevsky@gmail.com", "steveurkel99@gmail.com"], + "DEBUG": false +} \ No newline at end of file diff --git a/document_analysis.py b/document_analysis.py new file mode 100644 index 0000000..0326e24 --- /dev/null +++ b/document_analysis.py @@ -0,0 +1,190 @@ +import json +import re + +from VectorDatabase import Lantern, Publication +from google_sheets import SheetsApiClient +from prompts import get_qbi_hackathon_prompt, METHODS_KEYWORDS + +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.chat_models import ChatOpenAI +from langchain.chains import RetrievalQA +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain import PromptTemplate +from datetime import date +from langchain.vectorstores import FAISS + + +class DocumentAnalyzer: + """Takes in a list of publications to analyze, then prompts the chatbot, processes the response, + aggregates the results, and reports the results to the spreadsheet + """ + + CONFIG_PATH = "./config.json" + + def __init__(self): + self.lantern = Lantern() + self.sheets = SheetsApiClient() + self.llm = LlmHandler() + + self.email_addresses, self.notification_via_email = self.parse_config() + + @staticmethod + def parse_config(): + try: + with open(DocumentAnalyzer.CONFIG_PATH, 'r') as config_file: + config_data = json.load(config_file) + + # Extracting fields from the config_data + my_list = config_data.get('emails', []) # Default to an empty list if 'my_list' is not present + my_bool = config_data.get('DEBUG', False) # Default to False if 'my_bool' is not present + + return my_list, my_bool + + except FileNotFoundError: + print(f"Config file '{DocumentAnalyzer.CONFIG_PATH}' not found. Using defaults (no email addresses)") + return [], False + except json.JSONDecodeError as e: + print(f"Error decoding JSON in '{DocumentAnalyzer.CONFIG_PATH}': {e}") + return None, None + + def analyze_all_unread(self): + """pulls all new files from Lantern database, evaluates them, and publishes results to google sheets + """ + publications = self.lantern.getUnreadPublications() + self.process_publications(publications) + + def process_publications(self, publications: [Publication]): + """takes a list of publications, applies retrievalQA and processes responses + NOTE: completely untested, just refactored code from hackathon + + Args: + publications ([]): list of publications + """ + + rows = [] + hits = 0 + for pub in publications: + text_embeddings = self.lantern.get_embeddings_for_pub(pub.id) + classification, response = 0, '' + if self.paper_about_cryoem(text_embeddings): + classification, response = self.analyze_publication(text_embeddings) + hits += classification + else: + # print('paper not about cryo-em') + pass + # add date if it's added + rows.append([pub.doi, pub.title, "", str(date.today()), "", int(classification), response, ""]) + + self.update_spreadsheet(rows, hits) + + def update_spreadsheet(self, rows: [], hits: int): + """pushes a list of rows to the spreadsheet and notifies via email + + Args: + rows ([]): rows of data to be uploaded to sheet + hits (int): number of positive classifications in the rows + """ + if hits > len(rows): + raise ValueError(f"Number of hits ({hits}) is greater than the number of entries ({len(rows)}), sus") + + self.sheets.append_rows(rows) + + if self.notification_via_email: + self.email(hits, len(rows)) + + def email(self, hits: int, total: int): + msg = f""" +This batch of paper analysis has concluded. +{total} papers were analyzed in total over the date range 11/2 - 11/3 +{hits} {"were" if (hits != 1) else "was"} classified as having multi-method structural data""" + + self.sheets.email(msg, self.email_addresses) + + def analyze_publication(self, text_embeddings: []): + """poses a question about the document, processes the result and returns it + NOTE: for now, only uses the hackathon question, might add more later + + Args: + text_embeddings ([]): list of (embedding, text) pairs from document to be analyzed + + Returns: + bool: classification of response to query as positive (True) or negative (False) + str: response from chatGPT + """ + # NOTE: These very likely need to change + open_ai_emb = OpenAIEmbeddings() + query = get_qbi_hackathon_prompt(METHODS_KEYWORDS) + faiss_index = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb) + response = self.llm.evaluate_queries(faiss_index, query)[0] + return self.classify_response(response), response + + @staticmethod + def classify_response(response: str): + """converting text response from GPT into boolean + + Args: + response (str): response from ChatGPT to the query + + Returns: + bool: True if answer to question is "yes" + """ + if response is None: + return False + # this was used to filter out cases where ChatGPT said "Yes, Cryo-EM was used...", + # which is wrong because we asked it about + # inclusion of non-cryo-em stuff + # + # if "cryo" in response.lower(): + # return (False, None) + return response.lower().startswith('yes') + + @staticmethod + def paper_about_cryoem(text_embeddings: []): + """checks if the string "cryoem" or "cryo-em" is present in the text + + Args: + text_embeddings [(text, embedding)]: text and embeddings of a publication + + Returns: + bool: True if the text mentions cryo-em + """ + return any(re.search("cryo-?em", text, re.IGNORECASE) for text, _ in text_embeddings) + + +class LlmHandler: + """Handles creation of langchain and evaluation of queries + """ + + def __init__(self): + self.llm = ChatOpenAI( + temperature=0, model_name="gpt-4", max_tokens=300, request_timeout=30, max_retries=3 + ) + + def evaluate_queries(self, embedding, queries): + chatbot = RetrievalQA.from_chain_type( + llm=self.llm, + chain_type="stuff", + retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k": 3}) + ) + + template = """ {query}? """ + responses = [] + for q in queries: + prompt = PromptTemplate( + input_variables=["query"], + template=template, + ) + + responses.append(chatbot.run( + prompt.format(query=q) + )) + return responses + + +def main(): + document_analyzer = DocumentAnalyzer() + #document_analyzer.analyze_all_unread() # analyzes all new files in lantern db + + +if __name__ == '__main__': + main() diff --git a/fragment.py b/fragment.py deleted file mode 100644 index fad1193..0000000 --- a/fragment.py +++ /dev/null @@ -1,18 +0,0 @@ -# Class to represent a fragment of a publication with attributes id, header, content, and vector -class Fragment: - - - # Class variables to store default values for attributes - id = "" - header = "" - content = "" - vector = "" - - def __init__(self, id, header, content, vector): - # Constructor to initialize the attributes of the Fragment object - - # Set the attributes of the object with the values provided during instantiation - self.id = id # (DOI) Unique identifier for the fragment - self.header = header # Header or title of the fragment - self.content = content # Content or text of the fragment - self.vector = vector # Vector representation of the fragment diff --git a/get_database.sh b/get_database.sh deleted file mode 100755 index 3790b3f..0000000 --- a/get_database.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -sudo -u postgres pg_dump structdb > structdb.sql \ No newline at end of file diff --git a/google_sheets.py b/google_sheets.py index 10533d0..0d0e38d 100644 --- a/google_sheets.py +++ b/google_sheets.py @@ -1,61 +1,103 @@ +import os import gspread -class SpreadsheetUpdater(): - - credentials = { - "type": "service_account", - "project_id": "durable-sky-396700", - "private_key_id": "5735737884fa17f981beceb424001445e2476ae3", - "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDcN1SaQ3pgvu/Z\nYL4VHQSoztL5gUBtDFHvICFGLMBLwxj1SI9zWeI0uH4QvW32QkqTwnYO8cW2jCpZ\n7Lu2ZDdW5copVt3p7GCacYC++hYjH8Y13FSotE0yWpxh8qIQexzgTcHenrCr8nAd\nhkeHyNwwUpmjOASIqOtIHj7cGqp43jxMSuwh8fK94ef+Aemo5+7h+tXlHqwMIFap\nyjdE4TNdJ+mYp2nm17PUYiP0Y+WbYEOeeo29So9P/Ir1gMpH5Fyu3RcBI7jloFZd\nM8Hfrj17KKMnVOKItfnYeLlFaBDSfgYP17v9NUfzHVUSDEa2T67mSfUG63aIiYSm\nXrtAuignAgMBAAECggEACDHe6hnjIfQfazcLm8mHdNvnEFxCkExKRQ9f1AN/HGw9\nyR+47UTF0DE7yVYWed8gDon8Aef2JyoY7ioksILfzeuhld9vq3BqbK59aTeK2PL2\n80yOfsCtTSRmEWPWeBQjKcDhaAfLva2F7CaLeH59aY1WJLOSJ57xmOHXQP8uozsm\nm8dMs1PgEELl2B2zc+6JtHnWH2CAxZiA2b9yh+iZi3kiaJyLIW1bgx4U3suDnsFc\n+Igk+AYsIZ4UuPwFxlb+2mvYpiZd/Br0ASnBDQvgXDA4Xlu8wBeukfun8VZOviE4\nFjdxYkHMLeCsu15Xsc3E3UOt8wIXbr6b9Wi7mitY4QKBgQD7oMvTGl3SMMBynYHx\nYbbqW15UksGx6oPXBeUsqmCc64qBTiTZwLh2gY0TOELa0Evlf5kVtEVvMPviFynf\nvEvEc7ZV9rqkpgD1YRk2oi98wgAPG2/xU/asdSNblLVT/tK21/a7agDuyh6CfRO6\nfzQf4GYITKjtI34kjkEYa4Y6TwKBgQDgCtGNu2TpqKhle6v59EQzyB5aUEg4LdJn\n+YTGppohtbtbbW2N4nhoOi+ibLtN0dIDetfSdZtXe7CC13WSVS7T9QAOC/u6g0rj\nQstqktfUUyasIPYKdWe64rNtNJkIW+x+bgz2p8fOTGKwkTSFUFtPYwvajwIUv6Zc\n2/Vjtt82qQKBgBKLcTonsU5yZVyNGyyNBQwUm8kj376bCAhq2M8H54LpIRYSiki6\nGV4yghEujk7OFync041z8cIWHBo3ltB0cikSVhfTzUGhMmTjORZ7sYBCU/rJDOD+\nTSm8oFR5izubhjAPjpGVaGgw4TrAuRl/knne8eYesDx55ywOh+Gi2wulAoGBANdW\nrqnKtyi6mfjI0LhzpmYa78mgpnmQ2U5kjtEc6sKB2S38VLNuPIr5ejVkyvb2OCRu\nGyjHL2L7mOF51CCtTVAeiUn3DKHtdbpPxhKOR3Jl5aLGH5ZX2DbRlOHfD0PwjrPK\ndR1SkIJh+u1484E7hjgcnBUbJUXqGy3foNGRwKPZAoGAZ1Ig6vyIbZk9Lnh08COS\nOQ6JrTEdDCfr1i3CapHAW+rN6oHlM+S7PmzTFuxrWhGAHDWDOBczrPa+ohUAmLWa\niSJDC+bBJvj/L0jD4qIm39ifDCSyZfoAkshvpEPe010tw3IuO64pV9wowbwyu+wN\nieOoIE/RPaDtfFb2IZG7pGA=\n-----END PRIVATE KEY-----\n", - "client_email": "csv-updater@durable-sky-396700.iam.gserviceaccount.com", - "client_id": "116349894744257971396", - "auth_uri": "https://accounts.google.com/o/oauth2/auth", - "token_uri": "https://oauth2.googleapis.com/token", - "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/csv-updater%40durable-sky-396700.iam.gserviceaccount.com", - "universe_domain": "googleapis.com" - } - SPREADSHEET_NAME = "PDB-DEV_ChatGPT" +class SheetsApiClient: + """interface for all functionality with google sheets + enables connection, append, and notification + """ - def __init__(self): - self.gc = gspread.service_account_from_dict(type(self).credentials) - sh = self.gc.open(type(self).SPREADSHEET_NAME) - self.worksheet = sh.get_worksheet(0) + SPREADSHEET_NAME = "PDB-DEV_ChatGPT" + SCHEMA = [ + "DOI", + "Title", + "date of publishing", + "date of analysis", + "authors", + "classification", + "methods used", + "software", + ] + + def __init__(self): + self.client = self.connect() + self.spreadsheet = self.client.open(type(self).SPREADSHEET_NAME) + self.worksheet = self.spreadsheet.get_worksheet(0) + + @staticmethod + def connect(): + """connects to Google Sheets API service using private key file + """ + try: + secret_file = os.path.join(os.getcwd(), "google_sheets_credentials.json") + return gspread.service_account(secret_file) + except OSError as e: + print(e) def append_row(self, row: [str]): - ''' - Adds a row to the spreadsheet, must follow schema: - ['DOI', 'Title', 'date of publishing', 'date of analysis', 'authors', 'classification', 'methods used', 'software'] - ''' + """ + Adds a row to the spreadsheet, must follow SCHEMA: + """ self._check_row(row) self.worksheet.append_row(row) - + def append_rows(self, rows: [[str]]): - ''' - Adds a list of rows to the spreadsheet, each row must follow schema: - ['DOI', 'Title', 'date of publishing', 'date of analysis', 'authors', 'classification', 'methods used', 'software'] - ''' + """ + Adds a list of rows to the spreadsheet, each row must follow SCHEMA: + WARNING: Assumes that the [rows] list will never exceed the maximum throughput of one api call + """ for row in rows: self._check_row(row) self.worksheet.append_rows(rows) - + + def email(self, message: str, email_addresses: [str]): + """Shares the spreadsheet with arthur, along with the message in an email + Args: + message (str): message to be sent + email_addresses ([str]): recipients of notification + """ + for email_address in email_addresses: + self.spreadsheet.share( + email_address, + perm_type="user", + role="reader", + notify=True, + email_message=message, + ) + @staticmethod - def _check_row(row: [str]): - if len(row) != 8: - raise ValueError("Row must have 8 fields in the order specified") + def _check_row(row: []): + """Checks row - def notify_arthur(self, message: str): - self.sh.share('aozalevsky@gmail.com', perm_type='user', role='writer', notify=True, email_message=message) + Args: + row ([]): row of values to be added to worksheet + Raises: + ValueError: number of values in rows doesn't match schema + """ + if len(row) != len(SheetsApiClient.SCHEMA): + raise ValueError( + f"Row must have {len(SheetsApiClient.SCHEMA)} fields in the order specified\n{SheetsApiClient.SCHEMA}" + ) def main(): - spread = SpreadsheetUpdater() - dummy_row = ['DOI', 'Title', 'date of publishing', 'date of analysis', 'authors', 'classification', 'methods used', 'software'] + # some test code which initializes the client, then appends rows to the worksheet, then pings arthur + spread = SheetsApiClient() + dummy_row = [ + "DOI", + "Title", + "date of publishing", + "date of analysis", + "authors", + "classification", + "methods used", + "software", + ] spread.append_row(dummy_row) spread.append_rows([dummy_row, dummy_row, dummy_row]) - spread.notify_arthur("testing out my dope code") + # spread.notify_arthur("testing out the code") + -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/hackathon_runner.py b/hackathon_runner.py new file mode 100644 index 0000000..27b8319 --- /dev/null +++ b/hackathon_runner.py @@ -0,0 +1,268 @@ + +import os +import pandas as pd +import PyPDF2 +from paperscraper.pdf import save_pdf +from paperscraper.get_dumps import biorxiv + +from VectorDatabase import Lantern, Fragment, Publication +import openai +from langchain.document_loaders.csv_loader import CSVLoader +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.vectorstores import FAISS +from langchain.document_loaders import TextLoader + +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.vectorstores import FAISS +from langchain.chat_models import ChatOpenAI +from langchain.chains import RetrievalQA +from langchain import PromptTemplate +import PyPDF2 + +keywords_groups = { + 'CX-MS': ['cross-link', 'crosslink', 'XL-MS', 'CX-MS', 'CL-MS', 'XLMS', 'CXMS', 'CLMS', "chemical crosslinking mass spectrometry", 'photo-crosslinking', 'crosslinking restraints', 'crosslinking-derived restraints', 'chemical crosslinking', 'in vivo crosslinking', 'crosslinking data'], + 'HDX': ['Hydrogen–deuterium exchange mass spectrometry', 'Hydrogen/deuterium exchange mass spectrometry' 'HDX', 'HDXMS', 'HDX-MS'], + 'EPR': ['electron paramagnetic resonance spectroscopy', 'EPR', 'DEER', "Double electron electron resonance spectroscopy"], + 'FRET': ['FRET', "forster resonance energy transfer", "fluorescence resonance energy transfer"], + 'AFM': ['AFM', "atomic force microscopy" ], + 'SAS': ['SAS', 'SAXS', 'SANS', "Small angle solution scattering", "solution scattering", "SEC-SAXS", "SEC-SAS", "SASBDB", "Small angle X-ray scattering", "Small angle neutron scattering"], + '3DGENOME': ['HiC', 'Hi-C', "chromosome conformation capture"], + 'Y2H': ['Y2H', "yeast two-hybrid"], + 'DNA_FOOTPRINTING': ["DNA Footprinting", "hydroxyl radical footprinting"], + 'XRAY_TOMOGRAPHY': ["soft x-ray tomography"], + 'FTIR': ["FTIR", "Infrared spectroscopy", "Fourier-transform infrared spectroscopy"], + 'FLUORESCENCE': ["Fluorescence imaging", "fluorescence microscopy", "TIRF"], + 'EVOLUTION': ['coevolution', "evolutionary covariance"], + 'PREDICTED': ["predicted contacts"], + 'INTEGRATIVE': ["integrative structure", "hybrid structure", "integrative modeling", "hybrid modeling"], + 'SHAPE': ['Hydroxyl Acylation analyzed by Primer Extension'] +} + +import re + +class LlmHandler: + + def __init__(self): + self.text_splitter = RecursiveCharacterTextSplitter(separators = ["\n\n", "\n", ".", ","], chunk_size=300, chunk_overlap=100) + self.llm=ChatOpenAI( + openai_api_key=openai_api_key, + temperature=0, model_name="gpt-4", max_tokens=300, request_timeout = 30, max_retries=3 + ) + + + def evaluate_queries(self, embedding, queries): + chatbot = RetrievalQA.from_chain_type( + llm=self.llm, + chain_type="stuff", + retriever=embedding.as_retriever(search_type="similarity", search_kwargs={"k":3}) + ) + + template = """ {query}? """ + response = [] + for q in queries: + prompt = PromptTemplate( + input_variables=["query"], + template=template, + ) + + response.append(chatbot.run( + prompt.format(query=q) + )) + return response + + +llm = LlmHandler() + +methods_string = '' +for i, (k, v) in enumerate(keywords_groups.items()): + if i > 0: + methods_string += ' or ' + methods_string += f'{k} ({", ".join(v)})' + + +def get_embeddings(fname): + """ + """ + loader = TextLoader(fname) + documents = loader.load() + text_splitter = RecursiveCharacterTextSplitter(separators = ["\n\n", "\n", ".", ","],chunk_size = 300, chunk_overlap=100) + docs = text_splitter.split_documents(documents) + + emb = OpenAIEmbeddings() + input_texts = [d.page_content for d in docs] + + input_embeddings = emb.embed_documents(input_texts) + text_embeddings = list(zip(input_texts, input_embeddings)) + return text_embeddings, emb + +def retreiveTextFromPdf(inp_file): + + + json = pd.read_json(path_or_buf=inp_file, lines=True) + lantern = Lantern() + + for n, doi in enumerate(json['doi']): + #print(n, doi) + + + ##NOTE: This is for example purpose only + if n > 0: + break + + if lantern.publicationExists(doi): + continue + + paper_data = {'doi': doi} + doi = doi.replace("/", "-") + pdf_dir = './papers/' + if not os.path.exists(pdf_dir): + os.mkdir(pdf_dir) + + pdfsavefile='./papers/' + doi +'.pdf' + save_pdf(paper_data, filepath=pdfsavefile) + + # creating a pdf reader object + reader = PyPDF2.PdfReader(pdfsavefile) + save_txt_path = 'scrapped_txts/' + if not os.path.exists(save_txt_path): + os.mkdir(save_txt_path) + extract_text = '' + for page in reader.pages: + extract_text+=page.extract_text() + + txt_file = str('{}.txt'.format(doi)) + with open(save_txt_path+txt_file, 'w') as file: + file.write(extract_text) + + + txt_embs, emb = get_embeddings(save_txt_path+txt_file) + + fragments = [] + for txt, embs in txt_embs: + fragment = Fragment(doi, 'methods', txt, embs) + fragments.append(fragment) + + title = "" + pmc = "" + pubmed = "" + + publication = Publication(doi, title, pmc, pubmed, doi) + + lantern.insertEmbeddings(fragments) + lantern.insertPublication(publication) + + os.remove(pdfsavefile) + + +def add_publication_by_doi(doi): + lantern = Lantern() + if lantern.publicationExists(doi): + return + + paper_data = {'doi': doi} + doi = doi.replace("/", "-") + pdf_dir = './papers/' + if not os.path.exists(pdf_dir): + os.mkdir(pdf_dir) + + pdfsavefile='./papers/' + doi +'.pdf' + save_pdf(paper_data, filepath=pdfsavefile) + + # creating a pdf reader object + reader = PyPDF2.PdfReader(pdfsavefile) + save_txt_path = 'scrapped_txts/' + if not os.path.exists(save_txt_path): + os.mkdir(save_txt_path) + extract_text = '' + for page in reader.pages: + extract_text+=page.extract_text() + + txt_file = str('{}.txt'.format(doi)) + with open(save_txt_path+txt_file, 'w') as file: + file.write(extract_text) + + + txt_embs, emb = get_embeddings(save_txt_path+txt_file) + + fragments = [] + for txt, embs in txt_embs: + fragment = Fragment(doi, 'methods', txt, embs) + fragments.append(fragment) + + title = "" + pmc = "" + pubmed = "" + + publication = Publication(doi, title, pmc, pubmed, doi) + + lantern.insertEmbeddings(fragments) + lantern.insertPublication(publication) + #print(fragments) + os.remove(pdfsavefile) + + +def process_result(result): + if result == None: + return (False, None) + for response in result: + if "cryo" in response.lower(): + return (False, None) + return (response.lower().startswith('yes'), response) + +lantern = Lantern() +def get_embeddings_for_pub(id): + input_texts = [] + input_embeddings = [] + if lantern.publicationExists(id): + fragments = lantern.getAllFragmentsOfPublication(id) + for fragment in fragments: + input_texts.append(fragment.content) + input_embeddings.append(fragment.vector) + text_embeddings = list(zip(input_texts, input_embeddings)) + return text_embeddings + +def main(): + open_ai_emb = OpenAIEmbeddings() + #add_publication_by_doi('10.1101/2023.10.31.564925') + #add_publication_by_doi('10.1101/2023.03.03.531047') + query = [f"You are reading a materials and methods section of a scientific paper. Here is the list of structural biology methods {methods_string}.\n\n Did the authors use any methods from the list? \n\n Answer with Yes or No followed by the names of the methods."] + lantern = Lantern() + publications = lantern.getUnreadPublication() + + all_results = [] + rows = [] + hits = 0 + for pub in publications[5:]: + text_embeddings = get_embeddings_for_pub(pub.id) + flag = False + for text, _ in text_embeddings: + if re.search("cryo-?em", text, re.IGNORECASE): + flag = True + break + if flag: + faissIndex = FAISS.from_embeddings(text_embeddings=text_embeddings, embedding=open_ai_emb) + result = llm.evaluate_queries(faissIndex, query) + classification, response = process_result(result) + hits += classification + else: + classification, response = process_result(None) + #print('paper not about cryo-em') + rows.append([pub.doi, pub.title, "11-2-2023", "11-5-2023", "", int(classification), response, ""]) + + from google_sheets import SpreadsheetUpdater + gs = SpreadsheetUpdater() + print(rows) + gs.append_rows(rows) + msg = f""" + This batch of paper analysis has concluded. + {len(rows)} papers were analyzed in total over the date range 11/2 - 11/3 + {hits} {"were" if ((hits>0) or (hits == 0)) else was} classified as having multi-method structural data +""" + print(msg) + gs.email(message=msg) + + +main() + + diff --git a/prompts.py b/prompts.py new file mode 100644 index 0000000..d3196d1 --- /dev/null +++ b/prompts.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +"""A collection of various PDB-related prompts +""" + +# A list of abbreviated names and synonyms +# for various biophysical methonds +# that are typically used for integrative modeling + +METHODS_KEYWORDS = { + 'CX-MS': [ + 'cross-link', 'crosslink', + 'XL-MS', 'CX-MS', 'CL-MS', 'XLMS', 'CXMS', 'CLMS', + "chemical crosslinking mass spectrometry", + 'photo-crosslinking', 'crosslinking restraints', + 'crosslinking-derived restraints', 'chemical crosslinking', + 'in vivo crosslinking', 'crosslinking data', + ], + + 'HDX': [ + 'Hydrogen–deuterium exchange mass spectrometry', + 'Hydrogen/deuterium exchange mass spectrometry' + 'HDX', 'HDXMS', 'HDX-MS', + ], + + 'EPR': [ + 'electron paramagnetic resonance spectroscopy', + 'EPR', 'DEER', + "Double electron electron resonance spectroscopy", + ], + + 'FRET': [ + 'FRET', + "forster resonance energy transfer", + "fluorescence resonance energy transfer", + ], + + 'AFM': [ + 'AFM', "atomic force microscopy", + ], + + 'SAS': [ + 'SAS', 'SAXS', 'SANS', "Small angle solution scattering", + "solution scattering", "SEC-SAXS", "SEC-SAS", "SASBDB", + "Small angle X-ray scattering", "Small angle neutron scattering", + ], + + '3DGENOME': [ + 'HiC', 'Hi-C', "chromosome conformation capture", + ], + + 'Y2H': [ + 'Y2H', + "yeast two-hybrid", + ], + + 'DNA_FOOTPRINTING': [ + "DNA Footprinting", + "hydroxyl radical footprinting", + ], + + 'XRAY_TOMOGRAPHY': [ + "soft x-ray tomography", + ], + + 'FTIR': [ + "FTIR", "Infrared spectroscopy", + "Fourier-transform infrared spectroscopy", + ], + + 'FLUORESCENCE': [ + "Fluorescence imaging", + "fluorescence microscopy", "TIRF", + ], + + 'EVOLUTION': [ + 'coevolution', "evolutionary covariance", + ], + + 'PREDICTED': [ + "predicted contacts", + ], + + 'INTEGRATIVE': [ + "integrative structure", "hybrid structure", + "integrative modeling", "hybrid modeling", + ], + + 'SHAPE': [ + 'Hydroxyl Acylation analyzed by Primer Extension', + ] +} + + +def keywords_dict_to_string(keywords: dict) -> str: + """ + Convert dictionary with method keywords and synonyms + to a string + + Example: + + keywords = { + 'AFM': [ + 'AFM', "atomic force microscopy", + ], + + 'SAS': [ + 'SAS', "solution scattering", + ], + } + + Result: + + 'AFM (AFM, atomic force microscopy) or SAS (SAS, solution scattering)' + """ + + methods_string = '' + for i, (k, v) in enumerate(keywords.items()): + if i > 0: + methods_string += ' or ' + methods_string += f'{k} ({", ".join(v)})' + + return methods_string + +def get_qbi_hackathon_prompt(keywords: dict) -> str: + """ + Returns a prompt that was initially developed + during the QBI Hackathon. + """ + + if len(keywords) == 0: + raise(ValueError("Keywords dict can't be empty")) + + methods_string = keywords_dict_to_string(keywords) + + prompt = ( + "You are reading a materials and methods section " + "of a scientific paper. " + f"Here is the list of methods {methods_string}.\n\n" + "Did the authors use any of them? " + "Answer Yes or No, followed by the name(s) of methods. " + "Use only abbreviations." + ) + + return prompt + +if __name__ == '__main__': + # Just call an example function + print(get_qbi_hackathon_prompt(METHODS_KEYWORDS)) diff --git a/publication.py b/publication.py deleted file mode 100644 index 5cb96a8..0000000 --- a/publication.py +++ /dev/null @@ -1,15 +0,0 @@ -# Class to represent a publication with attributes id, title, pmc, pubmed, and doi -class Publication: - - id = "" - title = "" - pmc = "" - pubmed = "" - doi = "" - - def __init__(self, id, title, pmc, pubmed, doi): - self.id = id # (DOI) Unique identifier for the publication - self.title = title # Title of the publication - self.pmc = pmc # PubMed Central (PMC) Link - self.pubmed = pubmed # PubMed Link - self.doi = doi # Digital Object Identifier (DOI) Link for the publication diff --git a/ScrapperPipeline.ipynb b/tests/ScrapperPipeline.ipynb similarity index 100% rename from ScrapperPipeline.ipynb rename to tests/ScrapperPipeline.ipynb diff --git a/context_retrieve.py b/tests/context_retrieve.py similarity index 98% rename from context_retrieve.py rename to tests/context_retrieve.py index aa7c5d2..f4c4fc2 100644 --- a/context_retrieve.py +++ b/tests/context_retrieve.py @@ -120,5 +120,6 @@ def main(): faissIndex = FAISS.from_embeddings(text_embeddings=txt_embs, embedding=emb) inp_query = "Does the paper report a new structure of a biomolecule or biomolecular complex modeled using experimental data" Query(inp_query, faissIndex) - -main() \ No newline at end of file + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/faiss.ipynb b/tests/faiss.ipynb similarity index 100% rename from faiss.ipynb rename to tests/faiss.ipynb diff --git a/input_file.txt b/tests/input_file.txt similarity index 100% rename from input_file.txt rename to tests/input_file.txt diff --git a/latern_test.ipynb b/tests/latern_test.ipynb similarity index 100% rename from latern_test.ipynb rename to tests/latern_test.ipynb diff --git a/test.py b/tests/test.py similarity index 93% rename from test.py rename to tests/test.py index 8347787..b782d81 100644 --- a/test.py +++ b/tests/test.py @@ -1,6 +1,4 @@ -from fragment import Fragment -from publication import Publication -from VectorDatabase import Latern +from VectorDatabase import Lantern, Fragment, Publication from tqdm.auto import tqdm from sentence_transformers import SentenceTransformer import torch diff --git a/updated_prompt.py b/tests/test_qbi_hackathon_promt.py similarity index 100% rename from updated_prompt.py rename to tests/test_qbi_hackathon_promt.py diff --git a/testing.py b/tests/testing.py similarity index 100% rename from testing.py rename to tests/testing.py