Skip to content

Commit

Permalink
Merge pull request #311 from MannLabs/ollama
Browse files Browse the repository at this point in the history
MVP: local models via ollama
  • Loading branch information
mschwoer authored Sep 17, 2024
2 parents f7cbcf9 + 47eb7c7 commit 856ef3d
Show file tree
Hide file tree
Showing 16 changed files with 1,231 additions and 703 deletions.
26 changes: 26 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[theme]

# Primary accent for interactive elements
primaryColor = '#005358'

# Background color for the main content area
backgroundColor = '#FFFFFF'

# Background color for sidebar and most interactive widgets
secondaryBackgroundColor = '#f2f2f2'

# Color used for almost all text
textColor = '#302E30'

# Font family for all text in the app, except code blocks
# Accepted values (serif | sans serif | monospace)
# Default: "sans serif"
font = "sans serif"

[server]
maxUploadSize = 500
enableXsrfProtection = false
enableCORS = false

[browser]
gatherUsageStats = true
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ alphastats gui
```
If you get an `AxiosError: Request failed with status code 403'` when uploading files, try running `DISABLE_XSRF=1 alphastats gui`.

If you want to use local Large Language Models to help interpret the data, you would need to download and install ollama (https://ollama.com/download).

AlphaStats can be imported as a Python package into any Python script or notebook with the command `import alphastats`.
A brief [Jupyter notebook tutorial](nbs/getting_started.ipynb) on how to use the API is also present in the [nbs folder](nbs).

Expand Down
18 changes: 15 additions & 3 deletions alphastats/DataSet_Preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,41 @@ def _linear_normalization(self, dataframe: pd.DataFrame):

@ignore_warning(UserWarning)
@ignore_warning(RuntimeWarning)
def _normalization(self, method: str):
def _normalization(self, method: str) -> None:
"""Normalize across samples."""
# TODO make both sample and protein normalization available
if method == "zscore":
scaler = sklearn.preprocessing.StandardScaler()
# normalize samples => for preprocessing
normalized_array = scaler.fit_transform(
self.mat.values.transpose()
).transpose()
# normalize proteins => for downstream processing
# normalized_array = scaler.fit_transform(self.mat.values)

elif method == "quantile":
qt = sklearn.preprocessing.QuantileTransformer(random_state=0)
normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose()
# normalized_array = qt.fit_transform(self.mat.values) # normalize proteins

elif method == "linear":
normalized_array = self._linear_normalization(self.mat)

# normalized_array = self._linear_normalization(
# self.mat.transpose()
# ).transpose() # normalize proteins

elif method == "vst":
minmax = sklearn.preprocessing.MinMaxScaler()
scaler = sklearn.preprocessing.PowerTransformer()
minmaxed_array = minmax.fit_transform(self.mat.values.transpose())
normalized_array = scaler.fit_transform(minmaxed_array).transpose()
# minmaxed_array = minmax.fit_transform(self.mat.values) # normalize proteins
# normalized_array = scaler.fit_transform(minmaxed_array) # normalize proteins

else:
raise ValueError(
"Normalization method: {method} is invalid"
f"Normalization method: {method} is invalid. "
"Choose from 'zscore', 'quantile', 'linear' normalization. or 'vst' for variance stabilization transformation"
)

Expand Down Expand Up @@ -277,7 +289,7 @@ def batch_correction(self, batch: str):
@ignore_warning(RuntimeWarning)
def preprocess(
self,
log2_transform: bool = True,
log2_transform: bool = False,
remove_contaminations: bool = False,
subset: bool = False,
data_completeness: float = 0,
Expand Down
2 changes: 2 additions & 0 deletions alphastats/gui/.streamlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ font = "sans serif"

[server]
maxUploadSize = 500
enableXsrfProtection = false
enableCORS = false

[browser]
gatherUsageStats = true
144 changes: 79 additions & 65 deletions alphastats/gui/pages/05_GPT.py → alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
import os
import streamlit as st
import pandas as pd
from openai import OpenAI, OpenAIError, AuthenticationError
from openai import AuthenticationError

from alphastats.gui.utils.analysis_helper import (
check_if_options_are_loaded,
display_df,
display_figure,
download_figure,
download_preprocessing_info,
get_analysis,
load_options,
save_plot_to_session_state,
gui_volcano_plot_differential_expression_analysis,
helper_compare_two_groups,
)
from alphastats.gui.utils.gpt_helper import (
get_assistant_functions,
display_proteins,
get_gene_function,
get_info,
get_subgroups_for_each_group,
turn_args_to_float,
perform_dimensionality_reduction,
wait_for_run_completion,
send_message_save_thread,
get_general_assistant_functions,
)
from alphastats.gui.utils.openai_utils import (
try_to_set_api_key,
)
from alphastats.gui.utils.ollama_utils import LLMIntegration
from alphastats.gui.utils.options import interpretation_options
from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state


init_session_state()
sidebar_info()
st.session_state.plot_dict = {}
Expand All @@ -43,8 +36,8 @@ def select_analysis():
"""
method = st.selectbox(
"Analysis",
options=["Volcano plot"],
# options=list(st.session_state.interpretation_options.keys()),
# options=["Volcano plot"],
options=list(interpretation_options(st.session_state).keys()),
)
return method

Expand All @@ -54,32 +47,34 @@ def select_analysis():
st.stop()


st.markdown("### GPT4 Analysis")
st.markdown("### LLM Analysis")

sidebar_info()


# set background to white so downloaded pngs dont have grey background
styl = f"""
styl = """
<style>
.css-jc5rf5 {{
.css-jc5rf5 {
position: absolute;
background: rgb(255, 255, 255);
color: rgb(48, 46, 48);
inset: 0px;
overflow: hidden;
}}
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)

# Initialize session state variables
if "llm_integration" not in st.session_state:
st.session_state["llm_integration"] = None
if "api_type" not in st.session_state:
st.session_state["api_type"] = "gpt"

if "plot_list" not in st.session_state:
st.session_state["plot_list"] = []


if "openai_model" not in st.session_state:
# st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
st.session_state["openai_model"] = "gpt-4-0125-preview" # "gpt-4-1106-preview"

if "messages" not in st.session_state:
st.session_state["messages"] = []

Expand All @@ -101,14 +96,17 @@ def select_analysis():
with c1:
method = select_analysis()
chosen_parameter_dict = helper_compare_two_groups()
api_key = st.text_input("API Key", type="password")

try_to_set_api_key(api_key)
st.session_state["api_type"] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
index=0 if st.session_state["api_type"] == "gpt4o" else 1,
)
base_url = "http://localhost:11434/v1"
if st.session_state["api_type"] == "gpt4o":
api_key = st.text_input("Enter OpenAI API Key", type="password")
try_to_set_api_key(api_key)

try:
client = OpenAI(api_key=st.secrets["openai_api_key"])
except OpenAIError:
pass
method = st.selectbox(
"Differential Analysis using:",
options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"],
Expand Down Expand Up @@ -240,9 +238,10 @@ def select_analysis():
"A user will present you with data regarding proteins upregulated in certain cells "
"sourced from UniProt and abstracts from scientific publications. They seek your "
"expertise in understanding the connections between these proteins and their potential role "
f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. "
f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable."
f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state.dataset.metadata))}."
"Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that."
"Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to"
" you from a function has references to the literature (for example, PubMed), always quote the references in your response."
)
if "column" in chosen_parameter_dict and "upregulated" in st.session_state:
st.session_state["user_prompt"] = (
Expand Down Expand Up @@ -280,31 +279,53 @@ def select_analysis():
st.session_state["gpt_submitted_clicked"]
> st.session_state["gpt_submitted_counter"]
):
try_to_set_api_key()

client = OpenAI(api_key=st.secrets["openai_api_key"])
if st.session_state["api_type"] == "gpt4o":
try_to_set_api_key()

try:
st.session_state["assistant"] = client.beta.assistants.create(
instructions=st.session_state["instructions"],
name="Proteomics interpreter",
model=st.session_state["openai_model"],
tools=get_assistant_functions(
gene_to_prot_id_dict=st.session_state["gene_to_prot_id"],
if st.session_state["api_type"] == "gpt4o":
st.session_state["llm_integration"] = LLMIntegration(
api_type="gpt",
api_key=st.secrets["openai_api_key"],
dataset=st.session_state["dataset"],
metadata=st.session_state["dataset"].metadata,
)
else:
st.session_state["llm_integration"] = LLMIntegration(
api_type="ollama",
base_url=base_url,
dataset=st.session_state["dataset"],
metadata=st.session_state["dataset"].metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
st.session_state["dataset"].metadata
),
),
)
st.success(
f"{st.session_state['api_type'].upper()} integration initialized successfully!"
)
except AuthenticationError:
st.warning(
"Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX"
)
st.stop()

if "artefact_enum_dict" not in st.session_state:
st.session_state["artefact_enum_dict"] = {}
if "llm_integration" not in st.session_state or not st.session_state["llm_integration"]:
st.warning("Please initialize the model first")
st.stop()

llm = st.session_state["llm_integration"]

# Set instructions and update tools
llm.tools = [
*get_general_assistant_functions(),
*get_assistant_functions(
gene_to_prot_id_dict=st.session_state["gene_to_prot_id"],
metadata=st.session_state["dataset"].metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
st.session_state["dataset"].metadata
),
),
]

if "artifacts" not in st.session_state:
st.session_state["artifacts"] = {}

if (
st.session_state["gpt_submitted_counter"]
Expand All @@ -313,30 +334,23 @@ def select_analysis():
st.session_state["gpt_submitted_counter"] = st.session_state[
"gpt_submitted_clicked"
]
st.session_state["artefact_enum_dict"] = {}
thread = client.beta.threads.create()
st.session_state["thread_id"] = thread.id
artefacts = send_message_save_thread(client, st.session_state["user_prompt"])
if artefacts:
st.session_state["artefact_enum_dict"][len(st.session_state.messages) - 1] = (
artefacts
)
st.session_state["artifacts"] = {}
llm.messages = [{"role": "system", "content": st.session_state["instructions"]}]
response = llm.chat_completion(st.session_state["user_prompt"])

if st.session_state["gpt_submitted_clicked"] > 0:
if prompt := st.chat_input("Say something"):
st.session_state.messages.append({"role": "user", "content": prompt})
artefacts = send_message_save_thread(client, prompt)
if artefacts:
st.session_state["artefact_enum_dict"][
len(st.session_state.messages) - 1
] = artefacts
response = llm.chat_completion(prompt)
for num, role_content_dict in enumerate(st.session_state.messages):
if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system":
continue
if "tool_calls" in role_content_dict:
continue
with st.chat_message(role_content_dict["role"]):
st.markdown(role_content_dict["content"])
if num in st.session_state["artefact_enum_dict"]:
for artefact in st.session_state["artefact_enum_dict"][num]:
if num in st.session_state["artifacts"]:
for artefact in st.session_state["artifacts"][num]:
if isinstance(artefact, pd.DataFrame):
st.dataframe(artefact)
else:
elif "plotly" in str(type(artefact)):
st.plotly_chart(artefact)
print(st.session_state["artefact_enum_dict"])
Empty file.
Loading

0 comments on commit 856ef3d

Please sign in to comment.