Skip to content

Commit

Permalink
Minimal test for fiboa model fit
Browse files Browse the repository at this point in the history
  • Loading branch information
m-mohr committed Oct 29, 2024
1 parent e094d41 commit 43cbed1
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 2 deletions.
45 changes: 45 additions & 0 deletions src/tests/data-files/min_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
trainer:
max_epochs: 1
log_every_n_steps: 5
accelerator: cpu
default_root_dir: logs/FTW-CI
devices: 2
strategy: ddp_find_unused_parameters_true
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
mode: min
save_top_k: 0
save_last: true
filename: "{epoch}-{val_loss:.2f}"
model:
class_path: ftw.trainers.CustomSemanticSegmentationTask
init_args:
class_weights: [0.31,0.68]
loss: "ce"
model: "unet"
backbone: "efficientnet-b3"
weights: true
patch_weights : false
in_channels: 8
num_classes: 2
num_filters: 64
ignore_index: 3
lr: 1e-3
patience: 100
data:
class_path: ftw.datamodules.FTWDataModule
init_args:
batch_size: 32
num_workers: 4
train_countries:
- rwanda
val_countries:
- rwanda
test_countries:
- rwanda
dict_kwargs:
root: data/ftw
load_boundaries: false
seed_everything: 7
33 changes: 31 additions & 2 deletions src/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

from click.testing import CliRunner

from ftw_cli.cli import model_fit, model_test
from ftw_cli.cli import model_fit, model_test, data_download

CKPT_FILE = "logs/FTW-CI/lightning_logs/version_0/checkpoints/last.ckpt"
CONFIG_FILE = "src/tests/data-files/min_config.yaml"

def test_model_fit():
runner = CliRunner()
Expand All @@ -12,7 +15,19 @@ def test_model_fit():
assert result.exit_code == 0, result.output
assert "Usage: fit [OPTIONS] [CLI_ARGS]..." in result.output

# TODO: Add more tests
# Download required data for the fit command
runner.invoke(data_download, ["--countries=Kenya,Rwanda"])
assert os.path.exists("data/ftw/kenya")
assert os.path.exists("data/ftw/rwanda")
assert os.path.exists(CONFIG_FILE)

# Run minimal fit
result = runner.invoke(model_fit, ["-c", CONFIG_FILE])
assert result.exit_code == 0, result.output
assert "Train countries: ['kenya', 'rwanda']" in result.output
assert "Epoch 0: 100%|" in result.output
assert "`Trainer.fit` stopped: `max_epochs=1` reached." in result.output
assert os.path.exists(CKPT_FILE)

def test_model_test():
runner = CliRunner()
Expand All @@ -22,4 +37,18 @@ def test_model_test():
assert result.exit_code == 0, result.output
assert "Usage: test [OPTIONS] [CLI_ARGS]..." in result.output

# Actually run the test
result = runner.invoke(model_test, [
"--gpu", "0",
"--model", CKPT_FILE,
"--countries", "Kenya", # should be "kenya", but let's test case insensitivity
"--out", "results.csv"
])
assert result.exit_code == 0, result.output
assert "Running test command" in result.output
assert "Created dataloader" in result.output
assert "100%|" in result.output
assert "Object level recall: 0.0000" in result.output
assert os.path.exists("results.csv")

# TODO: Add more tests

0 comments on commit 43cbed1

Please sign in to comment.