Skip to content

Commit

Permalink
First
Browse files Browse the repository at this point in the history
  • Loading branch information
EthanSteinberg committed Mar 5, 2024
1 parent 5e8db07 commit 195b1f1
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 27 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ dependencies = [
"icecream == 2.1.3",
"nptyping == 2.4.1",
"msgpack >= 1.0.5",
"polars",
]
requires-python=">3.9"
version = "0.2.0"
version = "0.1.314"

[project.scripts]

Expand Down
17 changes: 15 additions & 2 deletions src/femr/models/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def start_batch(self, num_patients, max_length):
self.offsets = np.zeros(num_patients, dtype=np.int32)

self.tokens = np.zeros((num_patients, max_length), dtype=np.int32)

self.hier_tokens = []
self.hier_token_indices = [0]
self.hier_token_offsets = np.zeros((num_patients, max_length), dtype=np.int32)

self.valid_tokens = np.zeros((num_patients, max_length), dtype=np.bool_)

self.ages = np.zeros((num_patients, max_length), dtype=np.float32)
Expand Down Expand Up @@ -121,7 +126,12 @@ def add_patient(self, patient, offset):
break

if self.tokenizer.is_hierarchical:
assert False # TODO: Implement this
self.hier_token_offsets[self.patient_index, self.length_index - offset] = len(self.hier_token_indices) - 1
for t in features:
self.hier_tokens.append(t)

self.hier_token_indices.append(len(self.hier_tokens))

else:
self.tokens[self.patient_index, self.length_index - offset] = features[0]

Expand All @@ -145,13 +155,16 @@ def add_patient(self, patient, offset):
num_added = self.task.add_event(last_time, None, None)
for i in range(num_added):
self.label_indices.append(self.patient_index * self.max_length + self.length_index - offset - 1)

self.patient_index += 1

