From 7a5c0474181efd68e139b4037b23ecc443091b4b Mon Sep 17 00:00:00 2001 From: Tim Adams Date: Mon, 21 Oct 2024 11:35:45 +0200 Subject: [PATCH] Add function, tests & documentation to visualize vector embeddings for data dictonaries --- README.md | 14 ++++++++ datastew/visualisation.py | 63 ++++++++++++++++++++++++++++++++-- tests/resources/data_dict1.csv | 51 +++++++++++++++++++++++++++ tests/resources/data_dict2.csv | 51 +++++++++++++++++++++++++++ tests/test_visualisation.py | 10 +++++- 5 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 tests/resources/data_dict1.csv create mode 100644 tests/resources/data_dict2.csv diff --git a/README.md b/README.md index f6903bc..028ac20 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,20 @@ embedding_model = GPT4Adapter(key="your_api_key") df = map_dictionary_to_dictionary(source, target, embedding_model=embedding_model) ``` +You can also retrieve embeddings from data dictionaries and visualize them in form of an interactive scatterplot to +explore sematic neighborhoods: + +```python +from datastew.visualisation import plot_embeddings + +# Get embedding vectors for your dictionaries +source_embeddings = source.get_embeddings() + +# plot embedding neighborhoods for several dictionaries +plot_embeddings(data_dictionaries=[source, target]) + +``` + ### Creating and using stored mappings A simple example how to initialize an in memory database and compute a similarity mapping is shown in diff --git a/datastew/visualisation.py b/datastew/visualisation.py index e07f7d1..7b8f0e3 100644 --- a/datastew/visualisation.py +++ b/datastew/visualisation.py @@ -8,8 +8,10 @@ import plotly.graph_objects as go import seaborn as sns from sklearn.manifold import TSNE -from typing import Optional +from typing import Optional, List +from datastew.process.parsing import DataDictionarySource +from datastew.embedding import EmbeddingModel, MPNetAdapter from datastew.conf import COLORS_AD, COLORS_PD from datastew.mapping import MappingTable from datastew.repository.base import BaseRepository @@ -177,7 +179,8 @@ def scatter_plot_all_cohorts(tables1: [MappingTable], tables2: [MappingTable], l fig.show() -def get_plot_for_current_database_state(repository: BaseRepository, terminology: Optional[str] = None, perplexity: int = 5, return_type="html") -> str: +def get_plot_for_current_database_state(repository: BaseRepository, terminology: Optional[str] = None, + perplexity: int = 5, return_type="html") -> str: if not terminology: mappings = repository.get_mappings() else: @@ -218,3 +221,59 @@ def get_plot_for_current_database_state(repository: BaseRepository, terminology: else: plot = "Too few database entries to visualize" return plot + + +def plot_embeddings(data_dictionaries: List[DataDictionarySource], embedding_model: Optional[EmbeddingModel] = None, + perplexity: int = 5): + """ + Plots a t-SNE representation of embeddings from multiple data dictionaries and displays the plot. + + :param data_dictionaries: A list of DataDictionarySource objects to extract embeddings from. + :param embedding_model: The embedding model used to compute embeddings. Defaults to MPNetAdapter. + :param perplexity: The perplexity for the t-SNE algorithm. Higher values give more global structure. + """ + if embedding_model is None: + embedding_model = MPNetAdapter() + all_embeddings = [] + all_texts = [] + all_colors = [] + plotly_colors = px.colors.qualitative.Plotly + for idx, dictionary in enumerate(data_dictionaries): + embeddings_dict = dictionary.get_embeddings(embedding_model=embedding_model) + embeddings = list(embeddings_dict.values()) + texts = dictionary.to_dataframe()['description'] + color = plotly_colors[idx % len(plotly_colors)] + all_embeddings.extend(embeddings) + all_texts.extend(texts) + all_colors.extend([color] * len(embeddings)) + embeddings_array = np.array(all_embeddings) + # Adjust perplexity if there are enough points + if embeddings_array.shape[0] > 30: + perplexity = min(perplexity, 30) + if embeddings_array.shape[0] > perplexity: + # Compute t-SNE embeddings + tsne_embeddings = TSNE(n_components=2, perplexity=perplexity).fit_transform(embeddings_array) + # Create Plotly scatter plot + scatter_plot = go.Scatter( + x=tsne_embeddings[:, 0], + y=tsne_embeddings[:, 1], + mode="markers", + marker=dict( + size=8, + color=all_colors, # Use the assigned colors from Plotly palette + opacity=0.7 + ), + text=all_texts, + hoverinfo="text" + ) + layout = go.Layout( + title="t-SNE Embeddings of Data Dictionaries", + xaxis=dict(title="t-SNE Component 1"), + yaxis=dict(title="t-SNE Component 2"), + ) + fig = go.Figure(data=[scatter_plot], layout=layout) + # Display the plot + fig.show() + else: + print("Too few data dictionary entries to visualize. Adjust param 'perplexity' to a value less then the number " + "of data points.") diff --git a/tests/resources/data_dict1.csv b/tests/resources/data_dict1.csv new file mode 100644 index 0000000..a39c552 --- /dev/null +++ b/tests/resources/data_dict1.csv @@ -0,0 +1,51 @@ +VAR,DESC +VAR_1,VAR_1 amounts measured under controlled conditions. +VAR_2,VAR_2 data collected on a monthly basis. +VAR_3,Summarizes VAR_3 occurrences during the experiment. +VAR_4,Summarizes VAR_4 occurrences during the experiment. +VAR_5,This is the index value for VAR_5 activity. +VAR_6,Tracks fluctuations in VAR_6 over time. +VAR_7,This is the index value for VAR_7 activity. +VAR_8,It records the count of VAR_8 for each observed period. +VAR_9,VAR_9 data collected on a monthly basis. +VAR_10,Summarizes VAR_10 occurrences during the experiment. +VAR_11,VAR_11 amounts measured under controlled conditions. +VAR_12,It records the count of VAR_12 for each observed period. +VAR_13,Summarizes VAR_13 occurrences during the experiment. +VAR_14,Summarizes VAR_14 occurrences during the experiment. +VAR_15,VAR_15 data collected on a monthly basis. +VAR_16,VAR_16 amounts measured under controlled conditions. +VAR_17,It records the count of VAR_17 for each observed period. +VAR_18,VAR_18 data collected on a monthly basis. +VAR_19,Tracks fluctuations in VAR_19 over time. +VAR_20,An indicator representing the levels of VAR_20 in a sample. +VAR_21,Variable associated with the change in VAR_21 throughout the day. +VAR_22,Variable associated with the change in VAR_22 throughout the day. +VAR_23,This variable captures the daily average VAR_23 readings. +VAR_24,This is the index value for VAR_24 activity. +VAR_25,Tracks fluctuations in VAR_25 over time. +VAR_26,Tracks fluctuations in VAR_26 over time. +VAR_27,Variable associated with the change in VAR_27 throughout the day. +VAR_28,VAR_28 measurement gathered every quarter. +VAR_29,VAR_29 measurement gathered every quarter. +VAR_30,VAR_30 data collected on a monthly basis. +VAR_31,It records the count of VAR_31 for each observed period. +VAR_32,Summarizes VAR_32 occurrences during the experiment. +VAR_33,Variable associated with the change in VAR_33 throughout the day. +VAR_34,VAR_34 amounts measured under controlled conditions. +VAR_35,It records the count of VAR_35 for each observed period. +VAR_36,VAR_36 amounts measured under controlled conditions. +VAR_37,An indicator representing the levels of VAR_37 in a sample. +VAR_38,Summarizes VAR_38 occurrences during the experiment. +VAR_39,This variable captures the daily average VAR_39 readings. +VAR_40,Variable associated with the change in VAR_40 throughout the day. +VAR_41,VAR_41 amounts measured under controlled conditions. +VAR_42,Variable associated with the change in VAR_42 throughout the day. +VAR_43,Variable associated with the change in VAR_43 throughout the day. +VAR_44,This variable captures the daily average VAR_44 readings. +VAR_45,Summarizes VAR_45 occurrences during the experiment. +VAR_46,Variable associated with the change in VAR_46 throughout the day. +VAR_47,Variable associated with the change in VAR_47 throughout the day. +VAR_48,An indicator representing the levels of VAR_48 in a sample. +VAR_49,VAR_49 amounts measured under controlled conditions. +VAR_50,VAR_50 amounts measured under controlled conditions. diff --git a/tests/resources/data_dict2.csv b/tests/resources/data_dict2.csv new file mode 100644 index 0000000..f906e6d --- /dev/null +++ b/tests/resources/data_dict2.csv @@ -0,0 +1,51 @@ +VAR,DESC +VAR_1,Measurement of VAR_1 on a daily timescale. +VAR_2,Measurement of VAR_2 on a daily timescale. +VAR_3,Monthly statistics for VAR_3 collection. +VAR_4,Variable linked to fluctuations in VAR_4 across different times. +VAR_5,Variable linked to fluctuations in VAR_5 across different times. +VAR_6,Captures the total number of VAR_6 per session. +VAR_7,Tracks variations in VAR_7 over multiple intervals. +VAR_8,Measurement of VAR_8 on a daily timescale. +VAR_9,Captures the total number of VAR_9 per session. +VAR_10,Measurement of VAR_10 on a daily timescale. +VAR_11,Monthly statistics for VAR_11 collection. +VAR_12,Captures the total number of VAR_12 per session. +VAR_13,The count of VAR_13 instances recorded throughout the study. +VAR_14,Quarterly data focused on VAR_14 trends. +VAR_15,Monthly statistics for VAR_15 collection. +VAR_16,VAR_16 index reflecting seasonal variation. +VAR_17,Measurement of VAR_17 on a daily timescale. +VAR_18,Variable linked to fluctuations in VAR_18 across different times. +VAR_19,Captures the total number of VAR_19 per session. +VAR_20,The count of VAR_20 instances recorded throughout the study. +VAR_21,Captures the total number of VAR_21 per session. +VAR_22,An estimate of the concentration of VAR_22 in specific regions. +VAR_23,Captures the total number of VAR_23 per session. +VAR_24,The count of VAR_24 instances recorded throughout the study. +VAR_25,Quarterly data focused on VAR_25 trends. +VAR_26,VAR_26 index reflecting seasonal variation. +VAR_27,Variable linked to fluctuations in VAR_27 across different times. +VAR_28,Variable linked to fluctuations in VAR_28 across different times. +VAR_29,Tracks variations in VAR_29 over multiple intervals. +VAR_30,VAR_30 values measured under precise situations. +VAR_31,Tracks variations in VAR_31 over multiple intervals. +VAR_32,Monthly statistics for VAR_32 collection. +VAR_33,VAR_33 values measured under precise situations. +VAR_34,Tracks variations in VAR_34 over multiple intervals. +VAR_35,VAR_35 values measured under precise situations. +VAR_36,Tracks variations in VAR_36 over multiple intervals. +VAR_37,An estimate of the concentration of VAR_37 in specific regions. +VAR_38,VAR_38 index reflecting seasonal variation. +VAR_39,VAR_39 index reflecting seasonal variation. +VAR_40,Variable linked to fluctuations in VAR_40 across different times. +VAR_41,Quarterly data focused on VAR_41 trends. +VAR_42,VAR_42 values measured under precise situations. +VAR_43,An estimate of the concentration of VAR_43 in specific regions. +VAR_44,Variable linked to fluctuations in VAR_44 across different times. +VAR_45,The count of VAR_45 instances recorded throughout the study. +VAR_46,Monthly statistics for VAR_46 collection. +VAR_47,Monthly statistics for VAR_47 collection. +VAR_48,An estimate of the concentration of VAR_48 in specific regions. +VAR_49,Tracks variations in VAR_49 over multiple intervals. +VAR_50,Variable linked to fluctuations in VAR_50 across different times. diff --git a/tests/test_visualisation.py b/tests/test_visualisation.py index ea7e663..1ee0b41 100644 --- a/tests/test_visualisation.py +++ b/tests/test_visualisation.py @@ -10,7 +10,7 @@ bar_chart_average_acc_two_distributions, enrichment_plot, scatter_plot_all_cohorts, - scatter_plot_two_distributions, + scatter_plot_two_distributions, plot_embeddings, ) @@ -127,3 +127,11 @@ def test_bar_chart_average_acc_two_distributions(self): mpnet_2 = pd.DataFrame({"M1": [0.9, 0.78, 0.68], "M2": [0.69, 0.9, 0.68], "M3": [0.71, 0.75, 0.9]}, index=labels).T bar_chart_average_acc_two_distributions(fuzzy_1, gpt_1, mpnet_1, fuzzy_2, gpt_2, mpnet_2, "title", "AD", "PD") + + def test_plot_data_dict(self): + TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) + data_dictionary_source_1 = DataDictionarySource(os.path.join(TEST_DIR_PATH, "resources", "data_dict1.csv"), + "VAR", "DESC") + data_dictionary_source_2 = DataDictionarySource(os.path.join(TEST_DIR_PATH, "resources", "data_dict2.csv"), + "VAR", "DESC") + plot_embeddings([data_dictionary_source_1, data_dictionary_source_2]) \ No newline at end of file