Skip to content

Commit

Permalink
Adapt code to .lzma file
Browse files Browse the repository at this point in the history
  • Loading branch information
aidanacquah committed Feb 4, 2024
1 parent 18444e8 commit 14343aa
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 64 deletions.
2 changes: 1 addition & 1 deletion src/actinet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
__maintainer_email__ = "shing.chan@ndph.ox.ac.uk"
__license__ = "See LICENSE file."

__model_version__ = "ssl-ukb-c24-rw"
__model_version__ = "ssl_ukb_c24_rw_20240204"
__model_md5__ = ""

from . import _version
Expand Down
8 changes: 5 additions & 3 deletions src/actinet/actinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from actinet.summarisation import getActivitySummary, ACTIVITY_LABELS
from actinet.utils.utils import infer_freq

BASE_URL = "https://zenodo.org/records/10616280/files/"


def main():

Expand Down Expand Up @@ -86,7 +88,7 @@ def main():
model_path = pathlib.Path(__file__).parent / f"{__model_version__}.joblib.lzma"
check_md5 = args.model_path is None
model: ActivityClassifier = load_model(
args.model_path or model_path, args.model_type, check_md5, args.force_download
args.model_path or model_path, check_md5, args.force_download
)

model.verbose = verbose
Expand Down Expand Up @@ -222,14 +224,14 @@ def resolve_path(path):
return dirname, filename, extension


def load_model(model_path, model_type, check_md5=True, force_download=False):
def load_model(model_path, check_md5=True, force_download=False):
"""Load trained model. Download if not exists."""

pth = pathlib.Path(model_path)

if force_download or not pth.exists():

url = f"https://wearables-files.ndph.ox.ac.uk/files/models/stepcount/{__model_version__}.joblib.lzma"
url = f"{BASE_URL}{__model_version__}.joblib.lzma"

print(f"Downloading {url}...")