def get_batch_data(self):
transformer = {
"length": self.max_length,
"tokens": self.tokens,
"hier_tokens": np.array(self.hier_tokens, dtype=np.int32),
"hier_token_indices": np.array(self.hier_token_indices, dtype=np.int32),
"hier_token_offsets": self.hier_token_offsets,
"valid_tokens": self.valid_tokens,
"ages": self.ages,
"integer_ages": self.integer_ages,
Expand Down
71 changes: 51 additions & 20 deletions src/femr/models/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import collections
import pickle
import datetime
import functools
import math
Expand Down Expand Up @@ -148,9 +149,8 @@ def convert_statistics_to_msgpack(statistics, vocab_size, is_hierarchical):


class FEMRTokenizer(transformers.utils.PushToHubMixin):
def __init__(self, dictionary):
assert not dictionary["is_hierarchical"], "Currently not supported"

def __init__(self, dictionary, ontology=None):
self.ontology = ontology
self.is_hierarchical = dictionary["is_hierarchical"]

self.dictionary = dictionary
Expand Down Expand Up @@ -197,10 +197,20 @@ def from_pretrained(self, pretrained_model_name_or_path: Union[str, os.PathLike]

with open(dictionary_file, "rb") as f:
dictionary = msgpack.load(f)

ontology_file = transformers.utils.hub.cached_file(
pretrained_model_name_or_path, "ontology.pkl", **kwargs
)

if os.path.exists(ontology_file):
with open(ontology_file, 'rb') as f:
ontology = pickle.load(f)
else:
ontology = None

return FEMRTokenizer(dictionary)
return FEMRTokenizer(dictionary, ontology=ontology)

def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, save_ontology: bool = False,**kwargs):
"""
Save the FEMR tokenizer.
Expand Down Expand Up @@ -232,6 +242,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
with open(os.path.join(save_directory, "dictionary.msgpack"), "wb") as f:
msgpack.dump(self.dictionary, f)

if save_ontology:
with open(os.path.join(save_directory, "ontology.pkl"), "wb") as f:
pickle.dump(self.ontology, f)

if push_to_hub:
self._upload_modified_files(
save_directory,
Expand All @@ -242,24 +256,41 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub:
)

def get_feature_codes(self, measurement):
if measurement.get("numeric_value") is not None:
for start, end, i in self.numeric_lookup.get(measurement["code"], []):
if start <= measurement["numeric_value"] < end:
return [i]
else:
return []
elif measurement.get("text_value") is not None:
value = self.string_lookup.get((measurement["code"], measurement["text_value"]))
if value is not None:
return [value]
if not self.is_hierarchical:
if measurement.get("numeric_value") is not None:
for start, end, i in self.numeric_lookup.get(measurement["code"], []):
if start <= measurement["numeric_value"] < end:
return [i]
else:
return []
elif measurement.get("text_value") is not None:
value = self.string_lookup.get((measurement["code"], measurement["text_value"]))
if value is not None:
return [value]
else:
return []
else:
return []
value = self.code_lookup.get(measurement["code"])
if value is not None:
return [value]
else:
return []
else:
value = self.code_lookup.get(measurement["code"])
if value is not None:
return [value]
result = []
if measurement.get("numeric_value") is not None:
for start, end, i in self.numeric_lookup.get(measurement["code"], []):
if start <= measurement["numeric_value"] < end:
result.append(i)
elif measurement.get("text_value") is not None:
value = self.string_lookup.get((measurement["code"], measurement["text_value"]))
if value is not None:
result.append(value)
else:
return []
for parent in self.ontology.get_all_parents(measurement['code']):
value = self.code_lookup.get(parent)
if value is not None:
result.append(value)
return result

def normalize_age(self, age):
return (age - self.dictionary["age_stats"]["mean"]) / (self.dictionary["age_stats"]["std"])
11 changes: 7 additions & 4 deletions src/femr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,16 @@ def __init__(self, config: FEMRTransformerConfig):
if not self.config.is_hierarchical:
self.embed = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
else:
# Need to be using an embedding bag here
assert False
self.embed = nn.EmbeddingBag(self.config.vocab_size, self.config.hidden_size)

self.layers = nn.ModuleList([FEMREncoderLayer(config) for _ in range(self.config.n_layers)])

def forward(self, batch):
x = self.embed(batch["tokens"])
if not self.config.is_hierarchical:
x = self.embed(batch["tokens"])
else:
embedded = self.embed(batch["hier_tokens"], batch["hier_token_indices"])
x = embedded[batch['hier_token_offsets'], :]

x = self.in_norm(x)
pos_embed = fixed_pos_embedding(batch["ages"], self.config.hidden_size // self.config.n_heads, x.dtype)
Expand All @@ -169,7 +172,7 @@ def forward(self, features: torch.Tensor, batch: Mapping[str, torch.Tensor]):


class CLMBRTaskHead(nn.Module):
def __init__(self, hidden_size: int, clmbr_vocab_size: int):
def __init__(self, hidden_size: int, clmbr_vocab_size: int, **kwargs):
super().__init__()

self.final_layer = nn.Linear(hidden_size, clmbr_vocab_size)
Expand Down
182 changes: 182 additions & 0 deletions src/femr/ontology.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from __future__ import annotations

import collections
import functools
import os
from typing import Any, Dict, Iterable, Optional, Set

import datasets
import polars as pl

import femr.hf_utils


def _get_all_codes_map(batch) -> Set[str]:
result = set()
for events in batch["events"]:
for event in events:
for measurement in event["measurements"]:
result.add(measurement["code"])
return result


def _get_all_codes_agg(first: Set[str], second: Set[str]) -> Set[str]:
first |= second
return first


class Ontology:
def __init__(self, athena_path: str, code_metadata = {}):
"""Create an Ontology from an Athena download and an optional meds Code Metadata structure.
NOTE: This is an expensive operation.
It is recommended to create an ontology once and then save/load it as necessary.
"""
# Load from code metadata
self.description_map: Dict[str, str] = {}
self.parents_map: Dict[str, Set[str]] = collections.defaultdict(set)

# Load from the athena path ...
concept = pl.scan_csv(os.path.join(athena_path, "CONCEPT.csv"), separator="\t", infer_schema_length=0)
code_col = pl.col("vocabulary_id") + "/" + pl.col("concept_code")
description_col = pl.col("concept_name")
concept_id_col = pl.col("concept_id").cast(pl.Int64)

processed_concepts = (
concept.select(code_col, concept_id_col, description_col, pl.col("standard_concept").is_null())
.collect()
.rows()
)

concept_id_to_code_map = {}

non_standard_concepts = set()

for code, concept_id, description, is_non_standard in processed_concepts:
concept_id_to_code_map[concept_id] = code

# We don't want to override code metadata
if code not in self.description_map:
self.description_map[code] = description

if is_non_standard:
non_standard_concepts.add(concept_id)

relationship = pl.scan_csv(
os.path.join(athena_path, "CONCEPT_RELATIONSHIP.csv"), separator="\t", infer_schema_length=0
)
relationship_id = pl.col("relationship_id")
relationship = relationship.filter(
relationship_id == "Maps to", pl.col("concept_id_1") != pl.col("concept_id_2")
)
for concept_id_1, concept_id_2 in (
relationship.select(pl.col("concept_id_1").cast(pl.Int64), pl.col("concept_id_2").cast(pl.Int64))
.collect()
.rows()
):
if concept_id_1 in non_standard_concepts:
if concept_id_1 in concept_id_to_code_map and concept_id_2 in concept_id_to_code_map:
self.parents_map[concept_id_to_code_map[concept_id_1]].add(concept_id_to_code_map[concept_id_2])

ancestor = pl.scan_csv(os.path.join(athena_path, "CONCEPT_ANCESTOR.csv"), separator="\t", infer_schema_length=0)
ancestor = ancestor.filter(pl.col("min_levels_of_separation") == "1")
for concept_id, parent_concept_id in (
ancestor.select(
pl.col("descendant_concept_id").cast(pl.Int64), pl.col("ancestor_concept_id").cast(pl.Int64)
)
.collect()
.rows()
):
if concept_id in concept_id_to_code_map and parent_concept_id in concept_id_to_code_map:
self.parents_map[concept_id_to_code_map[concept_id]].add(concept_id_to_code_map[parent_concept_id])

# Have to add after OMOP to overwrite ...
for code, code_info in code_metadata.items():
if code_info.get("description") is not None:
self.description_map[code] = code_info["description"]
if code_info.get("parent_codes") is not None:
self.parents_map[code] = set(code_info["parent_codes"])

self.children_map = collections.defaultdict(set)
for code, parents in self.parents_map.items():
for parent in parents:
self.children_map[parent].add(code)

self.all_parents_map: Dict[str, Set[str]] = {}
self.all_children_map: Dict[str, Set[str]] = {}

def prune_to_dataset(
self,
dataset: datasets.Dataset,
num_proc: int = 1,
prune_all_descriptions: bool = False,
remove_ontologies: Set[str] = set(),
) -> None:
valid_codes = femr.hf_utils.aggregate_over_dataset(
dataset,
functools.partial(_get_all_codes_map),
_get_all_codes_agg,
num_proc=num_proc,
batch_size=1_000,
)

if prune_all_descriptions:
self.description_map = {}

all_parents = set()

for code in valid_codes:
all_parents |= self.get_all_parents(code)

def is_valid(code):
ontology = code.split("/")[0]
return (code in valid_codes) or ((ontology not in remove_ontologies) and (code in all_parents))

codes = self.children_map.keys() | self.parents_map.keys() | self.description_map.keys()
for code in codes:
m: Any
if is_valid(code):
for m in (self.children_map, self.parents_map):
m[code] = {a for a in m[code] if is_valid(a)}
else:
for m in (self.children_map, self.parents_map, self.description_map):
if code in m:
del m[code]

self.all_parents_map = {}
self.all_children_map = {}

# Prime the pump
for code in self.children_map.keys() | self.parents_map.keys():
self.get_all_parents(code)

def get_description(self, code: str) -> Optional[str]:
"""Get a description of a code."""
return self.description_map.get(code)

def get_children(self, code: str) -> Iterable[str]:
"""Get the children for a given code."""
return self.children_map.get(code, set())

def get_parents(self, code: str) -> Iterable[str]:
"""Get the parents for a given code."""
return self.parents_map.get(code, set())

def get_all_children(self, code: str) -> Set[str]:
"""Get all children, including through the ontology."""
if code not in self.all_children_map:
result = {code}
for child in self.children_map.get(code, set()):
result |= self.get_all_children(child)
self.all_children_map[code] = result
return self.all_children_map[code]

def get_all_parents(self, code: str) -> Set[str]:
"""Get all parents, including through the ontology."""
if code not in self.all_parents_map:
result = {code}
for parent in self.parents_map.get(code, set()):
result |= self.get_all_parents(parent)
self.all_parents_map[code] = result

return self.all_parents_map[code]

0 comments on commit 195b1f1

Please sign in to comment.