Skip to content

Latest commit

 

History

History
219 lines (161 loc) · 7.77 KB

README.md

File metadata and controls

219 lines (161 loc) · 7.77 KB

A very lightweight framework on top of PyTorch with full functionality.

Just one way of doing things means no learning curve.

Logo

PyPi version YourActionName Actions Status Python versions


Installation

  1. pip install --upgrade pip
  2. Install latest pytorch and torchvision from Pytorch
  3. pip install easytorch

Let's start with something simple like MNIST digit classification:

from easytorch import EasyTorch, ETRunner, ConfusionMatrix, ETMeter
from torchvision import datasets, transforms
import torch.nn.functional as F
import torch
from examples.models import MNISTNet

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


class MNISTTrainer(ETRunner):
    def _init_nn_model(self):
        self.nn['model'] = MNISTNet()

    def iteration(self, batch):
        inputs, labels = batch[0].to(self.device['gpu']).float(), batch[1].to(self.device['gpu']).long()

        out = self.nn['model'](inputs)
        loss = F.nll_loss(out, labels)
        _, pred = torch.max(out, 1)

        meter = self.new_meter()
        meter.averages.add(loss.item(), len(inputs))
        meter.metrics['cfm'].add(pred, labels.float())

        return {'loss': loss, 'meter': meter, 'predictions': pred}

    def init_experiment_cache(self):
        self.cache['log_header'] = 'Loss|Accuracy,F1,Precision,Recall'
        self.cache.update(monitor_metric='f1', metric_direction='maximize')

    def new_meter(self):
        return ETMeter(
            cfm=ConfusionMatrix(num_classes=10),
            device=self.device['gpu']
        )


if __name__ == "__main__":
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST('../data', train=False, transform=transform)

    dataloader_args = {'train': {'dataset': train_dataset}, 'validation': {'dataset': val_dataset}}
    runner = EasyTorch(phase='train', batch_size=512,
                       epochs=10, gpus=[0], dataloader_args=dataloader_args)
    runner.run(MNISTTrainer)

Run as:

python script.py -ph train -b 512 -e 10 -gpus 0

... with 20+ useful options. Check here for full list.


General use case:

1. Define your trainer

from easytorch import ETRunner, Prf1a, ETMeter, AUCROCMetrics


class MyTrainer(ETRunner):

    def _init_nn_model(self):
        self.nn['model'] = NeuralNetModel(out_size=self.conf['num_class'])

    def iteration(self, batch):
        """Handle a single batch"""
        """Must have loss and meter"""
        meter = self.new_meter()
        ...
        return {'loss': ..., 'meter': ..., 'predictions': ...}

    def new_meter(self):
        return ETMeter(
            num_averages=1,
            prf1a=Prf1a(),
            auc=AUCROCMetrics(),
            device=self.device['gpu']
        )

    def init_cache(self):
        """Will plot Loss in one plot, and Accuracy,F1_score in another."""
        self.cache['log_header'] = 'Loss|Accuracy,F1_score'

        """Model selection using validation set if present"""
        self.cache.update(monitor_metric='f1', metric_direction='maximize')
  • Method new_meter() returns ETMeter that takes any implementation of easytorch.meter.ETMetrics. Provided ones:
    • easytorch.metrics.Prf1a() for binary classification that computes accuracy,f1,precision,recall, overlap/IOU.
    • easytorch.metrics.ConfusionMatrix(num_classes=...) for multiclass classification that also computes global accuracy,f1,precision,recall.
    • easytorch.metrics.AUCROCMetrics for binary ROC-AUC score.

2. Define specification for your datasets:

  • EasyTorch automatically splits the training data in data_source as specified by split_ratio(-spl or --split-ratio 0.7, 0.15, 0.15, for train validation and test portion) OR Custom splits in
    1. Text files:
      • data_source = "/some/path/*.txt", where it looks for 'train.txt', 'validation.txt', and 'test.txt' if phase is train, and only 'test.txt' if phase is test
    2. Json files:
      • data_source = "some/path/split.json", where each split key has list of files as {'train': [], ' validation' :[], 'test':[]}
    3. Just glob as data_source = "some/path/**/*.txt", must also provide split_ratio if phase = train
from easytorch import ETDataset


class MyDataset(ETDataset):
    def load_index(self, file):
        """(Optional) Load/Process something and add to diskcache as:
                self.diskcahe.add(file, value)"""
        """This method runs in multiple processes by default"""

        self.indices.append([file, 'something_extra'])

    def __getitem__(self, index):
        file = self.indices[index]
        """(Optional) Retrieve from diskcache as self.diskcache.get(file)"""

        image =  # Todo # Load file/Image. 
        label =  # Todo # Load corresponding label.

        # Extra preprocessing, if needed.
        # Apply transforms, if needed.

        return image, label

3. Entry point (say main.py)

Run as:

python main.py -ph train -b 512 -e 10 -gpus 0

One can also directly pass arguments as below which overrides all.

from easytorch import EasyTorch

runner = EasyTorch(phase="train", batch_size=4, epochs=10,
                   gpus=[0], num_channel=1, 
                   num_class=2, data_source="<some_data>/data_split.json")
runner.run(MyTrainer, MyDataset)

All the best! Cheers! 🎉

Cite the following papers if you use this library:

@article{deepdyn_10.3389/fcomp.2020.00035,
	title        = {Dynamic Deep Networks for Retinal Vessel Segmentation},
	author       = {Khanal, Aashis and Estrada, Rolando},
	year         = 2020,
	journal      = {Frontiers in Computer Science},
	volume       = 2,
	pages        = 35,
	doi          = {10.3389/fcomp.2020.00035},
	issn         = {2624-9898}
}

@misc{2202.02382,
        Author       = {Aashis Khanal and Saeid Motevali and Rolando Estrada},
        Title        = {Fully Automated Tree Topology Estimation and Artery-Vein Classification},
        Year         = {2022},
        Eprint       = {arXiv:2202.02382},
}

Feature Higlights:

  • DataHandle that is always available, and decoupled from other modules enabling easy customization (ETDataHandle).
    • Use custom & complex data handling mechanism.
  • Simple lightweight logger/plotter.
    • Plot: set log_header = 'Loss,F1,Accuracy' to plot in same plot or set log_header = 'Loss|F1,Accuracy' to plot Loss in one plot, and F1,Accuracy in another plot.
    • Logs: all arguments/generated data will be saved in logs.json file after the experiment finishes.
  • Gradient accumulation, automatic logging/plotting, model checkpointing, save everything.
  • Multiple metrics implementation at easytorch.metrics: Precision, Recall, Accuracy, Overlap, F1, ROC-AUC, Confusion matrix
  • For advanced training with multiple networks, and complex training steps, click here:
  • Implement custom metrics as here.