From 222e5e4b64abb8e5cac705231315d23f3674031f Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:38:05 +0200 Subject: [PATCH 01/13] first bunch of session state replacements --- alphastats/gui/pages/02_Import Data.py | 5 ++-- alphastats/gui/pages/03_Data Overview.py | 4 +-- alphastats/gui/pages/03_Preprocessing.py | 16 +++++++---- alphastats/gui/pages/04_Analysis.py | 14 ++++++---- alphastats/gui/pages/05_LLM.py | 26 +++++++++--------- alphastats/gui/utils/analysis_helper.py | 29 +++++++++++--------- alphastats/gui/utils/gpt_helper.py | 4 +-- alphastats/gui/utils/import_helper.py | 5 ++-- alphastats/gui/utils/openai_utils.py | 10 ++++--- alphastats/gui/utils/overview_helper.py | 22 ++++++++------- alphastats/gui/utils/ui_helper.py | 35 +++++++++++++++++++++--- alphastats/gui/utils/uniprot_utils.py | 3 +- tests/gui/test_02_import_data.py | 25 +++++++++-------- tests/gui/test_03_data_overview.py | 3 +- tests/gui/test_04_preprocessing.py | 6 ++-- tests/test_DataSet.py | 4 +-- tests/test_gpt.py | 3 +- 17 files changed, 131 insertions(+), 83 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index bfa32263..e255f560 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -18,6 +18,7 @@ sidebar_info, empty_session_state, init_session_state, + StateKeys, ) @@ -27,11 +28,11 @@ def _finalize_data_loading( dataset: DataSet, ) -> None: """Finalize the data loading process.""" - st.session_state["loader"] = ( + st.session_state[StateKeys.LOADER] = ( loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. ) st.session_state["metadata_columns"] = metadata_columns - st.session_state["dataset"] = dataset + st.session_state[StateKeys.DATASET] = dataset load_options() sidebar_info() diff --git a/alphastats/gui/pages/03_Data Overview.py b/alphastats/gui/pages/03_Data Overview.py index 9340768d..241bce87 100644 --- a/alphastats/gui/pages/03_Data Overview.py +++ b/alphastats/gui/pages/03_Data Overview.py @@ -7,7 +7,7 @@ get_intensity_distribution_unprocessed, display_loaded_dataset, ) -from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state +from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state, StateKeys init_session_state() sidebar_info() @@ -18,7 +18,7 @@ st.markdown("### DataSet Info") -display_loaded_dataset(st.session_state["dataset"]) +display_loaded_dataset(st.session_state[StateKeys.DATASET]) st.markdown("## DataSet overview") diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/03_Preprocessing.py index e5cb3e41..8d0228d0 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/03_Preprocessing.py @@ -11,7 +11,7 @@ reset_preprocessing, PREPROCESSING_STEPS, ) -from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state +from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state, StateKeys init_session_state() sidebar_info() @@ -28,7 +28,7 @@ with c2: if "dataset" in st.session_state: - settings = configure_preprocessing(dataset=st.session_state["dataset"]) + settings = configure_preprocessing(dataset=st.session_state[StateKeys.DATASET]) new_workflow = update_workflow(settings) if new_workflow != st.session_state.workflow: st.session_state.workflow = new_workflow @@ -44,12 +44,16 @@ else: c11, c12 = st.columns([1, 1]) if c11.button("Run preprocessing", key="_run_preprocessing"): - run_preprocessing(settings, st.session_state["dataset"]) + run_preprocessing(settings, st.session_state[StateKeys.DATASET]) # TODO show more info about the preprocessing steps - display_preprocessing_info(st.session_state["dataset"].preprocessing_info) + display_preprocessing_info( + st.session_state[StateKeys.DATASET].preprocessing_info + ) if c12.button("Reset all Preprocessing steps", key="_reset_preprocessing"): - reset_preprocessing(st.session_state["dataset"]) - display_preprocessing_info(st.session_state["dataset"].preprocessing_info) + reset_preprocessing(st.session_state[StateKeys.DATASET]) + display_preprocessing_info( + st.session_state[StateKeys.DATASET].preprocessing_info + ) # TODO: Add comparison plot of intensity distribution before and after preprocessing diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index fcb5bb60..e186b6db 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -4,6 +4,7 @@ sidebar_info, init_session_state, convert_df, + StateKeys, ) from alphastats.gui.utils.analysis_helper import ( get_analysis, @@ -24,8 +25,8 @@ def select_analysis(): load_options() method = st.selectbox( "Analysis", - options=list(st.session_state.plotting_options.keys()) - + list(st.session_state.statistic_options.keys()), + options=list(st.session_state[StateKeys.PLOTTING_OPTIONS].keys()) + + list(st.session_state[StateKeys.STATISTIC_OPTIONS].keys()), ) return method @@ -64,15 +65,16 @@ def select_analysis(): with c1: method = select_analysis() - if method in st.session_state.plotting_options.keys(): + if method in st.session_state[StateKeys.PLOTTING_OPTIONS].keys(): analysis_result = get_analysis( - method=method, options_dict=st.session_state.plotting_options + method=method, options_dict=st.session_state[StateKeys.PLOTTING_OPTIONS] ) plot_to_display = True - elif method in st.session_state.statistic_options.keys(): + elif method in st.session_state[StateKeys.STATISTIC_OPTIONS].keys(): analysis_result = get_analysis( - method=method, options_dict=st.session_state.statistic_options + method=method, + options_dict=st.session_state[StateKeys.STATISTIC_OPTIONS], ) df_to_display = True diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 143b3786..021d6357 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -21,7 +21,7 @@ ) from alphastats.gui.utils.ollama_utils import LLMIntegration from alphastats.gui.utils.options import interpretation_options -from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state +from alphastats.gui.utils.ui_helper import sidebar_info, init_session_state, StateKeys init_session_state() sidebar_info() @@ -127,7 +127,7 @@ def select_analysis(): label="UniProt organism ID, for example human is 9606, R. norvegicus is 10116", value=9606, ) - st.session_state["organism"] = organism + st.session_state[StateKeys.ORGANISM] = organism min_fc = st.select_slider("Foldchange cutoff", range(0, 3), value=1) @@ -169,8 +169,8 @@ def select_analysis(): genes_of_interest_colored_df = volcano_plot.get_colored_labels_df() print(genes_of_interest_colored_df) - gene_names_colname = st.session_state["loader"].gene_names - prot_ids_colname = st.session_state["loader"].index_column + gene_names_colname = st.session_state[StateKeys.LOADER].gene_names + prot_ids_colname = st.session_state[StateKeys.LOADER].index_column st.session_state["prot_id_to_gene"] = dict( zip( @@ -178,7 +178,7 @@ def select_analysis(): genes_of_interest_colored_df[gene_names_colname].tolist(), ) ) - st.session_state["gene_to_prot_id"] = dict( + st.session_state[StateKeys.GENE_TO_PROT_ID] = dict( zip( genes_of_interest_colored_df[gene_names_colname].tolist(), genes_of_interest_colored_df[prot_ids_colname].tolist(), @@ -239,7 +239,7 @@ def select_analysis(): "sourced from UniProt and abstracts from scientific publications. They seek your " "expertise in understanding the connections between these proteins and their potential role " f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. Use formatting to make your response more human readable." - f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state.dataset.metadata))}." + f"The data you have has following groups and respective subgroups: {str(get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata))}." "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to" " you from a function has references to the literature (for example, PubMed), always quote the references in your response." ) @@ -287,15 +287,15 @@ def select_analysis(): st.session_state["llm_integration"] = LLMIntegration( api_type="gpt", api_key=st.secrets["openai_api_key"], - dataset=st.session_state["dataset"], - metadata=st.session_state["dataset"].metadata, + dataset=st.session_state[StateKeys.DATASET], + metadata=st.session_state[StateKeys.DATASET].metadata, ) else: st.session_state["llm_integration"] = LLMIntegration( api_type="ollama", base_url=base_url, - dataset=st.session_state["dataset"], - metadata=st.session_state["dataset"].metadata, + dataset=st.session_state[StateKeys.DATASET], + metadata=st.session_state[StateKeys.DATASET].metadata, ) st.success( f"{st.session_state['api_type'].upper()} integration initialized successfully!" @@ -316,10 +316,10 @@ def select_analysis(): llm.tools = [ *get_general_assistant_functions(), *get_assistant_functions( - gene_to_prot_id_dict=st.session_state["gene_to_prot_id"], - metadata=st.session_state["dataset"].metadata, + gene_to_prot_id_dict=st.session_state[StateKeys.GENE_TO_PROT_ID], + metadata=st.session_state[StateKeys.DATASET].metadata, subgroups_for_each_group=get_subgroups_for_each_group( - st.session_state["dataset"].metadata + st.session_state[StateKeys.DATASET].metadata ), ), ] diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index 078727a0..28fef901 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -2,7 +2,7 @@ import streamlit as st import io -from alphastats.gui.utils.ui_helper import convert_df +from alphastats.gui.utils.ui_helper import convert_df, StateKeys from alphastats.plots.VolcanoPlot import VolcanoPlot @@ -77,7 +77,9 @@ def download_preprocessing_info(plot): def get_unique_values_from_column(column): - unique_values = st.session_state.dataset.metadata[column].unique().tolist() + unique_values = ( + st.session_state[StateKeys.DATASET].metadata[column].unique().tolist() + ) return unique_values @@ -114,7 +116,7 @@ def gui_volcano_plot_differential_expression_analysis( initalize volcano plot object with differential expression analysis results """ volcano_plot = VolcanoPlot( - dataset=st.session_state.dataset, **chosen_parameter_dict, plot=False + dataset=st.session_state[StateKeys.DATASET], **chosen_parameter_dict, plot=False ) volcano_plot._perform_differential_expression_analysis() volcano_plot._add_hover_data_columns() @@ -200,7 +202,7 @@ def get_analysis_options_from_dict(method, options_dict): return st_plot_umap(method_dict) elif "settings" not in method_dict.keys(): - if st.session_state.dataset.mat.isna().values.any() == True: + if st.session_state[StateKeys.DATASET].mat.isna().values.any() == True: st.error( "Data contains missing values impute your data before plotting (Preprocessing - Imputation)." ) @@ -290,7 +292,8 @@ def helper_compare_two_groups(): chosen_parameter_dict = {} group = st.selectbox( "Grouping variable", - options=["< None >"] + st.session_state.dataset.metadata.columns.to_list(), + options=["< None >"] + + st.session_state[StateKeys.DATASET].metadata.columns.to_list(), ) if group != "< None >": @@ -312,18 +315,18 @@ def helper_compare_two_groups(): else: group1 = st.multiselect( "Group 1 samples:", - options=st.session_state.dataset.metadata[ - st.session_state.dataset.sample - ].to_list(), + options=st.session_state[StateKeys.DATASET] + .metadata[st.session_state[StateKeys.DATASET].sample] + .to_list(), ) group2 = st.multiselect( "Group 2 samples:", options=list( reversed( - st.session_state.dataset.metadata[ - st.session_state.dataset.sample - ].to_list() + st.session_state[StateKeys.DATASET] + .metadata[st.session_state[StateKeys.DATASET].sample] + .to_list() ) ), ) @@ -378,8 +381,8 @@ def load_options(): # interpretation_options, ) - st.session_state["plotting_options"] = plotting_options(st.session_state) - st.session_state["statistic_options"] = statistic_options(st.session_state) + st.session_state[StateKeys.PLOTTING_OPTIONS] = plotting_options(st.session_state) + st.session_state[StateKeys.STATISTIC_OPTIONS] = statistic_options(st.session_state) # st.session_state["interpretation_options"] = interpretation_options diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 8c09b010..42f5fe8c 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -9,7 +9,7 @@ import streamlit as st from alphastats.plots.DimensionalityReduction import DimensionalityReduction - +from gui.utils.ui_helper import StateKeys Entrez.email = "lebedev_mikhail@outlook.com" # Always provide your email address when using NCBI services. @@ -302,7 +302,7 @@ def get_assistant_functions( def perform_dimensionality_reduction(group, method, circle, **kwargs): dr = DimensionalityReduction( - st.session_state.dataset, group, method, circle, **kwargs + st.session_state[StateKeys.DATASET], group, method, circle, **kwargs ) return dr.plot diff --git a/alphastats/gui/utils/import_helper.py b/alphastats/gui/utils/import_helper.py index ab8437f1..61c11d05 100644 --- a/alphastats/gui/utils/import_helper.py +++ b/alphastats/gui/utils/import_helper.py @@ -11,14 +11,15 @@ from alphastats.DataSet import DataSet from alphastats.gui.utils.options import SOFTWARE_OPTIONS from alphastats.loader.MaxQuantLoader import MaxQuantLoader, BaseLoader +from gui.utils.ui_helper import StateKeys def load_options(): # TODO move import to top from alphastats.gui.utils.options import plotting_options, statistic_options - st.session_state["plotting_options"] = plotting_options(st.session_state) - st.session_state["statistic_options"] = statistic_options(st.session_state) + st.session_state[StateKeys.PLOTTING_OPTIONS] = plotting_options(st.session_state) + st.session_state[StateKeys.STATISTIC_OPTIONS] = statistic_options(st.session_state) def load_proteomics_data(uploaded_file, intensity_column, index_column, software): diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index d7f5edbd..3f7b5bdd 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -71,8 +71,8 @@ def wait_for_run_completion( print("requires_action", run_status) print( [ - st.session_state.plotting_options[i]["function"].__name__ - for i in st.session_state.plotting_options + st.session_state[StateKeys.PLOTTING_OPTIONS][i]["function"].__name__ + for i in st.session_state[StateKeys.PLOTTING_OPTIONS] ] ) tool_calls = run_status.required_action.submit_tool_outputs.tool_calls @@ -94,8 +94,10 @@ def wait_for_run_completion( elif ( tool_call.function.name in [ - st.session_state.plotting_options[i]["function"].__name__ - for i in st.session_state.plotting_options + st.session_state[StateKeys.PLOTTING_OPTIONS][i][ + "function" + ].__name__ + for i in st.session_state[StateKeys.PLOTTING_OPTIONS] ] or tool_call.function.name in assistant_functions ): diff --git a/alphastats/gui/utils/overview_helper.py b/alphastats/gui/utils/overview_helper.py index dc7a91d2..acdd025c 100644 --- a/alphastats/gui/utils/overview_helper.py +++ b/alphastats/gui/utils/overview_helper.py @@ -2,29 +2,29 @@ import pandas as pd from alphastats import DataSet -from alphastats.gui.utils.ui_helper import convert_df +from alphastats.gui.utils.ui_helper import convert_df, StateKeys # @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash def get_sample_histogram_matrix(): - return st.session_state.dataset.plot_samplehistograms() + return st.session_state[StateKeys.DATASET].plot_samplehistograms() # @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash def get_intensity_distribution_unprocessed(): - return st.session_state.dataset.plot_sampledistribution(use_raw=True) + return st.session_state[StateKeys.DATASET].plot_sampledistribution(use_raw=True) # @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash def get_intensity_distribution_processed(): - return st.session_state.dataset.plot_sampledistribution() + return st.session_state[StateKeys.DATASET].plot_sampledistribution() # @st.cache_data # TODO check if caching is sensible here and if so, reimplement with dataset-hash def get_display_matrix(): processed_df = pd.DataFrame( - st.session_state.dataset.mat.values, - index=st.session_state.dataset.mat.index.to_list(), + st.session_state[StateKeys.DATASET].mat.values, + index=st.session_state[StateKeys.DATASET].mat.index.to_list(), ).head(10) return processed_df @@ -33,18 +33,20 @@ def get_display_matrix(): def display_matrix(): text = ( "Normalization: " - + str(st.session_state.dataset.preprocessing_info["Normalization"]) + + str(st.session_state[StateKeys.DATASET].preprocessing_info["Normalization"]) + ", Imputation: " - + str(st.session_state.dataset.preprocessing_info["Imputation"]) + + str(st.session_state[StateKeys.DATASET].preprocessing_info["Imputation"]) + ", Log2-transformed: " - + str(st.session_state.dataset.preprocessing_info["Log2-transformed"]) + + str( + st.session_state[StateKeys.DATASET].preprocessing_info["Log2-transformed"] + ) ) st.markdown("**DataFrame used for analysis** *preview*") st.markdown(text) df = get_display_matrix() - csv = convert_df(st.session_state.dataset.mat) + csv = convert_df(st.session_state[StateKeys.DATASET].mat) st.dataframe(df) diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index ed9b9e94..f1ab6d98 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -40,7 +40,7 @@ def _display_sidebar_html_table(): if "dataset" not in st.session_state: return - preprocessing_dict = st.session_state.dataset.preprocessing_info + preprocessing_dict = st.session_state[StateKeys.DATASET].preprocessing_info html_string = ( "" @@ -80,10 +80,37 @@ def init_session_state() -> None: """Initialize the session state if not done yet.""" if "user_session_id" not in st.session_state: - st.session_state["user_session_id"] = str(uuid.uuid4()) + st.session_state[StateKeys.USER_SESSION_ID] = str(uuid.uuid4()) if "gene_to_prot_id" not in st.session_state: - st.session_state["gene_to_prot_id"] = {} + st.session_state[StateKeys.GENE_TO_PROT_ID] = {} if "organism" not in st.session_state: - st.session_state["organism"] = 9606 # human + st.session_state[StateKeys.ORGANISM] = 9606 # human + + +class StateKeys: + ## 02_Data Import + # on 1st run + ORGANISM = "organism" + GENE_TO_PROT_ID = "gene_to_prot_id" + USER_SESSION_ID = "user_session_id" + LOADER = "loader" + # on sample run (function load_sample_data), removed on new session click + DATASET = "dataset" # functions upload_metadatafile + PLOTTING_OPTIONS = "plotting_options" # function load_options + STATISTIC_OPTIONS = "statistic_options" # function load_options + # on metadata upload + SAMPLE_COLUMN = "sample_column" + # "workflow" + # "plot_list" + # "openai_model" + # + # "plot_submitted_clicked" + # "plot_submitted_counter" + # + # "lookup_submitted_clicked" + # "lookup_submitted_counter" + # + # "gpt_submitted_clicked" + # "gpt_submitted_counter" diff --git a/alphastats/gui/utils/uniprot_utils.py b/alphastats/gui/utils/uniprot_utils.py index 736f11ab..85636778 100644 --- a/alphastats/gui/utils/uniprot_utils.py +++ b/alphastats/gui/utils/uniprot_utils.py @@ -4,6 +4,7 @@ import streamlit as st +from gui.utils.ui_helper import StateKeys uniprot_fields = [ # Names & Taxonomy @@ -318,7 +319,7 @@ def get_gene_function(gene_name: Union[str, Dict], organism_id=9606) -> str: str: The gene function and description. """ if "organism" in st.session_state: - organism_id = st.session_state["organism"] + organism_id = st.session_state[StateKeys.ORGANISM] if isinstance(gene_name, dict): gene_name = gene_name["gene_name"] result = get_uniprot_data(gene_name, organism_id) diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index fb52ee82..51738777 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -2,7 +2,7 @@ from pathlib import Path from unittest.mock import MagicMock, patch from .conftest import APP_FOLDER, data_buf, metadata_buf - +from .utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/02_Import Data.py" @@ -14,9 +14,9 @@ def test_page_02_loads_without_input(): assert not at.exception - assert at.session_state.organism == 9606 - assert at.session_state.user_session_id is not None - assert at.session_state.gene_to_prot_id == {} + assert at.session_state[StateKeys.ORGANISM] == 9606 + assert at.session_state[StateKeys.USER_SESSION_ID] is not None + assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch("streamlit.file_uploader") @@ -27,9 +27,9 @@ def test_patched_page_02_loads_without_input(mock_file_uploader: MagicMock): assert not at.exception - assert at.session_state.organism == 9606 - assert at.session_state.user_session_id is not None - assert at.session_state.gene_to_prot_id == {} + assert at.session_state[StateKeys.ORGANISM] == 9606 + assert at.session_state[StateKeys.USER_SESSION_ID] is not None + assert at.session_state[StateKeys.GENE_TO_PROT_ID] == {} @patch( @@ -51,9 +51,12 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): "Drug therapy (procedure) (416608005)", "Lipid-lowering therapy (134350008)", ] - assert str(type(at.session_state.dataset)) == "" assert ( - str(type(at.session_state.loader)) + str(type(at.session_state[StateKeys.DATASET])) + == "" + ) + assert ( + str(type(at.session_state[StateKeys.LOADER])) == "" ) assert "plotting_options" in at.session_state @@ -107,7 +110,7 @@ def test_page_02_loads_maxquant_testfiles( assert not at.exception - dataset = at.session_state.dataset + dataset = at.session_state[StateKeys.DATASET] assert dataset.gene_names == "Gene names" assert dataset.index_column == "Protein IDs" assert dataset.intensity_column == "LFQ intensity [sample]" @@ -115,7 +118,7 @@ def test_page_02_loads_maxquant_testfiles( assert dataset.software == "MaxQuant" assert dataset.sample == "sample" assert ( - str(type(at.session_state.loader)) + str(type(at.session_state[StateKeys.LOADER])) == "" ) assert "plotting_options" in at.session_state diff --git a/tests/gui/test_03_data_overview.py b/tests/gui/test_03_data_overview.py index 109f4584..df20b9b1 100644 --- a/tests/gui/test_03_data_overview.py +++ b/tests/gui/test_03_data_overview.py @@ -2,6 +2,7 @@ from pathlib import Path from unittest.mock import MagicMock, patch from .conftest import create_dataset_alphapept, APP_FOLDER +from .utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/03_Data Overview.py" @@ -19,7 +20,7 @@ def test_page_03_loads_with_input(): at = AppTest(TESTED_PAGE, default_timeout=200) at.run() - at.session_state["dataset"] = create_dataset_alphapept() + at.session_state[StateKeys.DATASET] = create_dataset_alphapept() at.run() assert not at.exception diff --git a/tests/gui/test_04_preprocessing.py b/tests/gui/test_04_preprocessing.py index bbe6a78b..091eec85 100644 --- a/tests/gui/test_04_preprocessing.py +++ b/tests/gui/test_04_preprocessing.py @@ -2,7 +2,7 @@ from pathlib import Path from unittest.mock import MagicMock, patch from .conftest import create_dataset_alphapept, APP_FOLDER - +from .utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/03_Preprocessing.py" @@ -20,7 +20,7 @@ def test_page_04_loads_with_input(): at = AppTest(TESTED_PAGE, default_timeout=200) at.run() - at.session_state["dataset"] = create_dataset_alphapept() + at.session_state[StateKeys.DATASET] = create_dataset_alphapept() at.run() assert not at.exception @@ -33,7 +33,7 @@ def test_page_04_runs_preprocessreset_alphapept(): at = AppTest(TESTED_PAGE, default_timeout=200) at.run() - at.session_state["dataset"] = create_dataset_alphapept() + at.session_state[StateKeys.DATASET] = create_dataset_alphapept() at.run() at.button(key="_run_preprocessing").click() diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index 51765554..cc16aaac 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -20,7 +20,7 @@ from alphastats.DataSet_Statistics import Statistics from alphastats.utils import LoaderError - +from gui.utils.ui_helper import StateKeys logger = logging.getLogger(__name__) @@ -523,7 +523,7 @@ def test_plot_intenstity_subgroup(self): def test_plot_intensity_subgroup_gracefully_handle_one_group(self): import streamlit as st - st.session_state["gene_to_prot_id"] = {} + st.session_state[StateKeys.GENE_TO_PROT_ID] = {} plot = self.obj.plot_intensity( protein_id="K7ERI9;A0A024R0T8;P02654;K7EJI9;K7ELM9;K7EPF9;K7EKP1", group="disease", diff --git a/tests/test_gpt.py b/tests/test_gpt.py index e8453ec5..5a5988e3 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -9,9 +9,10 @@ from alphastats.gui.utils.uniprot_utils import get_uniprot_data, extract_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader from alphastats.DataSet import DataSet +from gui.utils.ui_helper import StateKeys if "gene_to_prot_id" not in st.session_state: - st.session_state["gene_to_prot_id"] = {} + st.session_state[StateKeys.GENE_TO_PROT_ID] = {} logger = logging.getLogger(__name__) From 583b1e915cf4d824bd592d17c887ce7af142c455 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:42:04 +0200 Subject: [PATCH 02/13] remove unused metadata_columns --- alphastats/gui/pages/02_Import Data.py | 6 ++---- alphastats/gui/utils/import_helper.py | 3 +-- tests/gui/test_02_import_data.py | 6 ------ tests/test_gpt.py | 1 - 4 files changed, 3 insertions(+), 13 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index e255f560..1a543d90 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -24,14 +24,12 @@ def _finalize_data_loading( loader: BaseLoader, - metadata_columns: List[str], dataset: DataSet, ) -> None: """Finalize the data loading process.""" st.session_state[StateKeys.LOADER] = ( loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. ) - st.session_state["metadata_columns"] = metadata_columns st.session_state[StateKeys.DATASET] = dataset load_options() @@ -58,9 +56,9 @@ def _finalize_data_loading( if c2.button("Start new Session with example DataSet", key="_load_example_data"): empty_session_state() init_session_state() - loader, metadata_columns, dataset = load_example_data() + loader, dataset = load_example_data() - _finalize_data_loading(loader, metadata_columns, dataset) + _finalize_data_loading(loader, dataset) st.stop() diff --git a/alphastats/gui/utils/import_helper.py b/alphastats/gui/utils/import_helper.py index 61c11d05..77fedacd 100644 --- a/alphastats/gui/utils/import_helper.py +++ b/alphastats/gui/utils/import_helper.py @@ -129,8 +129,7 @@ def load_example_data(): ] ] dataset.preprocess(subset=True) - metadata_columns = dataset.metadata.columns.to_list() - return loader, metadata_columns, dataset + return loader, dataset def _check_softwarefile_df(df: pd.DataFrame, software: str) -> None: diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 51738777..97eb3004 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -45,12 +45,6 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): assert not at.exception - assert at.session_state.metadata_columns == [ - "sample", - "disease", - "Drug therapy (procedure) (416608005)", - "Lipid-lowering therapy (134350008)", - ] assert ( str(type(at.session_state[StateKeys.DATASET])) == "" diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 5a5988e3..23107745 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -31,7 +31,6 @@ def setUp(self): self.matrix_dim = (312, 2596) self.matrix_dim_filtered = (312, 2397) self.comparison_column = "disease" - st.session_state.metadata_columns = [self.comparison_column] class TestGetUniProtData(unittest.TestCase): From 1be635f632061d74ba52a687e1483906016d28d7 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:03:09 +0200 Subject: [PATCH 03/13] rest of keys --- .secrets.baseline | 4 +- alphastats/gui/pages/03_Preprocessing.py | 10 +- alphastats/gui/pages/04_Analysis.py | 2 +- alphastats/gui/pages/05_LLM.py | 119 ++++++++++++----------- alphastats/gui/pages/06_Results.py | 3 +- alphastats/gui/utils/analysis_helper.py | 2 +- alphastats/gui/utils/ollama_utils.py | 5 +- alphastats/gui/utils/openai_utils.py | 16 +-- alphastats/gui/utils/options.py | 1 + alphastats/gui/utils/ui_helper.py | 39 +++++--- 10 files changed, 111 insertions(+), 90 deletions(-) diff --git a/.secrets.baseline b/.secrets.baseline index 854fbfdf..cdb49dcf 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -145,7 +145,7 @@ "filename": "alphastats/gui/utils/ollama_utils.py", "hashed_secret": "8ed4322e8e2790b8c928d381ce8d07cfd966e909", "is_verified": false, - "line_number": 68, + "line_number": 70, "is_secret": false } ], @@ -160,5 +160,5 @@ } ] }, - "generated_at": "2024-09-12T14:19:09Z" + "generated_at": "2024-09-17T12:01:59Z" } diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/03_Preprocessing.py index 8d0228d0..34f95c47 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/03_Preprocessing.py @@ -16,8 +16,8 @@ init_session_state() sidebar_info() -if "workflow" not in st.session_state: - st.session_state["workflow"] = [ +if StateKeys.WORKFLOW not in st.session_state: + st.session_state[StateKeys.WORKFLOW] = [ PREPROCESSING_STEPS.REMOVE_CONTAMINATIONS, PREPROCESSING_STEPS.SUBSET, PREPROCESSING_STEPS.LOG2_TRANSFORM, @@ -30,13 +30,13 @@ if "dataset" in st.session_state: settings = configure_preprocessing(dataset=st.session_state[StateKeys.DATASET]) new_workflow = update_workflow(settings) - if new_workflow != st.session_state.workflow: - st.session_state.workflow = new_workflow + if new_workflow != st.session_state[StateKeys.WORKFLOW]: + st.session_state[StateKeys.WORKFLOW] = new_workflow with c1: st.write("#### Flowchart of preprocessing workflow:") - selected_nodes = draw_workflow(st.session_state.workflow) + selected_nodes = draw_workflow(st.session_state[StateKeys.WORKFLOW]) if "dataset" not in st.session_state: st.info("Import data first to configure and run preprocessing") diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index e186b6db..6bf47d67 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -52,7 +52,7 @@ def select_analysis(): if "plot_list" not in st.session_state: - st.session_state["plot_list"] = [] + st.session_state[StateKeys.PLOT_LIST] = [] if "dataset" in st.session_state: diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 021d6357..92a70396 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -25,7 +25,7 @@ init_session_state() sidebar_info() -st.session_state.plot_dict = {} +st.session_state[StateKeys.PLOT_DICT] = {} @check_if_options_are_loaded @@ -68,27 +68,27 @@ def select_analysis(): # Initialize session state variables if "llm_integration" not in st.session_state: - st.session_state["llm_integration"] = None + st.session_state[StateKeys.LLM_INTEGRATION] = None if "api_type" not in st.session_state: - st.session_state["api_type"] = "gpt" + st.session_state[StateKeys.API_TYPE] = "gpt" if "plot_list" not in st.session_state: - st.session_state["plot_list"] = [] + st.session_state[StateKeys.PLOT_LIST] = [] if "messages" not in st.session_state: - st.session_state["messages"] = [] + st.session_state[StateKeys.MESSAGES] = [] if "plot_submitted_clicked" not in st.session_state: - st.session_state["plot_submitted_clicked"] = 0 - st.session_state["plot_submitted_counter"] = 0 + st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] = 0 + st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = 0 if "lookup_submitted_clicked" not in st.session_state: - st.session_state["lookup_submitted_clicked"] = 0 - st.session_state["lookup_submitted_counter"] = 0 + st.session_state[StateKeys.LOOKUP_SUBMITTED_CLICKED] = 0 + st.session_state[StateKeys.LOOKUP_SUBMITTED_COUNTER] = 0 if "gpt_submitted_clicked" not in st.session_state: - st.session_state["gpt_submitted_clicked"] = 0 - st.session_state["gpt_submitted_counter"] = 0 + st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] = 0 + st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = 0 c1, c2 = st.columns((1, 2)) @@ -97,13 +97,13 @@ def select_analysis(): method = select_analysis() chosen_parameter_dict = helper_compare_two_groups() - st.session_state["api_type"] = st.selectbox( + st.session_state[StateKeys.API_TYPE] = st.selectbox( "Select LLM", ["gpt4o", "llama3.1 70b"], - index=0 if st.session_state["api_type"] == "gpt4o" else 1, + index=0 if st.session_state[StateKeys.API_TYPE] == "gpt4o" else 1, ) base_url = "http://localhost:11434/v1" - if st.session_state["api_type"] == "gpt4o": + if st.session_state[StateKeys.API_TYPE] == "gpt4o": api_key = st.text_input("Enter OpenAI API Key", type="password") try_to_set_api_key(api_key) @@ -149,14 +149,14 @@ def select_analysis(): plot_submitted = st.button("Plot") if plot_submitted: - st.session_state["plot_submitted_clicked"] += 1 + st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] += 1 if ( - st.session_state["plot_submitted_counter"] - < st.session_state["plot_submitted_clicked"] + st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] + < st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] ): - st.session_state["plot_submitted_counter"] = st.session_state[ + st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = st.session_state[ "plot_submitted_clicked" ] volcano_plot = gui_volcano_plot_differential_expression_analysis( @@ -172,7 +172,7 @@ def select_analysis(): gene_names_colname = st.session_state[StateKeys.LOADER].gene_names prot_ids_colname = st.session_state[StateKeys.LOADER].index_column - st.session_state["prot_id_to_gene"] = dict( + st.session_state[StateKeys.PROT_ID_TO_GENE] = dict( zip( genes_of_interest_colored_df[prot_ids_colname].tolist(), genes_of_interest_colored_df[gene_names_colname].tolist(), @@ -194,14 +194,14 @@ def select_analysis(): print("genes_of_interest", genes_of_interest_colored) save_plot_to_session_state(volcano_plot, method) - st.session_state["genes_of_interest_colored"] = genes_of_interest_colored + st.session_state[StateKeys.GENES_OF_INTEREST_COLORED] = genes_of_interest_colored # st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism) - st.session_state["upregulated"] = [ + st.session_state[StateKeys.UPREGULATED] = [ key for key in genes_of_interest_colored if genes_of_interest_colored[key] == "up" ] - st.session_state["downregulated"] = [ + st.session_state[StateKeys.DOWNREGULATED] = [ key for key in genes_of_interest_colored if genes_of_interest_colored[key] == "down" @@ -210,30 +210,30 @@ def select_analysis(): c1, c2 = st.columns((1, 2), gap="medium") with c1: st.write("Upregulated genes") - display_proteins(st.session_state["upregulated"], []) + display_proteins(st.session_state[StateKeys.UPREGULATED], []) with c2: st.write("Downregulated genes") - display_proteins([], st.session_state["downregulated"]) + display_proteins([], st.session_state[StateKeys.DOWNREGULATED]) elif ( - st.session_state["plot_submitted_counter"] > 0 - and st.session_state["plot_submitted_counter"] - == st.session_state["plot_submitted_clicked"] - and len(st.session_state["plot_list"]) > 0 + st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] > 0 + and st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] + == st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] + and len(st.session_state[StateKeys.PLOT_LIST]) > 0 ): with c2: - display_figure(st.session_state["plot_list"][-1][1].plot) + display_figure(st.session_state[StateKeys.PLOT_LIST][-1][1].plot) st.subheader("Genes of interest") c1, c2 = st.columns((1, 2), gap="medium") with c1: st.write("Upregulated genes") - display_proteins(st.session_state["upregulated"], []) + display_proteins(st.session_state[StateKeys.UPREGULATED], []) with c2: st.write("Downregulated genes") - display_proteins([], st.session_state["downregulated"]) + display_proteins([], st.session_state[StateKeys.DOWNREGULATED]) -st.session_state["instructions"] = ( +st.session_state[StateKeys.INSTRUCTIONS] = ( f"You are an expert biologist and have extensive experience in molecular biology, medicine and biochemistry.{os.linesep}" "A user will present you with data regarding proteins upregulated in certain cells " "sourced from UniProt and abstracts from scientific publications. They seek your " @@ -244,7 +244,7 @@ def select_analysis(): " you from a function has references to the literature (for example, PubMed), always quote the references in your response." ) if "column" in chosen_parameter_dict and "upregulated" in st.session_state: - st.session_state["user_prompt"] = ( + st.session_state[StateKeys.USER_PROMPT] = ( f"We've recently identified several proteins that appear to be differently regulated in cells " f"when comparing {chosen_parameter_dict['group1']} and {chosen_parameter_dict['group2']} in the {chosen_parameter_dict['column']} group. " f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state['upregulated'])}.{os.linesep}{os.linesep}" @@ -256,13 +256,13 @@ def select_analysis(): if "user_prompt" in st.session_state: st.subheader("Automatically generated prompt based on gene functions:") with st.expander("Adjust system prompt (see example below)", expanded=False): - st.session_state["instructions"] = st.text_area( - "", value=st.session_state["instructions"], height=150 + st.session_state[StateKeys.INSTRUCTIONS] = st.text_area( + "", value=st.session_state[StateKeys.INSTRUCTIONS], height=150 ) with st.expander("Adjust user prompt", expanded=True): - st.session_state["user_prompt"] = st.text_area( - "", value=st.session_state["user_prompt"], height=200 + st.session_state[StateKeys.USER_PROMPT] = st.text_area( + "", value=st.session_state[StateKeys.USER_PROMPT], height=200 ) gpt_submitted = st.button("Run GPT analysis") @@ -272,26 +272,26 @@ def select_analysis(): st.stop() if gpt_submitted: - st.session_state["gpt_submitted_clicked"] += 1 + st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] += 1 # creating new assistant only once TODO: add a button to create new assistant if ( - st.session_state["gpt_submitted_clicked"] - > st.session_state["gpt_submitted_counter"] + st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] + > st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] ): - if st.session_state["api_type"] == "gpt4o": + if st.session_state[StateKeys.API_TYPE] == "gpt4o": try_to_set_api_key() try: - if st.session_state["api_type"] == "gpt4o": - st.session_state["llm_integration"] = LLMIntegration( + if st.session_state[StateKeys.API_TYPE] == "gpt4o": + st.session_state[StateKeys.LLM_INTEGRATION] = LLMIntegration( api_type="gpt", api_key=st.secrets["openai_api_key"], dataset=st.session_state[StateKeys.DATASET], metadata=st.session_state[StateKeys.DATASET].metadata, ) else: - st.session_state["llm_integration"] = LLMIntegration( + st.session_state[StateKeys.LLM_INTEGRATION] = LLMIntegration( api_type="ollama", base_url=base_url, dataset=st.session_state[StateKeys.DATASET], @@ -306,11 +306,14 @@ def select_analysis(): ) st.stop() -if "llm_integration" not in st.session_state or not st.session_state["llm_integration"]: +if ( + "llm_integration" not in st.session_state + or not st.session_state[StateKeys.LLM_INTEGRATION] +): st.warning("Please initialize the model first") st.stop() -llm = st.session_state["llm_integration"] +llm = st.session_state[StateKeys.LLM_INTEGRATION] # Set instructions and update tools llm.tools = [ @@ -325,31 +328,33 @@ def select_analysis(): ] if "artifacts" not in st.session_state: - st.session_state["artifacts"] = {} + st.session_state[StateKeys.ARTIFACTS] = {} if ( - st.session_state["gpt_submitted_counter"] - < st.session_state["gpt_submitted_clicked"] + st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] + < st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] ): - st.session_state["gpt_submitted_counter"] = st.session_state[ + st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = st.session_state[ "gpt_submitted_clicked" ] - st.session_state["artifacts"] = {} - llm.messages = [{"role": "system", "content": st.session_state["instructions"]}] - response = llm.chat_completion(st.session_state["user_prompt"]) + st.session_state[StateKeys.ARTIFACTS] = {} + llm.messages = [ + {"role": "system", "content": st.session_state[StateKeys.INSTRUCTIONS]} + ] + response = llm.chat_completion(st.session_state[StateKeys.USER_PROMPT]) -if st.session_state["gpt_submitted_clicked"] > 0: +if st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] > 0: if prompt := st.chat_input("Say something"): response = llm.chat_completion(prompt) - for num, role_content_dict in enumerate(st.session_state.messages): + for num, role_content_dict in enumerate(st.session_state[StateKeys.MESSAGES]): if role_content_dict["role"] == "tool" or role_content_dict["role"] == "system": continue if "tool_calls" in role_content_dict: continue with st.chat_message(role_content_dict["role"]): st.markdown(role_content_dict["content"]) - if num in st.session_state["artifacts"]: - for artefact in st.session_state["artifacts"][num]: + if num in st.session_state[StateKeys.ARTIFACTS]: + for artefact in st.session_state[StateKeys.ARTIFACTS][num]: if isinstance(artefact, pd.DataFrame): st.dataframe(artefact) elif "plotly" in str(type(artefact)): diff --git a/alphastats/gui/pages/06_Results.py b/alphastats/gui/pages/06_Results.py index 83a687b3..3606b11f 100644 --- a/alphastats/gui/pages/06_Results.py +++ b/alphastats/gui/pages/06_Results.py @@ -6,6 +6,7 @@ sidebar_info, init_session_state, convert_df, + StateKeys, ) @@ -45,7 +46,7 @@ def download_preprocessing_info(plot, name, count): if "plot_list" in st.session_state: - for count, plot in enumerate(st.session_state.plot_list): + for count, plot in enumerate(st.session_state[StateKeys.PLOT_LIST]): print("plot", type(plot), count) count = str(count) diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index 28fef901..c95782cc 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -32,7 +32,7 @@ def save_plot_to_session_state(plot, method): """ save plot with method to session state to retrieve old results """ - st.session_state["plot_list"] += [(method, plot)] + st.session_state[StateKeys.PLOT_LIST] += [(method, plot)] def display_df(df): diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index 49572eb9..fd778cca 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -13,6 +13,7 @@ # from alphastats.gui.utils.artefacts import ArtifactManager from alphastats.gui.utils.uniprot_utils import get_gene_function from alphastats.gui.utils.enrichment_analysis import get_enrichment_data +from gui.utils.ui_helper import StateKeys class LLMIntegration: @@ -137,8 +138,8 @@ def update_session_state(self): ------- None """ - st.session_state["messages"] = self.messages - st.session_state["artifacts"] = self.artifacts + st.session_state[StateKeys.MESSAGES] = self.messages + st.session_state[StateKeys.ARTIFACTS] = self.artifacts def parse_model_response(self, response: Any) -> Dict[str, Any]: """ diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index 3f7b5bdd..b5989f27 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -8,6 +8,8 @@ import openai import streamlit as st +from gui.utils.ui_helper import StateKeys + try: from alphastats.gui.utils.gpt_helper import ( turn_args_to_float, @@ -53,12 +55,12 @@ def wait_for_run_completion( "create_intensity_plot", "perform_dimensionality_reduction", "create_sample_histogram", - "st.session_state.dataset.plot_volcano", - "st.session_state.dataset.plot_sampledistribution", - "st.session_state.dataset.plot_intensity", - "st.session_state.dataset.plot_pca", - "st.session_state.dataset.plot_umap", - "st.session_state.dataset.plot_tsne", + "st.session_state[StateKeys.DATASET].plot_volcano", + "st.session_state[StateKeys.DATASET].plot_sampledistribution", + "st.session_state[StateKeys.DATASET].plot_intensity", + "st.session_state[StateKeys.DATASET].plot_pca", + "st.session_state[StateKeys.DATASET].plot_umap", + "st.session_state[StateKeys.DATASET].plot_tsne", "get_enrichment_data", } if run_status.status == "completed": @@ -178,7 +180,7 @@ def try_to_set_api_key(api_key: str = None) -> None: None """ if api_key and "api_key" not in st.session_state: - st.session_state["openai_api_key"] = api_key + st.session_state[StateKeys.OPENAI_API_KEY] = api_key secret_path = Path(st.secrets._file_paths[-1]) secret_path.parent.mkdir(parents=True, exist_ok=True) with open(secret_path, "w") as f: diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index 7142a101..91eda0a4 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -185,6 +185,7 @@ def statistic_options(state): } +# TODO unused def interpretation_options(state): return { "Volcano Plot": { diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index f1ab6d98..526299d1 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -100,17 +100,28 @@ class StateKeys: DATASET = "dataset" # functions upload_metadatafile PLOTTING_OPTIONS = "plotting_options" # function load_options STATISTIC_OPTIONS = "statistic_options" # function load_options - # on metadata upload - SAMPLE_COLUMN = "sample_column" - # "workflow" - # "plot_list" - # "openai_model" - # - # "plot_submitted_clicked" - # "plot_submitted_counter" - # - # "lookup_submitted_clicked" - # "lookup_submitted_counter" - # - # "gpt_submitted_clicked" - # "gpt_submitted_counter" + + WORKFLOW = "workflow" + + PLOT_LIST = "plot_list" + OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret + API_TYPE = "api_type" + LLM_INTEGRATION = "llm_integration" + + PLOT_SUBMITTED_CLICKED = "plot_submitted_clicked" + PLOT_SUBMITTED_COUNTER = "plot_submitted_counter" + + LOOKUP_SUBMITTED_CLICKED = "lookup_submitted_clicked" + LOOKUP_SUBMITTED_COUNTER = "lookup_submitted_counter" + + GPT_SUBMITTED_CLICKED = "gpt_submitted_clicked" + GPT_SUBMITTED_COUNTER = "gpt_submitted_counter" + + INSTRUCTIONS = "instructions" + USER_PROMPT = "user_prompt" + MESSAGES = "messages" + ARTIFACTS = "artifacts" + PROT_ID_TO_GENE = "prot_id_to_gene" + GENES_OF_INTEREST_COLORED = "genes_of_interest_colored" + UPREGULATED = "upregulated" + DOWNREGULATED = "downregulated" From 3c5acb84c499feeccf3aa57e5d0342bf9af35864 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:12:37 +0200 Subject: [PATCH 04/13] other occurrences --- alphastats/gui/pages/02_Import Data.py | 2 +- alphastats/gui/pages/03_Data Overview.py | 2 +- alphastats/gui/pages/03_Preprocessing.py | 4 ++-- alphastats/gui/pages/04_Analysis.py | 4 ++-- alphastats/gui/pages/05_LLM.py | 28 ++++++++++++------------ alphastats/gui/pages/06_Results.py | 2 +- alphastats/gui/utils/analysis_helper.py | 2 +- alphastats/gui/utils/gpt_helper.py | 10 ++++----- alphastats/gui/utils/ui_helper.py | 11 +++++----- alphastats/gui/utils/uniprot_utils.py | 2 +- tests/gui/test_02_import_data.py | 8 +++---- tests/test_gpt.py | 2 +- 12 files changed, 39 insertions(+), 38 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index 1a543d90..9aa7d87d 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -63,7 +63,7 @@ def _finalize_data_loading( st.markdown("### Import Proteomics Data") -if "dataset" in st.session_state: +if StateKeys.DATASET in st.session_state: st.info(f"DataSet already present.") st.page_link("pages/03_Data Overview.py", label="=> Go to data overview page..") st.stop() diff --git a/alphastats/gui/pages/03_Data Overview.py b/alphastats/gui/pages/03_Data Overview.py index 241bce87..ba581deb 100644 --- a/alphastats/gui/pages/03_Data Overview.py +++ b/alphastats/gui/pages/03_Data Overview.py @@ -12,7 +12,7 @@ init_session_state() sidebar_info() -if "dataset" not in st.session_state: +if StateKeys.DATASET not in st.session_state: st.info("Import Data first") st.stop() diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/03_Preprocessing.py index 34f95c47..07d72aa1 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/03_Preprocessing.py @@ -27,7 +27,7 @@ c1, c2 = st.columns([1, 1]) with c2: - if "dataset" in st.session_state: + if StateKeys.DATASET in st.session_state: settings = configure_preprocessing(dataset=st.session_state[StateKeys.DATASET]) new_workflow = update_workflow(settings) if new_workflow != st.session_state[StateKeys.WORKFLOW]: @@ -38,7 +38,7 @@ selected_nodes = draw_workflow(st.session_state[StateKeys.WORKFLOW]) - if "dataset" not in st.session_state: + if StateKeys.DATASET not in st.session_state: st.info("Import data first to configure and run preprocessing") else: diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/04_Analysis.py index 6bf47d67..aac6af63 100644 --- a/alphastats/gui/pages/04_Analysis.py +++ b/alphastats/gui/pages/04_Analysis.py @@ -51,11 +51,11 @@ def select_analysis(): st.markdown(styl, unsafe_allow_html=True) -if "plot_list" not in st.session_state: +if StateKeys.PLOT_LIST not in st.session_state: st.session_state[StateKeys.PLOT_LIST] = [] -if "dataset" in st.session_state: +if StateKeys.DATASET in st.session_state: c1, c2 = st.columns((1, 2)) plot_to_display = False diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 92a70396..9e2e9e20 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -67,26 +67,26 @@ def select_analysis(): st.markdown(styl, unsafe_allow_html=True) # Initialize session state variables -if "llm_integration" not in st.session_state: +if StateKeys.LLM_INTEGRATION not in st.session_state: st.session_state[StateKeys.LLM_INTEGRATION] = None -if "api_type" not in st.session_state: +if StateKeys.API_TYPE not in st.session_state: st.session_state[StateKeys.API_TYPE] = "gpt" -if "plot_list" not in st.session_state: +if StateKeys.PLOT_LIST not in st.session_state: st.session_state[StateKeys.PLOT_LIST] = [] -if "messages" not in st.session_state: +if StateKeys.MESSAGES not in st.session_state: st.session_state[StateKeys.MESSAGES] = [] -if "plot_submitted_clicked" not in st.session_state: +if StateKeys.PLOT_SUBMITTED_CLICKED not in st.session_state: st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] = 0 st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = 0 -if "lookup_submitted_clicked" not in st.session_state: +if StateKeys.LOOKUP_SUBMITTED_CLICKED not in st.session_state: st.session_state[StateKeys.LOOKUP_SUBMITTED_CLICKED] = 0 st.session_state[StateKeys.LOOKUP_SUBMITTED_COUNTER] = 0 -if "gpt_submitted_clicked" not in st.session_state: +if StateKeys.GPT_SUBMITTED_CLICKED not in st.session_state: st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] = 0 st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = 0 @@ -157,7 +157,7 @@ def select_analysis(): < st.session_state[StateKeys.PLOT_SUBMITTED_CLICKED] ): st.session_state[StateKeys.PLOT_SUBMITTED_COUNTER] = st.session_state[ - "plot_submitted_clicked" + StateKeys.PLOT_SUBMITTED_CLICKED ] volcano_plot = gui_volcano_plot_differential_expression_analysis( chosen_parameter_dict @@ -243,7 +243,7 @@ def select_analysis(): "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that. If the data coming to" " you from a function has references to the literature (for example, PubMed), always quote the references in your response." ) -if "column" in chosen_parameter_dict and "upregulated" in st.session_state: +if "column" in chosen_parameter_dict and StateKeys.UPREGULATED in st.session_state: st.session_state[StateKeys.USER_PROMPT] = ( f"We've recently identified several proteins that appear to be differently regulated in cells " f"when comparing {chosen_parameter_dict['group1']} and {chosen_parameter_dict['group2']} in the {chosen_parameter_dict['column']} group. " @@ -253,7 +253,7 @@ def select_analysis(): f"to the differences. After that provide a high level summary" ) -if "user_prompt" in st.session_state: +if StateKeys.USER_PROMPT in st.session_state: st.subheader("Automatically generated prompt based on gene functions:") with st.expander("Adjust system prompt (see example below)", expanded=False): st.session_state[StateKeys.INSTRUCTIONS] = st.text_area( @@ -267,7 +267,7 @@ def select_analysis(): gpt_submitted = st.button("Run GPT analysis") -if gpt_submitted and "user_prompt" not in st.session_state: +if gpt_submitted and StateKeys.USER_PROMPT not in st.session_state: st.warning("Please enter a user prompt first") st.stop() @@ -307,7 +307,7 @@ def select_analysis(): st.stop() if ( - "llm_integration" not in st.session_state + StateKeys.LLM_INTEGRATION not in st.session_state or not st.session_state[StateKeys.LLM_INTEGRATION] ): st.warning("Please initialize the model first") @@ -327,7 +327,7 @@ def select_analysis(): ), ] -if "artifacts" not in st.session_state: +if StateKeys.ARTIFACTS not in st.session_state: st.session_state[StateKeys.ARTIFACTS] = {} if ( @@ -335,7 +335,7 @@ def select_analysis(): < st.session_state[StateKeys.GPT_SUBMITTED_CLICKED] ): st.session_state[StateKeys.GPT_SUBMITTED_COUNTER] = st.session_state[ - "gpt_submitted_clicked" + StateKeys.GPT_SUBMITTED_CLICKED ] st.session_state[StateKeys.ARTIFACTS] = {} llm.messages = [ diff --git a/alphastats/gui/pages/06_Results.py b/alphastats/gui/pages/06_Results.py index 3606b11f..c3d2293b 100644 --- a/alphastats/gui/pages/06_Results.py +++ b/alphastats/gui/pages/06_Results.py @@ -45,7 +45,7 @@ def download_preprocessing_info(plot, name, count): st.markdown("### Results") -if "plot_list" in st.session_state: +if StateKeys.PLOT_LIST in st.session_state: for count, plot in enumerate(st.session_state[StateKeys.PLOT_LIST]): print("plot", type(plot), count) count = str(count) diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index c95782cc..ca0b1d0b 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -10,7 +10,7 @@ def check_if_options_are_loaded(f): # decorator to check for missing values # TODO remove this def inner(*args, **kwargs): - if hasattr(st.session_state, "plotting_options") is False: + if hasattr(st.session_state, StateKeys.PLOTTING_OPTIONS) is False: load_options() return f(*args, **kwargs) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 42f5fe8c..f06aeec9 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -339,11 +339,11 @@ def get_gene_to_prot_id_mapping(gene_id: str) -> str: import streamlit as st session_state_copy = dict(copy.deepcopy(st.session_state)) - if "gene_to_prot_id" not in session_state_copy: - session_state_copy["gene_to_prot_id"] = {} - if gene_id in session_state_copy["gene_to_prot_id"]: - return session_state_copy["gene_to_prot_id"][gene_id] - for gene, prot_id in session_state_copy["gene_to_prot_id"].items(): + if StateKeys.GENE_TO_PROT_ID not in session_state_copy: + session_state_copy[StateKeys.GENE_TO_PROT_ID] = {} + if gene_id in session_state_copy[StateKeys.GENE_TO_PROT_ID]: + return session_state_copy[StateKeys.GENE_TO_PROT_ID][gene_id] + for gene, prot_id in session_state_copy[StateKeys.GENE_TO_PROT_ID].items(): if gene_id in gene.split(";"): return prot_id return gene_id diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index 526299d1..d5a5adea 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -37,7 +37,7 @@ def sidebar_info(): def _display_sidebar_html_table(): - if "dataset" not in st.session_state: + if StateKeys.DATASET not in st.session_state: return preprocessing_dict = st.session_state[StateKeys.DATASET].preprocessing_info @@ -79,13 +79,13 @@ def empty_session_state(): def init_session_state() -> None: """Initialize the session state if not done yet.""" - if "user_session_id" not in st.session_state: + if StateKeys.USER_SESSION_ID not in st.session_state: st.session_state[StateKeys.USER_SESSION_ID] = str(uuid.uuid4()) - if "gene_to_prot_id" not in st.session_state: + if StateKeys.GENE_TO_PROT_ID not in st.session_state: st.session_state[StateKeys.GENE_TO_PROT_ID] = {} - if "organism" not in st.session_state: + if StateKeys.ORGANISM not in st.session_state: st.session_state[StateKeys.ORGANISM] = 9606 # human @@ -102,8 +102,9 @@ class StateKeys: STATISTIC_OPTIONS = "statistic_options" # function load_options WORKFLOW = "workflow" - PLOT_LIST = "plot_list" + + # LLM OPENAI_API_KEY = "openai_api_key" # pragma: allowlist secret API_TYPE = "api_type" LLM_INTEGRATION = "llm_integration" diff --git a/alphastats/gui/utils/uniprot_utils.py b/alphastats/gui/utils/uniprot_utils.py index 85636778..4b3b4699 100644 --- a/alphastats/gui/utils/uniprot_utils.py +++ b/alphastats/gui/utils/uniprot_utils.py @@ -318,7 +318,7 @@ def get_gene_function(gene_name: Union[str, Dict], organism_id=9606) -> str: Returns: str: The gene function and description. """ - if "organism" in st.session_state: + if StateKeys.ORGANISM in st.session_state: organism_id = st.session_state[StateKeys.ORGANISM] if isinstance(gene_name, dict): gene_name = gene_name["gene_name"] diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 97eb3004..ccc89e8e 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -53,8 +53,8 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): str(type(at.session_state[StateKeys.LOADER])) == "" ) - assert "plotting_options" in at.session_state - assert "statistic_options" in at.session_state + assert StateKeys.PLOTTING_OPTIONS in at.session_state + assert StateKeys.STATISTIC_OPTIONS in at.session_state @patch("streamlit.file_uploader") @@ -115,5 +115,5 @@ def test_page_02_loads_maxquant_testfiles( str(type(at.session_state[StateKeys.LOADER])) == "" ) - assert "plotting_options" in at.session_state - assert "statistic_options" in at.session_state + assert StateKeys.PLOTTING_OPTIONS in at.session_state + assert StateKeys.STATISTIC_OPTIONS in at.session_state diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 23107745..6ca47c76 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -11,7 +11,7 @@ from alphastats.DataSet import DataSet from gui.utils.ui_helper import StateKeys -if "gene_to_prot_id" not in st.session_state: +if StateKeys.GENE_TO_PROT_ID not in st.session_state: st.session_state[StateKeys.GENE_TO_PROT_ID] = {} From de7ab01efd5efb8588bb1731a24ded253fb4c9ef Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:50:26 +0200 Subject: [PATCH 05/13] other occurrences --- alphastats/gui/utils/openai_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index b5989f27..a2d3b3f0 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -51,6 +51,7 @@ def wait_for_run_completion( thread_id=thread_id, run_id=run_id ) print(run_status.status, run_id, run_status.required_action) + # TODO check if this still works after introducing StateKeys.DATASET assistant_functions = { "create_intensity_plot", "perform_dimensionality_reduction", From 52194e75829a9947775455ef83504773c6ad68e0 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:59:09 +0200 Subject: [PATCH 06/13] fix imports --- alphastats/gui/utils/gpt_helper.py | 2 +- alphastats/gui/utils/import_helper.py | 2 +- alphastats/gui/utils/ollama_utils.py | 2 +- alphastats/gui/utils/openai_utils.py | 2 +- alphastats/gui/utils/uniprot_utils.py | 2 +- tests/test_DataSet.py | 2 +- tests/test_gpt.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index f06aeec9..09088f38 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -9,7 +9,7 @@ import streamlit as st from alphastats.plots.DimensionalityReduction import DimensionalityReduction -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys Entrez.email = "lebedev_mikhail@outlook.com" # Always provide your email address when using NCBI services. diff --git a/alphastats/gui/utils/import_helper.py b/alphastats/gui/utils/import_helper.py index 77fedacd..ebf22793 100644 --- a/alphastats/gui/utils/import_helper.py +++ b/alphastats/gui/utils/import_helper.py @@ -11,7 +11,7 @@ from alphastats.DataSet import DataSet from alphastats.gui.utils.options import SOFTWARE_OPTIONS from alphastats.loader.MaxQuantLoader import MaxQuantLoader, BaseLoader -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys def load_options(): diff --git a/alphastats/gui/utils/ollama_utils.py b/alphastats/gui/utils/ollama_utils.py index fd778cca..30531748 100644 --- a/alphastats/gui/utils/ollama_utils.py +++ b/alphastats/gui/utils/ollama_utils.py @@ -13,7 +13,7 @@ # from alphastats.gui.utils.artefacts import ArtifactManager from alphastats.gui.utils.uniprot_utils import get_gene_function from alphastats.gui.utils.enrichment_analysis import get_enrichment_data -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys class LLMIntegration: diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index a2d3b3f0..0bf54336 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -8,7 +8,7 @@ import openai import streamlit as st -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys try: from alphastats.gui.utils.gpt_helper import ( diff --git a/alphastats/gui/utils/uniprot_utils.py b/alphastats/gui/utils/uniprot_utils.py index 4b3b4699..f767ecd8 100644 --- a/alphastats/gui/utils/uniprot_utils.py +++ b/alphastats/gui/utils/uniprot_utils.py @@ -4,7 +4,7 @@ import streamlit as st -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys uniprot_fields = [ # Names & Taxonomy diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index cc16aaac..4ac22a46 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -20,7 +20,7 @@ from alphastats.DataSet_Statistics import Statistics from alphastats.utils import LoaderError -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys logger = logging.getLogger(__name__) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 6ca47c76..1ec1a41c 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -9,7 +9,7 @@ from alphastats.gui.utils.uniprot_utils import get_uniprot_data, extract_data from alphastats.loader.MaxQuantLoader import MaxQuantLoader from alphastats.DataSet import DataSet -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys if StateKeys.GENE_TO_PROT_ID not in st.session_state: st.session_state[StateKeys.GENE_TO_PROT_ID] = {} From 32a68f3b70ce7bc4bc63a1d5050760d6191bf1c4 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:03:15 +0200 Subject: [PATCH 07/13] fix tests --- tests/gui/test_02_import_data.py | 3 +-- tests/gui/test_03_data_overview.py | 4 +--- tests/gui/test_04_preprocessing.py | 4 +--- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index ccc89e8e..17559bad 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -1,8 +1,7 @@ from streamlit.testing.v1 import AppTest -from pathlib import Path from unittest.mock import MagicMock, patch from .conftest import APP_FOLDER, data_buf, metadata_buf -from .utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/02_Import Data.py" diff --git a/tests/gui/test_03_data_overview.py b/tests/gui/test_03_data_overview.py index df20b9b1..c3c3416a 100644 --- a/tests/gui/test_03_data_overview.py +++ b/tests/gui/test_03_data_overview.py @@ -1,8 +1,6 @@ from streamlit.testing.v1 import AppTest -from pathlib import Path -from unittest.mock import MagicMock, patch from .conftest import create_dataset_alphapept, APP_FOLDER -from .utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/03_Data Overview.py" diff --git a/tests/gui/test_04_preprocessing.py b/tests/gui/test_04_preprocessing.py index 091eec85..d10fa2c4 100644 --- a/tests/gui/test_04_preprocessing.py +++ b/tests/gui/test_04_preprocessing.py @@ -1,8 +1,6 @@ from streamlit.testing.v1 import AppTest -from pathlib import Path -from unittest.mock import MagicMock, patch from .conftest import create_dataset_alphapept, APP_FOLDER -from .utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys TESTED_PAGE = f"{APP_FOLDER}/pages/03_Preprocessing.py" From 9606ca6428a4d55f96a0eadd77309fae251be81f Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:14:09 +0200 Subject: [PATCH 08/13] Revert "remove unused metadata_columns" This reverts commit 583b1e915cf4d824bd592d17c887ce7af142c455. --- alphastats/gui/pages/02_Import Data.py | 6 ++++-- alphastats/gui/utils/import_helper.py | 3 ++- tests/gui/test_02_import_data.py | 6 ++++++ tests/test_gpt.py | 1 + 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index 9aa7d87d..b439795c 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -24,12 +24,14 @@ def _finalize_data_loading( loader: BaseLoader, + metadata_columns: List[str], dataset: DataSet, ) -> None: """Finalize the data loading process.""" st.session_state[StateKeys.LOADER] = ( loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. ) + st.session_state["metadata_columns"] = metadata_columns st.session_state[StateKeys.DATASET] = dataset load_options() @@ -56,9 +58,9 @@ def _finalize_data_loading( if c2.button("Start new Session with example DataSet", key="_load_example_data"): empty_session_state() init_session_state() - loader, dataset = load_example_data() + loader, metadata_columns, dataset = load_example_data() - _finalize_data_loading(loader, dataset) + _finalize_data_loading(loader, metadata_columns, dataset) st.stop() diff --git a/alphastats/gui/utils/import_helper.py b/alphastats/gui/utils/import_helper.py index ebf22793..dd313d78 100644 --- a/alphastats/gui/utils/import_helper.py +++ b/alphastats/gui/utils/import_helper.py @@ -129,7 +129,8 @@ def load_example_data(): ] ] dataset.preprocess(subset=True) - return loader, dataset + metadata_columns = dataset.metadata.columns.to_list() + return loader, metadata_columns, dataset def _check_softwarefile_df(df: pd.DataFrame, software: str) -> None: diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 17559bad..30193793 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -44,6 +44,12 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): assert not at.exception + assert at.session_state.metadata_columns == [ + "sample", + "disease", + "Drug therapy (procedure) (416608005)", + "Lipid-lowering therapy (134350008)", + ] assert ( str(type(at.session_state[StateKeys.DATASET])) == "" diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 1ec1a41c..18ed56a3 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -31,6 +31,7 @@ def setUp(self): self.matrix_dim = (312, 2596) self.matrix_dim_filtered = (312, 2397) self.comparison_column = "disease" + st.session_state.metadata_columns = [self.comparison_column] class TestGetUniProtData(unittest.TestCase): From 49effbcf48772157e9b69e05f9f1b75b1901f53b Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:18:17 +0200 Subject: [PATCH 09/13] reintroduce metadata_columns, substitute occurrences of "dataset" --- alphastats/gui/pages/02_Import Data.py | 2 +- alphastats/gui/utils/options.py | 62 ++++++++++++++------------ alphastats/gui/utils/ui_helper.py | 1 + tests/gui/test_02_import_data.py | 2 +- tests/test_gpt.py | 2 +- 5 files changed, 37 insertions(+), 32 deletions(-) diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index b439795c..389a4fc6 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -31,7 +31,7 @@ def _finalize_data_loading( st.session_state[StateKeys.LOADER] = ( loader # TODO: Figure out if we even need the loader here, as the dataset has the loader as an attribute. ) - st.session_state["metadata_columns"] = metadata_columns + st.session_state[StateKeys.METADATA_COLUMNS] = metadata_columns st.session_state[StateKeys.DATASET] = dataset load_options() diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index 91eda0a4..5834f8ed 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -5,24 +5,26 @@ from alphastats.loader.SpectronautLoader import SpectronautLoader from alphastats.loader.GenericLoader import GenericLoader from alphastats.loader.mzTabLoader import mzTabLoader +from gui.utils.ui_helper import StateKeys def plotting_options(state): + dataset = state[StateKeys.DATASET] plotting_options = { "Sampledistribution Plot": { "settings": { "method": {"options": ["violin", "box"], "label": "Plot layout"}, "color": { - "options": [None] + state.metadata_columns, + "options": [None] + state[StateKeys.METADATA_COLUMNS], "label": "Color according to", }, }, - "function": state.dataset.plot_sampledistribution, + "function": dataset.plot_sampledistribution, }, "Intensity Plot": { "settings": { "protein_id": { - "options": state.dataset.mat.columns.to_list(), + "options": dataset.mat.columns.to_list(), "label": "ProteinID/ProteinGroup", }, "method": { @@ -30,105 +32,106 @@ def plotting_options(state): "label": "Plot layout", }, "group": { - "options": [None] + state.metadata_columns, + "options": [None] + state[StateKeys.METADATA_COLUMNS], "label": "Color according to", }, }, - "function": state.dataset.plot_intensity, + "function": dataset.plot_intensity, }, "PCA Plot": { "settings": { "group": { - "options": [None] + state.metadata_columns, + "options": [None] + state[StateKeys.METADATA_COLUMNS], "label": "Color according to", }, "circle": {"label": "Circle"}, }, - "function": state.dataset.plot_pca, + "function": dataset.plot_pca, }, "UMAP Plot": { "settings": { "group": { - "options": [None] + state.metadata_columns, + "options": [None] + state[StateKeys.METADATA_COLUMNS], "label": "Color according to", }, "circle": {"label": "Circle"}, }, - "function": state.dataset.plot_umap, + "function": dataset.plot_umap, }, "t-SNE Plot": { "settings": { "group": { - "options": [None] + state.metadata_columns, + "options": [None] + state[StateKeys.METADATA_COLUMNS], "label": "Color according to", }, "circle": {"label": "Circle"}, }, - "function": state.dataset.plot_tsne, + "function": dataset.plot_tsne, }, "Volcano Plot": { "between_two_groups": True, - "function": state.dataset.plot_volcano, + "function": dataset.plot_volcano, }, - "Clustermap": {"function": state.dataset.plot_clustermap}, - # "Dendrogram": {"function": state.dataset.plot_dendrogram}, # TODO why commented? + "Clustermap": {"function": dataset.plot_clustermap}, + # "Dendrogram": {"function": state[StateKeys.DATASET].plot_dendrogram}, # TODO why commented? } return plotting_options def statistic_options(state): + dataset = state[StateKeys.DATASET] statistic_options = { "Differential Expression Analysis - T-test": { "between_two_groups": True, - "function": state.dataset.diff_expression_analysis, + "function": dataset.diff_expression_analysis, }, "Differential Expression Analysis - Wald-test": { "between_two_groups": True, - "function": state.dataset.diff_expression_analysis, + "function": dataset.diff_expression_analysis, }, "Tukey - Test": { "settings": { "protein_id": { - "options": state.dataset.mat.columns.to_list(), + "options": dataset.mat.columns.to_list(), "label": "ProteinID/ProteinGroup", }, "group": { - "options": state.metadata_columns, + "options": state[StateKeys.METADATA_COLUMNS], "label": "A metadata variable to calculate pairwise tukey", }, }, - "function": state.dataset.tukey_test, + "function": dataset.tukey_test, }, "ANOVA": { "settings": { "column": { - "options": state.metadata_columns, + "options": state[StateKeys.METADATA_COLUMNS], "label": "A variable from the metadata to calculate ANOVA", }, "protein_ids": { - "options": ["all"] + state.dataset.mat.columns.to_list(), + "options": ["all"] + dataset.mat.columns.to_list(), "label": "All ProteinIDs/or specific ProteinID to perform ANOVA", }, "tukey": {"label": "Follow-up Tukey"}, }, - "function": state.dataset.anova, + "function": dataset.anova, }, "ANCOVA": { "settings": { "protein_id": { - "options": [None] + state.dataset.mat.columns.to_list(), + "options": [None] + dataset.mat.columns.to_list(), "label": "Color according to", }, "covar": { - "options": state.metadata_columns, + "options": state[StateKeys.METADATA_COLUMNS], "label": "Name(s) of column(s) in metadata with the covariate.", }, "between": { - "options": state.metadata_columns, + "options": state[StateKeys.METADATA_COLUMNS], "label": "Name of the column in the metadata with the between factor.", }, }, - "function": state.dataset.ancova, + "function": dataset.ancova, }, } return statistic_options @@ -187,17 +190,18 @@ def statistic_options(state): # TODO unused def interpretation_options(state): + dataset = state[StateKeys.DATASET] return { "Volcano Plot": { "between_two_groups": True, - "function": state.dataset.plot_volcano, + "function": dataset.plot_volcano, }, "Differential Expression Analysis - T-test": { "between_two_groups": True, - "function": state.dataset.diff_expression_analysis, + "function": dataset.diff_expression_analysis, }, "Differential Expression Analysis - Wald-test": { "between_two_groups": True, - "function": state.dataset.diff_expression_analysis, + "function": dataset.diff_expression_analysis, }, } diff --git a/alphastats/gui/utils/ui_helper.py b/alphastats/gui/utils/ui_helper.py index d5a5adea..22d5d18b 100644 --- a/alphastats/gui/utils/ui_helper.py +++ b/alphastats/gui/utils/ui_helper.py @@ -101,6 +101,7 @@ class StateKeys: PLOTTING_OPTIONS = "plotting_options" # function load_options STATISTIC_OPTIONS = "statistic_options" # function load_options + METADATA_COLUMNS = "metadata_columns" WORKFLOW = "workflow" PLOT_LIST = "plot_list" diff --git a/tests/gui/test_02_import_data.py b/tests/gui/test_02_import_data.py index 30193793..abdcbcb3 100644 --- a/tests/gui/test_02_import_data.py +++ b/tests/gui/test_02_import_data.py @@ -44,7 +44,7 @@ def test_page_02_loads_example_data(mock_page_link: MagicMock): assert not at.exception - assert at.session_state.metadata_columns == [ + assert at.session_state[StateKeys.METADATA_COLUMNS] == [ "sample", "disease", "Drug therapy (procedure) (416608005)", diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 18ed56a3..4d1e415a 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -31,7 +31,7 @@ def setUp(self): self.matrix_dim = (312, 2596) self.matrix_dim_filtered = (312, 2397) self.comparison_column = "disease" - st.session_state.metadata_columns = [self.comparison_column] + st.session_state[StateKeys.METADATA_COLUMNS] = [self.comparison_column] class TestGetUniProtData(unittest.TestCase): From c142747dcd5155138bd12e6a95d54c44809800d4 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 15:37:12 +0200 Subject: [PATCH 10/13] fix tests --- alphastats/gui/utils/options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alphastats/gui/utils/options.py b/alphastats/gui/utils/options.py index 5834f8ed..87687167 100644 --- a/alphastats/gui/utils/options.py +++ b/alphastats/gui/utils/options.py @@ -5,7 +5,7 @@ from alphastats.loader.SpectronautLoader import SpectronautLoader from alphastats.loader.GenericLoader import GenericLoader from alphastats.loader.mzTabLoader import mzTabLoader -from gui.utils.ui_helper import StateKeys +from alphastats.gui.utils.ui_helper import StateKeys def plotting_options(state): From e0ea0f1d1da2c28b3a2ed681bd81fbb0b51fabcc Mon Sep 17 00:00:00 2001 From: Julia Schessner Date: Tue, 17 Sep 2024 16:48:22 +0200 Subject: [PATCH 11/13] Missed some --- alphastats/gui/pages/05_LLM.py | 8 ++++---- alphastats/gui/utils/analysis_helper.py | 1 + alphastats/gui/utils/openai_utils.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/alphastats/gui/pages/05_LLM.py b/alphastats/gui/pages/05_LLM.py index 9e2e9e20..18fed43e 100644 --- a/alphastats/gui/pages/05_LLM.py +++ b/alphastats/gui/pages/05_LLM.py @@ -42,7 +42,7 @@ def select_analysis(): return method -if "dataset" not in st.session_state: +if StateKeys.DATASET not in st.session_state: st.info("Import Data first") st.stop() @@ -247,8 +247,8 @@ def select_analysis(): st.session_state[StateKeys.USER_PROMPT] = ( f"We've recently identified several proteins that appear to be differently regulated in cells " f"when comparing {chosen_parameter_dict['group1']} and {chosen_parameter_dict['group2']} in the {chosen_parameter_dict['column']} group. " - f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state['upregulated'])}.{os.linesep}{os.linesep}" - f"Here is the list of proteins that are downregulated: {', '.join(st.session_state['downregulated'])}.{os.linesep}{os.linesep}" + f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state[StateKeys.UPREGULATED])}.{os.linesep}{os.linesep}" + f"Here is the list of proteins that are downregulated: {', '.join(st.session_state[StateKeys.DOWNREGULATED])}.{os.linesep}{os.linesep}" f"Help us understand the potential connections between these proteins and how they might be contributing " f"to the differences. After that provide a high level summary" ) @@ -298,7 +298,7 @@ def select_analysis(): metadata=st.session_state[StateKeys.DATASET].metadata, ) st.success( - f"{st.session_state['api_type'].upper()} integration initialized successfully!" + f"{st.session_state[StateKeys.API_TYPE].upper()} integration initialized successfully!" ) except AuthenticationError: st.warning( diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index ca0b1d0b..eb0f09c4 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -383,6 +383,7 @@ def load_options(): st.session_state[StateKeys.PLOTTING_OPTIONS] = plotting_options(st.session_state) st.session_state[StateKeys.STATISTIC_OPTIONS] = statistic_options(st.session_state) + # TODO: Check if this should be reintroduced or removed # st.session_state["interpretation_options"] = interpretation_options diff --git a/alphastats/gui/utils/openai_utils.py b/alphastats/gui/utils/openai_utils.py index 0bf54336..a9ff385a 100644 --- a/alphastats/gui/utils/openai_utils.py +++ b/alphastats/gui/utils/openai_utils.py @@ -180,7 +180,7 @@ def try_to_set_api_key(api_key: str = None) -> None: Returns: None """ - if api_key and "api_key" not in st.session_state: + if api_key and [StateKeys.OPENAI_API_KEY] not in st.session_state: st.session_state[StateKeys.OPENAI_API_KEY] = api_key secret_path = Path(st.secrets._file_paths[-1]) secret_path.parent.mkdir(parents=True, exist_ok=True) From abd297a7e7f9c8ecb613e477bb4e63c479596894 Mon Sep 17 00:00:00 2001 From: mschwoerer <82171591+mschwoer@users.noreply.github.com> Date: Tue, 17 Sep 2024 17:21:07 +0200 Subject: [PATCH 12/13] add another missed StateKeys substitution --- alphastats/gui/utils/gpt_helper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 09088f38..3296dd56 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -143,6 +143,10 @@ def get_assistant_functions( Returns: list[dict]: A list of assistant functions. """ + # TODO figure out how this relates to the parameter `subgroups_for_each_group` + subgroups_for_each_group_ = str( + get_subgroups_for_each_group(st.session_state[StateKeys.DATASET].metadata) + ) return [ { "type": "function", @@ -165,7 +169,8 @@ def get_assistant_functions( "subgroups": { "type": "array", "items": {"type": "string"}, - "description": f"Specific subgroups within the group to analyze. For each group you need to look up the subgroups in the dict {str(get_subgroups_for_each_group(st.session_state['dataset'].metadata))} or present user with them first if you are not sure what to choose", + "description": f"Specific subgroups within the group to analyze. For each group you need to look up the subgroups in the dict" + f" {subgroups_for_each_group_} or present user with them first if you are not sure what to choose", }, "method": { "type": "string", From 5209c606401ef144577b5b4fa4c8b10f742d261e Mon Sep 17 00:00:00 2001 From: Julia Schessner Date: Tue, 17 Sep 2024 22:59:12 +0200 Subject: [PATCH 13/13] rename misleading private method read_columns_as_string --- alphastats/loader/AlphaPeptLoader.py | 2 +- alphastats/loader/BaseLoader.py | 4 ++-- alphastats/loader/DIANNLoader.py | 2 +- alphastats/loader/GenericLoader.py | 2 +- alphastats/loader/MaxQuantLoader.py | 2 +- alphastats/loader/SpectronautLoader.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/alphastats/loader/AlphaPeptLoader.py b/alphastats/loader/AlphaPeptLoader.py index c10f0367..b2cdd90d 100644 --- a/alphastats/loader/AlphaPeptLoader.py +++ b/alphastats/loader/AlphaPeptLoader.py @@ -41,7 +41,7 @@ def __init__( # add contamination column "Reverse" self._add_contamination_reverse_column() self._add_contamination_column() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() #  make ProteinGroup column self.rawinput["ProteinGroup"] = self.rawinput[self.index_column].map( self._standardize_protein_group_column diff --git a/alphastats/loader/BaseLoader.py b/alphastats/loader/BaseLoader.py index ca86bd8b..d22375f0 100644 --- a/alphastats/loader/BaseLoader.py +++ b/alphastats/loader/BaseLoader.py @@ -45,7 +45,7 @@ def __init__( self.ptm_df = None self._add_contamination_column() self._check_if_columns_are_present() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() def _check_if_columns_are_present(self): """check if given columns present in rawinput""" @@ -61,7 +61,7 @@ def _check_if_columns_are_present(self): "MaxQuant Format: http://www.coxdocs.org/doku.php?id=maxquant:table:proteingrouptable" ) - def _read_all_columns_as_string(self): + def _read_all_column_names_as_string(self): self.rawinput.columns = self.rawinput.columns.astype(str) def _check_if_indexcolumn_is_unique(self): diff --git a/alphastats/loader/DIANNLoader.py b/alphastats/loader/DIANNLoader.py index 1ad43b92..4b22720b 100644 --- a/alphastats/loader/DIANNLoader.py +++ b/alphastats/loader/DIANNLoader.py @@ -44,7 +44,7 @@ def __init__( self._remove_filepath_from_name() self._add_tag_to_sample_columns() self._add_contamination_column() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() def _add_tag_to_sample_columns(self): """ diff --git a/alphastats/loader/GenericLoader.py b/alphastats/loader/GenericLoader.py index 71cbb9f5..4625a172 100644 --- a/alphastats/loader/GenericLoader.py +++ b/alphastats/loader/GenericLoader.py @@ -36,7 +36,7 @@ def __init__( self.ptm_df = None self._add_contamination_column() self._check_if_columns_are_present() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() def _extract_sample_names(self, metadata: pd.DataFrame, sample_column: str): sample_names = metadata[sample_column].to_list() diff --git a/alphastats/loader/MaxQuantLoader.py b/alphastats/loader/MaxQuantLoader.py index 3cad5e6c..fd77ff6b 100644 --- a/alphastats/loader/MaxQuantLoader.py +++ b/alphastats/loader/MaxQuantLoader.py @@ -39,7 +39,7 @@ def __init__( self.confidence_column = confidence_column self.software = "MaxQuant" self._set_filter_columns_to_true_false() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() if gene_names_column in self.rawinput.columns.to_list(): self.gene_names = gene_names_column diff --git a/alphastats/loader/SpectronautLoader.py b/alphastats/loader/SpectronautLoader.py index 4cbabe1d..3bb4f86c 100644 --- a/alphastats/loader/SpectronautLoader.py +++ b/alphastats/loader/SpectronautLoader.py @@ -53,7 +53,7 @@ def __init__( ) self._add_contamination_column() - self._read_all_columns_as_string() + self._read_all_column_names_as_string() def _reshape_spectronaut(self, sample_column, gene_names_column): """