Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MVP: local models via ollama #311

Merged
merged 16 commits into from
Sep 17, 2024
26 changes: 26 additions & 0 deletions .streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[theme]

# Primary accent for interactive elements
primaryColor = '#005358'

# Background color for the main content area
backgroundColor = '#FFFFFF'

# Background color for sidebar and most interactive widgets
secondaryBackgroundColor = '#f2f2f2'

# Color used for almost all text
textColor = '#302E30'

# Font family for all text in the app, except code blocks
# Accepted values (serif | sans serif | monospace)
# Default: "sans serif"
font = "sans serif"

[server]
maxUploadSize = 500
enableXsrfProtection = false
enableCORS = false

[browser]
gatherUsageStats = true
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ alphastats gui
```
If you get an `AxiosError: Request failed with status code 403'` when uploading files, try running `DISABLE_XSRF=1 alphastats gui`.

If you want to use local Large Language Models to help interpret the data, you would need to download and install ollama (https://ollama.com/download).

AlphaStats can be imported as a Python package into any Python script or notebook with the command `import alphastats`.
A brief [Jupyter notebook tutorial](nbs/getting_started.ipynb) on how to use the API is also present in the [nbs folder](nbs).

Expand Down
18 changes: 15 additions & 3 deletions alphastats/DataSet_Preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,29 +183,41 @@ def _linear_normalization(self, dataframe: pd.DataFrame):

@ignore_warning(UserWarning)
@ignore_warning(RuntimeWarning)
def _normalization(self, method: str):
def _normalization(self, method: str) -> None:
"""Normalize across samples."""
# TODO make both sample and protein normalization available
if method == "zscore":
scaler = sklearn.preprocessing.StandardScaler()
# normalize samples => for preprocessing
normalized_array = scaler.fit_transform(
self.mat.values.transpose()
).transpose()
# normalize proteins => for downstream processing
JuliaS92 marked this conversation as resolved.
Show resolved Hide resolved
# normalized_array = scaler.fit_transform(self.mat.values)

elif method == "quantile":
qt = sklearn.preprocessing.QuantileTransformer(random_state=0)
normalized_array = qt.fit_transform(self.mat.values.transpose()).transpose()
# normalized_array = qt.fit_transform(self.mat.values) # normalize proteins

elif method == "linear":
normalized_array = self._linear_normalization(self.mat)

# normalized_array = self._linear_normalization(
# self.mat.transpose()
# ).transpose() # normalize proteins

elif method == "vst":
minmax = sklearn.preprocessing.MinMaxScaler()
scaler = sklearn.preprocessing.PowerTransformer()
minmaxed_array = minmax.fit_transform(self.mat.values.transpose())
normalized_array = scaler.fit_transform(minmaxed_array).transpose()
# minmaxed_array = minmax.fit_transform(self.mat.values) # normalize proteins
# normalized_array = scaler.fit_transform(minmaxed_array) # normalize proteins

else:
raise ValueError(
"Normalization method: {method} is invalid"
f"Normalization method: {method} is invalid. "
"Choose from 'zscore', 'quantile', 'linear' normalization. or 'vst' for variance stabilization transformation"
)

Expand Down Expand Up @@ -277,7 +289,7 @@ def batch_correction(self, batch: str):
@ignore_warning(RuntimeWarning)
def preprocess(
self,
log2_transform: bool = True,
log2_transform: bool = False,
remove_contaminations: bool = False,
subset: bool = False,
data_completeness: float = 0,
Expand Down
2 changes: 2 additions & 0 deletions alphastats/gui/.streamlit/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ font = "sans serif"

[server]
maxUploadSize = 500
enableXsrfProtection = false
enableCORS = false

[browser]
gatherUsageStats = true
144 changes: 79 additions & 65 deletions alphastats/gui/pages/05_GPT.py → alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,28 @@
import os
import streamlit as st
import pandas as pd
from openai import OpenAI, OpenAIError, AuthenticationError
from openai import AuthenticationError

from alphastats.gui.utils.analysis_helper import (
check_if_options_are_loaded,
display_df,
display_figure,
download_figure,
download_preprocessing_info,
get_analysis,
load_options,
save_plot_to_session_state,
gui_volcano_plot_differential_expression_analysis,
helper_compare_two_groups,
)
from alphastats.gui.utils.gpt_helper import (
get_assistant_functions,
display_proteins,
get_gene_function,
get_info,
get_subgroups_for_each_group,
turn_args_to_float,
perform_dimensionality_reduction,
wait_for_run_completion,
send_message_save_thread,
get_general_assistant_functions,
)
from alphastats.gui.utils.openai_utils import (
try_to_set_api_key,
)
from alphastats.gui.utils.ollama_utils import LLMIntegration
from alphastats.gui.utils.options import interpretation_options
from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state


init_session_state()
sidebar_info()
st.session_state.plot_dict = {}
Expand All @@ -43,8 +36,8 @@ def select_analysis():
"""
method = st.selectbox(
"Analysis",
options=["Volcano plot"],
# options=list(st.session_state.interpretation_options.keys()),
# options=["Volcano plot"],
options=list(interpretation_options(st.session_state).keys()),
)
return method

Expand All @@ -54,32 +47,34 @@ def select_analysis():
st.stop()


st.markdown("### GPT4 Analysis")
st.markdown("### LLM Analysis")

sidebar_info()


# set background to white so downloaded pngs dont have grey background
styl = f"""
styl = """
<style>
.css-jc5rf5 {{
.css-jc5rf5 {
position: absolute;
background: rgb(255, 255, 255);
color: rgb(48, 46, 48);
inset: 0px;
overflow: hidden;
}}
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)

# Initialize session state variables
if "llm_integration" not in st.session_state:
st.session_state["llm_integration"] = None
if "api_type" not in st.session_state:
st.session_state["api_type"] = "gpt"

if "plot_list" not in st.session_state:
st.session_state["plot_list"] = []


if "openai_model" not in st.session_state:
# st.session_state["openai_model"] = "gpt-3.5-turbo-16k"
st.session_state["openai_model"] = "gpt-4-0125-preview" # "gpt-4-1106-preview"

if "messages" not in st.session_state:
st.session_state["messages"] = []

Expand All @@ -101,14 +96,17 @@ def select_analysis():
with c1:
method = select_analysis()
chosen_parameter_dict = helper_compare_two_groups()
api_key = st.text_input("API Key", type="password")

try_to_set_api_key(api_key)
st.session_state["api_type"] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
index=0 if st.session_state["api_type"] == "gpt4o" else 1,
)
base_url = "http://localhost:11434/v1"
if st.session_state["api_type"] == "gpt4o":
api_key = st.text_input("Enter OpenAI API Key", type="password")
try_to_set_api_key(api_key)

try:
client = OpenAI(api_key=st.secrets["openai_api_key"])
except OpenAIError:
pass
method = st.selectbox(
"Differential Analysis using:",
options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"],
Expand Down Expand Up @@ -240,9 +238,10 @@ def select_analysis():
"A user will present you with data regarding proteins upregulated in certain cells "
"sourced from UniProt and abstracts from scientific publications. They seek your "
"expertise in understanding the connections between these proteins and their potential role "
f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. "
f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable."
f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state.dataset.metadata))}."
"Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that."
"Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to"
" you from a function has references to the literature (for example, PubMed), always quote the references in your response."
)
if "column" in chosen_parameter_dict and "upregulated" in st.session_state:
st.session_state["user_prompt"] = (
Expand Down Expand Up @@ -280,31 +279,53 @@ def select_analysis():
st.session_state["gpt_submitted_clicked"]
> st.session_state["gpt_submitted_counter"]
):
try_to_set_api_key()

client = OpenAI(api_key=st.secrets["openai_api_key"])
if st.session_state["api_type"] == "gpt4o":
try_to_set_api_key()

try:
st.session_state["assistant"] = client.beta.assistants.create(
instructions=st.session_state["instructions"],
name="Proteomics interpreter",
model=st.session_state["openai_model"],
tools=get_assistant_functions(
gene_to_prot_id_dict=st.session_state["gene_to_prot_id"],
if st.session_state["api_type"] == "gpt4o":
st.session_state["llm_integration"] = LLMIntegration(
api_type="gpt",
api_key=st.secrets["openai_api_key"],
dataset=st.session_state["dataset"],
metadata=st.session_state["dataset"].metadata,
)
else:
st.session_state["llm_integration"] = LLMIntegration(
api_type="ollama",
base_url=base_url,
dataset=st.session_state["dataset"],
metadata=st.session_state["dataset"].metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
st.session_state["dataset"].metadata
),
),
)
st.success(
f"{st.session_state['api_type'].upper()} integration initialized successfully!"
)
except AuthenticationError:
st.warning(
"Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX"
)
st.stop()

if "artefact_enum_dict" not in st.session_state:
st.session_state["artefact_enum_dict"] = {}
if "llm_integration" not in st.session_state or not st.session_state["llm_integration"]:
st.warning("Please initialize the model first")
st.stop()

llm = st.session_state["llm_integration"]

# Set instructions and update tools
llm.tools = [
*get_general_assistant_functions(),
*get_assistant_functions(
gene_to_prot_id_dict=st.session_state["gene_to_prot_id"],
metadata=st.session_state["dataset"].metadata,
subgroups_for_each_group=get_subgroups_for_each_group(
st.session_state["dataset"].metadata
),
),
]

if "artifacts" not in st.session_state:
st.session_state["artifacts"] = {}

if (
st.session_state["gpt_submitted_counter"]
Expand All @@ -313,30 +334,23 @@ def select_analysis():
st.session_state["gpt_submitted_counter"] = st.session_state[
"gpt_submitted_clicked"
]
st.session_state["artefact_enum_dict"] = {}
thread = client.beta.threads.create()
st.session_state["thread_id"] = thread.id
artefacts = send_message_save_thread(client, st.session_state["user_prompt"])
if artefacts:
st.session_state["artefact_enum_dict"][len(st.session_state.messages) - 1] = (
artefacts
)
st.session_state["artifacts"] = {}
llm.messages = [{"role": "system", "content": st.session_state["instructions"]}]
response = llm.chat_completion(st.session_state["user_prompt"])

if st.session_state["gpt_submitted_clicked"] > 0:
if prompt := st.chat_input("Say something"):
st.session_state.messages.append({"role": "user", "content": prompt})
artefacts = send_message_save_thread(client, prompt)
if artefacts:
st.session_state["artefact_enum_dict"][
len(st.session_state.messages) - 1
] = artefacts
response = llm.chat_completion(prompt)
for num, role_content_dict in enumerate(st.session_state.messages):
if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system":
continue
if "tool_calls" in role_content_dict:
continue
with st.chat_message(role_content_dict["role"]):
st.markdown(role_content_dict["content"])
if num in st.session_state["artefact_enum_dict"]:
for artefact in st.session_state["artefact_enum_dict"][num]:
if num in st.session_state["artifacts"]:
for artefact in st.session_state["artifacts"][num]:
if isinstance(artefact, pd.DataFrame):
st.dataframe(artefact)
else:
elif "plotly" in str(type(artefact)):
st.plotly_chart(artefact)
print(st.session_state["artefact_enum_dict"])
Empty file.
Loading
Loading