Skip to content

Commit ddc2b58

Browse files
committed
feat: add functionality to plot time series of steps
1 parent f9e9119 commit ddc2b58

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def get_string(string, rel_path="src/stepcount/__init__.py"):
6262
"torch==1.13.*",
6363
"torchvision==0.14.*",
6464
"transforms3d==0.4.*",
65-
"numba==0.58.*"
65+
"numba==0.58.*",
66+
"matplotlib==3.7.*",
6667
],
6768
extras_require={
6869
"dev": [

src/stepcount/stepcount.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import pandas as pd
1313
import joblib
14+
import matplotlib.pyplot as plt
15+
import matplotlib.dates as mdates
1416
from numba import njit
1517

1618
from stepcount import utils
@@ -388,6 +390,10 @@ def main():
388390
print(daily_adj.set_index('Date').drop(columns='Filename'))
389391
print("\nOutput files saved in:", outdir)
390392

393+
print("\nPlotting...")
394+
fig = plot(Y, title=basename)
395+
fig.savefig(f"{outdir}/{basename}-Steps.png", bbox_inches='tight', pad_inches=0)
396+
391397
after = time.time()
392398
print(f"Done! ({round(after - before,2)}s)")
393399

@@ -1085,6 +1091,73 @@ def numba_detect_bouts(
10851091
return bouts
10861092

10871093

1094+
def plot(Y, title=None):
1095+
"""
1096+
Plot time series of steps per minute for each day.
1097+
1098+
Parameters:
1099+
- Y: pandas Series or DataFrame with a 'Steps' column. Must have a DatetimeIndex.
1100+
1101+
Returns:
1102+
- fig: matplotlib figure object
1103+
"""
1104+
1105+
MAX_STEPS_PER_MINUTE = 180
1106+
1107+
if isinstance(Y, pd.DataFrame):
1108+
Y = Y['Steps']
1109+
1110+
assert isinstance(Y, pd.Series), "Y must be a pandas Series, or a DataFrame with a 'Steps' column"
1111+
1112+
# Resample to 1 minute intervals
1113+
# Note: .sum() returns 0 when all values are NaN, so we need to use a custom function
1114+
def _sum(x):
1115+
if x.isna().all():
1116+
return np.nan
1117+
return x.sum()
1118+
1119+
Y = Y.resample('1T').agg(_sum)
1120+
1121+
dates_index = Y.index.normalize()
1122+
unique_dates = dates_index.unique()
1123+
1124+
# Set the plot figure and size
1125+
fig = plt.figure(figsize=(10, len(unique_dates) * 2))
1126+
1127+
# Group by each day
1128+
for i, (day, y) in enumerate(Y.groupby(dates_index)):
1129+
ax = fig.add_subplot(len(unique_dates), 1, i + 1)
1130+
1131+
# Plot steps
1132+
ax.plot(y.index, y, label='steps/min')
1133+
1134+
# Grey shading where NA
1135+
ax.fill_between(y.index, 0, MAX_STEPS_PER_MINUTE, where=y.isna(), color='grey', alpha=0.3, interpolate=True, label='missing')
1136+
1137+
# Formatting the x-axis to show hours and minutes
1138+
ax.xaxis.set_major_locator(mdates.HourLocator(interval=1))
1139+
ax.xaxis.set_minor_locator(mdates.MinuteLocator(interval=15))
1140+
ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
1141+
1142+
# Set x-axis limits to start at 00:00 and end at 24:00
1143+
ax.set_xlim(day, day + pd.DateOffset(days=1))
1144+
# Set y-axis limits
1145+
ax.set_ylim(-10, MAX_STEPS_PER_MINUTE)
1146+
1147+
ax.tick_params(axis='x', rotation=45)
1148+
ax.set_ylabel('steps/min')
1149+
ax.set_title(day.strftime('%Y-%m-%d'))
1150+
ax.grid(True)
1151+
ax.legend(loc='upper left')
1152+
1153+
if title:
1154+
fig.suptitle(title)
1155+
1156+
fig.tight_layout()
1157+
1158+
return fig
1159+
1160+
10881161

10891162
if __name__ == '__main__':
10901163
main()

0 commit comments

Comments
 (0)