Skip to content

Commit

Permalink
ehrshot
Browse files Browse the repository at this point in the history
  • Loading branch information
Miking98 committed Mar 28, 2024
1 parent 0515ef4 commit 13bf8ab
Show file tree
Hide file tree
Showing 5 changed files with 1,262 additions and 105 deletions.
120 changes: 119 additions & 1 deletion src/femr/labelers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,42 @@
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import datasets
import meds

import femr.hf_utils
from femr.labelers.omop import identity
import femr.ontology

##########################################################
##########################################################
#
# Helper functions
#
##########################################################
##########################################################

def identity(x: Any) -> Any:
return x


def get_death_concepts() -> List[str]:
return [
meds.death_code,
]

def move_datetime_to_end_of_day(date: datetime.datetime) -> datetime.datetime:
return date.replace(hour=23, minute=59, second=0)

##########################################################
##########################################################
#
# Shared classes
#
##########################################################
##########################################################

@dataclass(frozen=True)
class TimeHorizon:
Expand Down Expand Up @@ -88,6 +117,9 @@ def apply(
num_proc=num_proc,
)

def get_patient_start_end_times(patient):
"""Return the datetimes that we consider the (start, end) of this patient."""
return (patient["events"][0]["time"], patient["events"][-1]["time"])

##########################################################
# Specific Labeler Superclasses
Expand Down Expand Up @@ -287,6 +319,91 @@ def label(self, patient: meds.Patient) -> List[meds.Label]:
return n_labels



class CodeLabeler(TimeHorizonEventLabeler):
"""Apply a label based on 1+ outcome_codes' occurrence(s) over a fixed time horizon."""

def __init__(
self,
outcome_codes: List[str],
time_horizon: TimeHorizon,
prediction_codes: Optional[List[str]] = None,
prediction_time_adjustment_func: Optional[Callable] = None,
):
"""Create a CodeLabeler, which labels events whose index in your Ontology is in `self.outcome_codes`
Args:
prediction_codes (List[int]): Events that count as an occurrence of the outcome.
time_horizon (TimeHorizon): An interval of time. If the event occurs during this time horizon, then
the label is TRUE. Otherwise, FALSE.
prediction_codes (Optional[List[int]]): If not None, limit events at which you make predictions to
only events with an `event.code` in these codes.
prediction_time_adjustment_func (Optional[Callable]). A function that takes in a `datetime.datetime`
and returns a different `datetime.datetime`. Defaults to the identity function.
"""
self.outcome_codes: List[str] = outcome_codes
self.time_horizon: TimeHorizon = time_horizon
self.prediction_codes: Optional[List[str]] = prediction_codes
self.prediction_time_adjustment_func: Callable = (
prediction_time_adjustment_func if prediction_time_adjustment_func is not None else identity # type: ignore
)

def get_prediction_times(self, patient: meds.Patient) -> List[datetime.datetime]:
"""Return each event's start time (possibly modified by prediction_time_adjustment_func)
as the time to make a prediction. Default to all events whose `code` is in `self.prediction_codes`."""
times: List[datetime.datetime] = []
last_time = None
for e in patient["events"]:
prediction_time: datetime.datetime = self.prediction_time_adjustment_func(e.start)
if ((self.prediction_codes is None) or (e.code in self.prediction_codes)) and (
last_time != prediction_time
):
times.append(prediction_time)
last_time = prediction_time
return times

def get_time_horizon(self) -> TimeHorizon:
return self.time_horizon

def get_outcome_times(self, patient: meds.Patient) -> List[datetime.datetime]:
"""Return the start times of this patient's events whose `code` is in `self.outcome_codes`."""
times: List[datetime.datetime] = []
for event in patient.events:
if event.code in self.outcome_codes:
times.append(event.start)
return times

def allow_same_time_labels(self) -> bool:
# We cannot allow labels at the same time as the codes since they will generally be available as features ...
return False


class OMOPConceptCodeLabeler(CodeLabeler):
"""Same as CodeLabeler, but add the extra step of mapping OMOP concept IDs
(stored in `omop_concept_ids`) to femr codes (stored in `codes`)."""

# parent OMOP concept codes, from which all the outcome
# are derived (as children from our ontology)
original_omop_concept_codes: List[str] = []

def __init__(
self,
ontology: femr.ontology.Ontology,
time_horizon: TimeHorizon,
prediction_codes: Optional[List[str]] = None,
prediction_time_adjustment_func: Optional[Callable] = None,
):
outcome_codes: List[str] = ontology.get_all_children(self.original_omop_concept_codes)
super().__init__(
outcome_codes=outcome_codes,
time_horizon=time_horizon,
prediction_codes=prediction_codes,
prediction_time_adjustment_func=prediction_time_adjustment_func
if prediction_time_adjustment_func
else identity,
)


def compute_random_num(seed: int, num_1: int, num_2: int, modulus: int = 100):
network_num_1 = struct.pack("!q", num_1)
network_num_2 = struct.pack("!q", num_2)
Expand All @@ -303,3 +420,4 @@ def compute_random_num(seed: int, num_1: int, num_2: int, modulus: int = 100):
result = (result * 256 + hash_value[i]) % modulus

return result

Loading

0 comments on commit 13bf8ab

Please sign in to comment.