Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 102 additions & 28 deletions cardea/problem_definition/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,31 @@


class ProblemDefinition:
"""A class that defines the prediction problem
by specifying cutoff times and generating the target label if it does not exist.
"""Base class that defines a prediction problem.

Attributes:
target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
"""

def __init__(self, target_label_column_name,
target_entity, cutoff_time_label,
cutoff_entity, prediction_type,
updated_es=None, conn=None):

self.target_label_column_name = target_label_column_name
self.target_entity = target_entity
self.cutoff_time_label = cutoff_time_label
self.cutoff_entity = cutoff_entity
self.prediction_type = prediction_type

# optionals
self.conn = conn
self.updated_es = updated_es

def check_target_label(self, entity_set, target_entity, target_label):
"""Checks if target label exists in the entity set.

Expand Down Expand Up @@ -49,11 +70,12 @@ def generate_target_label(self, entity_set, target_entity, target_label):
Target entity with the generated label.
"""

def generate_cutoff_times(self, entity_set):
def generate_cutoff_times(self, entity_set,
cutoff_time_unifier='unify_cutoff_time_admission_time'):
"""Generates cutoff times for the predection problem.

Args:
entity_set: fhir entityset.
entity_set: the FHIR entityset.

Returns:
entity_set, target_entity, series of target_labels and a dataframe of cutoff_times.
Expand All @@ -62,6 +84,59 @@ def generate_cutoff_times(self, entity_set):
ValueError: An error occurs if the cutoff variable does not exist.
"""

loader = DataLoader()

target_label_exists = loader.check_column_existence(
entity_set, self.target_entity, self.target_label_column_name
)

target_label_has_missing_values = loader.check_for_missing_values(
entity_set, self.target_entity, self.target_label_column_name
)

if target_label_exists and not target_label_has_missing_values:
cutoff_time_label_exists = loader.check_column_existence(
entity_set, self.cutoff_entity, self.cutoff_time_label
)

if not cutoff_time_label_exists:
raise ValueError(
'Cutoff time label {} does not exist in table {}'.format(
self.cutoff_time_label,
self.cutoff_entity
)
)

cutoff_time_unifier_func = getattr(self, cutoff_time_unifier)
generated_cts = cutoff_time_unifier_func(
entity_set, self.cutoff_entity, self.cutoff_time_label
)

# new entity set
es = entity_set.entity_from_dataframe(
entity_id=self.cutoff_entity, dataframe=generated_cts, index='object_id'
)

label = es[self.target_entity].df[self.conn].values

instance_id = list(es[self.target_entity].df.index)

# get cutoff_times
cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()
cutoff_times = cutoff_times.reindex(index=label)
cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']
cutoff_times['label'] = list(es[self.target_entity].df[self.target_label_column_name])

return (es, self.target_entity, cutoff_times)

# get a new entity set
self.updated_es = self.generate_target_label(entity_set)

# recursive call
return self.generate_cutoff_times(self.updated_es)

def unify_cutoff_times_hours_admission_time(self, df, cutoff_time_label):
"""Unify records cutoff times based on shared time.

Expand All @@ -75,25 +150,24 @@ def unify_cutoff_times_hours_admission_time(self, df, cutoff_time_label):

if i == 0:

if df.get_value(i, 'checked') is not True:
df.set_value(i, 'ct', df.get_value(i, cutoff_time_label))
df.set_value(i, 'checked', True)
if df.at[i, 'checked'] is not True:
df.at[i, 'ct'] = df.at[i, cutoff_time_label]
df.at[i, 'checked'] = True

elif df.get_value(i, 'checked') is not True:
elif df.at[i, 'checked'] is not True:

ct_val1 = df.get_value(i - 1, 'ct')
end_val1 = df.get_value(i - 1, 'end')
start_val2 = df.get_value(i, cutoff_time_label)
df.get_value(i, 'end')
ct_val1 = df.at[i - 1, 'ct']
end_val1 = df.at[i - 1, 'end']
start_val2 = df.at[i, cutoff_time_label]
df.at[i, 'end']

if ct_val1 < start_val2 < end_val1:
df.set_value(i - 1, 'ct', start_val2)
df.set_value(i, 'ct', start_val2)
df.set_value(i, 'checked', True)

