Skip to content

Commit

Permalink
Complete upgrade of torchmetrics accuracy (#2025)
Browse files Browse the repository at this point in the history
Replace all instances of Accuracy with MulticlassAccuracy

---------

Co-authored-by: nik-mosaic <None>
Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
  • Loading branch information
nik-mosaic and mvpatel2000 authored Mar 2, 2023
1 parent 28bf919 commit c6f2c93
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 41 deletions.
4 changes: 2 additions & 2 deletions composer/callbacks/early_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class EarlyStopper(Callback):
>>> from composer import Evaluator, Trainer
>>> from composer.callbacks.early_stopper import EarlyStopper
>>> # constructing trainer object with this callback
>>> early_stopper = EarlyStopper("Accuracy", "my_evaluator", patience=1)
>>> early_stopper = EarlyStopper('MulticlassAccuracy', 'my_evaluator', patience=1)
>>> evaluator = Evaluator(
... dataloader = eval_dataloader,
... label = 'my_evaluator',
... metric_names = ['Accuracy']
... metric_names = ['MulticlassAccuracy']
... )
>>> trainer = Trainer(
... model=model,
Expand Down
6 changes: 3 additions & 3 deletions composer/callbacks/mlperf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class MLPerfCallback(Callback):
callback = MLPerfCallback(
root_folder='/submission',
index=0,
metric_name='Accuracy',
metric_name='MulticlassAccuracy',
metric_label='eval',
target='0.759',
)
Expand Down Expand Up @@ -113,7 +113,7 @@ class MLPerfCallback(Callback):
division (str, optional): Division of submission. Currently only ``open`` division supported.
Default: ``'open'``.
metric_name (str, optional): name of the metric to compare against the target.
Default: ``Accuracy``.
Default: ``MulticlassAccuracy``.
metric_label (str, optional): The label name. The metric will be accessed via
``state.eval_metrics[metric_label][metric_name]``.
submitter (str, optional): Submitting organization. Default: ``"MosaicML"``.
Expand All @@ -135,7 +135,7 @@ def __init__(
benchmark: str = 'resnet',
target: float = 0.759,
division: str = 'open',
metric_name: str = 'Accuracy',
metric_name: str = 'MulticlassAccuracy',
metric_label: str = 'eval',
submitter: str = 'MosaicML',
system_name: Optional[str] = None,
Expand Down
6 changes: 3 additions & 3 deletions composer/callbacks/threshold_stopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class ThresholdStopper(Callback):
Example:
.. doctest::
>>> from composer import Evaluator, Trainer
>>> from composer.callbacks.threshold_stopper import ThresholdStopper
>>> from torchmetrics.classification.accuracy import Accuracy
>>> # constructing trainer object with this callback
>>> threshold_stopper = ThresholdStopper("Accuracy", "my_evaluator", 0.7)
>>> threshold_stopper = ThresholdStopper('MulticlassAccuracy', 'my_evaluator', 0.7)
>>> evaluator = Evaluator(
... dataloader = eval_dataloader,
... label = 'my_evaluator',
... metric_names = ['Accuracy']
... metric_names = ['MulticlassAccuracy']
... )
>>> trainer = Trainer(
... model=model,
Expand Down
8 changes: 4 additions & 4 deletions composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,24 +99,24 @@ class Evaluator:
.. doctest::
>>> eval_evaluator = Evaluator(
... label="myEvaluator",
... label='myEvaluator',
... dataloader=eval_dataloader,
... metric_names=['Accuracy']
... metric_names=['MulticlassAccuracy']
... )
>>> trainer = Trainer(
... model=model,
... train_dataloader=train_dataloader,
... eval_dataloader=eval_evaluator,
... optimizers=optimizer,
... max_duration="1ep",
... max_duration='1ep',
... )
Args:
label (str): Name of the Evaluator.
dataloader (DataSpec | Iterable | Dict[str, Any]): Iterable that yields batches, a :class:`.DataSpec`
for evaluation, or a Dict of :class:`.DataSpec` kwargs.
metric_names: The list of metric names to compute.
Each value in this list can be a regex string (e.g. "Accuracy", "f1" for "BinaryF1Score",
Each value in this list can be a regex string (e.g. "MulticlassAccuracy", "f1" for "BinaryF1Score",
"Top-." for "Top-1", "Top-2", etc). Each regex string will be matched against the keys of the dictionary returned
by ``model.get_metrics()``. All matching metrics will be evaluated.
Expand Down
4 changes: 2 additions & 2 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ class State(Serializable):
... ...,
... train_dataloader=train_dataloader,
... eval_dataloader=[
... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['Accuracy']),
... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['Accuracy']),
... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['MulticlassAccuracy']),
... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['MulticlassAccuracy']),
... ],
... )
>>> trainer.fit()
Expand Down
2 changes: 1 addition & 1 deletion composer/models/bert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def create_bert_classification(num_labels: int = 2,
Second, the returned :class:`.ComposerModel`'s train/validation metrics will be :class:`~torchmetrics.MeanSquaredError` and :class:`~torchmetrics.SpearmanCorrCoef`.
For the classification case (when ``num_labels > 1``), the training loss is :class:`~torch.nn.CrossEntropyLoss`, and the train/validation
metrics are :class:`~torchmetrics.Accuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
metrics are :class:`~torchmetrics.MulticlassAccuracy` and :class:`~torchmetrics.MatthewsCorrCoef`, as well as :class:`.BinaryF1Score` if ``num_labels == 2``.
"""
try:
import transformers
Expand Down
2 changes: 1 addition & 1 deletion composer/models/classify_mnist/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@
_dataset = 'MNIST'
_name = 'SimpleConvNet'
_quality = ''
_metric = 'Accuracy'
_metric = 'MulticlassAccuracy'
_ttt = '?'
_hparams = 'classify_mnist_cpu.yaml'
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,13 +2535,13 @@ def eval(
glue_mrpc_task = Evaluator(
label='glue_mrpc',
dataloader=mrpc_dataloader,
metric_names=['BinaryF1Score', 'Accuracy']
metric_names=['BinaryF1Score', 'MulticlassAccuracy']
)
glue_mnli_task = Evaluator(
label='glue_mnli',
dataloader=mnli_dataloader,
metric_names=['Accuracy']
metric_names=['MulticlassAccuracy']
)
trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion docs/source/composer_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ A full example of a validation implementation would be:
def get_metrics(self, is_train=False):
# defines which metrics to use in each phase of training
return {'Accuracy': self.train_accuracy} if train else {'Accuracy': self.val_accuracy}
return {'MulticlassAccuracy': self.train_accuracy} if train else {'MulticlassAccuracy': self.val_accuracy}
.. note::

Expand Down
18 changes: 9 additions & 9 deletions docs/source/notes/early_stopping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
from composer.callbacks.early_stopper import EarlyStopper

early_stopper = EarlyStopper(
monitor='Accuracy',
monitor='MulticlassAccuracy',
dataloader_label='train',
patience='50ba',
comp=torch.greater,
Expand All @@ -32,7 +32,7 @@ The :class:`.EarlyStopper` callback stops training if a provided metric does not
max_duration="1ep",
)

In the above example, the ``'train'`` label means the callback is tracking the ``Accuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.
In the above example, the ``'train'`` label means the callback is tracking the ``MulticlassAccuracy`` metric for the train_dataloader. The default for the evaluation dataloader is ``eval``.

We also set ``patience='50ba'`` and ``min_delta=0.01`` which means that every 50 batches, if the Accuracy does not exceed the best recorded Accuracy by ``0.01``, training is stopped. The ``comp`` argument indicates that 'better' here means higher accuracy. Note that the ``patience`` parameter can take both a time string (see :doc:`Time</trainer/time>`) or an integer which specifies a number of epochs.

Expand All @@ -53,8 +53,8 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
from composer.callbacks.threshold_stopper import ThresholdStopper

threshold_stopper = ThresholdStopper(
monitor="Accuracy",
dataloader_label="eval",
monitor='MulticlassAccuracy',
dataloader_label='eval',
threshold=0.8,
)

Expand All @@ -64,7 +64,7 @@ The :class:`.ThresholdStopper`` callback also monitors a specific metric, but ha
eval_dataloader=eval_dataloader,
optimizers=optimizer,
callbacks=[threshold_stopper],
max_duration="1ep",
max_duration='1ep',
)

