From 6f169e052cc67a3988556cf4098ed7f733528c82 Mon Sep 17 00:00:00 2001 From: Mikhail Lebedev Date: Wed, 28 Aug 2024 22:36:12 +0200 Subject: [PATCH 01/13] MVP local models --- .streamlit/config.toml | 26 + .streamlit/secrets.toml | 1 + alphastats/DataSet_Preprocess.py | 20 +- alphastats/gui/.streamlit/config.toml | 2 + alphastats/gui/pages/02_Import Data.py | 3 +- alphastats/gui/pages/{05_GPT.py => 05_LLM.py} | 152 ++-- alphastats/gui/utils/__init__.py | 0 alphastats/gui/utils/enrichment_analysis.py | 77 ++ alphastats/gui/utils/gpt_helper.py | 664 ++---------------- alphastats/gui/utils/ollama_utils.py | 365 ++++++++++ alphastats/gui/utils/openai_utils.py | 205 ++++++ alphastats/gui/utils/options.py | 16 +- alphastats/gui/utils/uniprot_utils.py | 328 +++++++++ 13 files changed, 1185 insertions(+), 674 deletions(-) create mode 100644 .streamlit/config.toml create mode 100644 .streamlit/secrets.toml rename alphastats/gui/pages/{05_GPT.py => 05_LLM.py} (74%) create mode 100644 alphastats/gui/utils/__init__.py create mode 100644 alphastats/gui/utils/enrichment_analysis.py create mode 100644 alphastats/gui/utils/ollama_utils.py create mode 100644 alphastats/gui/utils/openai_utils.py create mode 100644 alphastats/gui/utils/uniprot_utils.py diff --git a/.streamlit/config.toml b/.streamlit/config.toml new file mode 100644 index 00000000..78decd10 --- /dev/null +++ b/.streamlit/config.toml @@ -0,0 +1,26 @@ +[theme] + +# Primary accent for interactive elements +primaryColor = '#005358' + +# Background color for the main content area +backgroundColor = '#FFFFFF' + +# Background color for sidebar and most interactive widgets +secondaryBackgroundColor = '#f2f2f2' + +# Color used for almost all text +textColor = '#302E30' + +# Font family for all text in the app, except code blocks +# Accepted values (serif | sans serif | monospace) +# Default: "sans serif" +font = "sans serif" + +[server] +maxUploadSize = 500 +enableXsrfProtection = false +enableCORS = false + +[browser] +gatherUsageStats = true diff --git a/.streamlit/secrets.toml b/.streamlit/secrets.toml new file mode 100644 index 00000000..ce30c061 --- /dev/null +++ b/.streamlit/secrets.toml @@ -0,0 +1 @@ +openai_api_key = "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" \ No newline at end of file diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index 3ebac6df..9e0e92cf 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -179,22 +179,26 @@ def _linear_normalization(self, dataframe: pd.DataFrame): def _normalization(self, method: str): if method == "zscore": scaler = sklearn.preprocessing.StandardScaler() - normalized_array = scaler.fit_transform( - self.mat.values.transpose() - ).transpose() + # normalized_array = scaler.fit_transform( + # self.mat.values.transpose() + # ).transpose() + normalized_array = scaler.fit_transform(self.mat.values) elif method == "quantile": qt = sklearn.preprocessing.QuantileTransformer(random_state=0) - normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose() + # normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose() + normalized_array = qt.fit_transform(self.mat.values) elif method == "linear": - normalized_array = self._linear_normalization(self.mat) + normalized_array = self._linear_normalization(self.mat.transpose()).transpose() elif method == "vst": minmax = sklearn.preprocessing.MinMaxScaler() scaler = sklearn.preprocessing.PowerTransformer() - minmaxed_array = minmax.fit_transform(self.mat.values.transpose()) - normalized_array = scaler.fit_transform(minmaxed_array).transpose() + # minmaxed_array = minmax.fit_transform(self.mat.values.transpose()) + # normalized_array = scaler.fit_transform(minmaxed_array).transpose() + minmaxed_array = minmax.fit_transform(self.mat.values) + normalized_array = scaler.fit_transform(minmaxed_array) else: raise ValueError( @@ -271,7 +275,7 @@ def batch_correction(self, batch: str): @ignore_warning(RuntimeWarning) def preprocess( self, - log2_transform: bool = True, + log2_transform: bool = False, remove_contaminations: bool = False, subset: bool = False, data_completeness: float = 0, diff --git a/alphastats/gui/.streamlit/config.toml b/alphastats/gui/.streamlit/config.toml index 117bf136..78decd10 100644 --- a/alphastats/gui/.streamlit/config.toml +++ b/alphastats/gui/.streamlit/config.toml @@ -19,6 +19,8 @@ font = "sans serif" [server] maxUploadSize = 500 +enableXsrfProtection = false +enableCORS = false [browser] gatherUsageStats = true diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index 9312d141..6112fc69 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -45,10 +45,11 @@ def load_options(): - from alphastats.gui.utils.options import plotting_options, statistic_options + from alphastats.gui.utils.options import plotting_options, statistic_options, interpretation_options st.session_state["plotting_options"] = plotting_options st.session_state["statistic_options"] = statistic_options + st.session_state["interpretation_options"] = interpretation_options def check_software_file(df, software): diff --git a/alphastats/gui/pages/05_GPT.py b/alphastats/gui/pages/05_LLM.py similarity index 74% rename from alphastats/gui/pages/05_GPT.py rename to alphastats/gui/pages/05_LLM.py index 2d541447..d324b500 100644 --- a/alphastats/gui/pages/05_GPT.py +++ b/alphastats/gui/pages/05_LLM.py @@ -20,15 +20,22 @@ from alphastats.gui.utils.gpt_helper import ( get_assistant_functions, display_proteins, - get_gene_function, - get_info, get_subgroups_for_each_group, turn_args_to_float, perform_dimensionality_reduction, + get_general_assistant_functions, + ) + from alphastats.gui.utils.uniprot_utils import ( + get_gene_function, + get_info, + ) + from alphastats.gui.utils.enrichment_analysis import get_enrichment_data + from alphastats.gui.utils.openai_utils import ( wait_for_run_completion, send_message_save_thread, try_to_set_api_key, ) + from alphastats.gui.utils.ollama_utils import LLMIntegration from alphastats.gui.utils.ui_helper import sidebar_info except ModuleNotFoundError: @@ -48,18 +55,24 @@ from utils.gpt_helper import ( get_assistant_functions, display_proteins, - get_gene_function, - get_info, get_subgroups_for_each_group, turn_args_to_float, perform_dimensionality_reduction, + get_general_assistant_functions, + ) + from utils.uniprot_utils import ( + get_gene_function, + get_info, + ) + from utils.enrichment_analysis import get_enrichment_data + from utils.openai_utils import ( wait_for_run_completion, send_message_save_thread, try_to_set_api_key, ) + from utils.ollama_utils import LLMIntegration from utils.ui_helper import sidebar_info - st.session_state.plot_dict = {} @@ -71,8 +84,8 @@ def select_analysis(): """ method = st.selectbox( "Analysis", - options=["Volcano plot"], - # options=list(st.session_state.interpretation_options.keys()), + # options=["Volcano plot"], + options=list(st.session_state.interpretation_options.keys()), ) return method @@ -82,7 +95,7 @@ def select_analysis(): st.stop() -st.markdown("### GPT4 Analysis") +st.markdown("### LLM Analysis") sidebar_info() @@ -101,15 +114,15 @@ def select_analysis(): """ st.markdown(styl, unsafe_allow_html=True) +# Initialize session state variables +if "llm_integration" not in st.session_state: + st.session_state["llm_integration"] = None +if "api_type" not in st.session_state: + st.session_state["api_type"] = "gpt" if "plot_list" not in st.session_state: st.session_state["plot_list"] = [] - -if "openai_model" not in st.session_state: - # st.session_state["openai_model"] = "gpt-3.5-turbo-16k" - st.session_state["openai_model"] = "gpt-4-0125-preview" # "gpt-4-1106-preview" - if "messages" not in st.session_state: st.session_state["messages"] = [] @@ -131,14 +144,17 @@ def select_analysis(): with c1: method = select_analysis() chosen_parameter_dict = helper_compare_two_groups() - api_key = st.text_input("API Key", type="password") - - try_to_set_api_key(api_key) + + st.session_state["api_type"] = st.selectbox( + "Select LLM", + ["gpt4o", "llama3.1 70b"], + index=0 if st.session_state["api_type"] == "gpt4o" else 1 + ) + base_url = "http://localhost:11434/v1" + if st.session_state["api_type"] == "gpt4o": + api_key = st.text_input("Enter OpenAI API Key", type="password") + try_to_set_api_key(api_key) - try: - client = OpenAI(api_key=st.secrets["openai_api_key"]) - except OpenAIError: - pass method = st.selectbox( "Differential Analysis using:", options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"], @@ -270,9 +286,10 @@ def select_analysis(): "A user will present you with data regarding proteins upregulated in certain cells " "sourced from UniProt and abstracts from scientific publications. They seek your " "expertise in understanding the connections between these proteins and their potential role " - f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. " + f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable." f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state.dataset.metadata))}." - "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that." + "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to" + " you from a function has references to the literature (for example, PubMed), always quote the references in your response." ) if "column" in chosen_parameter_dict and "upregulated" in st.session_state: st.session_state["user_prompt"] = ( @@ -310,32 +327,55 @@ def select_analysis(): st.session_state["gpt_submitted_clicked"] > st.session_state["gpt_submitted_counter"] ): - try_to_set_api_key() - - client = OpenAI(api_key=st.secrets["openai_api_key"]) + if st.session_state["api_type"] == "gpt4o": + try_to_set_api_key() try: - st.session_state["assistant"] = client.beta.assistants.create( - instructions=st.session_state["instructions"], - name="Proteomics interpreter", - model=st.session_state["openai_model"], - tools=get_assistant_functions( - gene_to_prot_id_dict=st.session_state["gene_to_prot_id"], - metadata=st.session_state["dataset"].metadata, - subgroups_for_each_group=get_subgroups_for_each_group( - st.session_state["dataset"].metadata - ), - ), - ) + if st.session_state["api_type"] == "gpt4o": + st.session_state["llm_integration"] = LLMIntegration( + api_type='gpt', + api_key=st.secrets["openai_api_key"], + dataset=st.session_state["dataset"], + metadata=st.session_state["dataset"].metadata + ) + else: + st.session_state["llm_integration"] = LLMIntegration( + api_type='ollama', + base_url=base_url, + dataset=st.session_state["dataset"], + metadata=st.session_state["dataset"].metadata + ) + st.success(f"{st.session_state['api_type'].upper()} integration initialized successfully!") except AuthenticationError: st.warning( "Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX" ) st.stop() -if "artefact_enum_dict" not in st.session_state: - st.session_state["artefact_enum_dict"] = {} +if "llm_integration" not in st.session_state or not st.session_state["llm_integration"]: + st.warning("Please initialize the model first") + st.stop() +llm = st.session_state["llm_integration"] + +# Set instructions and update tools +llm.tools = [ + *get_general_assistant_functions(), + *get_assistant_functions( + gene_to_prot_id_dict=st.session_state["gene_to_prot_id"], + metadata=st.session_state["dataset"].metadata, + subgroups_for_each_group=get_subgroups_for_each_group( + st.session_state["dataset"].metadata + ), + ) +] + +if "artifacts" not in st.session_state: + st.session_state["artifacts"] = {} +import time +start = time.time() +# 4o 23.52 +# llama 239.94 if ( st.session_state["gpt_submitted_counter"] < st.session_state["gpt_submitted_clicked"] @@ -343,30 +383,28 @@ def select_analysis(): st.session_state["gpt_submitted_counter"] = st.session_state[ "gpt_submitted_clicked" ] - st.session_state["artefact_enum_dict"] = {} - thread = client.beta.threads.create() - st.session_state["thread_id"] = thread.id - artefacts = send_message_save_thread(client, st.session_state["user_prompt"]) - if artefacts: - st.session_state["artefact_enum_dict"][len(st.session_state.messages) - 1] = ( - artefacts - ) + st.session_state["artifacts"] = {} + llm.messages = [{ + "role": "system", + "content": st.session_state["instructions"] + }] + response = llm.chat_completion(st.session_state["user_prompt"]) if st.session_state["gpt_submitted_clicked"] > 0: if prompt := st.chat_input("Say something"): - st.session_state.messages.append({"role": "user", "content": prompt}) - artefacts = send_message_save_thread(client, prompt) - if artefacts: - st.session_state["artefact_enum_dict"][ - len(st.session_state.messages) - 1 - ] = artefacts + response = llm.chat_completion(prompt) for num, role_content_dict in enumerate(st.session_state.messages): + if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system": + continue + if "tool_calls" in role_content_dict: + continue with st.chat_message(role_content_dict["role"]): st.markdown(role_content_dict["content"]) - if num in st.session_state["artefact_enum_dict"]: - for artefact in st.session_state["artefact_enum_dict"][num]: + if num in st.session_state["artifacts"]: + for artefact in st.session_state["artifacts"][num]: if isinstance(artefact, pd.DataFrame): st.dataframe(artefact) - else: + elif "plotly" in str(type(artefact)): st.plotly_chart(artefact) - print(st.session_state["artefact_enum_dict"]) + stop = time.time() + print("time", stop-start, "\n\n\n\n") diff --git a/alphastats/gui/utils/__init__.py b/alphastats/gui/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/alphastats/gui/utils/enrichment_analysis.py b/alphastats/gui/utils/enrichment_analysis.py new file mode 100644 index 00000000..c2203013 --- /dev/null +++ b/alphastats/gui/utils/enrichment_analysis.py @@ -0,0 +1,77 @@ +from typing import List +import requests + + +from gprofiler import GProfiler + +import pandas as pd + + +def get_functional_annotation_STRING(identifier, species_id="9606") -> pd.DataFrame: + """ + Get functional annotation from STRING for a gene identifier. + + Args: + identifier (str): A gene identifier. + species_id (str, optional): The Uniprot organism ID to search in. + + Returns: + pd.DataFrame: The functional annotation data. + """ + url = f"https://string-db.org/api/json/enrichment?identifiers={identifier}&species={int(species_id)}&caller_identity=alphapeptstats" + response = requests.get(url) + + if response.status_code == 200: + data = response.json() + df = pd.DataFrame(data) + return df + else: + print(f"Request failed with status code {response.status_code}") + return None + + +def get_functional_annotation_GProfiler(identifiers: List[str]) -> pd.DataFrame: + """ + Get functional annotation from g:Profiler for a list of gene identifiers. + + Args: + identifiers (list[str]): A list of gene identifiers. + + Returns: + pd.DataFrame: The functional annotation data. + """ + gp = GProfiler( + user_agent="AlphaPeptStats", + return_dataframe=True, + ) + df = gp.profile(query=identifiers) + return df + + +def get_enrichment_data( + difexpressed: List[str], organism_id: str = 9606, tool: str = "gprofiler" +) -> pd.DataFrame: + """ + Get enrichment data for a list of differentially expressed genes. + + Args: + difexpressed (list[str]): A list of differentially expressed genes. + organism_id (str, optional): The Uniprot organism ID to search in. + tool (str, optional): The tool to use for enrichment analysis. + + Returns: + pd.DataFrame: The enrichment data. + """ + enrichment_data = {} + assert tool in [ + "gprofiler", + "string", + ], "Tool must be either 'gprofiler' or 'string'" + if tool == "gprofiler": + enrichment_data = get_functional_annotation_GProfiler(difexpressed) + else: + enrichment_data = get_functional_annotation_STRING( + "%0d".join(difexpressed), organism_id + ) + + return enrichment_data diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 865ef554..8c09b010 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -1,16 +1,10 @@ import copy -from typing import Optional, Union, List -from pathlib import Path -import requests +from typing import Union, List, Dict -import time -import random import json from Bio import Entrez -from gprofiler import GProfiler -import openai import pandas as pd import streamlit as st @@ -19,62 +13,10 @@ Entrez.email = "lebedev_mikhail@outlook.com" # Always provide your email address when using NCBI services. -uniprot_fields = [ - # Names & Taxonomy - "gene_names", - "organism_name", - "protein_name", - # Function - "cc_function", - "cc_catalytic_activity", - "cc_activity_regulation", - "cc_pathway", - "kinetics", - "ph_dependence", - "temp_dependence", - # Interaction - "cc_interaction", - "cc_subunit", - # Expression - "cc_tissue_specificity", - "cc_developmental_stage", - "cc_induction", - # Gene Ontology (GO) - "go", - "go_p", - "go_c", - "go_f", - # Pathology & Biotech - "cc_disease", - "cc_disruption_phenotype", - "cc_pharmaceutical", - "ft_mutagen", - "ft_act_site", - # Structure - "cc_subcellular_location", - "organelle", - "absorption", - # Publications - "lit_pubmed_id", - # Family & Domains - "protein_families", - "cc_domain", - "ft_domain", - # Protein-Protein Interaction Databases - "xref_biogrid", - "xref_intact", - "xref_mint", - "xref_string", - # Chemistry Databases - "xref_drugbank", - "xref_chembl", - "reviewed", -] - def get_subgroups_for_each_group( metadata: pd.DataFrame, -) -> dict: +) -> Dict: """ Get the unique values for each column in the metadata file. @@ -101,7 +43,7 @@ def get_unique_values_from_column(column: str, metadata: pd.DataFrame) -> List[s metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). Returns: - list[str]: A list of unique values from the column. + List[str]: A list of unique values from the column. """ unique_values = metadata[column].unique().tolist() return [str(i) for i in unique_values] @@ -135,23 +77,7 @@ def display_proteins(overexpressed: List[str], underexpressed: List[str]) -> Non st.markdown(full_html, unsafe_allow_html=True) -def get_assistant_functions( - gene_to_prot_id_dict: dict, - metadata: pd.DataFrame, - subgroups_for_each_group: dict, -) -> List[dict]: - """ - Get a list of assistant functions for function calling in the ChatGPT model. - You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. - For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling - - Args: - gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. - metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). - subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). - Returns: - list[dict]: A list of assistant functions. - """ +def get_general_assistant_functions() -> List[Dict]: return [ { "type": "function", @@ -173,7 +99,55 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "st.session_state.dataset.plot_intensity", + "name": "get_enrichment_data", + "description": "Get enrichment data for a list of differentially expressed genes", + "parameters": { + "type": "object", + "properties": { + "difexpressed": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of differentially expressed gene names to search for", + }, + "organism_id": { + "type": "string", + "description": "The Uniprot organism ID to search in, e.g. 9606 for human", + }, + "tool": { + "type": "string", + "description": "The tool to use for enrichment analysis", + "enum": ["gprofiler", "string"], + }, + }, + "required": ["difexpressed", "organism_id"], + }, + }, + }, + ] + + +def get_assistant_functions( + gene_to_prot_id_dict: Dict, + metadata: pd.DataFrame, + subgroups_for_each_group: Dict, +) -> List[Dict]: + """ + Get a list of assistant functions for function calling in the ChatGPT model. + You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. + For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling + + Args: + gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. + metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). + subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). + Returns: + list[dict]: A list of assistant functions. + """ + return [ + { + "type": "function", + "function": { + "name": "plot_intensity", "description": "Create an intensity plot based on protein data and analytical methods.", "parameters": { "type": "object", @@ -241,7 +215,7 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "st.session_state.dataset.plot_sampledistribution", + "name": "plot_sampledistribution", "description": "Generates a histogram plot for each sample in the dataset matrix.", "parameters": { "type": "object", @@ -264,7 +238,7 @@ def get_assistant_functions( { "type": "function", "function": { - "name": "st.session_state.dataset.plot_volcano", + "name": "plot_volcano", "description": "Generates a volcano plot based on two subgroups of the same group", "parameters": { "type": "object", @@ -322,34 +296,7 @@ def get_assistant_functions( }, }, }, - { - "type": "function", - "function": { - "name": "get_enrichment_data", - "description": "Get enrichment data for a list of differentially expressed genes", - "parameters": { - "type": "object", - "properties": { - "difexpressed": { - "type": "array", - "items": {"type": "string"}, - "description": "A list of differentially expressed gene names to search for", - }, - "organism_id": { - "type": "string", - "description": "The Uniprot organism ID to search in", - }, - "tool": { - "type": "string", - "description": "The tool to use for enrichment analysis", - "enum": ["gprofiler", "string"], - }, - }, - "required": ["difexpressed", "organism_id"], - }, - }, - }, - {"type": "code_interpreter"}, + # {"type": "code_interpreter"}, ] @@ -360,347 +307,7 @@ def perform_dimensionality_reduction(group, method, circle, **kwargs): return dr.plot -def get_uniprot_data( - gene_name: str, - organism_id: str, - fields: List[str] = uniprot_fields, -) -> dict: - """ - Get data from UniProt for a given gene name and organism ID. - - Args: - gene_name (str): The gene name to search for. - organism_id (str, optional): The organism ID to search in. Defaults to streamlit session state. - fields (list[str], optional): The fields to retrieve from UniProt. Defaults to uniprot_fields defined above. - - Returns: - dict: The data retrieved from UniProt. - """ - base_url = "https://rest.uniprot.org/uniprotkb/search" - query = f"(gene:{gene_name}) AND (reviewed:true) AND (organism_id:{organism_id})" - - response = requests.get( - base_url, params={"query": query, "format": "json", "fields": ",".join(fields)} - ) - - if response.status_code != 200: - print( - f"Failed to retrieve data for {gene_name}. Status code: {response.status_code}" - ) - print(response.text) - return None - - data = response.json() - - if not data.get("results"): - print(f"No UniProt entry found for {gene_name}") - return None - - # Return the first result as a dictionary (assuming it's the most relevant) - data = data["results"][0] - for key, value in data.items(): - print(f"data - {key}: {value}, {type(value)}") - return data - - -def extract_data(data: dict) -> dict: - """ - Extract relevant data from a UniProt entry. - - Args: - data (dict): The data retrieved from UniProt. - - Returns: - dict: The extracted data. - """ - extracted = {} - - # 1. Entry Type - extracted["entryType"] = data.get("entryType", None) - - # 2. Primary Accession - extracted["primaryAccession"] = data.get("primaryAccession", None) - - # 3. Organism Details - organism = data.get("organism", {}) - extracted["organism"] = { - "scientificName": organism.get("scientificName", None), - "commonName": organism.get("commonName", None), - "taxonId": organism.get("taxonId", None), - "lineage": organism.get("lineage", []), - } - - # 4. Protein Details - protein_description = data.get("proteinDescription", {}) - recommended_name = ( - protein_description.get("recommendedName", {}) - .get("fullName", {}) - .get("value", None) - ) - alternative_names = [ - alt_name["fullName"]["value"] - for alt_name in protein_description.get("alternativeNames", []) - ] - extracted["protein"] = { - "recommendedName": recommended_name, - "alternativeNames": alternative_names, - "flag": protein_description.get("flag", None), - } - - # 5. Gene Details - genes = data.get("genes", [{}])[0] - extracted["genes"] = { - "geneName": genes.get("geneName", {}).get("value", None), - "synonyms": [syn["value"] for syn in genes.get("synonyms", [])], - } - - # 6. Functional Comments - function_comments = [ - text["value"] - for comment in data.get("comments", []) - if comment["commentType"] == "FUNCTION" - for text in comment.get("texts", []) - ] - extracted["functionComments"] = function_comments - - # 7. Subunit Details - subunit_comments = [ - text["value"] - for comment in data.get("comments", []) - if comment["commentType"] == "SUBUNIT" - for text in comment.get("texts", []) - ] - extracted["subunitComments"] = subunit_comments - - # 8. Protein Interactions - interactions = [] - - for c in data.get("comments", []): - if c["commentType"] == "INTERACTION": - for interaction in c.get("interactions", []): - interactantOne = interaction["interactantOne"].get( - "uniProtKBAccession", None - ) - interactantTwo = interaction["interactantTwo"].get( - "uniProtKBAccession", None - ) - - # Only append if both interactants are present - if interactantOne and interactantTwo: - interactions.append( - { - "interactantOne": interactantOne, - "interactantTwo": interactantTwo, - "numberOfExperiments": interaction["numberOfExperiments"], - } - ) - extracted["interactions"] = interactions - - # 9. Subcellular Locations - subcellular_locations_comments = [ - c["subcellularLocations"] - for c in data.get("comments", []) - if c["commentType"] == "SUBCELLULAR LOCATION" - ] - locations = [ - location["location"]["value"] - for locations_comment in subcellular_locations_comments - for location in locations_comment - ] - extracted["subcellularLocations"] = locations - - tissue_specificities = [ - text["value"] - for comment in data.get("comments", []) - if comment["commentType"] == "TISSUE SPECIFICITY" - for text in comment.get("texts", []) - ] - extracted["tissueSpecificity"] = tissue_specificities - - # 11. Protein Features - features = [ - { - "type": feature["type"], - "description": feature["description"], - "location_start": feature["location"]["start"]["value"], - "location_end": feature["location"]["end"]["value"], - } - for feature in data.get("features", []) - ] - extracted["features"] = features - - # 12. References - references = [ - { - "authors": ref["citation"].get("authors", []), - "title": ref["citation"].get("title", ""), - "journal": ref["citation"].get("journal", ""), - "publicationDate": ref["citation"].get("publicationDate", ""), - "comments": [c["value"] for c in ref.get("referenceComments", [])], - } - for ref in data.get("references", []) - ] - extracted["references"] = references - - # 13. Cross References - cross_references = [ - { - "database": ref["database"], - "id": ref["id"], - "properties": { - prop["key"]: prop["value"] for prop in ref.get("properties", []) - }, - } - for ref in data.get("uniProtKBCrossReferences", []) - ] - extracted["crossReferences"] = cross_references - return extracted - - -def get_info(genes_list: List[str], organism_id: str) -> List[str]: - """ - Get info from UniProt for a list of genes. - - Args: - genes_list (list[str]): A list of gene names to search for. - organism_id (str, optional): The Uniprot organism ID to search in. - - Returns: - list[str]: A list of gene functions.""" - results = {} - - for gene in genes_list: - result = get_uniprot_data(gene, organism_id) - - # If result is retrieved for the gene, extract data and continue with the next gene - if result: - results[gene] = extract_data(result) - continue - - # If no result is retrieved for the gene and the gene string does not contain a ";", continue with the next gene - if ";" not in gene: - print(f"Failed to retrieve data for {gene}") - continue - - # If no result is retrieved for the gene and the gene string contains a ";", try to get data for each split gene - split_genes = gene.split(";") - for split_gene in split_genes: - result = get_uniprot_data(split_gene.strip(), organism_id) - if result: - print( - f"Successfully retrieved data for {split_gene} (from split gene: {gene})" - ) - results[gene] = extract_data(result) - break - - # If still no result after trying split genes - if not result: - print(f"Failed to retrieve data for all parts of split gene: {gene}") - # TODO: Handle this case further if necessary - - gene_functions = [] - for gene in results: - if results[gene]["functionComments"]: - gene_functions.append(f"{gene}: {results[gene]['functionComments']}") - else: - gene_functions.append(f"{gene}: ?") - - return gene_functions - - -def get_functional_annotation_STRING(identifier, species_id="9606") -> pd.DataFrame: - """ - Get functional annotation from STRING for a gene identifier. - - Args: - identifier (str): A gene identifier. - species_id (str, optional): The Uniprot organism ID to search in. - - Returns: - pd.DataFrame: The functional annotation data. - """ - url = f"https://string-db.org/api/json/enrichment?identifiers={identifier}&species={int(species_id)}&caller_identity=alphapeptstats" - response = requests.get(url) - - if response.status_code == 200: - data = response.json() - df = pd.DataFrame(data) - return df - else: - print(f"Request failed with status code {response.status_code}") - return None - - -def get_functional_annotation_GProfiler(identifiers: List[str]) -> pd.DataFrame: - """ - Get functional annotation from g:Profiler for a list of gene identifiers. - - Args: - identifiers (list[str]): A list of gene identifiers. - - Returns: - pd.DataFrame: The functional annotation data. - """ - gp = GProfiler( - user_agent="AlphaPeptStats", - return_dataframe=True, - ) - df = gp.profile(query=identifiers) - return df - - -def get_enrichment_data( - difexpressed: List[str], organism_id: str = 9606, tool: str = "gprofiler" -) -> pd.DataFrame: - """ - Get enrichment data for a list of differentially expressed genes. - - Args: - difexpressed (list[str]): A list of differentially expressed genes. - organism_id (str, optional): The Uniprot organism ID to search in. - tool (str, optional): The tool to use for enrichment analysis. - - Returns: - pd.DataFrame: The enrichment data. - """ - enrichment_data = {} - assert tool in [ - "gprofiler", - "string", - ], "Tool must be either 'gprofiler' or 'string'" - if tool == "gprofiler": - enrichment_data = get_functional_annotation_GProfiler(difexpressed) - else: - enrichment_data = get_functional_annotation_STRING( - "%0d".join(difexpressed), organism_id - ) - - return enrichment_data - - -def get_gene_function(gene_name: Union[str, dict], organism_id: str = None) -> str: - """ - Get the gene function and description by UniProt lookup of gene identifier / name. - - Args: - gene_name (Union[str, dict]): Gene identifier / name for UniProt lookup. - organism_id (str): The UniProt organism ID to search in. - - Returns: - str: The gene function and description. - """ - if not organism_id: - organism_id = st.session_state["organism"] - if type(gene_name) == dict: - gene_name = gene_name["gene_name"] - result = get_uniprot_data(gene_name, organism_id) - if result and extract_data(result)["functionComments"]: - return str(extract_data(result)["functionComments"]) - else: - return "No data found" - - -def turn_args_to_float(json_string: Union[str, bytes, bytearray]) -> dict: +def turn_args_to_float(json_string: Union[str, bytes, bytearray]) -> Dict: """ Turn all values in a JSON string to floats if possible. @@ -740,160 +347,3 @@ def get_gene_to_prot_id_mapping(gene_id: str) -> str: if gene_id in gene.split(";"): return prot_id return gene_id - - -def wait_for_run_completion( - client: openai.OpenAI, thread_id: int, run_id: int, check_interval: int = 2 -) -> Optional[list]: - """ - Wait for a run and function calls to complete and return the plots, if they were created by function calling. - - Args: - client (openai.OpenAI): The OpenAI client. - thread_id (int): The thread ID. - run_id (int): The run ID. - check_interval (int, optional): The interval to check for run completion. Defaults to 2 seconds. - - Returns: - Optional[list]: A list of plots, if any. - """ - artefacts = [] - while True: - run_status = client.beta.threads.runs.retrieve( - thread_id=thread_id, run_id=run_id - ) - assistant_functions = { - "create_intensity_plot", - "perform_dimensionality_reduction", - "create_sample_histogram", - "st.session_state.dataset.plot_volcano", - "st.session_state.dataset.plot_sampledistribution", - "st.session_state.dataset.plot_intensity", - "st.session_state.dataset.plot_pca", - "st.session_state.dataset.plot_umap", - "st.session_state.dataset.plot_tsne", - "get_enrichment_data", - } - if run_status.status == "completed": - print("Run is completed!") - if artefacts: - print("Returning artefacts") - return artefacts - break - elif run_status.status == "requires_action": - print("requires_action", run_status) - print( - [ - st.session_state.plotting_options[i]["function"].__name__ - for i in st.session_state.plotting_options - ] - ) - tool_calls = run_status.required_action.submit_tool_outputs.tool_calls - tool_outputs = [] - for tool_call in tool_calls: - print("### calling:", tool_call.function.name) - if tool_call.function.name == "get_gene_function": - print( - type(tool_call.function.arguments), tool_call.function.arguments - ) - prompt = json.loads(tool_call.function.arguments)["gene_name"] - gene_function = get_gene_function(prompt) - tool_outputs.append( - { - "tool_call_id": tool_call.id, - "output": gene_function, - }, - ) - elif ( - tool_call.function.name - in [ - st.session_state.plotting_options[i]["function"].__name__ - for i in st.session_state.plotting_options - ] - or tool_call.function.name in assistant_functions - ): - args = tool_call.function.arguments - args = turn_args_to_float(args) - print(f"{tool_call.function.name}(**{args})") - artefact = eval(f"{tool_call.function.name}(**{args})") - artefact_json = artefact.to_json() - - tool_outputs.append( - {"tool_call_id": tool_call.id, "output": artefact_json}, - ) - artefacts.append(artefact) - - if tool_outputs: - _run = client.beta.threads.runs.submit_tool_outputs( - thread_id=thread_id, run_id=run_id, tool_outputs=tool_outputs - ) - else: - print("Run is not yet completed. Waiting...", run_status.status, run_id) - time.sleep(check_interval) - - -def send_message_save_thread(client: openai.OpenAI, message: str) -> Optional[list]: - """ - Send a message to the OpenAI ChatGPT model and save the thread in the session state, return plots if GPT called a function to create them. - - Args: - client (openai.OpenAI): The OpenAI client. - message (str): The message to send to the ChatGPT model. - - Returns: - Optional[list]: A list of plots, if any. - """ - message = client.beta.threads.messages.create( - thread_id=st.session_state["thread_id"], role="user", content=message - ) - - run = client.beta.threads.runs.create( - thread_id=st.session_state["thread_id"], - assistant_id=st.session_state["assistant"].id, - ) - try: - plots = wait_for_run_completion(client, st.session_state["thread_id"], run.id) - except KeyError as e: - print(e) - plots = None - messages = client.beta.threads.messages.list( - thread_id=st.session_state["thread_id"] - ) - st.session_state.messages = [] - for num, message in enumerate(messages.data[::-1]): - role = message.role - if message.content: - content = message.content[0].text.value - else: - content = "Sorry, I was unable to process this message. Try again or change your request." - st.session_state.messages.append({"role": role, "content": content}) - if not plots: - return - return plots - - -def try_to_set_api_key(api_key: str = None) -> None: - """ - Checks if the OpenAI API key is available in the environment / system variables. - If the API key is not available, saves the key to secrets.toml in the repository root directory. - - Args: - api_key (str, optional): The OpenAI API key. Defaults to None. - - Returns: - None - """ - if api_key and "api_key" not in st.session_state: - st.session_state["openai_api_key"] = api_key - secret_path = Path(st.secrets._file_paths[-1]) - secret_path.parent.mkdir(parents=True, exist_ok=True) - with open(secret_path, "w") as f: - f.write(f'openai_api_key = "{api_key}"') - openai.OpenAI.api_key = api_key - return - try: - openai.OpenAI.api_key = st.secrets["openai_api_key"] - except: - st.write( - "OpenAI API key not found in environment variables. Please enter your API key to continue." - ) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py new file mode 100644 index 00000000..3935db06 --- /dev/null +++ b/alphastats/gui/utils/ollama_utils.py @@ -0,0 +1,365 @@ +from openai import OpenAI +from typing import List, Dict, Any, Optional, Tuple +import json +import streamlit as st +import pandas as pd +from IPython.display import display, Markdown, HTML +import plotly.io as pio +from alphastats.gui.utils.gpt_helper import ( + get_assistant_functions, + get_subgroups_for_each_group, + turn_args_to_float, + perform_dimensionality_reduction, + get_general_assistant_functions, +) +from alphastats.gui.utils.artefacts import ArtifactManager +from alphastats.gui.utils.uniprot_utils import get_gene_function +from alphastats.gui.utils.enrichment_analysis import get_enrichment_data + +class LLMIntegration: + """ + A class to integrate different Language Model APIs and handle chat interactions. + + This class provides methods to interact with GPT and Ollama APIs, manage conversation + history, handle function calls, and manage artifacts. + + Parameters + ---------- + api_type : str, optional + The type of API to use ('gpt' or 'ollama'), by default 'gpt' + base_url : str, optional + The base URL for the API, by default None + api_key : str, optional + The API key for authentication, by default None + dataset : Any, optional + The dataset to be used in the conversation, by default None + metadata : Any, optional + Metadata associated with the dataset, by default None + + Attributes + ---------- + api_type : str + The type of API being used + client : OpenAI + The OpenAI client instance + model : str + The name of the language model being used + messages : List[Dict[str, Any]] + The conversation history + dataset : Any + The dataset being used + metadata : Any + Metadata associated with the dataset + tools : List[Dict[str, Any]] + List of available tools or functions + artifacts : Dict[str, Any] + Dictionary to store conversation artifacts + """ + + def __init__(self, api_type: str = 'gpt', base_url: Optional[str] = None, api_key: Optional[str] = None, dataset=None, metadata=None): + self.api_type = api_type + if api_type == 'ollama': + self.client = OpenAI(base_url=base_url or 'http://localhost:11434/v1', api_key='ollama') + self.model = "llama3.1:70b" + else: + self.client = OpenAI(api_key=api_key) + # self.model = "gpt-4-0125-preview" + self.model = "gpt-4o" + + self.messages = [] + self.dataset = dataset + self.metadata = metadata + self.tools = self._get_tools() + self.artifacts = {} + self.artifact_manager = ArtifactManager() + self.message_artifact_map = {} + + def set_api_key(self, api_key: str): + """ + Set the API key for GPT API. + + Parameters + ---------- + api_key : str + The API key to be set + + Returns + ------- + None + """ + if self.api_type == 'gpt': + self.client.api_key = api_key + st.secrets["openai_api_key"] = api_key + + def _get_tools(self) -> List[Dict[str, Any]]: + """ + Get the list of available tools or functions. + + Returns + ------- + List[Dict[str, Any]] + A list of dictionaries describing the available tools + """ + general_tools = get_general_assistant_functions() + return general_tools + + def truncate_conversation_history(self, max_tokens: int = 100000): + """ + Truncate the conversation history to stay within token limits. + + Parameters + ---------- + max_tokens : int, optional + The maximum number of tokens to keep in history, by default 100000 + + Returns + ------- + None + """ + total_tokens = sum(len(m['content'].split()) for m in self.messages) + while total_tokens > max_tokens and len(self.messages) > 1: + removed_message = self.messages.pop(0) + total_tokens -= len(removed_message['content'].split()) + + def update_session_state(self): + """ + Update the Streamlit session state with current conversation data. + + Returns + ------- + None + """ + st.session_state['messages'] = self.messages + st.session_state['artifacts'] = self.artifacts + + def parse_model_response(self, response: Any) -> Dict[str, Any]: + """ + Parse the response from the language model. + + Parameters + ---------- + response : Any + The raw response from the language model + + Returns + ------- + Dict[str, Any] + A dictionary containing the parsed content and tool calls + """ + return { + 'content': response.choices[0].message.content, + 'tool_calls': response.choices[0].message.tool_calls + } + + def execute_function(self, function_name: str, function_args: Dict[str, Any]) -> Any: + """ + Execute a function based on its name and arguments. + + Parameters + ---------- + function_name : str + The name of the function to execute + function_args : Dict[str, Any] + The arguments to pass to the function + + Returns + ------- + Any + The result of the function execution + + Raises + ------ + ValueError + If the function is not implemented or the dataset is not available + """ + try: + if function_name == "get_gene_function": + return get_gene_function(**function_args) + elif function_name == "get_enrichment_data": + return get_enrichment_data(**function_args) + elif function_name == "perform_dimensionality_reduction": + return perform_dimensionality_reduction(**function_args) + elif function_name.startswith("plot_") or function_name.startswith("perform_"): + plot_function = getattr(self.dataset, function_name.split('.')[-1], None) + if plot_function: + return plot_function(**function_args) + raise ValueError(f"Function {function_name} not implemented or dataset not available") + except Exception as e: + return f"Error executing {function_name}: {str(e)}" + + def handle_function_calls(self, tool_calls: List[Any], ) -> Dict[str, Any]: + """ + Handle function calls from the language model and manage resulting artifacts. + + Parameters + ---------- + tool_calls : List[Any] + List of tool calls from the language model + + Returns + ------- + Dict[str, Any] + The parsed response after handling function calls, including any new artifacts + + """ + function_messages = [] + new_artifacts = {} + print(len(tool_calls)) + funcs_and_args = '\n'.join([f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}" for tool_call in tool_calls]) + self.messages.append({ + "role": "assistant", + "content": funcs_and_args, + "tool_calls": tool_calls + }) + + for tool_call in tool_calls: + print(tool_call.id) + function_name = tool_call.function.name + print(f"Calling function: {function_name}") + function_args = json.loads(tool_call.function.arguments) + + function_result = self.execute_function(function_name, function_args) + artifact_id = f"{function_name}_{tool_call.id}" + + new_artifacts[artifact_id] = function_result + + self.messages.append({ + "role": "tool", + "content": json.dumps({"result": str(function_result), "artifact_id": artifact_id}), + "tool_call_id": tool_call.id + }) + post_artefact_message_idx = len(self.messages) + self.artifacts[post_artefact_message_idx] = new_artifacts.values() + response = self.client.chat.completions.create( + model=self.model, + messages=self.messages, + tools=self.tools, + ) + parsed_response = self.parse_model_response(response) + parsed_response['new_artifacts'] = new_artifacts + + return parsed_response + + def chat_completion(self, prompt: str, role: str = "user") -> Tuple[str, Dict[str, Any]]: + """ + Generate a chat completion based on the given prompt and manage any resulting artifacts. + + Parameters + ---------- + prompt : str + The user's input prompt + role : str, optional + The role of the message sender, by default "user" + + Returns + ------- + Tuple[str, Dict[str, Any]] + A tuple containing the generated response and a dictionary of new artifacts + + Raises + ------ + ArithmeticError + If there's an error in chat completion + """ + self.messages.append({"role": role, "content": prompt}) + self.truncate_conversation_history() + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=self.messages, + tools=self.tools, + ) + + parsed_response = self.parse_model_response(response) + new_artifacts = {} + + if parsed_response['tool_calls']: + parsed_response = self.handle_function_calls(parsed_response['tool_calls']) + new_artifacts = parsed_response.pop('new_artifacts', {}) + + self.messages.append({"role": "assistant", "content": parsed_response['content']}) + self.update_session_state() + return parsed_response['content'], new_artifacts + + except ArithmeticError as e: + error_message = f"Error in chat completion: {str(e)}" + self.messages.append({"role": "system", "content": error_message}) + self.update_session_state() + return error_message, {} + + def switch_backend(self, new_api_type: str, base_url: Optional[str] = None, api_key: Optional[str] = None): + """ + Switch between different API backends. + + Parameters + ---------- + new_api_type : str + The new API type to switch to ('gpt' or 'ollama') + base_url : str, optional + The base URL for the new API, by default None + api_key : str, optional + The API key for the new API, by default None + + Returns + ------- + None + """ + self.__init__(api_type=new_api_type, base_url=base_url, api_key=api_key, dataset=self.dataset, metadata=self.metadata) + + + def display_chat_history(self): + """ + Display the chat history, including messages, function calls, and associated artifacts. + + This method renders the chat history in a structured format, aligning artifacts + with their corresponding messages and the model's interpretation. + + Returns + ------- + None + """ + for i, message in enumerate(self.messages): + role = message['role'].capitalize() + content = message['content'] + + if role == 'Assistant' and 'tool_calls' in message: + display(Markdown(f"**{role}**: {content}")) + for tool_call in message['tool_calls']: + function_name = tool_call.function.name + function_args = tool_call.function.arguments + display(Markdown(f"*Function Call*: `{function_name}`")) + display(Markdown(f"*Arguments*: ```json\n{function_args}\n```")) + + elif role == 'Tool': + tool_result = json.loads(content) + artifact_id = tool_result.get('artifact_id') + if artifact_id and artifact_id in self.artifacts: + artifact = self.artifacts[artifact_id] + display(Markdown(f"**Function Result** (Artifact ID: {artifact_id}):")) + self._display_artifact(artifact) + else: + display(Markdown(f"**Function Result**: {content}")) + + else: + display(Markdown(f"**{role}**: {content}")) + + def _display_artifact(self, artifact): + """ + Display an artifact based on its type. + + Parameters + ---------- + artifact : Any + The artifact to display + + Returns + ------- + None + """ + if isinstance(artifact, pd.DataFrame): + display(artifact) + elif str(type(artifact)) == "": + display(HTML(pio.to_html(artifact, full_html=False))) + else: + display(Markdown(f"```\n{str(artifact)}\n```")) diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py new file mode 100644 index 00000000..7a4bd255 --- /dev/null +++ b/alphastats/gui/utils/openai_utils.py @@ -0,0 +1,205 @@ +from typing import Optional, List +from pathlib import Path + +import time +import json + + +import openai +import streamlit as st + +try: + from alphastats.gui.utils.gpt_helper import ( + turn_args_to_float, + get_assistant_functions, + display_proteins, + get_subgroups_for_each_group, + turn_args_to_float, + perform_dimensionality_reduction, + get_general_assistant_functions, + ) + from alphastats.gui.utils.uniprot_utils import ( + get_gene_function, + ) + from alphastats.gui.utils.enrichment_analysis import get_enrichment_data + +except ModuleNotFoundError: + from utils.gpt_helper import ( + turn_args_to_float, + get_assistant_functions, + display_proteins, + get_subgroups_for_each_group, + turn_args_to_float, + perform_dimensionality_reduction, + get_general_assistant_functions, + ) + from utils.uniprot_utils import ( + get_gene_function, + ) + from utils.openai_utils import ( + wait_for_run_completion, + ) + from utils.enrichment_analysis import get_enrichment_data + + +def wait_for_run_completion( + client: openai.OpenAI, thread_id: int, run_id: int, check_interval: int = 2 +) -> Optional[List]: + """ + Wait for a run and function calls to complete and return the plots, if they were created by function calling. + + Args: + client (openai.OpenAI): The OpenAI client. + thread_id (int): The thread ID. + run_id (int): The run ID. + check_interval (int, optional): The interval to check for run completion. Defaults to 2 seconds. + + Returns: + Optional[list]: A list of plots, if any. + """ + artefacts = [] + while True: + run_status = client.beta.threads.runs.retrieve( + thread_id=thread_id, run_id=run_id + ) + print(run_status.status, run_id, run_status.required_action) + assistant_functions = { + "create_intensity_plot", + "perform_dimensionality_reduction", + "create_sample_histogram", + "st.session_state.dataset.plot_volcano", + "st.session_state.dataset.plot_sampledistribution", + "st.session_state.dataset.plot_intensity", + "st.session_state.dataset.plot_pca", + "st.session_state.dataset.plot_umap", + "st.session_state.dataset.plot_tsne", + "get_enrichment_data", + } + if run_status.status == "completed": + print("Run is completed!") + if artefacts: + print("Returning artefacts") + return artefacts + break + elif run_status.status == "requires_action": + print("requires_action", run_status) + print( + [ + st.session_state.plotting_options[i]["function"].__name__ + for i in st.session_state.plotting_options + ] + ) + tool_calls = run_status.required_action.submit_tool_outputs.tool_calls + tool_outputs = [] + for tool_call in tool_calls: + print("### calling:", tool_call.function.name) + if tool_call.function.name == "get_gene_function": + print( + type(tool_call.function.arguments), tool_call.function.arguments + ) + prompt = json.loads(tool_call.function.arguments)["gene_name"] + gene_function = get_gene_function(prompt) + tool_outputs.append( + { + "tool_call_id": tool_call.id, + "output": gene_function, + }, + ) + elif ( + tool_call.function.name + in [ + st.session_state.plotting_options[i]["function"].__name__ + for i in st.session_state.plotting_options + ] + or tool_call.function.name in assistant_functions + ): + args = tool_call.function.arguments + args = turn_args_to_float(args) + print(f"{tool_call.function.name}(**{args})") + artefact = eval(f"{tool_call.function.name}(**{args})") + artefact_json = artefact.to_json() + + tool_outputs.append( + {"tool_call_id": tool_call.id, "output": artefact_json}, + ) + artefacts.append(artefact) + + if tool_outputs: + _run = client.beta.threads.runs.submit_tool_outputs( + thread_id=thread_id, run_id=run_id, tool_outputs=tool_outputs + ) + print("submitted") + else: + print("Run is not yet completed. Waiting...", run_status.status, run_id) + time.sleep(check_interval) + + +def send_message_save_thread( + client: openai.OpenAI, + message: str, + assistant_id: str, + thread_id: str, + storing_variable: str = "messages", +) -> Optional[List]: + """ + Send a message to the OpenAI ChatGPT model and save the thread in the session state, return plots if GPT called a function to create them. + + Args: + client (openai.OpenAI): The OpenAI client. + message (str): The message to send to the ChatGPT model. + + Returns: + Optional[list]: A list of plots, if any. + """ + message = client.beta.threads.messages.create( + thread_id=thread_id, role="user", content=message + ) + + run = client.beta.threads.runs.create( + thread_id=thread_id, + assistant_id=assistant_id, + ) + try: + plots = wait_for_run_completion(client, thread_id, run.id) + except KeyError as e: + print(e) + plots = None + messages = client.beta.threads.messages.list(thread_id=thread_id) + st.session_state[storing_variable] = [] + for num, message in enumerate(messages.data[::-1]): + role = message.role + if message.content: + content = message.content[0].text.value + else: + content = "Sorry, I was unable to process this message. Try again or change your request." + st.session_state[storing_variable].append({"role": role, "content": content}) + if not plots: + return + return plots + + +def try_to_set_api_key(api_key: str = None) -> None: + """ + Checks if the OpenAI API key is available in the environment / system variables. + If the API key is not available, saves the key to secrets.toml in the repository root directory. + + Args: + api_key (str, optional): The OpenAI API key. Defaults to None. + + Returns: + None + """ + if api_key and "api_key" not in st.session_state: + st.session_state["openai_api_key"] = api_key + secret_path = Path(st.secrets._file_paths[-1]) + secret_path.parent.mkdir(parents=True, exist_ok=True) + with open(secret_path, "w") as f: + f.write(f'openai_api_key = "{api_key}"') + openai.OpenAI.api_key = api_key + return + try: + openai.OpenAI.api_key = st.secrets["openai_api_key"] + except: + st.write( + "OpenAI API key not found in environment variables. Please enter your API key to continue." + ) diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index 9d3c09c4..a4944067 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -64,7 +64,7 @@ "function": st.session_state.dataset.plot_volcano, }, "Clustermap": {"function": st.session_state.dataset.plot_clustermap}, - "Dendrogram": {"function": st.session_state.dataset.plot_dendrogram}, + # "Dendrogram": {"function": st.session_state.dataset.plot_dendrogram}, } statistic_options = { @@ -121,3 +121,17 @@ "function": st.session_state.dataset.ancova, }, } + +interpretation_options = { + "Volcano Plot": { + "between_two_groups": True, + "function": st.session_state.dataset.plot_volcano, + }, + "Differential Expression Analysis - T-test": { + "between_two_groups": True, + "function": st.session_state.dataset.diff_expression_analysis, + }, + "Differential Expression Analysis - Wald-test": { + "between_two_groups": True, + "function": st.session_state.dataset.diff_expression_analysis, + },} diff --git a/alphastats/gui/utils/uniprot_utils.py b/alphastats/gui/utils/uniprot_utils.py new file mode 100644 index 00000000..7d3e2dc4 --- /dev/null +++ b/alphastats/gui/utils/uniprot_utils.py @@ -0,0 +1,328 @@ +from typing import Union, List, Dict +import requests + + +import streamlit as st + + +uniprot_fields = [ + # Names & Taxonomy + "gene_names", + "organism_name", + "protein_name", + # Function + "cc_function", + "cc_catalytic_activity", + "cc_activity_regulation", + "cc_pathway", + "kinetics", + "ph_dependence", + "temp_dependence", + # Interaction + "cc_interaction", + "cc_subunit", + # Expression + "cc_tissue_specificity", + "cc_developmental_stage", + "cc_induction", + # Gene Ontology (GO) + "go", + "go_p", + "go_c", + "go_f", + # Pathology & Biotech + "cc_disease", + "cc_disruption_phenotype", + "cc_pharmaceutical", + "ft_mutagen", + "ft_act_site", + # Structure + "cc_subcellular_location", + "organelle", + "absorption", + # Publications + "lit_pubmed_id", + # Family & Domains + "protein_families", + "cc_domain", + "ft_domain", + # Protein-Protein Interaction Databases + "xref_biogrid", + "xref_intact", + "xref_mint", + "xref_string", + # Chemistry Databases + "xref_drugbank", + "xref_chembl", + "reviewed", +] + + +def get_uniprot_data( + gene_name: str, + organism_id: str, + fields: List[str] = uniprot_fields, +) -> Dict: + """ + Get data from UniProt for a given gene name and organism ID. + + Args: + gene_name (str): The gene name to search for. + organism_id (str, optional): The organism ID to search in. Defaults to streamlit session state. + fields (list[str], optional): The fields to retrieve from UniProt. Defaults to uniprot_fields defined above. + + Returns: + dict: The data retrieved from UniProt. + """ + base_url = "https://rest.uniprot.org/uniprotkb/search" + query = f"(gene:{gene_name}) AND (reviewed:true) AND (organism_id:{organism_id})" + + response = requests.get( + base_url, params={"query": query, "format": "json", "fields": ",".join(fields)} + ) + + if response.status_code != 200: + print( + f"Failed to retrieve data for {gene_name}. Status code: {response.status_code}" + ) + print(response.text) + return None + + data = response.json() + + if not data.get("results"): + print(f"No UniProt entry found for {gene_name}") + return None + + # Return the first result as a dictionary (assuming it's the most relevant) + data = data["results"][0] + # for key, value in data.items(): + # print(f"data - {key}: {value}, {type(value)}") + return data + + +def extract_data(data: Dict) -> Dict: + """ + Extract relevant data from a UniProt entry. + + Args: + data (dict): The data retrieved from UniProt. + + Returns: + dict: The extracted data. + """ + extracted = {} + + # 1. Entry Type + extracted["entryType"] = data.get("entryType", None) + + # 2. Primary Accession + extracted["primaryAccession"] = data.get("primaryAccession", None) + + # 3. Organism Details + organism = data.get("organism", {}) + extracted["organism"] = { + "scientificName": organism.get("scientificName", None), + "commonName": organism.get("commonName", None), + "taxonId": organism.get("taxonId", None), + "lineage": organism.get("lineage", []), + } + + # 4. Protein Details + protein_description = data.get("proteinDescription", {}) + recommended_name = ( + protein_description.get("recommendedName", {}) + .get("fullName", {}) + .get("value", None) + ) + alternative_names = [ + alt_name["fullName"]["value"] + for alt_name in protein_description.get("alternativeNames", []) + ] + extracted["protein"] = { + "recommendedName": recommended_name, + "alternativeNames": alternative_names, + "flag": protein_description.get("flag", None), + } + + # 5. Gene Details + genes = data.get("genes", [{}])[0] + extracted["genes"] = { + "geneName": genes.get("geneName", {}).get("value", None), + "synonyms": [syn["value"] for syn in genes.get("synonyms", [])], + } + + # 6. Functional Comments + function_comments = [ + text["value"] + for comment in data.get("comments", []) + if comment["commentType"] == "FUNCTION" + for text in comment.get("texts", []) + ] + extracted["functionComments"] = function_comments + + # 7. Subunit Details + subunit_comments = [ + text["value"] + for comment in data.get("comments", []) + if comment["commentType"] == "SUBUNIT" + for text in comment.get("texts", []) + ] + extracted["subunitComments"] = subunit_comments + + # 8. Protein Interactions + interactions = [] + + for c in data.get("comments", []): + if c["commentType"] == "INTERACTION": + for interaction in c.get("interactions", []): + interactantOne = interaction["interactantOne"].get( + "uniProtKBAccession", None + ) + interactantTwo = interaction["interactantTwo"].get( + "uniProtKBAccession", None + ) + + # Only append if both interactants are present + if interactantOne and interactantTwo: + interactions.append( + { + "interactantOne": interactantOne, + "interactantTwo": interactantTwo, + "numberOfExperiments": interaction["numberOfExperiments"], + } + ) + extracted["interactions"] = interactions + + # 9. Subcellular Locations + subcellular_locations_comments = [ + c["subcellularLocations"] + for c in data.get("comments", []) + if c["commentType"] == "SUBCELLULAR LOCATION" + ] + locations = [ + location["location"]["value"] + for locations_comment in subcellular_locations_comments + for location in locations_comment + ] + extracted["subcellularLocations"] = locations + + tissue_specificities = [ + text["value"] + for comment in data.get("comments", []) + if comment["commentType"] == "TISSUE SPECIFICITY" + for text in comment.get("texts", []) + ] + extracted["tissueSpecificity"] = tissue_specificities + + # 11. Protein Features + features = [ + { + "type": feature["type"], + "description": feature["description"], + "location_start": feature["location"]["start"]["value"], + "location_end": feature["location"]["end"]["value"], + } + for feature in data.get("features", []) + ] + extracted["features"] = features + + # 12. References + references = [ + { + "authors": ref["citation"].get("authors", []), + "title": ref["citation"].get("title", ""), + "journal": ref["citation"].get("journal", ""), + "publicationDate": ref["citation"].get("publicationDate", ""), + "comments": [c["value"] for c in ref.get("referenceComments", [])], + } + for ref in data.get("references", []) + ] + extracted["references"] = references + + # 13. Cross References + cross_references = [ + { + "database": ref["database"], + "id": ref["id"], + "properties": { + prop["key"]: prop["value"] for prop in ref.get("properties", []) + }, + } + for ref in data.get("uniProtKBCrossReferences", []) + ] + extracted["crossReferences"] = cross_references + return extracted + + +def get_info(genes_list: List[str], organism_id: str) -> List[str]: + """ + Get info from UniProt for a list of genes. + + Args: + genes_list (list[str]): A list of gene names to search for. + organism_id (str, optional): The Uniprot organism ID to search in. + + Returns: + list[str]: A list of gene functions.""" + results = {} + + for gene in genes_list: + result = get_uniprot_data(gene, organism_id) + + # If result is retrieved for the gene, extract data and continue with the next gene + if result: + results[gene] = extract_data(result) + continue + + # If no result is retrieved for the gene and the gene string does not contain a ";", continue with the next gene + if ";" not in gene: + print(f"Failed to retrieve data for {gene}") + continue + + # If no result is retrieved for the gene and the gene string contains a ";", try to get data for each split gene + split_genes = gene.split(";") + for split_gene in split_genes: + result = get_uniprot_data(split_gene.strip(), organism_id) + if result: + print( + f"Successfully retrieved data for {split_gene} (from split gene: {gene})" + ) + results[gene] = extract_data(result) + break + + # If still no result after trying split genes + if not result: + print(f"Failed to retrieve data for all parts of split gene: {gene}") + # TODO: Handle this case further if necessary + + gene_functions = [] + for gene in results: + if results[gene]["functionComments"]: + gene_functions.append(f"{gene}: {results[gene]['functionComments']}") + else: + gene_functions.append(f"{gene}: ?") + + return gene_functions + + +def get_gene_function(gene_name: Union[str, Dict], organism_id=9606) -> str: + """ + Get the gene function and description by UniProt lookup of gene identifier / name. + + Args: + gene_name (Union[str, dict]): Gene identifier / name for UniProt lookup. + organism_id (str): The UniProt organism ID to search in. + + Returns: + str: The gene function and description. + """ + if "organism" in st.session_state: + organism_id = st.session_state["organism"] + if type(gene_name) == dict: + gene_name = gene_name["gene_name"] + result = get_uniprot_data(gene_name, organism_id) + if result and extract_data(result)["functionComments"]: + return str(extract_data(result)["functionComments"]) + else: + return "No data found" From 8e9b4048a56d9330f2e7d366d59138de0bc801cb Mon Sep 17 00:00:00 2001 From: Mikhail Lebedev Date: Wed, 28 Aug 2024 22:57:47 +0200 Subject: [PATCH 02/13] ruff --- alphastats/gui/pages/02_Import Data.py | 6 +- alphastats/gui/pages/05_LLM.py | 74 +++--------- alphastats/gui/utils/ollama_utils.py | 154 ++++++++++++++++--------- alphastats/gui/utils/openai_utils.py | 18 +-- alphastats/gui/utils/options.py | 3 +- alphastats/gui/utils/uniprot_utils.py | 2 +- 6 files changed, 126 insertions(+), 131 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index 6112fc69..ba251699 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -45,7 +45,11 @@ def load_options(): - from alphastats.gui.utils.options import plotting_options, statistic_options, interpretation_options + from alphastats.gui.utils.options import ( + plotting_options, + statistic_options, + interpretation_options, + ) st.session_state["plotting_options"] = plotting_options st.session_state["statistic_options"] = statistic_options diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index d324b500..fde8f1fc 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -1,18 +1,12 @@ import os import streamlit as st import pandas as pd -from openai import OpenAI, OpenAIError, AuthenticationError +from openai import AuthenticationError try: from alphastats.gui.utils.analysis_helper import ( check_if_options_are_loaded, - convert_df, - display_df, display_figure, - download_figure, - download_preprocessing_info, - get_analysis, - load_options, save_plot_to_session_state, gui_volcano_plot_differential_expression_analysis, helper_compare_two_groups, @@ -21,18 +15,9 @@ get_assistant_functions, display_proteins, get_subgroups_for_each_group, - turn_args_to_float, - perform_dimensionality_reduction, get_general_assistant_functions, ) - from alphastats.gui.utils.uniprot_utils import ( - get_gene_function, - get_info, - ) - from alphastats.gui.utils.enrichment_analysis import get_enrichment_data from alphastats.gui.utils.openai_utils import ( - wait_for_run_completion, - send_message_save_thread, try_to_set_api_key, ) from alphastats.gui.utils.ollama_utils import LLMIntegration @@ -41,13 +26,7 @@ except ModuleNotFoundError: from utils.analysis_helper import ( check_if_options_are_loaded, - convert_df, - display_df, display_figure, - download_figure, - download_preprocessing_info, - get_analysis, - load_options, save_plot_to_session_state, gui_volcano_plot_differential_expression_analysis, helper_compare_two_groups, @@ -56,18 +35,9 @@ get_assistant_functions, display_proteins, get_subgroups_for_each_group, - turn_args_to_float, - perform_dimensionality_reduction, get_general_assistant_functions, ) - from utils.uniprot_utils import ( - get_gene_function, - get_info, - ) - from utils.enrichment_analysis import get_enrichment_data from utils.openai_utils import ( - wait_for_run_completion, - send_message_save_thread, try_to_set_api_key, ) from utils.ollama_utils import LLMIntegration @@ -101,15 +71,15 @@ def select_analysis(): # set background to white so downloaded pngs dont have grey background -styl = f""" +styl = """ """ st.markdown(styl, unsafe_allow_html=True) @@ -144,11 +114,11 @@ def select_analysis(): with c1: method = select_analysis() chosen_parameter_dict = helper_compare_two_groups() - + st.session_state["api_type"] = st.selectbox( - "Select LLM", - ["gpt4o", "llama3.1 70b"], - index=0 if st.session_state["api_type"] == "gpt4o" else 1 + "Select LLM", + ["gpt4o", "llama3.1 70b"], + index=0 if st.session_state["api_type"] == "gpt4o" else 1, ) base_url = "http://localhost:11434/v1" if st.session_state["api_type"] == "gpt4o": @@ -333,19 +303,21 @@ def select_analysis(): try: if st.session_state["api_type"] == "gpt4o": st.session_state["llm_integration"] = LLMIntegration( - api_type='gpt', + api_type="gpt", api_key=st.secrets["openai_api_key"], dataset=st.session_state["dataset"], - metadata=st.session_state["dataset"].metadata + metadata=st.session_state["dataset"].metadata, ) else: st.session_state["llm_integration"] = LLMIntegration( - api_type='ollama', + api_type="ollama", base_url=base_url, dataset=st.session_state["dataset"], - metadata=st.session_state["dataset"].metadata + metadata=st.session_state["dataset"].metadata, ) - st.success(f"{st.session_state['api_type'].upper()} integration initialized successfully!") + st.success( + f"{st.session_state['api_type'].upper()} integration initialized successfully!" + ) except AuthenticationError: st.warning( "Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX" @@ -367,15 +339,12 @@ def select_analysis(): subgroups_for_each_group=get_subgroups_for_each_group( st.session_state["dataset"].metadata ), - ) + ), ] if "artifacts" not in st.session_state: st.session_state["artifacts"] = {} -import time -start = time.time() -# 4o 23.52 -# llama 239.94 + if ( st.session_state["gpt_submitted_counter"] < st.session_state["gpt_submitted_clicked"] @@ -384,10 +353,7 @@ def select_analysis(): "gpt_submitted_clicked" ] st.session_state["artifacts"] = {} - llm.messages = [{ - "role": "system", - "content": st.session_state["instructions"] - }] + llm.messages = [{"role": "system", "content": st.session_state["instructions"]}] response = llm.chat_completion(st.session_state["user_prompt"]) if st.session_state["gpt_submitted_clicked"] > 0: @@ -404,7 +370,5 @@ def select_analysis(): for artefact in st.session_state["artifacts"][num]: if isinstance(artefact, pd.DataFrame): st.dataframe(artefact) - elif "plotly" in str(type(artefact)): + elif "plotly" in str(type(artefact)): st.plotly_chart(artefact) - stop = time.time() - print("time", stop-start, "\n\n\n\n") diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 3935db06..8f0ca124 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -6,9 +6,6 @@ from IPython.display import display, Markdown, HTML import plotly.io as pio from alphastats.gui.utils.gpt_helper import ( - get_assistant_functions, - get_subgroups_for_each_group, - turn_args_to_float, perform_dimensionality_reduction, get_general_assistant_functions, ) @@ -16,6 +13,7 @@ from alphastats.gui.utils.uniprot_utils import get_gene_function from alphastats.gui.utils.enrichment_analysis import get_enrichment_data + class LLMIntegration: """ A class to integrate different Language Model APIs and handle chat interactions. @@ -56,16 +54,25 @@ class LLMIntegration: Dictionary to store conversation artifacts """ - def __init__(self, api_type: str = 'gpt', base_url: Optional[str] = None, api_key: Optional[str] = None, dataset=None, metadata=None): + def __init__( + self, + api_type: str = "gpt", + base_url: Optional[str] = None, + api_key: Optional[str] = None, + dataset=None, + metadata=None, + ): self.api_type = api_type - if api_type == 'ollama': - self.client = OpenAI(base_url=base_url or 'http://localhost:11434/v1', api_key='ollama') + if api_type == "ollama": + self.client = OpenAI( + base_url=base_url or "http://localhost:11434/v1", api_key="ollama" + ) self.model = "llama3.1:70b" else: self.client = OpenAI(api_key=api_key) # self.model = "gpt-4-0125-preview" self.model = "gpt-4o" - + self.messages = [] self.dataset = dataset self.metadata = metadata @@ -87,7 +94,7 @@ def set_api_key(self, api_key: str): ------- None """ - if self.api_type == 'gpt': + if self.api_type == "gpt": self.client.api_key = api_key st.secrets["openai_api_key"] = api_key @@ -116,10 +123,10 @@ def truncate_conversation_history(self, max_tokens: int = 100000): ------- None """ - total_tokens = sum(len(m['content'].split()) for m in self.messages) + total_tokens = sum(len(m["content"].split()) for m in self.messages) while total_tokens > max_tokens and len(self.messages) > 1: removed_message = self.messages.pop(0) - total_tokens -= len(removed_message['content'].split()) + total_tokens -= len(removed_message["content"].split()) def update_session_state(self): """ @@ -129,8 +136,8 @@ def update_session_state(self): ------- None """ - st.session_state['messages'] = self.messages - st.session_state['artifacts'] = self.artifacts + st.session_state["messages"] = self.messages + st.session_state["artifacts"] = self.artifacts def parse_model_response(self, response: Any) -> Dict[str, Any]: """ @@ -147,11 +154,13 @@ def parse_model_response(self, response: Any) -> Dict[str, Any]: A dictionary containing the parsed content and tool calls """ return { - 'content': response.choices[0].message.content, - 'tool_calls': response.choices[0].message.tool_calls + "content": response.choices[0].message.content, + "tool_calls": response.choices[0].message.tool_calls, } - def execute_function(self, function_name: str, function_args: Dict[str, Any]) -> Any: + def execute_function( + self, function_name: str, function_args: Dict[str, Any] + ) -> Any: """ Execute a function based on its name and arguments. @@ -179,15 +188,24 @@ def execute_function(self, function_name: str, function_args: Dict[str, Any]) -> return get_enrichment_data(**function_args) elif function_name == "perform_dimensionality_reduction": return perform_dimensionality_reduction(**function_args) - elif function_name.startswith("plot_") or function_name.startswith("perform_"): - plot_function = getattr(self.dataset, function_name.split('.')[-1], None) + elif function_name.startswith("plot_") or function_name.startswith( + "perform_" + ): + plot_function = getattr( + self.dataset, function_name.split(".")[-1], None + ) if plot_function: return plot_function(**function_args) - raise ValueError(f"Function {function_name} not implemented or dataset not available") + raise ValueError( + f"Function {function_name} not implemented or dataset not available" + ) except Exception as e: return f"Error executing {function_name}: {str(e)}" - def handle_function_calls(self, tool_calls: List[Any], ) -> Dict[str, Any]: + def handle_function_calls( + self, + tool_calls: List[Any], + ) -> Dict[str, Any]: """ Handle function calls from the language model and manage resulting artifacts. @@ -202,33 +220,37 @@ def handle_function_calls(self, tool_calls: List[Any], ) -> Dict[str, Any]: The parsed response after handling function calls, including any new artifacts """ - function_messages = [] new_artifacts = {} - print(len(tool_calls)) - funcs_and_args = '\n'.join([f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}" for tool_call in tool_calls]) - self.messages.append({ - "role": "assistant", - "content": funcs_and_args, - "tool_calls": tool_calls - }) - + funcs_and_args = "\n".join( + [ + f"Calling function: {tool_call.function.name} with arguments: {tool_call.function.arguments}" + for tool_call in tool_calls + ] + ) + self.messages.append( + {"role": "assistant", "content": funcs_and_args, "tool_calls": tool_calls} + ) + for tool_call in tool_calls: - print(tool_call.id) function_name = tool_call.function.name print(f"Calling function: {function_name}") function_args = json.loads(tool_call.function.arguments) - + function_result = self.execute_function(function_name, function_args) artifact_id = f"{function_name}_{tool_call.id}" - + new_artifacts[artifact_id] = function_result - - self.messages.append({ - "role": "tool", - "content": json.dumps({"result": str(function_result), "artifact_id": artifact_id}), - "tool_call_id": tool_call.id - }) - post_artefact_message_idx = len(self.messages) + + self.messages.append( + { + "role": "tool", + "content": json.dumps( + {"result": str(function_result), "artifact_id": artifact_id} + ), + "tool_call_id": tool_call.id, + } + ) + post_artefact_message_idx = len(self.messages) self.artifacts[post_artefact_message_idx] = new_artifacts.values() response = self.client.chat.completions.create( model=self.model, @@ -236,11 +258,13 @@ def handle_function_calls(self, tool_calls: List[Any], ) -> Dict[str, Any]: tools=self.tools, ) parsed_response = self.parse_model_response(response) - parsed_response['new_artifacts'] = new_artifacts + parsed_response["new_artifacts"] = new_artifacts return parsed_response - def chat_completion(self, prompt: str, role: str = "user") -> Tuple[str, Dict[str, Any]]: + def chat_completion( + self, prompt: str, role: str = "user" + ) -> Tuple[str, Dict[str, Any]]: """ Generate a chat completion based on the given prompt and manage any resulting artifacts. @@ -270,17 +294,21 @@ def chat_completion(self, prompt: str, role: str = "user") -> Tuple[str, Dict[st messages=self.messages, tools=self.tools, ) - + parsed_response = self.parse_model_response(response) new_artifacts = {} - if parsed_response['tool_calls']: - parsed_response = self.handle_function_calls(parsed_response['tool_calls']) - new_artifacts = parsed_response.pop('new_artifacts', {}) - - self.messages.append({"role": "assistant", "content": parsed_response['content']}) + if parsed_response["tool_calls"]: + parsed_response = self.handle_function_calls( + parsed_response["tool_calls"] + ) + new_artifacts = parsed_response.pop("new_artifacts", {}) + + self.messages.append( + {"role": "assistant", "content": parsed_response["content"]} + ) self.update_session_state() - return parsed_response['content'], new_artifacts + return parsed_response["content"], new_artifacts except ArithmeticError as e: error_message = f"Error in chat completion: {str(e)}" @@ -288,7 +316,12 @@ def chat_completion(self, prompt: str, role: str = "user") -> Tuple[str, Dict[st self.update_session_state() return error_message, {} - def switch_backend(self, new_api_type: str, base_url: Optional[str] = None, api_key: Optional[str] = None): + def switch_backend( + self, + new_api_type: str, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + ): """ Switch between different API backends. @@ -305,8 +338,13 @@ def switch_backend(self, new_api_type: str, base_url: Optional[str] = None, api_ ------- None """ - self.__init__(api_type=new_api_type, base_url=base_url, api_key=api_key, dataset=self.dataset, metadata=self.metadata) - + self.__init__( + api_type=new_api_type, + base_url=base_url, + api_key=api_key, + dataset=self.dataset, + metadata=self.metadata, + ) def display_chat_history(self): """ @@ -320,23 +358,25 @@ def display_chat_history(self): None """ for i, message in enumerate(self.messages): - role = message['role'].capitalize() - content = message['content'] + role = message["role"].capitalize() + content = message["content"] - if role == 'Assistant' and 'tool_calls' in message: + if role == "Assistant" and "tool_calls" in message: display(Markdown(f"**{role}**: {content}")) - for tool_call in message['tool_calls']: + for tool_call in message["tool_calls"]: function_name = tool_call.function.name function_args = tool_call.function.arguments display(Markdown(f"*Function Call*: `{function_name}`")) display(Markdown(f"*Arguments*: ```json\n{function_args}\n```")) - elif role == 'Tool': + elif role == "Tool": tool_result = json.loads(content) - artifact_id = tool_result.get('artifact_id') + artifact_id = tool_result.get("artifact_id") if artifact_id and artifact_id in self.artifacts: artifact = self.artifacts[artifact_id] - display(Markdown(f"**Function Result** (Artifact ID: {artifact_id}):")) + display( + Markdown(f"**Function Result** (Artifact ID: {artifact_id}):") + ) self._display_artifact(artifact) else: display(Markdown(f"**Function Result**: {content}")) diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index 7a4bd255..d7f5edbd 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -11,27 +11,14 @@ try: from alphastats.gui.utils.gpt_helper import ( turn_args_to_float, - get_assistant_functions, - display_proteins, - get_subgroups_for_each_group, - turn_args_to_float, - perform_dimensionality_reduction, - get_general_assistant_functions, ) from alphastats.gui.utils.uniprot_utils import ( get_gene_function, ) - from alphastats.gui.utils.enrichment_analysis import get_enrichment_data except ModuleNotFoundError: from utils.gpt_helper import ( turn_args_to_float, - get_assistant_functions, - display_proteins, - get_subgroups_for_each_group, - turn_args_to_float, - perform_dimensionality_reduction, - get_general_assistant_functions, ) from utils.uniprot_utils import ( get_gene_function, @@ -39,9 +26,8 @@ from utils.openai_utils import ( wait_for_run_completion, ) - from utils.enrichment_analysis import get_enrichment_data - + def wait_for_run_completion( client: openai.OpenAI, thread_id: int, run_id: int, check_interval: int = 2 ) -> Optional[List]: @@ -199,7 +185,7 @@ def try_to_set_api_key(api_key: str = None) -> None: return try: openai.OpenAI.api_key = st.secrets["openai_api_key"] - except: + except KeyError: st.write( "OpenAI API key not found in environment variables. Please enter your API key to continue." ) diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index a4944067..3c365aa8 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -134,4 +134,5 @@ "Differential Expression Analysis - Wald-test": { "between_two_groups": True, "function": st.session_state.dataset.diff_expression_analysis, - },} + }, +} diff --git a/alphastats/gui/utils/uniprot_utils.py b/alphastats/gui/utils/uniprot_utils.py index 7d3e2dc4..736f11ab 100644 --- a/alphastats/gui/utils/uniprot_utils.py +++ b/alphastats/gui/utils/uniprot_utils.py @@ -319,7 +319,7 @@ def get_gene_function(gene_name: Union[str, Dict], organism_id=9606) -> str: """ if "organism" in st.session_state: organism_id = st.session_state["organism"] - if type(gene_name) == dict: + if isinstance(gene_name, dict): gene_name = gene_name["gene_name"] result = get_uniprot_data(gene_name, organism_id) if result and extract_data(result)["functionComments"]: From c06f4aa7e57c6c90fee2aeae9fb85c250cef6379 Mon Sep 17 00:00:00 2001 From: Mikhail Lebedev Date: Wed, 11 Sep 2024 17:44:54 +0200 Subject: [PATCH 03/13] ollama mvp --- .streamlit/secrets.toml | 1 - README.md | 2 ++ requirements.txt | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) delete mode 100644 .streamlit/secrets.toml diff --git a/.streamlit/secrets.toml b/.streamlit/secrets.toml deleted file mode 100644 index ce30c061..00000000 --- a/.streamlit/secrets.toml +++ /dev/null @@ -1 +0,0 @@ -openai_api_key = "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" \ No newline at end of file diff --git a/README.md b/README.md index 2d70f05b..19675670 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ In case you want to use the Graphical User Interface, use following command in t alphastats gui ``` +If you want to use local Large Language Models to help interpret the data, you would need to download and install ollama (https://ollama.com/download). + AlphaStats can be imported as a Python package into any Python script or notebook with the command `import alphastats`. A brief [Jupyter notebook tutorial](nbs/getting_started.ipynb) on how to use the API is also present in the [nbs folder](nbs). diff --git a/requirements.txt b/requirements.txt index bf108e2d..7586131e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,4 @@ nbformat>=5.0 biopython==1.83 openai==1.12.0 gprofiler-official==1.0.0 +ollama==0.3.3 From 273ca27e3d2c2772533d8fade8c54de120248381 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:16:59 +0200 Subject: [PATCH 04/13] formatting --- alphastats/DataSet_Preprocess.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index af192fe1..305106a2 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -197,7 +197,9 @@ def _normalization(self, method: str): normalized_array = qt.fit_transform(self.mat.values) elif method == "linear": - normalized_array = self._linear_normalization(self.mat.transpose()).transpose() + normalized_array = self._linear_normalization( + self.mat.transpose() + ).transpose() elif method == "vst": minmax = sklearn.preprocessing.MinMaxScaler() From bba57c10a2c31063266667f8cb1241fc75563d7e Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:17:50 +0200 Subject: [PATCH 05/13] adapt to latest changes --- alphastats/gui/pages/05_LLM.py | 4 +++- alphastats/gui/utils/options.py | 30 ++++++++++++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index fde8f1fc..201b2952 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -3,6 +3,8 @@ import pandas as pd from openai import AuthenticationError +from alphastats.gui.utils.options import interpretation_options + try: from alphastats.gui.utils.analysis_helper import ( check_if_options_are_loaded, @@ -55,7 +57,7 @@ def select_analysis(): method = st.selectbox( "Analysis", # options=["Volcano plot"], - options=list(st.session_state.interpretation_options.keys()), + options=list(interpretation_options(st.session_state).keys()), ) return method diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index f43c32e3..7142a101 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -184,17 +184,19 @@ def statistic_options(state): }, } -interpretation_options = { - "Volcano Plot": { - "between_two_groups": True, - "function": st.session_state.dataset.plot_volcano, - }, - "Differential Expression Analysis - T-test": { - "between_two_groups": True, - "function": st.session_state.dataset.diff_expression_analysis, - }, - "Differential Expression Analysis - Wald-test": { - "between_two_groups": True, - "function": st.session_state.dataset.diff_expression_analysis, - }, -} + +def interpretation_options(state): + return { + "Volcano Plot": { + "between_two_groups": True, + "function": state.dataset.plot_volcano, + }, + "Differential Expression Analysis - T-test": { + "between_two_groups": True, + "function": state.dataset.diff_expression_analysis, + }, + "Differential Expression Analysis - Wald-test": { + "between_two_groups": True, + "function": state.dataset.diff_expression_analysis, + }, + } From b9f510e510ff565a378da53485999c1a204dfced Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:18:11 +0200 Subject: [PATCH 06/13] minor changes --- alphastats/gui/utils/enrichment_analysis.py | 9 ++++----- alphastats/gui/utils/ollama_utils.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/alphastats/gui/utils/enrichment_analysis.py b/alphastats/gui/utils/enrichment_analysis.py index c2203013..e487b406 100644 --- a/alphastats/gui/utils/enrichment_analysis.py +++ b/alphastats/gui/utils/enrichment_analysis.py @@ -7,7 +7,7 @@ import pandas as pd -def get_functional_annotation_STRING(identifier, species_id="9606") -> pd.DataFrame: +def _get_functional_annotation_string(identifier, species_id="9606") -> pd.DataFrame: """ Get functional annotation from STRING for a gene identifier. @@ -30,7 +30,7 @@ def get_functional_annotation_STRING(identifier, species_id="9606") -> pd.DataFr return None -def get_functional_annotation_GProfiler(identifiers: List[str]) -> pd.DataFrame: +def _get_functional_annotation_gprofiler(identifiers: List[str]) -> pd.DataFrame: """ Get functional annotation from g:Profiler for a list of gene identifiers. @@ -62,15 +62,14 @@ def get_enrichment_data( Returns: pd.DataFrame: The enrichment data. """ - enrichment_data = {} assert tool in [ "gprofiler", "string", ], "Tool must be either 'gprofiler' or 'string'" if tool == "gprofiler": - enrichment_data = get_functional_annotation_GProfiler(difexpressed) + enrichment_data = _get_functional_annotation_gprofiler(difexpressed) else: - enrichment_data = get_functional_annotation_STRING( + enrichment_data = _get_functional_annotation_string( "%0d".join(difexpressed), organism_id ) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 8f0ca124..27062719 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -183,6 +183,7 @@ def execute_function( """ try: if function_name == "get_gene_function": + # TODO log whats going on return get_gene_function(**function_args) elif function_name == "get_enrichment_data": return get_enrichment_data(**function_args) From f55428da5e59f4f0a98856fc4ff04b325874816c Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:34:35 +0200 Subject: [PATCH 07/13] fix tests --- tests/test_gpt.py | 26 +------------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index d0dfad18..e8453ec5 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -1,42 +1,18 @@ -import unittest -import pandas as pd -import logging import logging -import numpy as np -import pandas as pd -import openai -import json -import plotly -from contextlib import contextmanager -# import dictdiffer import unittest from unittest.mock import patch, MagicMock -# from pandas.api.types import is_object_dtype, is_numeric_dtype, is_bool_dtype import streamlit as st +from alphastats.gui.utils.uniprot_utils import get_uniprot_data, extract_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader from alphastats.DataSet import DataSet if "gene_to_prot_id" not in st.session_state: st.session_state["gene_to_prot_id"] = {} -from alphastats.gui.utils.gpt_helper import ( - get_assistant_functions, - display_proteins, - get_gene_function, - get_info, - get_subgroups_for_each_group, - turn_args_to_float, - perform_dimensionality_reduction, - wait_for_run_completion, - send_message_save_thread, - try_to_set_api_key, - get_uniprot_data, - extract_data, -) logger = logging.getLogger(__name__) From b3498656f544a5359192e1973f7e78c51d05ffe5 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:21:55 +0200 Subject: [PATCH 08/13] turn the normalization around again --- alphastats/DataSet_Preprocess.py | 36 +++++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index 305106a2..0003da93 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -183,35 +183,41 @@ def _linear_normalization(self, dataframe: pd.DataFrame): @ignore_warning(UserWarning) @ignore_warning(RuntimeWarning) - def _normalization(self, method: str): + def _normalization(self, method: str) -> None: + """Normalize across samples.""" + # TODO make both sample and protein normalization available if method == "zscore": scaler = sklearn.preprocessing.StandardScaler() - # normalized_array = scaler.fit_transform( - # self.mat.values.transpose() - # ).transpose() - normalized_array = scaler.fit_transform(self.mat.values) + # normalize samples => for preprocessing + normalized_array = scaler.fit_transform( + self.mat.values.transpose() + ).transpose() + # normalize proteins => for downstream processing + # normalized_array = scaler.fit_transform(self.mat.values) elif method == "quantile": qt = sklearn.preprocessing.QuantileTransformer(random_state=0) - # normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose() - normalized_array = qt.fit_transform(self.mat.values) + normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose() + # normalized_array = qt.fit_transform(self.mat.values) # normalize proteins elif method == "linear": - normalized_array = self._linear_normalization( - self.mat.transpose() - ).transpose() + normalized_array = self._linear_normalization(self.mat) + + # normalized_array = self._linear_normalization( + # self.mat.transpose() + # ).transpose() # normalize proteins elif method == "vst": minmax = sklearn.preprocessing.MinMaxScaler() scaler = sklearn.preprocessing.PowerTransformer() - # minmaxed_array = minmax.fit_transform(self.mat.values.transpose()) - # normalized_array = scaler.fit_transform(minmaxed_array).transpose() - minmaxed_array = minmax.fit_transform(self.mat.values) - normalized_array = scaler.fit_transform(minmaxed_array) + minmaxed_array = minmax.fit_transform(self.mat.values.transpose()) + normalized_array = scaler.fit_transform(minmaxed_array).transpose() + # minmaxed_array = minmax.fit_transform(self.mat.values) # normalize proteins + # normalized_array = scaler.fit_transform(minmaxed_array) # normalize proteins else: raise ValueError( - "Normalization method: {method} is invalid" + f"Normalization method: {method} is invalid. " "Choose from 'zscore', 'quantile', 'linear' normalization. or 'vst' for variance stabilization transformation" ) From 56d72a94379b38955347f142be328eab3d1cf35d Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:22:18 +0200 Subject: [PATCH 09/13] set log2 transform in tests explicitly to account for new default of method --- tests/test_DataSet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index d693d85e..51765554 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -762,7 +762,9 @@ def test_plot_samplehistograms(self): self.assertEqual(312, len(fig["data"])) def test_batch_correction(self): - self.obj.preprocess(subset=True, imputation="knn", normalization="linear") + self.obj.preprocess( + subset=True, imputation="knn", normalization="linear", log2_transform=True + ) self.obj.batch_correction(batch="batch_artifical_added") first_value = self.obj.mat.values[0, 0] self.assertAlmostEqual(-0.00555, first_value, places=3) @@ -865,7 +867,7 @@ def test_plot_dendrogram_not_imputed(self): self.obj.plot_dendrogram() def test_volcano_plot_anova(self): - self.obj.preprocess(imputation="knn") + self.obj.preprocess(imputation="knn", log2_transform=True) plot = self.obj.plot_volcano( column="grouping1", group1="Healthy", group2="Disease", method="anova" ) From 8c708acc0efe055d3692d7e921fd6f3af0a0973e Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:30:06 +0200 Subject: [PATCH 10/13] fix merge conflicts --- alphastats/gui/pages/05_LLM.py | 62 ++++++++++++---------------------- 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 201b2952..143b3786 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -3,48 +3,28 @@ import pandas as pd from openai import AuthenticationError +from alphastats.gui.utils.analysis_helper import ( + check_if_options_are_loaded, + display_figure, + save_plot_to_session_state, + gui_volcano_plot_differential_expression_analysis, + helper_compare_two_groups, +) +from alphastats.gui.utils.gpt_helper import ( + get_assistant_functions, + display_proteins, + get_subgroups_for_each_group, + get_general_assistant_functions, +) +from alphastats.gui.utils.openai_utils import ( + try_to_set_api_key, +) +from alphastats.gui.utils.ollama_utils import LLMIntegration from alphastats.gui.utils.options import interpretation_options +from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state -try: - from alphastats.gui.utils.analysis_helper import ( - check_if_options_are_loaded, - display_figure, - save_plot_to_session_state, - gui_volcano_plot_differential_expression_analysis, - helper_compare_two_groups, - ) - from alphastats.gui.utils.gpt_helper import ( - get_assistant_functions, - display_proteins, - get_subgroups_for_each_group, - get_general_assistant_functions, - ) - from alphastats.gui.utils.openai_utils import ( - try_to_set_api_key, - ) - from alphastats.gui.utils.ollama_utils import LLMIntegration - from alphastats.gui.utils.ui_helper import sidebar_info - -except ModuleNotFoundError: - from utils.analysis_helper import ( - check_if_options_are_loaded, - display_figure, - save_plot_to_session_state, - gui_volcano_plot_differential_expression_analysis, - helper_compare_two_groups, - ) - from utils.gpt_helper import ( - get_assistant_functions, - display_proteins, - get_subgroups_for_each_group, - get_general_assistant_functions, - ) - from utils.openai_utils import ( - try_to_set_api_key, - ) - from utils.ollama_utils import LLMIntegration - from utils.ui_helper import sidebar_info - +init_session_state() +sidebar_info() st.session_state.plot_dict = {} @@ -180,7 +160,7 @@ def select_analysis(): "plot_submitted_clicked" ] volcano_plot = gui_volcano_plot_differential_expression_analysis( - chosen_parameter_dict, user_session_id=st.session_state.user_session_id + chosen_parameter_dict ) volcano_plot._update(plotting_parameter_dict) volcano_plot._annotate_result_df() From 677377e54439762e96203b8aa2cb2d881c5164b9 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:35:42 +0200 Subject: [PATCH 11/13] comment out ArtifactManager code --- alphastats/gui/utils/ollama_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 27062719..e8481f7f 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -9,7 +9,7 @@ perform_dimensionality_reduction, get_general_assistant_functions, ) -from alphastats.gui.utils.artefacts import ArtifactManager +# from alphastats.gui.utils.artefacts import ArtifactManager from alphastats.gui.utils.uniprot_utils import get_gene_function from alphastats.gui.utils.enrichment_analysis import get_enrichment_data @@ -78,7 +78,7 @@ def __init__( self.metadata = metadata self.tools = self._get_tools() self.artifacts = {} - self.artifact_manager = ArtifactManager() + # self.artifact_manager = ArtifactManager() self.message_artifact_map = {} def set_api_key(self, api_key: str): From 164c1d67f8a2bc429f5f12bd0b7a4683b894d63b Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 11:51:11 +0200 Subject: [PATCH 12/13] fix pre-commit hooks --- pyproject.toml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..af671d70 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,24 @@ +[tool.ruff] +extend-exclude = [".bumpversion.cfg", ".secrets.baseline"] + +# [tool.ruff.lint] +# select = [ +# # pycodestyle +# "E", +# # Pyflakes +# "F", +# # pyupgrade +# "UP", +# # flake8-bugbear +# "B", +# # flake8-simplify +# "SIM", +# # isort +# "I", +# ] + +# ignore = [ +# "E501", # Line too long (ruff wraps code, but not docstrings) +# "B028", # No explicit `stacklevel` keyword argument found (for warnings) +# "B905" # This causes problems in numba code: `zip()` without an explicit `strict=` parameter +# ] From 47eb7c713d566972781585c2973a92984e98d10b Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 11:57:27 +0200 Subject: [PATCH 13/13] fix pre-commit hooks --- alphastats/gui/utils/ollama_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index e8481f7f..49572eb9 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -9,6 +9,7 @@ perform_dimensionality_reduction, get_general_assistant_functions, ) + # from alphastats.gui.utils.artefacts import ArtifactManager from alphastats.gui.utils.uniprot_utils import get_gene_function from alphastats.gui.utils.enrichment_analysis import get_enrichment_data