-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add test and several workflow regarding tests
- smoke test for end to end training test - metric test
- Loading branch information
Showing
7 changed files
with
402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: io | ||
on: [push] | ||
|
||
jobs: | ||
io: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: set-var | ||
id: set-var | ||
shell: python | ||
run: | | ||
import os | ||
with open(os.environ['GITHUB_OUTPUT'], 'a') as f: | ||
print(f'VAL=true', file=f) | ||
- name: print value | ||
run: echo "the value is ${{ steps.set-var.outputs.VAL }}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import redis | ||
|
||
rc = redis.Redis(host='localhost', port=6379) | ||
|
||
# [[host=127.0.0.1,port=16379,name=127.0.0.1:16379,server_type=primary,redis_connection=Redis<ConnectionPool<Connection<host=127.0.0.1,port=16379,db=0>>>], ... | ||
|
||
rc.set('foo', 'bar') | ||
# True | ||
|
||
rc.get('foo') | ||
# b'bar' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
WANDB_PROJECT = "mlops-course-001" | ||
ENTITY = None | ||
BDD_CLASSES = {i:c for i,c in enumerate(['background', 'road', 'traffic light', 'traffic sign', 'person', 'vehicle', 'bicycle'])} | ||
RAW_DATA_AT = 'bdd_simple_1k' | ||
PROCESSED_DATA_AT = 'bdd_simple_1k_split' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
import utils | ||
from fastai.vision.all import Learner | ||
import math | ||
|
||
#For testing: a fake learner and a metric that isn't an average | ||
class TstLearner(Learner): | ||
def __init__(self,dls=None,model=None,**kwargs): self.pred,self.xb,self.yb = None,None,None | ||
|
||
#Go through a fake cycle with various batch sizes and computes the value of met | ||
def compute_val(met, x1, x2): | ||
met.reset() | ||
vals = [0,6,15,20] | ||
learn = TstLearner() | ||
for i in range(3): | ||
learn.pred,learn.yb = x1[vals[i]:vals[i+1]],(x2[vals[i]:vals[i+1]],) | ||
met.accumulate(learn) | ||
return met.value | ||
|
||
def test_metrics(): | ||
x1a = torch.ones(20,1,1,1) # predicting background pixels | ||
x1b = torch.clone(x1a)*0.3 | ||
x1c = torch.clone(x1a)*0.1 | ||
x1 = torch.cat((x1a,x1b,x1c),dim=1) # Prediction: 20xClass0 | ||
x2 = torch.zeros(20,1,1) # Target: 20xClass0 | ||
|
||
assert compute_val(utils.BackgroundIOU(), x1, x2) == 1. | ||
road_iou = compute_val(utils.RoadIOU(), x1, x2) | ||
assert math.isnan(road_iou) | ||
|
||
|
||
x1b = torch.ones(20,1,1,1) # predicting road pixels | ||
x1a = torch.clone(x1a)*0.3 | ||
x1c = torch.clone(x1a)*0.1 | ||
x1 = torch.cat((x1a,x1b,x1c),dim=1) # Prediction: 20xClass1 | ||
x2 = torch.ones(20,1,1) # Target: 20xClass1 | ||
background_iou = compute_val(utils.BackgroundIOU(), x1, x2) | ||
assert math.isnan(background_iou) | ||
assert compute_val(utils.RoadIOU(), x1, x2) == 1. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import train | ||
from fastai.vision.all import SimpleNamespace | ||
|
||
def test_train(): | ||
default_config = SimpleNamespace( | ||
framework="fastai", | ||
img_size=30, # small size for the smoke test | ||
batch_size=5, # low bs to fit on CPU if needed | ||
augment=True, # use data augmentation | ||
epochs=1, | ||
lr=2e-3, | ||
pretrained=True, # whether to use pretrained encoder, | ||
mixed_precision=True, # use automatic mixed precision | ||
arch="resnet18", | ||
seed=42, | ||
log_preds=False, | ||
) | ||
train.train(default_config, nrows=20) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import argparse, os | ||
import wandb | ||
from pathlib import Path | ||
import torchvision.models as tvmodels | ||
import pandas as pd | ||
from fastai.vision.all import * | ||
from fastai.callback.wandb import WandbCallback | ||
|
||
import params | ||
from utils import get_predictions, create_iou_table, MIOU, BackgroundIOU, \ | ||
RoadIOU, TrafficLightIOU, TrafficSignIOU, PersonIOU, VehicleIOU, BicycleIOU, t_or_f | ||
# defaults | ||
default_config = SimpleNamespace( | ||
framework="fastai", | ||
img_size=180, #(180, 320) in 16:9 proportions, | ||
batch_size=8, #8 keep small in Colab to be manageable | ||
augment=True, # use data augmentation | ||
epochs=10, # for brevity, increase for better results :) | ||
lr=2e-3, | ||
pretrained=True, # whether to use pretrained encoder, | ||
mixed_precision=True, # use automatic mixed precision | ||
arch="resnet18", | ||
seed=42, | ||
log_preds=False, | ||
) | ||
|
||
|
||
def parse_args(): | ||
"Overriding default argments" | ||
argparser = argparse.ArgumentParser(description='Process hyper-parameters') | ||
argparser.add_argument('--img_size', type=int, default=default_config.img_size, help='image size') | ||
argparser.add_argument('--batch_size', type=int, default=default_config.batch_size, help='batch size') | ||
argparser.add_argument('--epochs', type=int, default=default_config.epochs, help='number of training epochs') | ||
argparser.add_argument('--lr', type=float, default=default_config.lr, help='learning rate') | ||
argparser.add_argument('--arch', type=str, default=default_config.arch, help='timm backbone architecture') | ||
argparser.add_argument('--augment', type=t_or_f, default=default_config.augment, help='Use image augmentation') | ||
argparser.add_argument('--seed', type=int, default=default_config.seed, help='random seed') | ||
argparser.add_argument('--log_preds', type=t_or_f, default=default_config.log_preds, help='log model predictions') | ||
argparser.add_argument('--pretrained', type=t_or_f, default=default_config.pretrained, help='Use pretrained model') | ||
argparser.add_argument('--mixed_precision', type=t_or_f, default=default_config.mixed_precision, help='use fp16') | ||
args = argparser.parse_args() | ||
vars(default_config).update(vars(args)) | ||
return | ||
|
||
def download_data(): | ||
"Grab dataset from artifact" | ||
processed_data_at = wandb.use_artifact(f'{params.PROCESSED_DATA_AT}:latest') | ||
processed_dataset_dir = Path(processed_data_at.download()) | ||
return processed_dataset_dir | ||
|
||
def label_func(fname): | ||
return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png" | ||
|
||
def get_df(processed_dataset_dir, is_test=False): | ||
df = pd.read_csv(processed_dataset_dir / 'data_split.csv') | ||
|
||
if not is_test: | ||
df = df[df.Stage != 'test'].reset_index(drop=True) | ||
df['is_valid'] = df.Stage == 'valid' | ||
else: | ||
df = df[df.Stage == 'test'].reset_index(drop=True) | ||
|
||
|
||
# assign paths | ||
df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values] | ||
df["label_fname"] = [label_func(f) for f in df.image_fname.values] | ||
return df | ||
|
||
def get_data(df, bs=4, img_size=180, augment=True): | ||
block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=params.BDD_CLASSES)), | ||
get_x=ColReader("image_fname"), | ||
get_y=ColReader("label_fname"), | ||
splitter=ColSplitter(), | ||
item_tfms=Resize((img_size, int(img_size * 16 / 9))), | ||
batch_tfms=aug_transforms() if augment else None, | ||
) | ||
return block.dataloaders(df, bs=bs) | ||
|
||
|
||
def log_predictions(learn): | ||
"Log a Table with model predictions and metrics" | ||
samples, outputs, predictions = get_predictions(learn) | ||
table = create_iou_table(samples, outputs, predictions, params.BDD_CLASSES) | ||
wandb.log({"pred_table":table}) | ||
|
||
def final_metrics(learn): | ||
"Log latest metrics values" | ||
scores = learn.validate() | ||
metric_names = ['final_loss'] + [f'final_{x.name}' for x in learn.metrics] | ||
final_results = {metric_names[i] : scores[i] for i in range(len(scores))} | ||
for k,v in final_results.items(): | ||
wandb.summary[k] = v | ||
|
||
def train(config, nrows=None): | ||
set_seed(config.seed) | ||
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=config) | ||
|
||
# good practice to inject params using sweeps | ||
config = wandb.config | ||
|
||
# prepare data | ||
processed_dataset_dir = download_data() | ||
proc_df = get_df(processed_dataset_dir).head(nrows) | ||
dls = get_data(proc_df, bs=config.batch_size, img_size=config.img_size, augment=config.augment) | ||
|
||
metrics = [MIOU(), BackgroundIOU(), RoadIOU(), TrafficLightIOU(), | ||
TrafficSignIOU(), PersonIOU(), VehicleIOU(), BicycleIOU()] | ||
|
||
cbs = [WandbCallback(log_preds=False, log_model=True), | ||
SaveModelCallback(fname=f'run-{wandb.run.id}-model', monitor='miou')] | ||
cbs += ([MixedPrecision()] if config.mixed_precision else []) | ||
|
||
learn = unet_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, | ||
metrics=metrics) | ||
|
||
learn.fit_one_cycle(config.epochs, config.lr, cbs=cbs) | ||
if config.log_preds: | ||
log_predictions(learn) | ||
final_metrics(learn) | ||
|
||
wandb.finish() | ||
|
||
if __name__ == '__main__': | ||
parse_args() | ||
train(default_config) |
Oops, something went wrong.