-
Notifications
You must be signed in to change notification settings - Fork 0
5. Main components
Before getting to details of specification format, it is useful to understand the inner workings of TorchAssistant and glance over its major components and concepts. Many of those (such as loss functions, optimizers, models, etc.) will look familiar, while others will not.
We are going to focus here only on mechanics of train.py script and what it does.
The major components are (neural) batch processor, processing graph, input injector.
Batch processor can be thought of as a computational unit that converts a bunch of incoming data. Batch processor can be any object that implements a certain interface, but typically it contains a number of models/neural nets connected in some way.
A processing graph is a graph whose nodes are batch processors. Each leaf node of the graph can participate in metric computation. In fact, different set of metrics can be applied to different leaf nodes.
Input injector sends prepared pieces of training batch to the right input nodes of the processing graph. Essentially, it wraps a bunch of decorated torch DataLoader instances to combine their values generated on every iteration step.
Input injector does its work by delegating it to the data loaders. Each of data loaders does its thing, then their results are aggregated in some way and returned.
Behind the curtain, here are the steps that each of data loaders performs:
- load the next group of raw examples from a dataset, perform necessary data transforms
- collate those examples into a batch (usually a tuple of tensors), perform batch level preprocessing (such as sequence padding) if necessary
- create a data frame by associating each item in a batch tuple with a name
e.g., a batch
([[1, 0], [0, 0]], [[1], [0]])
becomes{'input_1': [[1, 0], [0, 0]], 'input_2': [[1], [0]]}
As soon as all data loaders finish their iteration, input injector aggregates
their outputs into a single dictionary. Concretely, it maps each data frame
to the right input node of processing graph, e.g.:
{'graph_input_1': data_frame1, 'graph_input_2: data_frame2}
With that description in mind, a single iteration of a training loop involves the following steps:
- prepare the next batch of training data by input injector (as a mapping from graph_input_node to a data frame)
- send the batch to the processing graph
- compute predictions for each leaf node
- compute loss and metrics for each node where loss or metrics were specified
- compute the gradient and update parameters for all models
A training epoch consists of running the above for a bunch of times. After input injector iterates over all of its training batches, the epoch is done. Whenever another epoch finishes, the framework computes loss and metrics separately on training and validation datasets. Computed metrics are then printed to stdout and appended to a history file in csv format. Also, at this point parameters of all models and their optimizers are automatically saved. This allows one to resume training from the last finished epoch if the training script was interrupted.
Final important detail about TorchAssistance is that the training process is split into multiple stages. Although 1 stage will suffice for most cases, multiple stages give an extra flexibility. Different stages can use different processing graphs, datasets, metrics, etc. The stage will finish when the stopping condition checked after every epoch returns True.
That was a high level intro. Now let's discuss in detail different entities participating in the training process.
First, let's look at data related entities. Very briefly, datasets provide individual examples as tuples. A data split can be defined to split dataset into training, validation and test datasets. Collators turn a collections of examples into batches. By giving names to columns of batches, the latter become data frames. Finally, input injector groups data frames coming from different datasets and convert them to data frame dicts. We will discuss it later.
Dataset can be any object that implements a sequence protocol. Concretely,
dataset class needs to implement __len__
and __getitem__
. Naturally,
any built-in dataset class from torchvision package is a valid dataset.
Moreover, datasets can wrap/decorate other datasets and have
associated data transformations/preprocessors.
Data split is used to randomly split a given dataset into 2 or more slices.
One can control relative size of each slice in the split.
After creation, one can access slices of a given dataset via a dot notation.
For instance, if a dataset called mnist
was split into 2 parts, we can
refer to a training slice with mnist.train
.
Each slice is a special kind of dataset. That means, data slice can be used
anywhere where dataset is needed.
Preprocessor is any object which implements process(value)
method.
Implementing an optional method fit(dataset)
makes a given preprocessor
learnable. Whether it's learnable or not, once defined, it can be applied
to a given dataset or a dataset slice.
Collator is a callable object which turns a list of tuples (examples) into a tuple of collections (lists, tensors, etc.). It is useful, when one needs to apply extra preprocessing on a collection of examples rather than individually 1 example at a time. For instance, in Machine Translation datasets often contain sentences with variable length. This makes training with batch size > 1 problematic. One possible solution is to pad sentences of a collection of examples in a collator.
Data frame is dict-like data structure which essentially associates names with data collections (lists or tensors). Note that it has nothing to do with entities called data frames in other libraries such as Pandas. For example, this is a valid data frame:
data_frame = {
"x1": [1, 2, 3],
"x2": [10, 20, 30]
}
Data loader represents Pytorch DataLoader class.
Input injector is an iterator yielding data frames ready to be injected into a processing graph. Input injector is used on every training iteration to provide training examples to learn from and compute loss/metrics. It's purpose will become more clear when we look at remaining pieces.
Let's go over familiar ones first and discuss entities specific to TorchAssistant later.
Model is a subclass of torch.nn.Module class.
Optimizer is a built-in optimizer class in torch.optim module.
Loss is a built-in loss function class (e.g. CrossEntropyLoss) in torch.nn module.
Metric is a built-in class from torchmetrics package.
Batch processor performs a particular computation on its inputs and produces some outputs. Typically, (but not always), it is a graph of nodes where each node is a model (neural network). This abstraction allows to create quite sophisticated computational graphs where output from one or more neural nets becomes an input to others. For example, it makes it easy to create an encoder-decoder RNN architecture.
To specify a batch processor, we need to specify "input_adapter" and "neural_graph". We may also optionally specify "output_adapter" and "device" (CPU or CUDA).
Batch processor takes a data frame (name->tensor mapping) of inputs and produces a data frame of outputs. Normally, input data frame will contain all the information needed to extract input tensors expected by corresponding nodes (neural networks). Inputs extraction is done by input adapter. Input adapter takes a dataframe and constructs from it input tensors for every node in the batch processor. Output adapter can do some extra processing on the resulting data frame, add or exclude some tensors to/from it.
The graph topology is defined by providing an ordered list of nodes such that dependent node should be listed after the node(s) it depends on. To specify each node, we need to provide:
- a name of a neural network defined earlier,
- names of input tensors used as inputs to the network
- names of output tensors predicted by the network
- a name of an optimizer defined earlier
A node may have both multiple inputs and multiple outputs. When a neural network of the corresponding node produces a tuple of tensors, each of these tensors gets a name as specified by an array of output tensor names in the node config.
Forward pass through the graph respects dependency. That is, each node runs computation only after all nodes it depends on finished computation. Outputs from previous nodes can become inputs to the dependent ones. Some nodes can receive inputs directly from a data frame, other nodes use outputs from previous nodes as their inputs. It is also possible to form inputs by combining tensors from a data frame and tensors from outputs computed by earlier nodes.
When a particular node finishes, it's outputs are added to the dictionary of results (for example, if node defines its outputs as out1, out2, out3, the results dictionary will contain tensors for each of these output names). When leaf nodes finish, we are done and forward pass through the whole batch processor object completes. All computation results are aggregated and saved in a dictionary (that includes computations on intermediate nodes too).
Processing graph is a higher-level graph whose nodes are batch processors. Since each batch processor is itself a graph, processing graph is graph of graphs. Processing graphs allow to create even more complex training pipelines. But in simple cases, they may be omitted.
A processing graph consists of 2 types of nodes/vertices: computational nodes (batch processors) and input ports. Input ports serve as entry points supplying input tensors to the graph to carry out a computation. Each port has a name. Before running a computation, each input port binds with a concrete batch of data. Input injector is the entity that generates inputs and sends them to the appropriate input ports.
Pipeline is a self-contained entity that describes computational graph, the source of training data, which losses and metrics to compute. In other words, it fully describes a single iteration of a training loop. It is easy to create multiple pipelines where different pipelines may use different processing graphs, datasets, losses and metrics.
A (training) stage is a self-contained specification of the entire training process. Basically, it specifies which pipeline to use for training and validation as well as a stopping condition. The latter serves to determine when the stage is done. The simplest stopping condition just sets the number of epochs before completion. It is possible to create multiple stages, each of them having different pipelines and stopping conditions.