In this example, training will exit when the model's validation accuracy exceeds 0.8. For a full list of arguments, see the documentation for :class:`.ThresholdStopper.`
Expand All @@ -76,7 +76,7 @@ When there are multiple datasets and metrics to use for validation and evaluatio

Each Evaluator object is marked with a ``label`` field for logging, and a ``metric_names`` field that accepts a list of metric names. These can be provided to the callbacks above to indiciate which metric to monitor.

In the example below, the callback will monitor the `Accuracy` metric in the dataloader marked `eval_dataset1`.`
In the example below, the callback will monitor the `MulticlassAccuracy` metric in the dataloader marked `eval_dataset1`.`

.. testsetup::

Expand All @@ -90,17 +90,17 @@ In the example below, the callback will monitor the `Accuracy` metric in the dat
evaluator1 = Evaluator(
label='eval_dataset1',
dataloader=eval_dataloader,
metric_names=['Accuracy']
metric_names=['MulticlassAccuracy']
)

evaluator2 = Evaluator(
label='eval_dataset2',
dataloader=eval_dataloader2,
metric_names=['Accuracy']
metric_names=['MulticlassAccuracy']
)

early_stopper = EarlyStopper(
monitor='Accuracy',
monitor='MulticlassAccuracy',
dataloader_label='eval_dataset1',
patience=1
)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/trainer/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ can be specified as in the following example:
glue_mrpc_task = Evaluator(
label='glue_mrpc',
dataloader=mrpc_dataloader,
metric_names=['BinaryF1Score', 'Accuracy']
metric_names=['BinaryF1Score', 'MulticlassAccuracy']
)
glue_mnli_task = Evaluator(
label='glue_mnli',
dataloader=mnli_dataloader,
metric_names=['Accuracy']
metric_names=['MulticlassAccuracy']
)
trainer = Trainer(
Expand Down
8 changes: 4 additions & 4 deletions examples/early_stopping.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
"evaluator = Evaluator(\n",
" dataloader = eval_dataloader,\n",
" label = \"eval\",\n",
" metric_names = ['Accuracy']\n",
" metric_names = ['MulticlassAccuracy']\n",
")"
]
},
Expand Down Expand Up @@ -199,7 +199,7 @@
"[time]: https://docs.mosaicml.com/en/stable/api_reference/generated/composer.Time.html#time\n",
"[api]: https://docs.mosaicml.com/en/stable/api_reference/generated/composer.callbacks.EarlyStopper.html\n",
"\n",
"Here, we'll use our callback to track the Accuracy metric over one epoch on the test dataset:"
"Here, we'll use our callback to track the MulticlassAccuracy metric over one epoch on the test dataset:"
]
},
{
Expand All @@ -210,7 +210,7 @@
"source": [
"from composer.callbacks import EarlyStopper\n",
"\n",
"early_stopper = EarlyStopper(monitor=\"Accuracy\", dataloader_label=\"eval\", patience=1)"
"early_stopper = EarlyStopper(monitor=\"MulticlassAccuracy\", dataloader_label=\"eval\", patience=1)"
]
},
{
Expand Down Expand Up @@ -284,7 +284,7 @@
"source": [
"from composer.callbacks import ThresholdStopper\n",
"\n",
"threshold_stopper = ThresholdStopper(\"Accuracy\", \"eval\", threshold=0.3)\n",
"threshold_stopper = ThresholdStopper(\"MulticlassAccuracy\", \"eval\", threshold=0.3)\n",
"\n",
"# Threshold stopping should stop training before we reach 100 epochs!\n",
"train_epochs = \"100ep\"\n",
Expand Down
6 changes: 3 additions & 3 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,18 +300,18 @@ def test_schedulers(
Evaluator(
label='eval',
dataloader=_get_classification_dataloader(),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
), # an evaluator
[ # multiple evaluators
Evaluator(
label='eval1',
dataloader=_get_classification_dataloader(),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
),
Evaluator(
label='eval2',
dataloader=_get_classification_dataloader(),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
),
],
],
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def test_eval_at_fit_end(eval_interval: Union[str, Time, int], max_duration: str
dataset=eval_dataset,
sampler=dist.get_sampler(eval_dataset),
),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
)

evaluator.eval_interval = evaluate_periodically(
Expand Down Expand Up @@ -234,7 +234,7 @@ def _get_classification_dataloader():
Evaluator(
label='eval',
dataloader=_get_classification_dataloader(),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
),
])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_eval_params_evaluator():
dataset=eval_dataset,
sampler=dist.get_sampler(eval_dataset),
),
metric_names=['Accuracy'],
metric_names=['MulticlassAccuracy'],
eval_interval=f'{eval_interval_batches}ba',
subset_num_batches=eval_subset_num_batches,
)
Expand Down Expand Up @@ -373,7 +373,7 @@ def test_eval_batch_can_be_modified(add_algorithm: bool):
trainer.eval()


@pytest.mark.parametrize('metric_names', ['Accuracy', ['Accuracy']])
@pytest.mark.parametrize('metric_names', ['MulticlassAccuracy', ['MulticlassAccuracy']])
def test_evaluator_metric_names_string_errors(metric_names):
eval_dataset = RandomClassificationDataset(size=8)
eval_dataloader = DataLoader(eval_dataset, batch_size=4, sampler=dist.get_sampler(eval_dataset))
Expand Down

0 comments on commit c6f2c93

Please sign in to comment.