1
1
from abc import ABC
2
+ from typing import Dict
3
+ from datastew .embedding import EmbeddingModel , MPNetAdapter
2
4
3
5
import pandas as pd
4
6
import numpy as np
@@ -52,16 +54,34 @@ def to_dataframe(self) -> pd.DataFrame:
52
54
53
55
54
56
class DataDictionarySource (Source ):
55
- """
56
- Contains mapping of variable -> description
57
- """
58
57
59
58
def __init__ (self , file_path : str , variable_field : str , description_field : str ):
60
- self .file_path = file_path
61
- self .variable_field = variable_field
62
- self .description_field = description_field
63
-
64
- def to_dataframe (self ) -> pd .DataFrame :
59
+ """
60
+ Initialize the DataDictionarySource with the path to the data dictionary file
61
+ and the fields that represent the variables and their descriptions.
62
+
63
+ :param file_path: Path to the data dictionary file.
64
+ :param variable_field: The column that contains the variable names.
65
+ :param description_field: The column that contains the variable descriptions.
66
+ """
67
+ self .file_path : str = file_path
68
+ self .variable_field : str = variable_field
69
+ self .description_field : str = description_field
70
+
71
+ def to_dataframe (self , dropna : bool = True ) -> pd .DataFrame :
72
+ """
73
+ Load the data dictionary file into a pandas DataFrame, select the variable and
74
+ description fields, and ensure they exist. Optionally remove rows with missing
75
+ variables or descriptions based on the 'dropna' parameter.
76
+
77
+ :param dropna: If True, rows with missing 'variable' or 'description' values are
78
+ dropped. Defaults to True.
79
+ :return: A DataFrame containing two columns:
80
+ - 'variable': The variable names from the data dictionary.
81
+ - 'description': The descriptions corresponding to each variable.
82
+ :raises ValueError: If either the variable field or the description field is not
83
+ found in the data dictionary file.
84
+ """
65
85
df = super ().to_dataframe ()
66
86
# sanity check
67
87
if self .variable_field not in df .columns :
@@ -70,9 +90,32 @@ def to_dataframe(self) -> pd.DataFrame:
70
90
raise ValueError (f"Description field { self .description_field } not found in { self .file_path } " )
71
91
df = df [[self .variable_field , self .description_field ]]
72
92
df = df .rename (columns = {self .variable_field : "variable" , self .description_field : "description" })
73
- df .dropna (subset = ["variable" , "description" ], inplace = True )
93
+ if dropna :
94
+ df .dropna (subset = ["variable" , "description" ], inplace = True )
74
95
return df
75
-
96
+
97
+ def get_embeddings (self , embedding_model : EmbeddingModel = None ) -> Dict [str , list ]:
98
+ """
99
+ Compute embedding vectors for each description in the data dictionary. The
100
+ resulting vectors are mapped to their respective variables and returned as a
101
+ dictionary.
102
+
103
+ :param embedding_model: The embedding model used to compute embeddings for the descriptions.
104
+ Defaults to MPNetAdapter.
105
+ :return: A dictionary where each key is a variable name and the value is the
106
+ embedding vector for the corresponding description.
107
+ :rtype: Dict[str, list]
108
+ """
109
+ # Compute vectors for all descriptions
110
+ df : pd .DataFrame = self .to_dataframe ()
111
+ descriptions : list [str ] = df ["description" ].tolist ()
112
+ if embedding_model is None :
113
+ embedding_model = MPNetAdapter ()
114
+ embeddings = embedding_model .get_embeddings (descriptions )
115
+ # variable identify descriptions -> variable to embedding
116
+ variable_to_embedding : Dict [str , list ] = dict (zip (df ["variable" ], embeddings ))
117
+ return variable_to_embedding
118
+
76
119
77
120
class EmbeddingSource :
78
121
def __init__ (self , source_path : str ):
0 commit comments