From 48369cd75ed7a8f66e7814f29f88b0b93fcae1b3 Mon Sep 17 00:00:00 2001 From: vhabhsgieraa Date: Tue, 20 Feb 2024 14:13:34 -0500 Subject: [PATCH] fix: Add serialization for UmlsMatch --- quickumls/core.py | 3 +-- quickumls/spacy_component.py | 14 ++++++++------ quickumls/toolbox.py | 2 +- quickumls/umls_match.py | 36 +++++++++++++++++++++++++++++++++--- requirements.txt | 1 + 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/quickumls/core.py b/quickumls/core.py index 2bcff69..0135277 100644 --- a/quickumls/core.py +++ b/quickumls/core.py @@ -316,12 +316,11 @@ def _get_all_matches(self, ngrams): if not self.to_lowercase_flag and ngram_normalized.isupper() and not self.keep_uppercase: ngram_normalized = ngram_normalized.lower() - prev_cui = None ngram_cands = list(self.ss_db.get(ngram_normalized)) ngram_dict = {} for match in ngram_cands: - cuisem_match = sorted(self.cuisem_db.get(match)) + cuisem_match = self.cuisem_db.get(match) match_similarity = toolbox.get_similarity( x=ngram_normalized, diff --git a/quickumls/spacy_component.py b/quickumls/spacy_component.py index 32747bd..1ece0b6 100644 --- a/quickumls/spacy_component.py +++ b/quickumls/spacy_component.py @@ -114,13 +114,13 @@ def __init__(self, nlp, name = "medspacy_quickumls", quickumls_fp=None, # umls_matches below which contains more information and enables overlapping if not Span.has_extension("similarity"): Span.set_extension('similarity', default = -1.0) - if not Span.has_extension("semtypes"): + if not Span.has_extension("semtypes"): Span.set_extension('semtypes', default = -1.0) # match objects are a set, since span objects with the same start/end keys # would have the same values for custom attributes in spacy if not Span.has_extension("umls_matches"): - Span.set_extension('umls_matches', default=set()) + Span.set_extension('umls_matches', default=[]) @property def result_type(self) -> str: @@ -205,11 +205,13 @@ def __call__(self, doc): span._.semtypes = ngram_match_dict['semtypes'] # let's create this more fully featured match object - umls_match = UmlsMatch(cui, - ngram_match_dict['semtypes'], - ngram_match_dict['similarity']) + umls_match = UmlsMatch( + cui, + ngram_match_dict['semtypes'], + ngram_match_dict['similarity'], + ) - span._.umls_matches.add(umls_match) + span._.umls_matches.append(umls_match) if self.result_type.lower() == "ents": doc.ents = list(doc.ents) + [span] diff --git a/quickumls/toolbox.py b/quickumls/toolbox.py index d44937a..ddcf1ca 100644 --- a/quickumls/toolbox.py +++ b/quickumls/toolbox.py @@ -283,7 +283,7 @@ def get(self, term): matches = ( ( cui, - pickle.loads(self.semtypes_db_get(db_key_encode(cui))), + list(pickle.loads(self.semtypes_db_get(db_key_encode(cui)))), is_preferred ) for cui, is_preferred in cuis diff --git a/quickumls/umls_match.py b/quickumls/umls_match.py index ca52f73..7816f9d 100644 --- a/quickumls/umls_match.py +++ b/quickumls/umls_match.py @@ -1,11 +1,12 @@ -from typing import Set +from typing import Any, Dict, List +import srsly class UmlsMatch: def __init__(self, cui: str, - semtypes: Set[str], + semtypes: List[str], similarity: float): """Instantiate UmlsMatch object @@ -15,10 +16,39 @@ def __init__(self, Args: cui: UMLS controlled unique identifier (CUI) value (e.g., "C0243095") - semtypes (Set[str]): List of UMLS semantic types as Type Unique Identifier values (TUI) + semtypes (List[str]): List of UMLS semantic types as Type Unique Identifier values (TUI) for this matched concept (e.g., "T203") similarity (float): Similarity score between match and UMLS concept """ self.cui = cui self.semtypes = semtypes self.similarity = similarity + + def __repr__(self): + return f"UmlsMatch({str(self.cui), str(self.semtypes), str(self.similarity)})" + + def serialized_representation(self) -> Dict[str, Any]: + """ + Returns the serialized representation of the UmlsMatch + """ + return self.__dict__ + + @classmethod + def from_serialized_representation(cls, serialized_representation): + """ + Creates the UmlsMatch from the serialized representation + """ + return UmlsMatch(**serialized_representation) + +@srsly.msgpack_encoders("umls_match") +def serialize_context_graph(obj, chain=None): + if isinstance(obj, UmlsMatch): + return {"umls_match": obj.serialized_representation()} + return obj if chain is None else chain(obj) + + +@srsly.msgpack_decoders("umls_match") +def deserialize_context_graph(obj, chain=None): + if "umls_match" in obj: + return UmlsMatch.from_serialized_representation(obj["umls_match"]) + return obj if chain is None else chain(obj) diff --git a/requirements.txt b/requirements.txt index ad986a7..c0a4ff6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ nltk>=3.3 medspacy_simstring>=2.1 unqlite>=0.8.1 pytest>=6 +srsly>=2.4.8 six \ No newline at end of file