Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slurm #52

Merged
merged 38 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e055d9e
Files modified to be run on arc and slurm
Benjamin-Walker Aug 4, 2023
1bfe93a
Slurm file
Benjamin-Walker Aug 4, 2023
e956e98
Merge with main
Benjamin-Walker Aug 4, 2023
fe57c17
Tidied up implementation
Benjamin-Walker Aug 4, 2023
eb3329f
Updated slurm file and added lr_scheduler as argument
Benjamin-Walker Aug 4, 2023
25648b4
Added data_dir as argument
Benjamin-Walker Aug 4, 2023
793e005
Slurm training working.
Benjamin-Walker Aug 4, 2023
133497d
Hyperparameter search for lstm, lru, and ssm.
Benjamin-Walker Aug 4, 2023
b3a4c8c
Added training accuracy and schedule flag in output name
Benjamin-Walker Aug 14, 2023
a5fa24a
Merge with main
Benjamin-Walker Aug 15, 2023
013a9b8
Merge branch 'main' into slurm
Benjamin-Walker Aug 16, 2023
448219e
Log-NCDE experiments
Benjamin-Walker Aug 24, 2023
4a37742
Merged presplit data fix from main
Benjamin-Walker Aug 24, 2023
3656122
Merge with main
Benjamin-Walker Aug 25, 2023
08be4c9
Merge branch 'main' into slurm
Benjamin-Walker Aug 25, 2023
f2fea4e
Merge branch 'main' into slurm
Benjamin-Walker Sep 6, 2023
5d49af6
Merged with main
Benjamin-Walker Sep 6, 2023
9177642
Merged with main
Benjamin-Walker Sep 11, 2023
65118eb
Merge branch 'main' into slurm
Benjamin-Walker Sep 15, 2023
a445763
Slurm experiments for log_ncde, ncde, and nrde
Benjamin-Walker Nov 7, 2023
068ef29
Fixed jvp
Benjamin-Walker Nov 7, 2023
6f20ebf
test for log_ncde
Benjamin-Walker Nov 8, 2023
337623d
Log-NCDE Experiments
Benjamin-Walker Nov 22, 2023
537fb51
Added slurm loop for repeat experiments
Benjamin-Walker Dec 1, 2023
57ecf4b
Merge branch 'slurm' of https://github.com/Benjamin-Walker/Log-Neural…
Benjamin-Walker Dec 1, 2023
1ec2c73
Repeat experiments
Benjamin-Walker Dec 18, 2023
7e14e58
One channel missing
Benjamin-Walker Dec 18, 2023
d1bffbc
Multiple missing channels
Benjamin-Walker Dec 19, 2023
d0174f9
Added flag for missing channels
Benjamin-Walker Jan 10, 2024
a060ca1
Merge with main
Benjamin-Walker Jan 10, 2024
abe6a4c
Modified NRDE
Benjamin-Walker Jan 19, 2024
88c5051
Repeat experiments with fixed splits
Benjamin-Walker Mar 15, 2024
a414301
Added speech dataset
Benjamin-Walker Mar 15, 2024
a12e903
Final repeat experiments
Benjamin-Walker Apr 16, 2024
29c0482
Merge with main
Benjamin-Walker Apr 16, 2024
d0c11dd
Updated NeuralRDE
Benjamin-Walker Jun 3, 2024
5582e46
Merge branch 'main' into slurm
Benjamin-Walker Jun 3, 2024
0e667ad
Pre-commit fix in datasets_LogNCDE.py
Benjamin-Walker Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -244,4 +244,4 @@ fabric.properties
.idea/webServers.xml