df.at[i - 1, 'ct'] = start_val2
df.at[i, 'ct'] = start_val2
df.at[i, 'checked'] = True
else:
df.set_value(i, 'ct', df.get_value(i, cutoff_time_label))
df.set_value(i, 'checked', True)
df.at[i, 'ct'] = df.at[i, cutoff_time_label]
df.at[i, 'checked'] = True

if i + 1 == len(df):
break
Expand All @@ -118,14 +192,14 @@ def unify_cutoff_times_days_admission_time(self, df, cutoff_time_label):
final_date = sub_duration_greater.iloc[-1][cutoff_time_label]

for i in sub_duration_greater.index:
sub_duration_greater.set_value(i, 'ct', final_date)
sub_duration_greater.set_value(i, 'checked', True)
sub_duration_greater.at[i, 'ct'] = final_date
sub_duration_greater.at[i, 'checked'] = True

frames.append(sub_duration_greater)

for i in sub_duration_less.index:
sub_duration_less.set_value(i, 'ct', pd.NaT)
sub_duration_less.set_value(i, 'checked', False)
sub_duration_less.at[i, 'ct'] = pd.NaT
sub_duration_less.at[i, 'checked'] = False

frames.append(sub_duration_less)

Expand Down Expand Up @@ -181,13 +255,13 @@ def unify_cutoff_times_days_discharge_time(self, df, cutoff_time_label):
first_date = sub_duration_greater.iloc[0][cutoff_time_label]

for i in sub_duration_greater.index:
sub_duration_greater.set_value(i, 'ct', first_date)
sub_duration_greater.set_value(i, 'checked', True)
sub_duration_greater.at[i, 'ct'] = first_date
sub_duration_greater.at[i, 'checked'] = True
frames.append(sub_duration_greater)

for i in sub_duration_less.index:
sub_duration_less.set_value(i, 'ct', pd.NaT)
sub_duration_less.set_value(i, 'checked', False)
sub_duration_less.at[i, 'ct'] = pd.NaT
sub_duration_less.at[i, 'checked'] = False
frames.append(sub_duration_less)

result = pd.concat(frames)
Expand All @@ -212,8 +286,8 @@ def unify_cutoff_times_hours_discharge_time(self, df, cutoff_time_label):
if len(sub_hour) != 0:
first_date = sub_hour.iloc[0][cutoff_time_label]
for i in sub_hour.index:
sub_hour.set_value(i, 'ct', first_date)
sub_hour.set_value(i, 'checked', True)
sub_hour.at[i, 'ct'] = first_date
sub_hour.at[i, 'checked'] = True

frames.append(sub_hour)

Expand Down
80 changes: 13 additions & 67 deletions cardea/problem_definition/length_of_stay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,23 @@
from cardea.problem_definition import ProblemDefinition


class LengthOfStay (ProblemDefinition):
"""Defines the problem of length of stay, predicting how many days
the patient will be in the hospital.

Attributes:
target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
class LengthOfStay(ProblemDefinition):
"""Defines the problem of Length of Stay.

It predicts how many days the patient will be in the hospital.
"""

__name__ = 'los'

updated_es = None
target_label_column_name = 'length'
target_entity = 'Encounter'
cutoff_time_label = 'start'
cutoff_entity = 'Period'
conn = 'period'
prediction_type = 'regression'

def generate_cutoff_times(self, es):
"""Generates cutoff times for the predection problem.

Args:
es: fhir entityset.

Returns:
entity_set, target_entity, and a dataframe of cutoff_times and target_labels.

