Skip to content

Latest commit

 

History

History
162 lines (128 loc) · 7.23 KB

README.md

File metadata and controls

162 lines (128 loc) · 7.23 KB

Multi-task self-supervised learning for wearables

This repository is the official implementation of Self-supervised learning for Human Activity Recognition Using 700,000 Person-days of Wearable Data.


Figure: Overview of multi-task self-supervised learning (SSL) pipeline.

Use the pre-trained models

Required:

  • Python 3.7+
  • Torch 1.7+
import torch
import numpy as np

repo = 'OxWearables/ssl-wearables'
harnet5 = torch.hub.load(repo, 'harnet5', class_num=5, pretrained=True)
x = np.random.rand(1, 3, 150)
x = torch.FloatTensor(x)
harnet5(x)

harnet10 = torch.hub.load(repo, 'harnet10', class_num=5, pretrained=True)
x = np.random.rand(1, 3, 300)
x = torch.FloatTensor(x)
harnet10(x)

harnet30 = torch.hub.load(repo, 'harnet30', class_num=5, pretrained=True)
x = np.random.rand(1, 3, 900)
x = torch.FloatTensor(x)
harnet30(x)

This is an example of a five-class prediction for both 10-second and 30-second long examples. The assumed sampling rate is 30Hz.

The first part of these models is a feature_extractor, pre-trained using self-supervised learning. The second part is a classifier that is not trained at all. In order to use this model, you thus have to train the classifier part on a downstream task (for instance, train it for classification on any public activity recognition dataset). You should adapt the parameter class_num to the number of classes that you wish your final model to able to distinguish.

Requirements

If you would like to develop the model for your own use, you need to follow the instructions below:

Installation

conda create -n ssl_env python=3.7
conda activate ssl_env
pip install -r req.txt

Directory structure

To run the models, the data directory will have to be structured in a similar fashion as below. The ADL dataset has been included as an example.

- data:
  |_ downstream
    |_oppo
      |_ X.npy
      |_ Y.npy
      |_ pid.npy
    |_pamap2
    ...

  |_ ssl # ignore the ssl folder if you don't wish to pre-train using your own dataset
    |_ ssl_capture_24
      |_data
        |_ train
          |_ *.npy
          |_ file_list.csv # containing the paths to all the files
        |_ test
          |_ *.npy
      |_ logs
        |_models

Training

Self-supervised learning

First you will want to download the processed capture24 dataset on your local machine. Self-supervised training on capture-24 for all of the three tasks can be run using:

python mtl.py runtime.gpu=0 data.data_root=PATH2DATA runtime.is_epoch_data=True data=ssl_capture_24 task=all task.scale=false augmentation=all   model=resnet data.batch_subject_num=5 dataloader=ten_sec

It would then save the model trained into PATH2DATA/logs/models.

Fine-tuning

You will need to specify your benchmark datasets using the config files under conf/data directory. All the specified models will be evaluated sequentially.

python downstream_task_evaluation.py data=custom_10s report_root=PATH2REPORT evaluation.flip_net_path=PATH2WEIGHT data.data_root=PATH2DATA is_dist=True evaluation=all

Change the path of the full model to obtain different results. An example ADL dataset has already been included in the data folder. The weight path is the path to the model file in model_check_point. report_root can be anything where on your machine.

Pre-trained Models

You can download pretrained models here:

Dataset Subject count Arrow of Time Permutation Time-warp Link
UK-Biobank 100k ☑️ ☑️ ☑️ Download
UK-Biobank 1k ☑️ ☑️️ Download
UK-Biobank 1k ☑️ ☑️ Download
UK-Biobank 1k ☑️ ☑️ Download
Capture-24 ~150 ☑️ ☑️ ☑️ Download
Rowlands ~10 ☑️ ☑️ ☑️ Download

Results

Human activity recognition benchmarks

Our model achieves the following performance using ResNet (Mean F1 score ± SD):

Data Trained from scratch Fine-tune after ConV layers Fine-tune all layers Improvement %
Capture-24 .708 ± 094 .723 ± .097 .726 ± .093 2.5
Rowlands .696 ± .106 .724 ± .081 .796 ± .093 14.4
WISDM .684 ± .123 .759 ± .121 .810 ± .127 18.4
REALWORLD .705 ± .062 .764 ± .052 .792 ± .075 12.3
Opportunity .383 ± .124 .570 ± .078 .595 ± .085 55.4
PAMAP2 .605 ± .086 .725 ± .054 .789 ± .054 30.4
ADL .414 ± .179 .645 ± .107 .829 ± .101 100.0

Feature visualisation using UMAP

Rowlands WISDM

Result tables and figures generation can be found in the plots/* folder.

Datasets

All the data pre-processing is specified in the data_parsing folder. We have uploaded the processed dataset files for you to use. You can download them here. If you wish to process those datasets yourself, you can use data_parsing/make_*.py to understand how we processed each dataset in details.

Contributing

Our self-supervised model can help build state-of-the-art human activity recognition models with minimal effort. We expect our model to be used by people from diverse backgrounds, so please do let us know if we can make this repo easier to use. Pull requests are very welcome. Please open an issue if you have suggested improvements or a bug report. We plan to maintain this project regularly but do excuse us for a late response due to other commitments.

Reference

If you use our work, please cite:

@article{yuan2024self,
  title={Self-supervised learning for human activity recognition using 700,000 person-days of wearable data},
  author={Yuan, Hang* and Chan, Shing* and Creagh, Andrew P and Tong, Catherine and Acquah, Aidan and Clifton, David A and Doherty, Aiden},
  journal={NPJ digital medicine},
  volume={7},
  number={1},
  pages={91},
  year={2024},
  publisher={Nature Publishing Group UK London}
}

License

This software is intended for use by academics carrying out research and not for use by consumers of commercial business, see academic use licence file. If you are interested in using this software commercially, please contact Oxford University Innovation Limited to negotiate a licence. Contact details are enquiries@innovation.ox.ac.uk