Expand Down
16 changes: 12 additions & 4 deletions src/actinet/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@ class HMM:
Implement a basic HMM model with parameter saving/loading.
"""

def __init__(self, labels=None, uniform_prior=True):
self.prior = None
self.emission = None
self.transition = None
def __init__(
self,
prior=None,
emission=None,
transition=None,
labels=None,
uniform_prior=True,
):
self.prior = prior
self.emission = emission
self.transition = transition
self.labels = labels
self.uniform_prior = uniform_prior

def __str__(self):
return (
"Hidden Markov Model\n"
"prior: {prior}\n"
"emission: {emission}\n"
"transition: {transition}\n"
Expand Down
50 changes: 32 additions & 18 deletions src/actinet/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

Expand All @@ -16,6 +15,7 @@ def __init__(
window_sec=30,
weights_path="state_dict.pt",
labels=[],
ssl_repo=None,
repo_tag="v1.0.0",
hmm_params=None,
verbose=False,
Expand All @@ -26,14 +26,26 @@ def __init__(
self.batch_size = batch_size
self.window_sec = window_sec
self.state_dict = None
self.label_encoder = LabelEncoder().fit(labels)
self.labels = labels
self.window_len = int(np.ceil(self.window_sec * sslmodel.SAMPLE_RATE))

self.verbose = verbose

self.model = self._load_ssl(ssl_repo, weights_path)

hmm_params = hmm_params or dict()
self.hmms = hmm.HMM(**hmm_params)

def __str__(self):
return (
"Activity Classifier\n"
"class_labels: {self.labels}\n"
"window_length: {self.window_sec}\n"
"batch_size: {self.batch_size}\n"
"device: {self.device}\n"
"hmm: {self.hmms}\n"
"model: {self.model}".format(self=self)
)

def predict_from_frame(self, data):

def fn(chunk):
Expand All @@ -55,39 +67,41 @@ def fn(chunk):
data, self.window_sec, fn=fn, return_index=True, verbose=self.verbose
)

Y_labels = self.label_encoder.inverse_transform(self._predict(X))

Y = raw_to_df(X, Y_labels, T, self.label_encoder.classes_, reindex=False)
Y = raw_to_df(X, self._predict(X), T, self.labels, reindex=False)

return Y

def _predict(self, X, groups=None):
def _predict(self, X):
sslmodel.verbose = self.verbose

dataset = sslmodel.NormalDataset(X, name="prediction")
dataset = sslmodel.NormalDataset(X)
dataloader = DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=0,
)

_, y_pred, _ = sslmodel.predict(
self.model, dataloader, self.device, output_logits=False
)

y_pred = self.hmms.predict(y_pred)

return y_pred

def _load_ssl(self, ssl_repo, weights):
model = sslmodel.get_sslnet(
self.device,
tag=self.repo_tag,
pretrained=False,
local_repo_path=ssl_repo,
pretrained=weights,
window_sec=self.window_sec,
num_labels=len(self.label_encoder.classes_),
num_labels=len(self.labels),
)
model.load_state_dict(self.state_dict)
model.to(self.device)

_, y_pred, _ = sslmodel.predict(
model, dataloader, self.device, output_logits=False
)

y_pred = self.hmms.predict(y_pred, groups=groups)

return y_pred
return model


def make_windows(data, window_sec, fn=None, return_index=False, verbose=True):
Expand Down
91 changes: 56 additions & 35 deletions src/actinet/sslmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,42 +184,26 @@ def save_checkpoint(self, val_loss, model):


def get_sslnet(
tag="v1.0.0", pretrained=False, window_sec: int = 30, num_labels: int = 4
device,
tag="v1.0.0",
local_repo_path=None,
pretrained=False,
window_sec: int = 30,
num_labels: int = 4,
):
"""
Load and return the Self Supervised Learning (SSL) model from pytorch hub.
Load and return the Self Supervised Learning (SSL) model from pytorch hub or local storage.
:param str device: PyTorch device to use
:param str tag: Tag on the ssl-wearables repo to check out
:param bool pretrained: Initialise the model with UKB self-supervised pretrained weights
:param str local_repo_path: Path to local version of the SSL repo for offline usage
:param bool/str pretrained: Initialise the model with UKB self-supervised pretrained weights
:param int window_sec: The length of the window of data in seconds (limited to 5, 10 or 30)
:param int num_labels: The number of labels to predict
:return: pytorch SSL model
:rtype: nn.Module
"""

repo_name = "ssl-wearables"
repo = f"OxWearables/{repo_name}:{tag}"

if not torch_cache_path.exists():
Path.mkdir(torch_cache_path, parents=True, exist_ok=True)

torch.hub.set_dir(str(torch_cache_path))

# find repo cache dir that matches repo name and tag
cache_dirs = [f for f in torch_cache_path.iterdir() if f.is_dir()]
repo_path = next(
(f for f in cache_dirs if repo_name in f.name and tag in f.name), None
)

if repo_path is None:
repo_path = repo
source = "github"
else:
repo_path = str(repo_path)
source = "local"
if verbose:
print(f"Using local {repo_path}")

if window_sec not in [5, 10, 30]:
raise ValueError(
"Length of window in seconds must be either 5, 10 or 30 seconds"
Expand All @@ -228,15 +212,52 @@ def get_sslnet(
if num_labels < 1:
raise ValueError("Numer of class labels should be > 0")

sslnet: nn.Module = torch.hub.load(
repo_path,
f"harnet{window_sec}",
trust_repo=True,
source=source,
class_num=num_labels,
pretrained=pretrained,
verbose=verbose,
)
if local_repo_path is not None:
sslnet: nn.Module = torch.hub.load(
local_repo_path,
f"harnet{window_sec}",
source="local",
class_num=num_labels,
pretrained=pretrained == True,
)

else:
repo_name = "ssl-wearables"
repo = f"OxWearables/{repo_name}:{tag}"

if not torch_cache_path.exists():
Path.mkdir(torch_cache_path, parents=True, exist_ok=True)

torch.hub.set_dir(str(torch_cache_path))

# find repo cache dir that matches repo name and tag
cache_dirs = [f for f in torch_cache_path.iterdir() if f.is_dir()]
repo_path = next(
(f for f in cache_dirs if repo_name in f.name and tag in f.name), None
)

if repo_path is None:
repo_path = repo
source = "github"
else:
repo_path = str(repo_path)
source = "local"
if verbose:
print(f"Using local {repo_path}")

sslnet: nn.Module = torch.hub.load(
repo_path,
f"harnet{window_sec}",
trust_repo=True,
source=source,
class_num=num_labels,
pretrained=pretrained == True,
verbose=verbose,
)

model_dict = torch.load(pretrained, map_location=device)
sslnet.load_state_dict(model_dict)

return sslnet


Expand Down
4 changes: 1 addition & 3 deletions src/actinet/summarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from actinet.utils.utils import date_parser, toScreen
from actinet import circadian

ACTIVITY_LABELS = ["light", "MVPA", "sedentary", "sleep"]
ACTIVITY_LABELS = ["light", "moderate-vigorous", "sedentary", "sleep"]


def getActivitySummary(
Expand Down Expand Up @@ -98,8 +98,6 @@ def _summarise(
for col in cols:
summary[f"day{i}-recorded-{col}(hrs)"] = row.loc[col]

summary["day_avg"]

# Calculate empirical cumulative distribution function of vector magnitudes
if intensityDistribution:
summary = calculateECDF(data["acc"], summary)
Expand Down

0 comments on commit 14343aa

Please sign in to comment.