Skip to content

Commit

Permalink
refactor: flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
chanshing committed Nov 19, 2024
1 parent a1b206d commit 04d4ba5
Showing 1 changed file with 33 additions and 31 deletions.
64 changes: 33 additions & 31 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def encode_one_hot(y):
3 -> 0,0,0,1,0
4 -> 0,0,0,0,1
'''
return (y.reshape(-1,1) == np.arange(NUM_CLASSES)).astype(int)
return (y.reshape(-1, 1) == np.arange(NUM_CLASSES)).astype(int)


def train_hmm(Y_pred, y_true):
Expand All @@ -87,12 +87,12 @@ def train_hmm(Y_pred, y_true):
if Y_pred.ndim == 1 or Y_pred.shape[1] == 1:
Y_pred = encode_one_hot(Y_pred)

prior = np.mean(y_true.reshape(-1,1) == np.arange(NUM_CLASSES), axis=0)
prior = np.mean(y_true.reshape(-1, 1) == np.arange(NUM_CLASSES), axis=0)
emission = np.vstack(
[np.mean(Y_pred[y_true==i], axis=0) for i in range(NUM_CLASSES)]
[np.mean(Y_pred[y_true == i], axis=0) for i in range(NUM_CLASSES)]
)
transition = np.vstack(
[np.mean(y_true[1:][(y_true==i)[:-1]].reshape(-1,1) == np.arange(NUM_CLASSES), axis=0)
[np.mean(y_true[1:][(y_true == i)[:-1]].reshape(-1, 1) == np.arange(NUM_CLASSES), axis=0)
for i in range(NUM_CLASSES)]
)
return prior, emission, transition
Expand All @@ -107,19 +107,19 @@ def log(x):

num_obs = len(y_pred)
probs = np.zeros((num_obs, NUM_CLASSES))
probs[0,:] = log(prior) + log(emission[:, y_pred[0]])
probs[0, :] = log(prior) + log(emission[:, y_pred[0]])
for j in range(1, num_obs):
for i in range(NUM_CLASSES):
probs[j,i] = np.max(
log(emission[i, y_pred[j]]) + \
log(transition[:, i]) + \
probs[j-1,:]) # probs already in log scale
probs[j, i] = np.max(
log(emission[i, y_pred[j]]) +
log(transition[:, i]) +
probs[j - 1, :]) # probs already in log scale
viterbi_path = np.zeros_like(y_pred)
viterbi_path[-1] = np.argmax(probs[-1,:])
for j in reversed(range(num_obs-1)):
viterbi_path[-1] = np.argmax(probs[-1, :])
for j in reversed(range(num_obs - 1)):
viterbi_path[j] = np.argmax(
log(transition[:, viterbi_path[j+1]]) + \
probs[j,:]) # probs already in log scale
log(transition[:, viterbi_path[j + 1]]) +
probs[j, :]) # probs already in log scale

return viterbi_path

Expand All @@ -132,11 +132,11 @@ def compute_scores(y_true, y_pred):
balanced_acuracy = metrics.balanced_accuracy_score(y_true, y_pred)
kappa = metrics.cohen_kappa_score(y_true, y_pred)
return {
'confusion':confusion,
'per_class_recall':per_class_recall,
'confusion': confusion,
'per_class_recall': per_class_recall,
'accuracy': accuracy,
'balanced_accuracy': balanced_acuracy,
'kappa':kappa,
'kappa': kappa,
}


Expand Down Expand Up @@ -166,9 +166,11 @@ def print_scores(scores):
from pandas.plotting import register_matplotlib_converters
register_matplotlib_converters()
from datetime import datetime, timedelta, time


def plot_activity(x, y, t):
''' Plot activity timeseries '''
BACKGROUND_COLOR = '#d3d3d3' # lightgray
BACKGROUND_COLOR = '#d3d3d3' # lightgray

def split_by_timegap(group, seconds=30):
subgroupIDs = (group.index.to_series().diff() > timedelta(seconds=seconds)).cumsum()
Expand All @@ -177,9 +179,9 @@ def split_by_timegap(group, seconds=30):

convert_date = np.vectorize(
lambda day, x: matplotlib.dates.date2num(datetime.combine(day, x)))
timeseries = pd.DataFrame(data={'x':x, 'y':y, 't':t})
timeseries = pd.DataFrame(data={'x': x, 'y': y, 't': t})
timeseries.set_index('t', inplace=True)
timeseries['x'] = timeseries['x'].rolling(window=12, min_periods=1).mean() #! inplace?
timeseries['x'] = timeseries['x'].rolling(window=12, min_periods=1).mean() # ! inplace?
ylim_min, ylim_max = np.min(x), np.max(x)
groups = timeseries.groupby(timeseries.index.date)
fig, axs = plt.subplots(nrows=len(groups) + 1)
Expand All @@ -192,16 +194,16 @@ def split_by_timegap(group, seconds=30):

ax.get_xaxis().grid(True, which='major', color='grey', alpha=0.5)
ax.get_xaxis().grid(True, which='minor', color='grey', alpha=0.25)
ax.set_xlim((datetime.combine(day,time(0, 0, 0, 0)),
datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0))))
ax.set_xticks(pd.date_range(start=datetime.combine(day,time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='4H'))
ax.set_xticks(pd.date_range(start=datetime.combine(day,time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='1H'), minor=True)
ax.set_xlim((datetime.combine(day, time(0, 0, 0, 0)),
datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0))))
ax.set_xticks(pd.date_range(start=datetime.combine(day, time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='4H'))
ax.set_xticks(pd.date_range(start=datetime.combine(day, time(0, 0, 0, 0)),
end=datetime.combine(day + timedelta(days=1), time(0, 0, 0, 0)),
freq='1H'), minor=True)
ax.set_ylim((ylim_min, ylim_max))
ax.get_yaxis().set_ticks([]) # hide y-axis lables
ax.get_yaxis().set_ticks([]) # hide y-axis lables
ax.spines['top'].set_color(BACKGROUND_COLOR)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
Expand All @@ -225,16 +227,16 @@ def split_by_timegap(group, seconds=30):
for color, label in zip(COLORS, CLASSES):
legend_patches.append(mpatches.Patch(facecolor=color, label=label, alpha=0.5))
axs[-1].legend(handles=legend_patches, bbox_to_anchor=(0., 0., 1., 1.),
loc='center', ncol=min(3,len(legend_patches)), mode="best",
borderaxespad=0, framealpha=0.6, frameon=True, fancybox=True)
loc='center', ncol=min(3, len(legend_patches)), mode="best",
borderaxespad=0, framealpha=0.6, frameon=True, fancybox=True)
axs[-1].spines['left'].set_visible(False)
axs[-1].spines['right'].set_visible(False)
axs[-1].spines['top'].set_visible(False)
axs[-1].spines['bottom'].set_visible(False)

# format x-axis to show hours
fig.autofmt_xdate()
hours = [(str(hr) + 'am') if hr<=12 else (str(hr-12) + 'pm') for hr in range(0,24,4)]
hours = [(str(hr) + 'am') if hr <= 12 else (str(hr - 12) + 'pm') for hr in range(0, 24, 4)]
axs[0].set_xticklabels(hours)
axs[0].tick_params(labelbottom=False, labeltop=True, labelleft=False)

Expand Down

0 comments on commit 04d4ba5

Please sign in to comment.