Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamline the LLM workflow a bit #330

Merged
merged 10 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"filename": "alphastats/gui/utils/ollama_utils.py",
"hashed_secret": "8ed4322e8e2790b8c928d381ce8d07cfd966e909",
"is_verified": false,
"line_number": 68,
"line_number": 69,
"is_secret": false
}
],
Expand All @@ -160,5 +160,5 @@
}
]
},
"generated_at": "2024-09-12T14:19:09Z"
"generated_at": "2024-09-18T07:09:01Z"
}
35 changes: 20 additions & 15 deletions alphastats/gui/pages/05_LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
get_general_assistant_functions,
)
from alphastats.gui.utils.openai_utils import (
try_to_set_api_key,
set_api_key,
)
from alphastats.gui.utils.ollama_utils import LLMIntegration
from alphastats.gui.utils.options import interpretation_options
Expand Down Expand Up @@ -50,6 +50,7 @@ def select_analysis():
st.markdown("### LLM Analysis")

sidebar_info()
init_session_state()


# set background to white so downloaded pngs dont have grey background
Expand Down Expand Up @@ -90,23 +91,27 @@ def select_analysis():
st.session_state["gpt_submitted_clicked"] = 0
st.session_state["gpt_submitted_counter"] = 0

c1, c2 = st.columns((1, 2))

st.markdown("#### Configure LLM")
mschwoer marked this conversation as resolved.
Show resolved Hide resolved
st.session_state["api_type"] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
index=0 if st.session_state["api_type"] == "gpt4o" else 1,
)
base_url = "http://localhost:11434/v1"
api_key = None
if st.session_state["api_type"] == "gpt4o":
api_key = st.text_input("Enter OpenAI API Key", type="password")
set_api_key(api_key)


c1, c2 = st.columns((1, 2))

with c1:
st.markdown("#### Analysis")
method = select_analysis()
chosen_parameter_dict = helper_compare_two_groups()

st.session_state["api_type"] = st.selectbox(
"Select LLM",
["gpt4o", "llama3.1 70b"],
index=0 if st.session_state["api_type"] == "gpt4o" else 1,
)
base_url = "http://localhost:11434/v1"
if st.session_state["api_type"] == "gpt4o":
api_key = st.text_input("Enter OpenAI API Key", type="password")
try_to_set_api_key(api_key)

method = st.selectbox(
"Differential Analysis using:",
options=["ttest", "anova", "wald", "sam", "paired-ttest", "welch-ttest"],
Expand Down Expand Up @@ -265,7 +270,7 @@ def select_analysis():
"", value=st.session_state["user_prompt"], height=200
)

gpt_submitted = st.button("Run GPT analysis")
gpt_submitted = st.button("Run LLM analysis")

if gpt_submitted and "user_prompt" not in st.session_state:
st.warning("Please enter a user prompt first")
Expand All @@ -280,13 +285,13 @@ def select_analysis():
> st.session_state["gpt_submitted_counter"]
):
if st.session_state["api_type"] == "gpt4o":
try_to_set_api_key()
set_api_key()

try:
if st.session_state["api_type"] == "gpt4o":
st.session_state["llm_integration"] = LLMIntegration(
api_type="gpt",
api_key=st.secrets["openai_api_key"],
api_key=st.session_state["openai_api_key"],
dataset=st.session_state["dataset"],
metadata=st.session_state["dataset"].metadata,
)
Expand Down
5 changes: 3 additions & 2 deletions alphastats/gui/utils/analysis_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,13 @@ def helper_compare_two_groups():
"""

chosen_parameter_dict = {}
default_option = "<select>"
group = st.selectbox(
"Grouping variable",
options=["< None >"] + st.session_state.dataset.metadata.columns.to_list(),
options=[default_option] + st.session_state.dataset.metadata.columns.to_list(),
)

if group != "< None >":
if group != default_option:
unique_values = get_unique_values_from_column(column=group)

group1 = st.selectbox("Group 1", options=unique_values)
Expand Down
49 changes: 29 additions & 20 deletions alphastats/gui/utils/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,28 +164,37 @@ def send_message_save_thread(
return plots


def try_to_set_api_key(api_key: str = None) -> None:
"""
Checks if the OpenAI API key is available in the environment / system variables.
If the API key is not available, saves the key to secrets.toml in the repository root directory.
def set_api_key(api_key: str = None) -> None:
"""Put the OpenAI API key in the session state.

If provided, use the `api_key`.
If not, take the key from the secrets.toml file.
Show a message if the file is not found.

Args:
api_key (str, optional): The OpenAI API key. Defaults to None.

Returns:
None
"""
if api_key and "api_key" not in st.session_state:
if api_key:
st.session_state["openai_api_key"] = api_key
secret_path = Path(st.secrets._file_paths[-1])
secret_path.parent.mkdir(parents=True, exist_ok=True)
with open(secret_path, "w") as f:
f.write(f'openai_api_key = "{api_key}"')
openai.OpenAI.api_key = api_key
return
try:
openai.OpenAI.api_key = st.secrets["openai_api_key"]
except KeyError:
st.write(
"OpenAI API key not found in environment variables. Please enter your API key to continue."
)
# TODO we should not write secrets to disk without user consent
# secret_path = Path("./.streamlit/secrets.toml")
# secret_path.parent.mkdir(parents=True, exist_ok=True)
# with open(secret_path, "w") as f:
# f.write(f'openai_api_key = "{api_key}"')
# openai.OpenAI.api_key = st.session_state["openai_api_key"]
# return
else:
try:
api_key = st.secrets["openai_api_key"]
except FileNotFoundError:
st.info(
"Please enter an OpenAI key or provide it in a secrets.toml file in the "
"alphastats/gui/.streamlit directory like "
"`openai_api_key = <key>`"
)
except KeyError:
st.write("OpenAI API key not found in secrets.")
except Exception as e:
st.write(f"Error loading OpenAI API key: {e}.")

openai.OpenAI.api_key = api_key
Loading