Skip to content

Commit

Permalink
Merge pull request #109 from mjq2020/dev
Browse files Browse the repository at this point in the history
Optimizing Dataset Sampling
  • Loading branch information
mjq2020 authored Jul 13, 2023
2 parents 325d7f7 + 478c1b5 commit d907922
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 17 deletions.
19 changes: 11 additions & 8 deletions configs/accelerometer/3axes_accelerometer_62.5Hz_1s_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@

num_classes = 3
num_axes = 3
frequency = 62.5
window = 1000
window_size = 30
stride = 20

model = dict(
type='AccelerometerClassifier',
backbone=dict(
type='AxesNet',
num_axes=num_axes,
frequency=frequency,
window=window,
window_size=window_size,
num_classes=num_classes,
),
head=dict(
type='edgelab.ClsHead',
type='edgelab.AxesClsHead',
loss=dict(type='mmcls.CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
Expand All @@ -29,15 +28,15 @@
batch_size = 1
workers = 1

shape = num_classes * int(62.5 * 1000 / 1000)
shape = [1, num_axes * window_size]

train_pipeline = [
dict(type='edgelab.LoadSensorFromFile'),
# dict(type='edgelab.LoadSensorFromFile'),
dict(type='edgelab.PackSensorInputs'),
]

test_pipeline = [
dict(type='edgelab.LoadSensorFromFile'),
# dict(type='edgelab.LoadSensorFromFile'),
dict(type='edgelab.PackSensorInputs'),
]

Expand All @@ -49,6 +48,8 @@
data_root=data_root,
data_prefix='training',
ann_file='info.labels',
window_size=window_size,
stride=stride,
pipeline=train_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=True),
Expand All @@ -61,6 +62,8 @@
dataset=dict(
type=dataset_type,
data_root=data_root,
window_size=window_size,
stride=stride,
data_prefix='testing',
ann_file='info.labels',
pipeline=test_pipeline,
Expand Down
67 changes: 62 additions & 5 deletions edgelab/datasets/sensordataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os
from typing import Optional, Union
import os.path as osp
from typing import List, Optional, Union

import cbor
import numpy as np
from mmcls.datasets import CustomDataset

from edgelab.registry import DATASETS
Expand All @@ -17,6 +20,11 @@ def __init__(
metainfo: Optional[dict] = None,
data_root: str = '',
data_prefix: Union[str, dict] = '',
window_size: int = 80,
stride: int = 30,
retention: float = 0.8,
# source: str = 'EI',
flatten: bool = True,
multi_label: bool = False,
**kwargs,
):
Expand All @@ -26,6 +34,12 @@ def __init__(
self.data_root = data_root
self.ann_file = ann_file
self.data_prefix = data_prefix
self.window_size = window_size
self.stride = stride
self.retention = retention
self.flatten = flatten

self.data_dir = osp.join(self.data_root, self.data_prefix)

self.info_lables = json.load(open(os.path.join(self.data_root, self.data_prefix, self.ann_file)))

Expand Down Expand Up @@ -58,7 +72,6 @@ def _find_samples(self):
gt_label = j
break
samples.append((filename, gt_label))
print(samples)
return samples

def load_data_list(self):
Expand All @@ -75,12 +88,56 @@ def load_data_list(self):

data_list = []
for filename, gt_label in samples:
img_path = os.path.join(self.img_prefix, filename)
info = {'file_path': img_path, 'gt_label': int(gt_label)}
data_list.append(info)
ann_path = os.path.join(self.data_dir, filename)
data_list.extend(
[{'data': np.asanyarray([data]), 'gt_label': int(gt_label)} for data in self.read_split_data(ann_path)]
)

return data_list

def read_split_data(self, file_path: str) -> List:
if file_path.lower().endswith('.cbor'):
with open(file_path, 'rb') as f:
data = cbor.loads(f.read())
elif file_path.lower().endswith('.json'):
with open(file_path, 'r') as f:
data = json.load(f)

values = np.asanyarray(data['payload']['values'])

result = []
values_len = len(values)
if values_len <= self.window_size:
result.append(self.pad_data(values, self.window_size).transpose(0, 1).reshape(-1))
else:
indexes = range(0, values_len, self.stride)
for i in indexes:
if (values_len - i + 1) < self.window_size or i == indexes[-1]:
if self.retention * self.window_size < (values_len - i + 1):
data = self.pad_data(values[i:], self.window_size)
else:
continue
else:
end = i + self.window_size
if end >= values_len:
if self.retention * self.window_size < (values_len - i + 1):
data = self.pad_data(values[i:], self.window_size)
else:
continue
else:
data = values[i:end]
if self.flatten:
data = data.transpose(0, 1).reshape(-1)
result.append(data)
return result

def pad_data(self, data: np.asanyarray, total_len: int, mode='constant', pad_val=0) -> np.array:
pad_len = total_len - len(data)
front = pad_len // 2
arfter = pad_len - front
data = np.pad(data, ((front, arfter), (0, 0)), mode=mode, constant_values=pad_val)
return data

def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return True
7 changes: 3 additions & 4 deletions edgelab/models/backbones/AxesNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@

@MODELS.register_module()
class AxesNet(nn.Module):
def __init__(
self, num_axes=3, frequency=62.5, window=1000, num_classes=-1 # axes number # sample frequency # window size
):
def __init__(self, num_axes=3, window_size=80, num_classes=-1): # axes number # sample frequency # window size
super().__init__()
self.num_classes = num_classes
self.intput_feature = num_axes * int(frequency * window / 1000)
self.intput_feature = num_axes * window_size
liner_feature = self.liner_feature_fit()
self.fc1 = nn.Linear(in_features=self.intput_feature, out_features=liner_feature, bias=True)
self.fc2 = nn.Linear(in_features=liner_feature, out_features=liner_feature, bias=True)
Expand All @@ -23,6 +21,7 @@ def liner_feature_fit(self):
return (int(self.intput_feature / 1024) + 1) * 256

def forward(self, x):
x = x[0] if isinstance(x, list) else x
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))

Expand Down
3 changes: 3 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# common
albumentations>=1.3.0

# sensor
cbor
numpy>=1.23.5

# vision
Expand Down

0 comments on commit d907922

Please sign in to comment.