# No Images
*.png
*.png
30 changes: 30 additions & 0 deletions best_hyperparameters.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
lru EigenWorms T 1.00 time True nsteps 100000 lr 0.001 num blocks 6 hidden dim 128 vf depth 1 vf width 1 ssm dim 16 ssm blocks 8 dt0 1 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
lru EthanolConcentration T 1.00 time False nsteps 100000 lr 0.0001 num blocks 2 hidden dim 128 vf depth 1 vf width 1 ssm dim 16 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
lru Heartbeat T 1.00 time False nsteps 100000 lr 0.0001 num blocks 6 hidden dim 128 vf depth 1 vf width 1 ssm dim 256 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
lru MotorImagery T 1.00 time False nsteps 100000 lr 0.001 num blocks 6 hidden dim 64 vf depth 1 vf width 1 ssm dim 256 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
lru SelfRegulationSCP1 T 1.00 time True nsteps 100000 lr 0.0001 num blocks 6 hidden dim 64 vf depth 1 vf width 1 ssm dim 256 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
lru SelfRegulationSCP2 T 1.00 time False nsteps 100000 lr 0.001 num blocks 6 hidden dim 128 vf depth 1 vf width 1 ssm dim 64 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
ncde EigenWorms T 1.00 time True nsteps 10000 lr 0.0001 num blocks 1 hidden dim 128 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 5.560189046427579e-05 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ncde EthanolConcentration T 1.00 time True nsteps 10000 lr 1e-05 num blocks 1 hidden dim 16 vf depth 3 vf width 64 ssm dim 128 ssm blocks 1 dt0 0.0005707762557077625 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ncde Heartbeat T 1.00 time False nsteps 10000 lr 0.001 num blocks 1 hidden dim 16 vf depth 2 vf width 32 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ncde MotorImagery T 1.00 time False nsteps 10000 lr 0.0001 num blocks 1 hidden dim 128 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0003332222592469177 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
ncde SelfRegulationSCP1 T 1.00 time False nsteps 10000 lr 0.0001 num blocks 1 hidden dim 16 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0011148272017837235 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ncde SelfRegulationSCP2 T 1.00 time True nsteps 10000 lr 1e-05 num blocks 1 hidden dim 64 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0008673026886383347 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ssm EigenWorms T 1.00 time True nsteps 100000 lr 0.001 num blocks 2 hidden dim 128 vf depth 1 vf width 1 ssm dim 256 ssm blocks 8 dt0 1 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
ssm EthanolConcentration T 1.00 time True nsteps 100000 lr 0.001 num blocks 6 hidden dim 16 vf depth 1 vf width 1 ssm dim 16 ssm blocks 2 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
ssm Heartbeat T 1.00 time False nsteps 100000 lr 0.0001 num blocks 2 hidden dim 64 vf depth 1 vf width 1 ssm dim 16 ssm blocks 4 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
ssm MotorImagery T 1.00 time True nsteps 100000 lr 0.0001 num blocks 4 hidden dim 128 vf depth 1 vf width 1 ssm dim 256 ssm blocks 4 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
ssm SelfRegulationSCP1 T 1.00 time False nsteps 100000 lr 0.0001 num blocks 4 hidden dim 16 vf depth 1 vf width 1 ssm dim 64 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
ssm SelfRegulationSCP2 T 1.00 time True nsteps 100000 lr 0.0001 num blocks 2 hidden dim 64 vf depth 1 vf width 1 ssm dim 256 ssm blocks 8 dt0 1 solver None stepsize controller None scale 1 lambd 0.0 seed 1234
nrde EigenWorms T 1.00 time True nsteps 10000 lr 0.001 stepsize 16.00 depth 2 num blocks 1 hidden dim 16 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0008888888888888889 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0 seed 1234
nrde EthanolConcentration T 1.00 time False nsteps 100000 lr 0.0001 stepsize 1.00 depth 1 num blocks 1 hidden dim 128 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0005707762557077625 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
nrde Heartbeat T 1.00 time False nsteps 100000 lr 0.0001 stepsize 2.00 depth 2 num blocks 1 hidden dim 128 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
nrde MotorImagery T 1.00 time False nsteps 100000 lr 0.0001 stepsize 1.00 depth 1 num blocks 1 hidden dim 128 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0003332222592469177 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
nrde SelfRegulationSCP1 T 1.00 time False nsteps 100000 lr 0.001 stepsize 4.00 depth 2 num blocks 1 hidden dim 16 vf depth 2 vf width 32 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
nrde SelfRegulationSCP2 T 1.00 time False nsteps 100000 lr 0.001 stepsize 1.00 depth 1 num blocks 1 hidden dim 16 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0008673026886383347 solver Heun stepsize controller ConstantStepSize scale 1 lambd 0.0 seed 1234
log_ncde EigenWorms T 1.00 time False nsteps 10000 lr 0.0001 stepsize 12.00 depth 2 num blocks 1 hidden dim 64 vf depth 4 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.00066711140760507 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 1e-06 seed 1234
log_ncde EthanolConcentration T 1.00 time True nsteps 10000 lr 0.0001 stepsize 2.00 depth 2 num blocks 1 hidden dim 64 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.001141552511415525 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 0 seed 1234
log_ncde Heartbeat T 1.00 time True nsteps 10000 lr 0.001 stepsize 2.00 depth 2 num blocks 1 hidden dim 16 vf depth 2 vf width 32 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 0 seed 1234
log_ncde MotorImagery T 1.00 time True nsteps 10000 lr 0.0001 stepsize 4.00 depth 2 num blocks 1 hidden dim 64 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.0013315579227696406 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 1e-06 seed 1234
log_ncde SelfRegulationSCP1 T 1.00 time False nsteps 10000 lr 0.001 stepsize 12.00 depth 2 num blocks 1 hidden dim 128 vf depth 3 vf width 64 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 0.001 seed 1234
log_ncde SelfRegulationSCP2 T 1.00 time True nsteps 10000 lr 0.001 stepsize 8.00 depth 2 num blocks 1 hidden dim 64 vf depth 3 vf width 128 ssm dim 128 ssm blocks 1 dt0 0.002 solver Heun stepsize controller ConstantStepSize scale 1000 lambd 0.001 seed 1234
10 changes: 5 additions & 5 deletions conda_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pre-commit=3.3.1
sktime=0.17.2
jaxlib=0.4.7
jax=0.4.9
tqdm=4.65.0
pre-commit
sktime
jaxlib
jax
tqdm
197 changes: 135 additions & 62 deletions data/datasets_LogNCDE.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
import pickle
from dataclasses import dataclass
Expand Down Expand Up @@ -95,7 +96,6 @@ def dataset_generator(
include_time,
T,
inmemory=True,
coeffs_needed=True,
idxs=None,
use_presplit=False,
*,
Expand Down Expand Up @@ -125,46 +125,70 @@ def dataset_generator(
val_data, val_labels = data[idxs[1]], labels[idxs[1]]
test_data, test_labels = None, None

zero_channels_out = False

if zero_channels_out:
n_channels = max(1, math.floor(0.15 * train_data.shape[-1]))

train_key, val_key, test_key, key = jr.split(key, 4)
train_all_zero = jr.choice(
train_key,
jnp.arange(1, train_data.shape[2]),
shape=(len(train_data), n_channels),
)
val_all_zero = jr.choice(
val_key,
jnp.arange(1, val_data.shape[2]),
shape=(len(val_data), n_channels),
)

for i in range(len(train_data)):
train_data = train_data.at[i, :, train_all_zero[i]].set(0)
for i in range(len(val_data)):
val_data = val_data.at[i, :, val_all_zero[i]].set(0)

if test_data is not None:
test_all_zero = jr.choice(
test_key,
jnp.arange(1, test_data.shape[2]),
shape=(len(test_data), n_channels),
)
for i in range(len(test_data)):
test_data = test_data.at[i, :, test_all_zero[i]].set(0)

train_paths = batch_calc_paths(train_data, stepsize, depth)
val_paths = batch_calc_paths(val_data, stepsize, depth)
test_paths = batch_calc_paths(test_data, stepsize, depth)
intervals = jnp.arange(0, train_data.shape[1], stepsize)
intervals = jnp.concatenate((intervals, jnp.array([train_data.shape[1]])))
intervals = intervals * (T / train_data.shape[1])

if coeffs_needed:
train_coeffs = calc_coeffs(train_data, include_time, T)
val_coeffs = calc_coeffs(val_data, include_time, T)
test_coeffs = calc_coeffs(test_data, include_time, T)
train_coeff_data = (
(T / train_data.shape[1])
* jnp.repeat(
jnp.arange(train_data.shape[1])[None, :], train_data.shape[0], axis=0
),
train_coeffs,
train_data[:, 0, :],
)
val_coeff_data = (
(T / val_data.shape[1])
train_coeffs = calc_coeffs(train_data, include_time, T)
val_coeffs = calc_coeffs(val_data, include_time, T)
test_coeffs = calc_coeffs(test_data, include_time, T)
train_coeff_data = (
(T / train_data.shape[1])
* jnp.repeat(
jnp.arange(train_data.shape[1])[None, :], train_data.shape[0], axis=0
),
train_coeffs,
train_data[:, 0, :],
)
val_coeff_data = (
(T / val_data.shape[1])
* jnp.repeat(jnp.arange(val_data.shape[1])[None, :], val_data.shape[0], axis=0),
val_coeffs,
val_data[:, 0, :],
)
if idxs is None:
test_coeff_data = (
(T / test_data.shape[1])
* jnp.repeat(
jnp.arange(val_data.shape[1])[None, :], val_data.shape[0], axis=0
jnp.arange(test_data.shape[1])[None, :], test_data.shape[0], axis=0
),
val_coeffs,
val_data[:, 0, :],
test_coeffs,
test_data[:, 0, :],
)
if idxs is None:
test_coeff_data = (
(T / test_data.shape[1])
* jnp.repeat(
jnp.arange(test_data.shape[1])[None, :], test_data.shape[0], axis=0
),
test_coeffs,
test_data[:, 0, :],
)
else:
train_coeff_data = None
val_coeff_data = None
test_coeff_data = None

train_path_data = (
(T / train_data.shape[1])
Expand Down Expand Up @@ -226,42 +250,62 @@ def dataset_generator(


def create_uea_dataset(
data_dir, name, use_idxs, use_presplit, stepsize, depth, include_time, T, *, key
data_dir,
name,
use_idxs,
use_presplit,
stepsize,
depth,
include_time,
T,
seed,
*,
key,
):

if use_presplit:
idxs = None
with open(data_dir + f"/processed/UEA/{name}/X_train.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/X_train.pkl", "rb") as f:
train_data = pickle.load(f)
with open(data_dir + f"/processed/UEA/{name}/y_train.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/y_train.pkl", "rb") as f:
train_labels = pickle.load(f)
with open(data_dir + f"/processed/UEA/{name}/X_val.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/X_val.pkl", "rb") as f:
val_data = pickle.load(f)
with open(data_dir + f"/processed/UEA/{name}/y_val.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/y_val.pkl", "rb") as f:
val_labels = pickle.load(f)
with open(data_dir + f"/processed/UEA/{name}/X_test.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/X_test.pkl", "rb") as f:
test_data = pickle.load(f)
with open(data_dir + f"/processed/UEA/{name}/y_test.pkl", "rb") as f:
with open(data_dir + f"/processed/UEA/{name}/{seed}/y_test.pkl", "rb") as f:
test_labels = pickle.load(f)
onehot_train_labels = jnp.zeros(
(len(train_labels), len(jnp.unique(train_labels)))
)
onehot_train_labels = onehot_train_labels.at[
jnp.arange(len(train_labels)), train_labels
].set(1)
onehot_val_labels = jnp.zeros((len(val_labels), len(jnp.unique(val_labels))))
onehot_val_labels = onehot_val_labels.at[
jnp.arange(len(val_labels)), val_labels
].set(1)
onehot_test_labels = jnp.zeros((len(test_labels), len(jnp.unique(test_labels))))
onehot_test_labels = onehot_test_labels.at[
jnp.arange(len(test_labels)), test_labels
].set(1)
if include_time:
ts = (T / train_data.shape[1]) * jnp.repeat(
jnp.arange(train_data.shape[1])[None, :], train_data.shape[0], axis=0
)
train_data = jnp.concatenate([ts[:, :, None], train_data[:, :, 1:]], axis=2)
train_data = jnp.concatenate([ts[:, :, None], train_data], axis=2)
ts = (T / val_data.shape[1]) * jnp.repeat(
jnp.arange(val_data.shape[1])[None, :], val_data.shape[0], axis=0
)
val_data = jnp.concatenate([ts[:, :, None], val_data[:, :, 1:]], axis=2)
val_data = jnp.concatenate([ts[:, :, None], val_data], axis=2)
ts = (T / test_data.shape[1]) * jnp.repeat(
jnp.arange(test_data.shape[1])[None, :], test_data.shape[0], axis=0
)
test_data = jnp.concatenate([ts[:, :, None], test_data[:, :, 1:]], axis=2)
else:
train_data = train_data[:, :, 1:]
val_data = val_data[:, :, 1:]
test_data = test_data[:, :, 1:]
test_data = jnp.concatenate([ts[:, :, None], test_data], axis=2)
data = (train_data, val_data, test_data)
onehot_labels = (train_labels, val_labels, test_labels)
onehot_labels = (onehot_train_labels, onehot_val_labels, onehot_test_labels)
else:
with open(data_dir + f"/processed/UEA/{name}/data.pkl", "rb") as f:
data = pickle.load(f)
Expand Down Expand Up @@ -410,7 +454,7 @@ def create_toy_dataset(data_dir, stepsize, depth, include_time, T, *, key):
def create_speech_dataset(data_dir, stepsize, depth, include_time, T, *, key):
data = []
labels = []
for i in range(10):
for i in range(2):
data.append(np.load(data_dir + f"/processed/speech/data_{i}.npy"))
labels.append(np.load(data_dir + f"/processed/speech/labels_{i}.npy"))
data = np.concatenate(data)
Expand All @@ -425,30 +469,47 @@ def create_speech_dataset(data_dir, stepsize, depth, include_time, T, *, key):
include_time,
T,
inmemory=False,
coeffs_needed=False,
use_presplit=False,
key=key,
)


def create_ppg_dataset(data_dir, stepsize, depth, include_time, T, *, key):
with open(data_dir + "/processed/PPG/X_train.pkl", "rb") as f:
def create_ppg_dataset(
data_dir, use_presplit, stepsize, depth, include_time, T, *, key
):
with open(data_dir + "/processed/PPG/ppg/X_train.pkl", "rb") as f:
train_data = pickle.load(f)
with open(data_dir + "/processed/PPG/y_train.pkl", "rb") as f:
with open(data_dir + "/processed/PPG/ppg/y_train.pkl", "rb") as f:
train_labels = pickle.load(f)
with open(data_dir + "/processed/PPG/X_val.pkl", "rb") as f:
with open(data_dir + "/processed/PPG/ppg/X_val.pkl", "rb") as f:
val_data = pickle.load(f)
with open(data_dir + "/processed/PPG/y_val.pkl", "rb") as f:
with open(data_dir + "/processed/PPG/ppg/y_val.pkl", "rb") as f:
val_labels = pickle.load(f)
with open(data_dir + "/processed/PPG/X_test.pkl", "rb") as f:
with open(data_dir + "/processed/PPG/ppg/X_test.pkl", "rb") as f:
test_data = pickle.load(f)
with open(data_dir + "/processed/PPG/y_test.pkl", "rb") as f:
with open(data_dir + "/processed/PPG/ppg/y_test.pkl", "rb") as f:
test_labels = pickle.load(f)

breakpoint()
if include_time:
ts = (T / train_data.shape[1]) * jnp.repeat(
jnp.arange(train_data.shape[1])[None, :], train_data.shape[0], axis=0
)
train_data = jnp.concatenate([ts[:, :, None], train_data], axis=2)
ts = (T / val_data.shape[1]) * jnp.repeat(
jnp.arange(val_data.shape[1])[None, :], val_data.shape[0], axis=0
)
val_data = jnp.concatenate([ts[:, :, None], val_data], axis=2)
ts = (T / test_data.shape[1]) * jnp.repeat(
jnp.arange(test_data.shape[1])[None, :], test_data.shape[0], axis=0
)
test_data = jnp.concatenate([ts[:, :, None], test_data], axis=2)

data = (train_data, val_data, test_data)
labels = (train_labels, val_labels, test_labels)
if use_presplit:
data = (train_data, val_data, test_data)
labels = (train_labels, val_labels, test_labels)
else:
data = jnp.concatenate((train_data, val_data, test_data), axis=0)
labels = jnp.concatenate((train_labels, val_labels, test_labels), axis=0)

return dataset_generator(
"speech",
Expand All @@ -459,14 +520,23 @@ def create_ppg_dataset(data_dir, stepsize, depth, include_time, T, *, key):
include_time,
T,
inmemory=False,
coeffs_needed=False,
use_presplit=True,
use_presplit=use_presplit,
key=key,
)


def create_dataset(
data_dir, name, use_idxs, use_presplit, stepsize, depth, include_time, T, *, key
data_dir,
name,
use_idxs,
use_presplit,
stepsize,
depth,
include_time,
T,
seed,
*,
key,
):
uea_subfolders = [
f.name for f in os.scandir(data_dir + "/processed/UEA") if f.is_dir()
Expand All @@ -484,6 +554,7 @@ def create_dataset(
depth,
include_time,
T,
seed,
key=key,
)
elif name in lra_subfolders:
Expand All @@ -505,6 +576,8 @@ def create_dataset(
data_dir, stepsize, depth, include_time, T, key=key
)
elif name == "ppg":
return create_ppg_dataset(data_dir, stepsize, depth, include_time, T, key=key)
return create_ppg_dataset(
data_dir, use_presplit, stepsize, depth, include_time, T, key=key
)
else:
raise ValueError(f"Dataset {name} not found in UEA folder and not toy dataset")
2 changes: 1 addition & 1 deletion data/download_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def download_and_unzip(url, save_dir, zipname):


if __name__ == "__main__":
data_dir = "data"
data_dir = "data/math-datasig/shug6778/Log-Neural-CDEs/data"
url = (
"http://www.timeseriesclassification.com/ClassificationDownloads/Archives"
"/Multivariate2018_arff.zip"
Expand Down
1 change: 1 addition & 0 deletions data/process_uea.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def convert_data(data):
def convert_all_files(data_dir):
"""Convert UEA files into jax data to be stored in /interim."""
arff_folder = data_dir + "/raw/UEA/Multivariate_arff"

for ds_name in tqdm(
[x for x in os.listdir(arff_folder) if os.path.isdir(arff_folder + "/" + x)]
):
Expand Down
1 change: 1 addition & 0 deletions models/LogNeuralCDEs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
hidden_dim * data_dim,
vf_hidden_dim,
vf_num_hidden,
activation=jax.nn.silu,
scale=scale,
key=vf_key,
)
Expand Down
Loading
Loading