From b36cd460a97d20233b3a83324887d8fda35a9302 Mon Sep 17 00:00:00 2001 From: Mikhail Lebedev Date: Fri, 12 Apr 2024 13:28:11 +0200 Subject: [PATCH] chat and fix --- alphastats/DataSet_Preprocess.py | 127 ++++-- alphastats/gui/pages/02_Import Data.py | 60 +-- alphastats/gui/pages/03_Data Overview.py | 23 +- ...3_Preprocessing.py => 04_Preprocessing.py} | 90 +++-- .../pages/{04_Analysis.py => 05_Analysis.py} | 0 .../{05_GPT.py => 06_Analysis with GPT.py} | 49 ++- alphastats/gui/pages/07_Chat with GPT.py | 378 ++++++++++++++++++ .../pages/{06_Results.py => 08_Results.py} | 0 alphastats/gui/utils/analysis_helper.py | 4 +- alphastats/gui/utils/gpt_helper.py | 119 +++--- tests/test_DataSet.py | 94 +++-- 11 files changed, 718 insertions(+), 226 deletions(-) rename alphastats/gui/pages/{03_Preprocessing.py => 04_Preprocessing.py} (69%) rename alphastats/gui/pages/{04_Analysis.py => 05_Analysis.py} (100%) rename alphastats/gui/pages/{05_GPT.py => 06_Analysis with GPT.py} (90%) create mode 100644 alphastats/gui/pages/07_Chat with GPT.py rename alphastats/gui/pages/{06_Results.py => 08_Results.py} (100%) diff --git a/alphastats/DataSet_Preprocess.py b/alphastats/DataSet_Preprocess.py index 1e9fd665..f592249a 100644 --- a/alphastats/DataSet_Preprocess.py +++ b/alphastats/DataSet_Preprocess.py @@ -9,6 +9,7 @@ from sklearn.experimental import enable_iterative_imputer import itertools +import streamlit as st class Preprocess: @@ -30,32 +31,45 @@ def preprocess_print_info(self): print(pd.DataFrame(self.preprocessing_info.items())) def _remove_na_values(self, cut_off): + if ( + self.preprocessing_info.get("Missing values were removed") + and self.preprocessing_info.get("Data completeness cut-off") == cut_off + ): + logging.info("Missing values have already been filtered.") + st.warning( + "Missing values have already been filtered. To apply another cutoff, reset preprocessing." + ) + return cut = 1 - cut_off - limit = self.mat.shape[0] * cut - + + num_samples, num_proteins = self.mat.shape + limit = num_samples * cut + + self.mat.replace(0, np.nan, inplace=True) keep_list = list() invalid = 0 for column_name in self.mat.columns: column = self.mat[column_name] - # Get the count of Zeros in column - count = (column == 0).sum() + count = column.isna().sum() try: count = count.item() if isinstance(count, int): if count < limit: keep_list += [column_name] - + except ValueError: - invalid +=1 + invalid += 1 continue - - self.mat= self.mat[keep_list] + self.mat = self.mat[keep_list] + self.preprocessing_info.update( - {"Data completeness cut-off": cut_off} + { + "Number of removed ProteinGroups due to data completeness cutoff": num_proteins + - self.mat.shape[1], + "Missing values were removed": True, + "Data completeness cut-off": cut_off, + } ) - percentage = cut_off * 100 - print(f"Proteins with a data completeness across all samples of less than {percentage} % have been removed.") - def _filter(self): if len(self.filter_columns) == 0: @@ -105,15 +119,18 @@ def _imputation(self, method: str): logging.info( f" {len(protein_group_na)} Protein Groups were removed due to missing values." ) - logging.info("Imputing data...") if method == "mean": - imp = sklearn.impute.SimpleImputer(missing_values=np.nan, strategy="mean", keep_empty_features=True) + imp = sklearn.impute.SimpleImputer( + missing_values=np.nan, strategy="mean", keep_empty_features=True + ) imputation_array = imp.fit_transform(self.mat.values) elif method == "median": - imp = sklearn.impute.SimpleImputer(missing_values=np.nan, strategy="median", keep_empty_features=True) + imp = sklearn.impute.SimpleImputer( + missing_values=np.nan, strategy="median", keep_empty_features=True + ) imputation_array = imp.fit_transform(self.mat.values) elif method == "knn": @@ -155,6 +172,22 @@ def _imputation(self, method: str): ) self.preprocessing_info.update({"Imputation": method}) + def _linear_normalization(self, array): + """Normalize data using l2 norm without breaking when encoutering nones + l2 = sqrt(sum(x**2)) + + Args: + array (pd.Series): array to normalize (1D array) + + Returns: + np.array: normalized array + """ + square_sum_per_row = array.pow(2).sum(axis=1, skipna=True) + + l2_norms = np.sqrt(square_sum_per_row) + normalized_vals = array.div(l2_norms.replace(0, 1), axis=0) + return normalized_vals.values + @ignore_warning(UserWarning) @ignore_warning(RuntimeWarning) def _normalization(self, method: str): @@ -168,13 +201,13 @@ def _normalization(self, method: str): normalized_array = qt.fit_transform(self.mat.values) elif method == "linear": - normalized_array = sklearn.preprocessing.normalize( - self.mat.values, norm="l2" - ) + normalized_array = self._linear_normalization(self.mat) elif method == "vst": - scaler = sklearn.preprocessing.PowerTransformer(standardize=False) - normalized_array = scaler.fit_transform(self.mat.values) + minmax = sklearn.preprocessing.MinMaxScaler() + scaler = sklearn.preprocessing.PowerTransformer() + minmaxed_array = minmax.fit_transform(self.mat.values) + normalized_array = scaler.fit_transform(minmaxed_array) else: raise ValueError( @@ -189,20 +222,19 @@ def _normalization(self, method: str): self.preprocessing_info.update({"Normalization": method}) def reset_preprocessing(self): - """ Reset all preprocessing steps - """ - #  reset all preprocessing steps + """Reset all preprocessing steps""" self.create_matrix() print("All preprocessing steps are reset.") - + @ignore_warning(RuntimeWarning) def _compare_preprocessing_modes(self, func, params_for_func) -> list: dataset = self imputation_methods = ["mean", "median", "knn", "randomforest"] - normalization_methods = ["vst","zscore", "quantile" ] - - preprocessing_modes = list(itertools.product(normalization_methods, imputation_methods)) + normalization_methods = ["vst", "zscore", "quantile"] + preprocessing_modes = list( + itertools.product(normalization_methods, imputation_methods) + ) results_list = [] @@ -212,7 +244,9 @@ def _compare_preprocessing_modes(self, func, params_for_func) -> list: for preprocessing_mode in preprocessing_modes: # reset preprocessing dataset.reset_preprocessing() - print(f"Normalization {preprocessing_mode[0]}, Imputation {str(preprocessing_mode[1])}") + print( + f"Normalization {preprocessing_mode[0]}, Imputation {str(preprocessing_mode[1])}" + ) dataset.mat.replace([np.inf, -np.inf], np.nan, inplace=True) dataset.preprocess( @@ -223,7 +257,7 @@ def _compare_preprocessing_modes(self, func, params_for_func) -> list: res = func(**params_for_func) results_list.append(res) - + print("\t") return results_list @@ -232,8 +266,8 @@ def _log2_transform(self): self.mat = np.log2(self.mat + 0.1) self.preprocessing_info.update({"Log2-transformed": True}) print("Data has been log2-transformed.") - - def batch_correction(self, batch:str): + + def batch_correction(self, batch: str): """Correct for technical bias/batch effects Behdenna A, Haziza J, Azencot CA and Nordor A. (2020) pyComBat, a Python tool for batch effects correction in high-throughput molecular data using empirical Bayes methods. bioRxiv doi: 10.1101/2020.03.17.995431 Args: @@ -241,20 +275,23 @@ def batch_correction(self, batch:str): """ import combat from combat.pycombat import pycombat + data = self.mat.transpose() - series_of_batches = self.metadata.set_index(self.sample).reindex(data.columns.to_list())[batch] + series_of_batches = self.metadata.set_index(self.sample).reindex( + data.columns.to_list() + )[batch] self.mat = pycombat(data=data, batch=series_of_batches).transpose() @ignore_warning(RuntimeWarning) def preprocess( self, - log2_transform: bool=True, - remove_contaminations: bool=False, - subset: bool=False, - data_completeness: float=0, - normalization: str=None, - imputation: str=None, - remove_samples: list=None, + log2_transform: bool = True, + remove_contaminations: bool = False, + subset: bool = False, + data_completeness: float = 0, + normalization: str = None, + imputation: str = None, + remove_samples: list = None, ): """Preprocess Protein data @@ -300,15 +337,14 @@ def preprocess( """ if remove_contaminations: self._filter() - + if remove_samples is not None: self._remove_sampels(sample_list=remove_samples) if subset: self.mat = self._subset() - - if data_completeness> 0: + if data_completeness > 0: self._remove_na_values(cut_off=data_completeness) if log2_transform and self.preprocessing_info.get("Log2-transformed") is False: @@ -317,9 +353,14 @@ def preprocess( if normalization is not None: self._normalization(method=normalization) self.mat = self.mat.replace([np.inf, -np.inf], np.nan) - + if imputation is not None: self._imputation(method=imputation) self.mat = self.mat.loc[:, (self.mat != 0).any(axis=0)] + self.preprocessing_info.update( + { + "Matrix: Number of ProteinIDs/ProteinGroups": self.mat.shape[1], + } + ) self.preprocessed = True diff --git a/alphastats/gui/pages/02_Import Data.py b/alphastats/gui/pages/02_Import Data.py index 9960114e..18b336e5 100644 --- a/alphastats/gui/pages/02_Import Data.py +++ b/alphastats/gui/pages/02_Import Data.py @@ -163,6 +163,11 @@ def select_sample_column_metadata(df, software): submitted = st.form_submit_button("Create DataSet") if submitted: + if len(df[st.session_state.sample_column].to_list()) != len( + df[st.session_state.sample_column].unique() + ): + st.error("Sample names have to be unique.") + st.stop() return True @@ -209,7 +214,6 @@ def create_metadata_file(): # Write each dataframe to a different worksheet. metadata.to_excel(writer, sheet_name="Sheet1", index=False) # Close the Pandas Excel writer and output the Excel file to the buffer - writer.close() st.download_button( label="Download metadata template as Excel", @@ -248,8 +252,6 @@ def upload_metadatafile(software): load_options() - display_loaded_dataset() - if st.session_state.loader is not None: create_metadata_file() st.write( @@ -265,8 +267,6 @@ def upload_metadatafile(software): load_options() - display_loaded_dataset() - def load_sample_data(): _this_file = os.path.abspath(__file__) @@ -279,9 +279,11 @@ def load_sample_data(): loader = MaxQuantLoader(file=filepath) ds = DataSet(loader=loader, metadata_path=metadatapath, sample_column="sample") - metadatapath = os.path.join(_this_directory, "sample_data", "metadata.xlsx").replace( - "pages/", "" - ).replace("pages\\", "") + metadatapath = ( + os.path.join(_this_directory, "sample_data", "metadata.xlsx") + .replace("pages/", "") + .replace("pages\\", "") + ) loader = MaxQuantLoader(file=filepath) ds = DataSet(loader=loader, metadata_path=metadatapath, sample_column="sample") @@ -305,17 +307,19 @@ def load_sample_data(): def import_data(): options = ["": - upload_softwarefile(software=software) - + if st.session_state.software != "" from streamlit.runtime import get_instance from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx @@ -367,7 +372,6 @@ def empty_session_state(): sidebar_info() - if "dataset" not in st.session_state: st.markdown("### Import Proteomics Data") @@ -375,8 +379,17 @@ def empty_session_state(): "Create a DataSet with the output of your proteomics software package and the corresponding metadata (optional). " ) - import_data() - st.markdown("### Or Load sample Dataset") +import_data() + +if "dataset" in st.session_state: + st.info("DataSet has been imported") + + if "distribution_plot" not in st.session_state: + save_plot_sampledistribution_rawdata() + + display_loaded_dataset() + +st.markdown("### Or Load sample Dataset") if st.button("Load sample DataSet - PXD011839"): st.write( @@ -407,16 +420,9 @@ def empty_session_state(): load_sample_data() -if "dataset" in st.session_state: - st.info("DataSet has been imported") - - if "distribution_plot" not in st.session_state: - save_plot_sampledistribution_rawdata() - - if st.button("New Session: Import new dataset"): - empty_session_state() - import_data() +st.markdown("### To start a new session:") - if "dataset" in st.session_state: - display_loaded_dataset() +if st.button("New Session: Import new dataset"): + empty_session_state() + st.rerun() diff --git a/alphastats/gui/pages/03_Data Overview.py b/alphastats/gui/pages/03_Data Overview.py index 821b5132..a0216bbc 100644 --- a/alphastats/gui/pages/03_Data Overview.py +++ b/alphastats/gui/pages/03_Data Overview.py @@ -44,15 +44,17 @@ def display_matrix(): @st.cache_data -def get_sample_histogram_matrix(user_session_id = st.session_state.user_session_id): +def get_sample_histogram_matrix(user_session_id=st.session_state.user_session_id): return st.session_state.dataset.plot_samplehistograms() + @st.cache_data -def get_intensity_distribution_processed(user_session_id = st.session_state.user_session_id): +def get_intensity_distribution_processed( + user_session_id=st.session_state.user_session_id, +): return st.session_state.dataset.plot_sampledistribution() - if "dataset" in st.session_state: st.markdown("## DataSet overview") @@ -70,16 +72,19 @@ def get_intensity_distribution_processed(user_session_id = st.session_state.user st.markdown("**Intensity distribution data per sample used for analysis**") st.plotly_chart( - get_intensity_distribution_processed(user_session_id = st.session_state.user_session_id) - .update_layout(plot_bgcolor="white"), use_container_width=True + get_intensity_distribution_processed( + user_session_id=st.session_state.user_session_id + ).update_layout(plot_bgcolor="white"), + use_container_width=True, ) - + st.plotly_chart( - get_sample_histogram_matrix(user_session_id = st.session_state.user_session_id) - .update_layout(plot_bgcolor="white"), use_container_width=True + get_sample_histogram_matrix( + user_session_id=st.session_state.user_session_id + ).update_layout(plot_bgcolor="white"), + use_container_width=True, ) - display_matrix() diff --git a/alphastats/gui/pages/03_Preprocessing.py b/alphastats/gui/pages/04_Preprocessing.py similarity index 69% rename from alphastats/gui/pages/03_Preprocessing.py rename to alphastats/gui/pages/04_Preprocessing.py index 8f2937d7..b56a388c 100644 --- a/alphastats/gui/pages/03_Preprocessing.py +++ b/alphastats/gui/pages/04_Preprocessing.py @@ -12,7 +12,7 @@ def preprocessing(): st.markdown( "Before analyzing your data, consider normalizing and imputing your data as well as the removal of contaminants. " - + "A more detailed description about the preprocessing methods can be found in the AlphaPeptStats " + + "A more detailed description about the preprocessing methods can be found in the AlphaPeptStats " + "[documentation](https://alphapeptstats.readthedocs.io/en/main/data_preprocessing.html)." ) @@ -30,17 +30,23 @@ def preprocessing(): ) remove_samples = st.multiselect( - "Remove samples from analysis", - options=st.session_state.dataset.metadata[st.session_state.dataset.sample].to_list() + "Remove samples from analysis", + options=st.session_state.dataset.metadata[ + st.session_state.dataset.sample + ].to_list(), ) data_completeness = st.number_input( f"Data completeness across samples cut-off \n(0.7 -> protein has to be detected in at least 70% of the samples)", - value=0., min_value=0., max_value=1., step=0.1 + value=0.0, + min_value=0.0, + max_value=1.0, + step=0.1, ) log2_transform = st.selectbox( - "Log2-transform dataset", options=[True, False], + "Log2-transform dataset", + options=[True, False], ) normalization = st.selectbox( @@ -56,58 +62,66 @@ def preprocessing(): if submitted: if len(remove_samples) == 0: remove_samples = None - + st.session_state.dataset.preprocess( remove_contaminations=remove_contaminations, log2_transform=log2_transform, - remove_samples = remove_samples, + remove_samples=remove_samples, data_completeness=data_completeness, subset=subset, normalization=normalization, imputation=imputation, ) - preprocessing = st.session_state.dataset.preprocessing_info - st.info( - "Data has been processed. " - + datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") + + st.session_state["preprocessing_info"] = ( + st.session_state.dataset.preprocessing_info ) - st.dataframe( - pd.DataFrame.from_dict(preprocessing, orient="index").astype(str), + + if submitted or "preprocessing_info" in st.session_state: + st.info( + "Data has been processed. " + + datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") + ) + st.dataframe( + pd.DataFrame.from_dict( + st.session_state["preprocessing_info"], orient="index" + ).astype(str), + use_container_width=True, + ) + with c2: + + if submitted: + st.markdown("**Intensity Distribution after preprocessing per sample**") + fig_processed = st.session_state.dataset.plot_sampledistribution() + st.plotly_chart( + fig_processed.update_layout(plot_bgcolor="white"), + use_container_width=True, + ) + + else: + st.markdown("**Intensity Distribution per sample**") + fig_none_processed = st.session_state.dataset.plot_sampledistribution() + st.plotly_chart( + fig_none_processed.update_layout(plot_bgcolor="white"), use_container_width=True, ) - + c1, c2 = st.columns(2) + with c1: st.markdown("#### Batch correction: correct for technical bias") with st.form("Batch correction: correct for technical bias"): batch = st.selectbox( - "Batch", - options= st.session_state.dataset.metadata.columns.to_list() + "Batch", options=st.session_state.dataset.metadata.columns.to_list() ) submit_batch_correction = st.form_submit_button("Submit") - + if submit_batch_correction: - st.session_state.dataset.batch_correction( - batch=batch - ) + st.session_state.dataset.batch_correction(batch=batch) st.info( "Data has been processed. " + datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") ) - - with c2: - - if submitted: - st.markdown("**Intensity Distribution after preprocessing per sample**") - fig_processed = st.session_state.dataset.plot_sampledistribution() - st.plotly_chart(fig_processed.update_layout(plot_bgcolor="white"), use_container_width=True) - - else: - st.markdown("**Intensity Distribution per sample**") - fig_none_processed = st.session_state.dataset.plot_sampledistribution() - st.plotly_chart(fig_none_processed.update_layout(plot_bgcolor="white"), use_container_width=True) - - reset_steps = st.button("Reset all Preprocessing steps") if reset_steps: @@ -116,14 +130,12 @@ def preprocessing(): def reset_preprocessing(): st.session_state.dataset.create_matrix() - preprocessing = st.session_state.dataset.preprocessing_info st.info( "Data has been reset. " + datetime.datetime.now().strftime("%d/%m/%Y %H:%M:%S") ) - st.dataframe( - pd.DataFrame.from_dict(preprocessing, orient="index").astype(str), - use_container_width=True, - ) + st.session_state["preprocessing_info"] = st.session_state.dataset.preprocessing_info + # reset the page + st.rerun() def main_preprocessing(): diff --git a/alphastats/gui/pages/04_Analysis.py b/alphastats/gui/pages/05_Analysis.py similarity index 100% rename from alphastats/gui/pages/04_Analysis.py rename to alphastats/gui/pages/05_Analysis.py diff --git a/alphastats/gui/pages/05_GPT.py b/alphastats/gui/pages/06_Analysis with GPT.py similarity index 90% rename from alphastats/gui/pages/05_GPT.py rename to alphastats/gui/pages/06_Analysis with GPT.py index 5195c3c4..7d167a08 100644 --- a/alphastats/gui/pages/05_GPT.py +++ b/alphastats/gui/pages/06_Analysis with GPT.py @@ -28,6 +28,7 @@ wait_for_run_completion, send_message_save_thread, try_to_set_api_key, + get_general_assistant_functions, ) from alphastats.gui.utils.ui_helper import sidebar_info @@ -56,6 +57,7 @@ wait_for_run_completion, send_message_save_thread, try_to_set_api_key, + get_general_assistant_functions, ) from utils.ui_helper import sidebar_info @@ -88,7 +90,7 @@ def select_analysis(): # set background to white so downloaded pngs dont have grey background -styl = f""" +style = f""" """ -st.markdown(styl, unsafe_allow_html=True) +st.markdown(style, unsafe_allow_html=True) -if "plot_list" not in st.session_state: - st.session_state["plot_list"] = [] +if "plot_list_GPT" not in st.session_state: + st.session_state["plot_list_GPT"] = [] if "openai_model" not in st.session_state: @@ -225,7 +227,7 @@ def select_analysis(): st.stop() print("genes_of_interest", genes_of_interest_colored) - save_plot_to_session_state(volcano_plot, method) + save_plot_to_session_state(volcano_plot, method, "plot_list_GPT") st.session_state["genes_of_interest_colored"] = genes_of_interest_colored # st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism) st.session_state["upregulated"] = [ @@ -251,10 +253,10 @@ def select_analysis(): st.session_state["plot_submitted_counter"] > 0 and st.session_state["plot_submitted_counter"] == st.session_state["plot_submitted_clicked"] - and len(st.session_state["plot_list"]) > 0 + and len(st.session_state["plot_list_GPT"]) > 0 ): with c2: - display_figure(st.session_state["plot_list"][-1][1].plot) + display_figure(st.session_state["plot_list_GPT"][-1][1].plot) st.subheader("Genes of interest") c1, c2 = st.columns((1, 2), gap="medium") @@ -305,7 +307,6 @@ def select_analysis(): if gpt_submitted: st.session_state["gpt_submitted_clicked"] += 1 -# creating new assistant only once TODO: add a button to create new assistant if ( st.session_state["gpt_submitted_clicked"] > st.session_state["gpt_submitted_counter"] @@ -319,13 +320,16 @@ def select_analysis(): instructions=st.session_state["instructions"], name="Proteomics interpreter", model=st.session_state["openai_model"], - tools=get_assistant_functions( - gene_to_prot_id_dict=st.session_state["gene_to_prot_id"], - metadata=st.session_state["dataset"].metadata, - subgroups_for_each_group=get_subgroups_for_each_group( - st.session_state["dataset"].metadata + tools=[ + *get_general_assistant_functions(), + *get_assistant_functions( + gene_to_prot_id_dict=st.session_state["gene_to_prot_id"], + metadata=st.session_state["dataset"].metadata, + subgroups_for_each_group=get_subgroups_for_each_group( + st.session_state["dataset"].metadata + ), ), - ), + ], ) except AuthenticationError: st.warning( @@ -345,8 +349,14 @@ def select_analysis(): ] st.session_state["artefact_enum_dict"] = {} thread = client.beta.threads.create() - st.session_state["thread_id"] = thread.id - artefacts = send_message_save_thread(client, st.session_state["user_prompt"]) + thread_id = thread.id + st.session_state["thread_id"] = thread_id + artefacts = send_message_save_thread( + client, + st.session_state["user_prompt"], + st.session_state["assistant"].id, + st.session_state["thread_id"], + ) if artefacts: st.session_state["artefact_enum_dict"][ len(st.session_state.messages) - 1 @@ -355,7 +365,12 @@ def select_analysis(): if st.session_state["gpt_submitted_clicked"] > 0: if prompt := st.chat_input("Say something"): st.session_state.messages.append({"role": "user", "content": prompt}) - artefacts = send_message_save_thread(client, prompt) + artefacts = send_message_save_thread( + client, + prompt, + st.session_state["assistant"].id, + st.session_state["thread_id"], + ) if artefacts: st.session_state["artefact_enum_dict"][ len(st.session_state.messages) - 1 diff --git a/alphastats/gui/pages/07_Chat with GPT.py b/alphastats/gui/pages/07_Chat with GPT.py new file mode 100644 index 00000000..b417dba5 --- /dev/null +++ b/alphastats/gui/pages/07_Chat with GPT.py @@ -0,0 +1,378 @@ +import os +import streamlit as st +import pandas as pd +from openai import OpenAI, OpenAIError, AuthenticationError + +try: + from alphastats.gui.utils.analysis_helper import ( + check_if_options_are_loaded, + convert_df, + display_df, + display_figure, + download_figure, + download_preprocessing_info, + get_analysis, + load_options, + save_plot_to_session_state, + gui_volcano_plot_differential_expression_analysis, + helper_compare_two_groups, + ) + from alphastats.gui.utils.gpt_helper import ( + get_assistant_functions, + display_proteins, + get_gene_function, + get_info, + get_subgroups_for_each_group, + turn_args_to_float, + perform_dimensionality_reduction, + wait_for_run_completion, + send_message_save_thread, + try_to_set_api_key, + get_general_assistant_functions, + ) + from alphastats.gui.utils.ui_helper import sidebar_info + +except ModuleNotFoundError: + from utils.analysis_helper import ( + check_if_options_are_loaded, + convert_df, + display_df, + display_figure, + download_figure, + download_preprocessing_info, + get_analysis, + load_options, + save_plot_to_session_state, + gui_volcano_plot_differential_expression_analysis, + helper_compare_two_groups, + ) + from utils.gpt_helper import ( + get_assistant_functions, + display_proteins, + get_gene_function, + get_info, + get_subgroups_for_each_group, + turn_args_to_float, + perform_dimensionality_reduction, + wait_for_run_completion, + send_message_save_thread, + try_to_set_api_key, + get_general_assistant_functions, + ) + from utils.ui_helper import sidebar_info + + +st.session_state.plot_dict = {} + + +@check_if_options_are_loaded +def select_analysis(): + """ + select box + loads keys from option dicts + """ + method = st.selectbox( + "Analysis", + options=["Volcano plot"], + # options=list(st.session_state.interpretation_options.keys()), + ) + return method + + +st.markdown("### Chat with GPT") + +sidebar_info() + + +# set background to white so downloaded pngs dont have grey background +style = f""" + + """ +st.markdown(style, unsafe_allow_html=True) + + +if "plot_list_chat" not in st.session_state: + st.session_state["plot_list_chat"] = [] + +if "plotting_options" not in st.session_state: + st.session_state["plotting_options"] = {} + +if "openai_model" not in st.session_state: + # st.session_state["openai_model"] = "gpt-3.5-turbo-16k" + st.session_state["openai_model"] = "gpt-4-0125-preview" # "gpt-4-1106-preview" + +if "messages_chat" not in st.session_state: + st.session_state["messages_chat"] = [] + +if "prompt_submitted_clicked_chat" not in st.session_state: + st.session_state["prompt_submitted_clicked_chat"] = 0 + st.session_state["prompt_submitted_counter_chat"] = 0 + +if "gpt_submitted_clicked_chat" not in st.session_state: + st.session_state["gpt_submitted_clicked_chat"] = 0 + st.session_state["gpt_submitted_counter_chat"] = 0 + +c1, c2 = st.columns((1, 2)) +st.subheader("Necessary context:") +st.session_state["upregulated_chat"] = st.text_area( + "List of upregulated proteins / genes:", height=75 +) + +st.session_state["downregulated_chat"] = st.text_area( + "List of downregulated proteins / genes:", height=75 +) + + +import re + + +def custom_string_list_parser(input_str): + """ + Parses a string that represents a list of strings, handling misformatted double quotes. + + Args: + - input_str (str): The string representation of the list to be parsed. + + Returns: + - list: A Python list of strings that have been extracted and corrected. + """ + + # Trim leading and trailing whitespaces and brackets if present + trimmed_str = re.sub(r"^[\[\{\(]", "", input_str.strip()) + trimmed_str = re.sub(r"[\]\}\)]$", "", trimmed_str) + + # Split by separators, considering quotes but ignoring misformatted ones + items = re.split(r',(?=(?:[^"]*"[^"]*")*[^"]*$)', trimmed_str) + + clean_items = [] + for item in items: + # Remove leading and trailing whitespaces + item = item.strip() + # Remove leading and trailing quotes + item = re.sub(r'^"|"$', "", item) + # Handle escaping properly. Replace two double quotes with one. + # If quotes were meant to be part of the string, they should be doubled. + item = item.replace('""', '"') + + clean_items.append(item) + + return clean_items + + +# Example usage +input_str = '["VCL;HEL114", "P4HB", "An erroneous ""string"" here", "PSME2"]' +parsed_list = custom_string_list_parser(input_str) +print(parsed_list) + +st.session_state["upregulated_chat"] = custom_string_list_parser( + st.session_state["upregulated_chat"] +) +st.session_state["downregulated_chat"] = custom_string_list_parser( + st.session_state["downregulated_chat"] +) + +with c1: + + api_key = st.text_input("API Key", type="password") + +with c2: + organism = st.number_input( + label="UniProt organism ID, e.g. human is 9606, R. norvegicus is 10116", + value=9606, + ) + st.session_state["organism"] = organism + +try_to_set_api_key(api_key) + +try: + client = OpenAI(api_key=st.secrets["openai_api_key"]) +except OpenAIError: + pass + + # TODO streamlit doesnt allow nested columns check for updates + + # st.session_state["prot_id_to_gene"] = dict( + # zip( + # genes_of_interest_colored_df[prot_ids_colname].tolist(), + # genes_of_interest_colored_df[gene_names_colname].tolist(), + # ) + # ) + # st.session_state["gene_to_prot_id"] = dict( + # zip( + # genes_of_interest_colored_df[gene_names_colname].tolist(), + # genes_of_interest_colored_df[prot_ids_colname].tolist(), + # ) + # ) + + # st.session_state["genes_of_interest_colored"] = genes_of_interest_colored + # # st.session_state["gene_functions"] = get_info(genes_of_interest_colored, organism) + # st.session_state["upregulated"] = [ + # key + # for key in genes_of_interest_colored + # if genes_of_interest_colored[key] == "up" + # ] + # st.session_state["downregulated"] = [ + # key + # for key in genes_of_interest_colored + # if genes_of_interest_colored[key] == "down" + # ] + +c1, c2 = st.columns((1, 1)) + +with c1: + group1 = st.text_input("Comparison group 1. e.g. healthy") + st.session_state["group1_chat"] = group1 + +with c2: + group2 = st.text_input("Comparison group 2. e.g. diseased") + st.session_state["group2_chat"] = group2 + +prompt_submitted = st.button("Create user prompt") +st.session_state["prompt_submitted_counter"] = 0 + + +if prompt_submitted: + st.session_state["prompt_submitted_clicked_chat"] += 1 + +if st.session_state["prompt_submitted_clicked_chat"] == 0: + st.stop() + +if ( + not st.session_state["upregulated_chat"] + or not st.session_state["downregulated_chat"] +): + st.warning("Please enter upregulated and downregulated proteins") + st.stop() + +if not st.session_state["group1_chat"] or not st.session_state["group2_chat"]: + st.warning("Please enter group names") + st.stop() + +if ( + st.session_state["prompt_submitted_clicked_chat"] + > st.session_state["prompt_submitted_counter_chat"] +): + st.session_state["prompt_submitted_counter_chat"] = st.session_state[ + "prompt_submitted_clicked_chat" + ] + st.session_state["instructions"] = ( + f"You are an expert biologist and have extensive experience in molecular biology, medicine and biochemistry.{os.linesep}" + "A user will present you with data regarding proteins upregulated in certain cells " + "sourced from UniProt and abstracts from scientific publications. They seek your " + "expertise in understanding the connections between these proteins and their potential role " + f"in disease genesis. {os.linesep}Provide a detailed and insightful, yet concise response based on the given information. " + "Plots are visualized using a graphical environment capable of rendering images, you don't need to worry about that." + ) + st.session_state["user_prompt_chat"] = ( + f"We've recently identified several proteins that appear to be differently regulated in cells " + f"when comparing {st.session_state['group1_chat']} and {st.session_state['group2_chat']}. " + f"From our proteomics experiments, we know that the following ones are upregulated: {', '.join(st.session_state['upregulated_chat'])}.{os.linesep}{os.linesep}" + f"Here is the list of proteins that are downregulated: {', '.join(st.session_state['downregulated_chat'])}.{os.linesep}{os.linesep}" + f"Help us understand the potential connections between these proteins and how they might be contributing " + f"to the differences. After that provide a high level summary" + ) + + +if "user_prompt_chat" in st.session_state: + st.subheader("Automatically generated prompt:") + with st.expander("Adjust system prompt (see example below)", expanded=False): + st.session_state["instructions"] = st.text_area( + "", value=st.session_state["instructions"], height=150 + ) + + with st.expander("Adjust user prompt", expanded=True): + st.session_state["user_prompt_chat"] = st.text_area( + "", value=st.session_state["user_prompt_chat"], height=200 + ) + +gpt_submitted = st.button("Run GPT analysis") + +if gpt_submitted and "user_prompt_chat" not in st.session_state: + st.warning("Please enter a user prompt first") + st.stop() + +if gpt_submitted: + st.session_state["gpt_submitted_clicked_chat"] += 1 + +if ( + st.session_state["gpt_submitted_clicked_chat"] + > st.session_state["gpt_submitted_counter_chat"] +): + try_to_set_api_key() + + client = OpenAI(api_key=st.secrets["openai_api_key"]) + + try: + st.session_state["assistant_chat"] = client.beta.assistants.create( + instructions=st.session_state["instructions"], + name="Proteomics interpreter", + model=st.session_state["openai_model"], + tools=get_general_assistant_functions(), + ) + print( + st.session_state["assistant_chat"], type(st.session_state["assistant_chat"]) + ) + except AuthenticationError: + st.warning( + "Incorrect API key provided. Please enter a valid API key, it should look like this: sk-XXXXX" + ) + st.stop() + +if "artefact_enum_dict_chat" not in st.session_state: + st.session_state["artefact_enum_dict_chat"] = {} + +if ( + st.session_state["gpt_submitted_counter_chat"] + < st.session_state["gpt_submitted_clicked_chat"] +): + st.session_state["gpt_submitted_counter_chat"] = st.session_state[ + "gpt_submitted_clicked_chat" + ] + st.session_state["artefact_enum_dict_chat"] = {} + thread = client.beta.threads.create() + thread_id = thread.id + st.session_state["thread_id_chat"] = thread_id + artefacts = send_message_save_thread( + client, + st.session_state["user_prompt_chat"], + st.session_state["assistant_chat"].id, + st.session_state["thread_id_chat"], + "messages_chat", + ) + if artefacts: + st.session_state["artefact_enum_dict_chat"][ + len(st.session_state.messages_chat) - 1 + ] = artefacts + +if st.session_state["gpt_submitted_clicked_chat"] > 0: + if prompt := st.chat_input("Say something"): + st.session_state.messages_chat.append({"role": "user", "content": prompt}) + artefacts = send_message_save_thread( + client, + prompt, + st.session_state["assistant_chat"].id, + st.session_state["thread_id_chat"], + "messages_chat", + ) + if artefacts: + st.session_state["artefact_enum_dict_chat"][ + len(st.session_state.messages_chat) - 1 + ] = artefacts + for num, role_content_dict in enumerate(st.session_state.messages_chat): + with st.chat_message(role_content_dict["role"]): + st.markdown(role_content_dict["content"]) + if num in st.session_state["artefact_enum_dict_chat"]: + for artefact in st.session_state["artefact_enum_dict_chat"][num]: + if isinstance(artefact, pd.DataFrame): + st.dataframe(artefact) + else: + st.plotly_chart(artefact) + print(st.session_state["artefact_enum_dict_chat"]) diff --git a/alphastats/gui/pages/06_Results.py b/alphastats/gui/pages/08_Results.py similarity index 100% rename from alphastats/gui/pages/06_Results.py rename to alphastats/gui/pages/08_Results.py diff --git a/alphastats/gui/utils/analysis_helper.py b/alphastats/gui/utils/analysis_helper.py index 4ce1386c..1bf72b5e 100644 --- a/alphastats/gui/utils/analysis_helper.py +++ b/alphastats/gui/utils/analysis_helper.py @@ -27,11 +27,11 @@ def display_figure(plot): st.pyplot(plot) -def save_plot_to_session_state(plot, method): +def save_plot_to_session_state(plot, method, variable_name="plot_list"): """ save plot with method to session state to retrieve old results """ - st.session_state["plot_list"] += [(method, plot)] + st.session_state[variable_name].append((method, plot)) def display_df(df): diff --git a/alphastats/gui/utils/gpt_helper.py b/alphastats/gui/utils/gpt_helper.py index 526c9078..26af9806 100644 --- a/alphastats/gui/utils/gpt_helper.py +++ b/alphastats/gui/utils/gpt_helper.py @@ -135,23 +135,7 @@ def display_proteins(overexpressed: list[str], underexpressed: list[str]) -> Non st.markdown(full_html, unsafe_allow_html=True) -def get_assistant_functions( - gene_to_prot_id_dict: dict, - metadata: pd.DataFrame, - subgroups_for_each_group: dict, -) -> list[dict]: - """ - Get a list of assistant functions for function calling in the ChatGPT model. - You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. - For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling - - Args: - gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. - metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). - subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). - Returns: - list[dict]: A list of assistant functions. - """ +def get_general_assistant_functions() -> list[dict]: return [ { "type": "function", @@ -170,6 +154,54 @@ def get_assistant_functions( }, }, }, + { + "type": "function", + "function": { + "name": "get_enrichment_data", + "description": "Get enrichment data for a list of differentially expressed genes", + "parameters": { + "type": "object", + "properties": { + "difexpressed": { + "type": "array", + "items": {"type": "string"}, + "description": "A list of differentially expressed gene names to search for", + }, + "organism_id": { + "type": "string", + "description": "The Uniprot organism ID to search in", + }, + "tool": { + "type": "string", + "description": "The tool to use for enrichment analysis", + "enum": ["gprofiler", "string"], + }, + }, + "required": ["difexpressed", "organism_id"], + }, + }, + }, + ] + + +def get_assistant_functions( + gene_to_prot_id_dict: dict, + metadata: pd.DataFrame, + subgroups_for_each_group: dict, +) -> list[dict]: + """ + Get a list of assistant functions for function calling in the ChatGPT model. + You can call this function with no arguments, arguments are given for clarity on what changes the behavior of the function. + For more information on how to format functions for Assistants, see https://platform.openai.com/docs/assistants/tools/function-calling + + Args: + gene_to_prot_id_dict (dict, optional): A dictionary with gene names as keys and protein IDs as values. + metadata (pd.DataFrame, optional): The metadata dataframe (which sample has which disease/treatment/condition/etc). + subgroups_for_each_group (dict, optional): A dictionary with the column names as keys and a list of unique values as values. Defaults to get_subgroups_for_each_group(). + Returns: + list[dict]: A list of assistant functions. + """ + return [ { "type": "function", "function": { @@ -322,33 +354,6 @@ def get_assistant_functions( }, }, }, - { - "type": "function", - "function": { - "name": "get_enrichment_data", - "description": "Get enrichment data for a list of differentially expressed genes", - "parameters": { - "type": "object", - "properties": { - "difexpressed": { - "type": "array", - "items": {"type": "string"}, - "description": "A list of differentially expressed gene names to search for", - }, - "organism_id": { - "type": "string", - "description": "The Uniprot organism ID to search in", - }, - "tool": { - "type": "string", - "description": "The tool to use for enrichment analysis", - "enum": ["gprofiler", "string"], - }, - }, - "required": ["difexpressed", "organism_id"], - }, - }, - }, {"type": "code_interpreter"}, ] @@ -762,6 +767,7 @@ def wait_for_run_completion( run_status = client.beta.threads.runs.retrieve( thread_id=thread_id, run_id=run_id ) + print(run_status.status, run_id, run_status.required_action) assistant_functions = { "create_intensity_plot", "perform_dimensionality_reduction", @@ -827,12 +833,19 @@ def wait_for_run_completion( _run = client.beta.threads.runs.submit_tool_outputs( thread_id=thread_id, run_id=run_id, tool_outputs=tool_outputs ) + print("submitted") else: print("Run is not yet completed. Waiting...", run_status.status, run_id) time.sleep(check_interval) -def send_message_save_thread(client: openai.OpenAI, message: str) -> Optional[list]: +def send_message_save_thread( + client: openai.OpenAI, + message: str, + assistant_id: str, + thread_id: str, + storing_variable: str = "messages", +) -> Optional[list]: """ Send a message to the OpenAI ChatGPT model and save the thread in the session state, return plots if GPT called a function to create them. @@ -844,29 +857,27 @@ def send_message_save_thread(client: openai.OpenAI, message: str) -> Optional[li Optional[list]: A list of plots, if any. """ message = client.beta.threads.messages.create( - thread_id=st.session_state["thread_id"], role="user", content=message + thread_id=thread_id, role="user", content=message ) run = client.beta.threads.runs.create( - thread_id=st.session_state["thread_id"], - assistant_id=st.session_state["assistant"].id, + thread_id=thread_id, + assistant_id=assistant_id, ) try: - plots = wait_for_run_completion(client, st.session_state["thread_id"], run.id) + plots = wait_for_run_completion(client, thread_id, run.id) except KeyError as e: print(e) plots = None - messages = client.beta.threads.messages.list( - thread_id=st.session_state["thread_id"] - ) - st.session_state.messages = [] + messages = client.beta.threads.messages.list(thread_id=thread_id) + st.session_state[storing_variable] = [] for num, message in enumerate(messages.data[::-1]): role = message.role if message.content: content = message.content[0].text.value else: content = "Sorry, I was unable to process this message. Try again or change your request." - st.session_state.messages.append({"role": role, "content": content}) + st.session_state[storing_variable].append({"role": role, "content": content}) if not plots: return return plots diff --git a/tests/test_DataSet.py b/tests/test_DataSet.py index 2002456e..cc954505 100644 --- a/tests/test_DataSet.py +++ b/tests/test_DataSet.py @@ -237,7 +237,11 @@ def test_remove_misc_samples_in_metadata(self, mock): df = pd.DataFrame( {"sample": ["A", "B", "C"], "b": ["disease", "health", "disease"]} ) - obj = DataSet(loader=self.loader, metadata_path=df, sample_column="sample",) + obj = DataSet( + loader=self.loader, + metadata_path=df, + sample_column="sample", + ) #  is sample C removed self.assertEqual(self.obj.metadata.shape, (2, 2)) mock.assert_called_once() @@ -247,7 +251,11 @@ def test_load_metadata_df(self): df = pd.read_csv(self.metadata_path) else: df = pd.read_excel(self.metadata_path) - obj = DataSet(loader=self.loader, metadata_path=df, sample_column="sample",) + obj = DataSet( + loader=self.loader, + metadata_path=df, + sample_column="sample", + ) self.assertIsInstance(obj.metadata, pd.DataFrame) self.assertFalse(obj.metadata.empty) @@ -297,9 +305,9 @@ def test_preprocess_normalize_vst(self): self.obj.preprocess(log2_transform=False, normalization="vst") expected_mat = pd.DataFrame( { - "a": [ 3.19059101, 11.591763, 8.365096], - "b": [0.084829, 0.084829, 0.084829], - "c": [0.000000, 7.850074, 6.435102], + "a": [-1.307734, 1.120100, 0.187634], + "b": [1.414214, -0.707107, -0.707107], + "c": [-1.360307, 1.015077, 0.345230], } ) pd._testing.assert_frame_equal(self.obj.mat.round(2), expected_mat.round(2)) @@ -341,9 +349,9 @@ def test_preprocess_imputation_randomforest_values(self): self.obj.preprocess(log2_transform=False, imputation="randomforest") expected_mat = pd.DataFrame( { - "a": [2.00000000e00, -9.22337204e12, 4.00000000e00], + "a": [2.00000000e00, 0, 4.00000000e00], "b": [5.00000000e00, 4.00000000e00, 4.0], - "c": [-9.22337204e12, 1.00000000e01, -9.22337204e12], + "c": [0, 1.00000000e01, 0], } ) pd._testing.assert_frame_equal(self.obj.mat, expected_mat) @@ -409,17 +417,19 @@ def test_load_evidence_wrong_sample_names(self): evidence_file="testfiles/maxquant_go/evidence.txt", ) DataSet( - loader=loader, metadata_path=self.metadata_path, sample_column="sample", + loader=loader, + metadata_path=self.metadata_path, + sample_column="sample", ) def test_plot_pca_group(self): pca_plot = self.obj.plot_pca(group=self.comparison_column) # 5 different disease self.assertEqual(len(pca_plot.to_plotly_json().get("data")), 5) - + def test_data_completeness(self): self.obj.preprocess(log2_transform=False, data_completeness=0.7) - self.assertEqual(self.obj.mat.shape[1], 517) + self.assertEqual(self.obj.mat.shape[1], 159) def test_plot_pca_circles(self): pca_plot = self.obj.plot_pca(group=self.comparison_column, circle=True) @@ -460,9 +470,7 @@ def test_plot_volcano_compare_preprocessing_modes(self): group2=["1_71_F10", "1_73_F12"], compare_preprocessing_modes=True, ) - - self.assertEqual(len(result_list), 12) - + self.assertEqual(len(result_list), 12) def test_preprocess_subset(self): self.obj.preprocess(subset=True, log2_transform=False) @@ -488,6 +496,9 @@ def test_plot_intenstity_subgroup(self): @patch("logging.Logger.warning") def test_plot_intenstity_subgroup_significance_warning(self, mock): + import streamlit as st + + st.session_state["gene_to_prot_id"] = {} plot = self.obj.plot_intensity( protein_id="K7ERI9;A0A024R0T8;P02654;K7EJI9;K7ELM9;K7EPF9;K7EKP1", group="disease", @@ -535,7 +546,7 @@ def test_plot_volcano_with_labels(self): draw_line=False, ) n_labels = len(plot.to_plotly_json().get("layout").get("annotations")) - #self.assertTrue(n_labels > 20) + # self.assertTrue(n_labels > 20) def test_plot_volcano_wald(self): """ @@ -570,13 +581,15 @@ def test_plot_volcano_sam(self): self.assertEqual(line_1, "spline") self.assertEqual(line_2, "spline") - + def test_plot_volcano_list(self): self.obj.preprocess(imputation="mean") - plot = self.obj.plot_volcano( method="ttest", + plot = self.obj.plot_volcano( + method="ttest", group1=["1_31_C6", "1_32_C7", "1_57_E8"], group2=["1_71_F10", "1_73_F12"], - color_list=self.obj.mat.columns.to_list()[0:20]) + color_list=self.obj.mat.columns.to_list()[0:20], + ) self.assertEqual(len(plot.to_plotly_json()["data"][0]["x"]), 20) def test_plot_clustermap_significant(self): @@ -602,7 +615,7 @@ def test_plot_volcano_with_labels_proteins(self): labels=True, ) n_labels = len(plot.to_plotly_json().get("layout").get("annotations")) - + def test_plot_volcano_with_labels_proteins_welch_ttest(self): # remove gene names self.obj.gene_names = None @@ -614,7 +627,7 @@ def test_plot_volcano_with_labels_proteins_welch_ttest(self): labels=True, ) n_labels = len(plot.to_plotly_json().get("layout").get("annotations")) - #self.assertTrue(n_labels > 20) + # self.assertTrue(n_labels > 20) def test_calculate_diff_exp_wrong(self): # get groups from comparison column @@ -711,28 +724,34 @@ def test_plot_intensity_sign_001(self): self.assertEqual(annotation, "***") def test_plot_intensity_all(self): - plot = self.obj.plot_intensity(protein_id="Q9BWP8", - group="disease", + plot = self.obj.plot_intensity( + protein_id="Q9BWP8", + group="disease", subgroups=["liver cirrhosis", "healthy"], method="all", - add_significance=True) + add_significance=True, + ) self.assertEqual(plot.to_plotly_json()["data"][0]["points"], "all") - def test_plot_samplehistograms(self): fig = self.obj.plot_samplehistograms().to_plotly_json() self.assertEqual(312, len(fig["data"])) - + def test_batch_correction(self): self.obj.preprocess(subset=True, imputation="knn", normalization="quantile") self.obj.batch_correction(batch="batch_artifical_added") - first_value = self.obj.mat.values[0,0] - self.assertAlmostEqual(0.0111, first_value, places=2) + first_value = self.obj.mat.values[0, 0] + self.assertAlmostEqual(0.0111, first_value, places=2) def test_multicova_analysis_invalid_covariates(self): self.obj.preprocess(imputation="knn", normalization="zscore", subset=True) res, _ = self.obj.multicova_analysis( - covariates=["disease", "Alkaline phosphatase measurement", "Body mass index ", "not here"], + covariates=[ + "disease", + "Alkaline phosphatase measurement", + "Body mass index ", + "not here", + ], subset={"disease": ["healthy", "liver cirrhosis"]}, ) self.assertEqual(res.shape[1], 45) @@ -743,7 +762,7 @@ def test_multicova_analysis_invalid_covariates(self): # group2="liver cirrhosis", # gene_sets= 'KEGG_2019_Human') - # cholesterol_enhanced = 'Cholesterol metabolism' in df.index.to_list() + # cholersterol_enhanced = 'Cholesterol metabolism' in df.index.to_list() # self.assertTrue(cholersterol_enhanced) @@ -910,7 +929,8 @@ def tearDownClass(cls): shutil.rmtree("testfiles/spectronaut/__MACOSX") os.remove("testfiles/spectronaut/results.tsv") - + + class TestGenericDataSet(BaseTestDataSet.BaseTest): @classmethod def setUpClass(cls): @@ -922,12 +942,17 @@ def setUpClass(cls): cls.cls_loader = GenericLoader( file="testfiles/fragpipe/combined_proteins.tsv", intensity_column=[ - "S1 Razor Intensity", "S2 Razor Intensity", "S3 Razor Intensity", - "S4 Razor Intensity", "S5 Razor Intensity", "S6 Razor Intensity", - "S7 Razor Intensity", "S8 Razor Intensity" + "S1 Razor Intensity", + "S2 Razor Intensity", + "S3 Razor Intensity", + "S4 Razor Intensity", + "S5 Razor Intensity", + "S6 Razor Intensity", + "S7 Razor Intensity", + "S8 Razor Intensity", ], index_column="Protein", - sep="\t" + sep="\t", ) cls.cls_metadata_path = "testfiles/fragpipe/metadata2.xlsx" cls.cls_obj = DataSet( @@ -935,7 +960,7 @@ def setUpClass(cls): metadata_path=cls.cls_metadata_path, sample_column="analytical_sample external_id", ) - + def setUp(self): self.loader = copy.deepcopy(self.cls_loader) self.metadata_path = copy.deepcopy(self.cls_metadata_path) @@ -949,7 +974,6 @@ def tearDownClass(cls): if os.path.isdir("testfiles/fragpipe/__MACOSX"): shutil.rmtree("testfiles/fragpipe/__MACOSX") - if __name__ == "__main__": unittest.main()