Raises:
ValueError: An error occurs if the cutoff variable does not exist.
"""

if (self.check_target_label(es,
self.target_entity,
self.target_label_column_name) and not
self.check_for_missing_values_in_target_label(es,
self.target_entity,
self.target_label_column_name)):
if DL().check_column_existence(es,
self.cutoff_entity,
self.cutoff_time_label):
generated_cts = self.unify_cutoff_time_admission_time(
es, self.cutoff_entity, self.cutoff_time_label)

es = es.entity_from_dataframe(entity_id=self.cutoff_entity,
dataframe=generated_cts,
index='object_id')

cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()

label = es[self.target_entity].df[self.conn].values
instance_id = list(es[self.target_entity].df.index)
cutoff_times = cutoff_times.reindex(index=label)
cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']

cutoff_times['label'] = list(
es[self.target_entity].df[self.target_label_column_name])
return(es, self.target_entity, cutoff_times)
else:
raise ValueError('Cutoff time label {} in table {}' +
'does not exist'.format(self.cutoff_time_label,
self.target_entity))

else:
updated_es = self.generate_target_label(es)
return self.generate_cutoff_times(updated_es)
def __init__(self):
super().__init__(
'length', # target_label_column_name
'Encounter', # target_entity
'start', # cutoff_time_label
'Period', # cutoff_entity
'regression', # prediction_type
conn='period'
)

def generate_target_label(self, es):
"""Generates target labels in the case of having missing label in the entityset.
Expand Down
85 changes: 23 additions & 62 deletions cardea/problem_definition/mortality_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,86 +3,47 @@
from cardea.data_loader import DataLoader
from cardea.problem_definition import ProblemDefinition

DEFAULT_CAUSES = ['X60', 'X84', 'Y87.0', 'X85', 'Y09', 'Y87.1',
'V02', 'V04', 'V09.0', 'V09.2', 'V12', 'V14']

class MortalityPrediction (ProblemDefinition):
"""Defines the problem of diagnosis Prediction.

Finding whether a patient will be diagnosed with a specifed diagnosis.
class MortalityPrediction(ProblemDefinition):
"""Defines the problem of Diagnosis Prediction.

It finds whether a patient will be diagnosed with a specifed diagnosis.

Note:
The patient visit is considered a readmission if he visits
the hospital again within 30 days.

The readmission diagnosis does not have to be the same as the initial visit diagnosis,
(he could be diagnosed of something that is a complication of the initial diagnosis).

Attributes:

target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
"""
__name__ = 'mortality'

updated_es = None
target_label_column_name = 'diagnosis'
target_entity = 'Encounter'
cutoff_time_label = 'start'
cutoff_entity = 'Period'
prediction_type = 'classification'
conn = 'period'
causes_of_death = ['X60', 'X84', 'Y87.0', 'X85', 'Y09',
'Y87.1', 'V02', 'V04', 'V09.0', 'V09.2', 'V12', 'V14']
def __init__(self, causes_of_death=DEFAULT_CAUSES):
self.causes_of_death = causes_of_death

def generate_cutoff_times(self, es):
"""Generates cutoff times for the predection problem.

Args:
es: fhir entityset.

Returns:
entity_set, target_entity, and a dataframe of cutoff_times and target_labels.

Raises:
ValueError: An error occurs if the cutoff variable does not exist.
"""
super().__init__(
'diagnosis', # target_label_column_name
'Encounter', # target_entity
'start', # cutoff_time_label
'Period', # cutoff_entity
'classification', # prediction_type
conn='period'
)

def generate_cutoff_times(self, es):
es = self.generate_target_label(es)

if DataLoader().check_column_existence(
es,
self.cutoff_entity,
self.cutoff_time_label): # check the existance of the cutoff label

generated_cts = self.unify_cutoff_time_admission_time(
es, self.cutoff_entity, self.cutoff_time_label)

es = es.entity_from_dataframe(entity_id=self.cutoff_entity,
dataframe=generated_cts,
index='object_id')
entity_set, target_entity, cutoff_times = super().generate_cutoff_times(es)

cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()
# post-processing step
for (idx, row) in cutoff_times.iterrows():
new_val = row.loc['label'] in self.causes_of_death
cutoff_times.at[idx, 'label'] = new_val

label = es[self.target_entity].df[self.conn].values
instance_id = list(es[self.target_entity].df.index)
cutoff_times = cutoff_times.reindex(index=label)

cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']

cutoff_times['label'] = list(es[self.target_entity].df[self.target_label_column_name])

for (idx, row) in cutoff_times.iterrows():
new_val = row.loc['label'] in self.causes_of_death
cutoff_times.set_value(idx, 'label', new_val)

return(es, self.target_entity, cutoff_times)
else:
raise ValueError('Cutoff time label {} in table {} does not exist'
.format(self.cutoff_time_label, self.target_entity))
return (entity_set, target_entity, cutoff_times)

def generate_target_label(self, es):
"""Generates target labels in the case of having missing label in the entityset.
Expand Down
Loading