diff --git a/CHANGELOG.md b/CHANGELOG.md index e1bbd55c..d0a4340f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ Keep it human-readable, your future self will thank you! - Feature: Add configurable models [#50](https://github.com/ecmwf/anemoi-training/pulls/50) - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) - Long Rollout Plots +- Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report + ### Fixed diff --git a/docs/images/profiler/anemoi_profiler_architecture.png b/docs/images/profiler/anemoi_profiler_architecture.png new file mode 100644 index 00000000..483571d1 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_architecture.png differ diff --git a/docs/images/profiler/anemoi_profiler_benchmark_profiler.png b/docs/images/profiler/anemoi_profiler_benchmark_profiler.png new file mode 100644 index 00000000..5cc6d7d1 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_benchmark_profiler.png differ diff --git a/docs/images/profiler/anemoi_profiler_config.png b/docs/images/profiler/anemoi_profiler_config.png new file mode 100644 index 00000000..dd98469b Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_config.png differ diff --git a/docs/images/profiler/anemoi_profiler_high_level.png b/docs/images/profiler/anemoi_profiler_high_level.png new file mode 100644 index 00000000..bd86c4fe Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_high_level.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration.png b/docs/images/profiler/anemoi_profiler_mlflow_integration.png new file mode 100644 index 00000000..cbd03d9f Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png b/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png new file mode 100644 index 00000000..196b8818 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration_2.png differ diff --git a/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png b/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png new file mode 100644 index 00000000..d4897502 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_mlflow_integration_3.png differ diff --git a/docs/images/profiler/anemoi_profiler_speed_report.png b/docs/images/profiler/anemoi_profiler_speed_report.png new file mode 100644 index 00000000..dbec34e4 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_speed_report.png differ diff --git a/docs/images/profiler/anemoi_profiler_speedreport_diagram.png b/docs/images/profiler/anemoi_profiler_speedreport_diagram.png new file mode 100644 index 00000000..a69324e4 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_speedreport_diagram.png differ diff --git a/docs/images/profiler/anemoi_profiler_training_rates.png b/docs/images/profiler/anemoi_profiler_training_rates.png new file mode 100644 index 00000000..e26da246 Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_training_rates.png differ diff --git a/docs/images/profiler/anemoi_profiler_validation_rates.png b/docs/images/profiler/anemoi_profiler_validation_rates.png new file mode 100644 index 00000000..aa352cde Binary files /dev/null and b/docs/images/profiler/anemoi_profiler_validation_rates.png differ diff --git a/docs/images/profiler/example_memory_report.png b/docs/images/profiler/example_memory_report.png new file mode 100644 index 00000000..0f42ebd0 Binary files /dev/null and b/docs/images/profiler/example_memory_report.png differ diff --git a/docs/images/profiler/example_memory_timeline.png b/docs/images/profiler/example_memory_timeline.png new file mode 100644 index 00000000..93591893 Binary files /dev/null and b/docs/images/profiler/example_memory_timeline.png differ diff --git a/docs/images/profiler/example_model_summary.png b/docs/images/profiler/example_model_summary.png new file mode 100644 index 00000000..498eff30 Binary files /dev/null and b/docs/images/profiler/example_model_summary.png differ diff --git a/docs/images/profiler/example_model_summary_2.png b/docs/images/profiler/example_model_summary_2.png new file mode 100644 index 00000000..c8adc538 Binary files /dev/null and b/docs/images/profiler/example_model_summary_2.png differ diff --git a/docs/images/profiler/example_system_report.png b/docs/images/profiler/example_system_report.png new file mode 100644 index 00000000..f6f002fa Binary files /dev/null and b/docs/images/profiler/example_system_report.png differ diff --git a/docs/images/profiler/example_time_report.png b/docs/images/profiler/example_time_report.png new file mode 100644 index 00000000..b8918a33 Binary files /dev/null and b/docs/images/profiler/example_time_report.png differ diff --git a/docs/images/profiler/idle_time_breakdown.png b/docs/images/profiler/idle_time_breakdown.png new file mode 100644 index 00000000..e183b010 Binary files /dev/null and b/docs/images/profiler/idle_time_breakdown.png differ diff --git a/docs/images/profiler/kernel_breakdown_dfs.png b/docs/images/profiler/kernel_breakdown_dfs.png new file mode 100644 index 00000000..20aee8c7 Binary files /dev/null and b/docs/images/profiler/kernel_breakdown_dfs.png differ diff --git a/docs/images/profiler/kernel_breakdown_plots.png b/docs/images/profiler/kernel_breakdown_plots.png new file mode 100644 index 00000000..e36d59a4 Binary files /dev/null and b/docs/images/profiler/kernel_breakdown_plots.png differ diff --git a/docs/images/profiler/memory_snapshot_diagram.png b/docs/images/profiler/memory_snapshot_diagram.png new file mode 100644 index 00000000..87ca6669 Binary files /dev/null and b/docs/images/profiler/memory_snapshot_diagram.png differ diff --git a/docs/images/profiler/memory_snapshot_output.png b/docs/images/profiler/memory_snapshot_output.png new file mode 100644 index 00000000..b22b9f4f Binary files /dev/null and b/docs/images/profiler/memory_snapshot_output.png differ diff --git a/docs/images/profiler/temporal_breakdown.png b/docs/images/profiler/temporal_breakdown.png new file mode 100644 index 00000000..a3a0370e Binary files /dev/null and b/docs/images/profiler/temporal_breakdown.png differ diff --git a/docs/index.rst b/docs/index.rst index 9ca3bbe5..b2617476 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -43,6 +43,7 @@ This package provides the *Anemoi* training functionality. user-guide/training user-guide/models user-guide/tracking + user-guide/benchmarking user-guide/distributed user-guide/debugging diff --git a/docs/overview.rst b/docs/overview.rst index 11611b6f..268e287c 100644 --- a/docs/overview.rst +++ b/docs/overview.rst @@ -91,6 +91,18 @@ and resolve issues during the training process, including: - Debug configurations for quick error identification - Guidance on isolating and addressing common problems +8. Benchmarking and HPC Profiling +================================= + +Anemoi Training offers tools and configurations to support benchmarking +and High-Performance Computing (HPC) profiling, allowing users to +optimize training performance. This includes: + +- Benchmarking configurations for evaluating training efficiency across + different hardware setups. +- Profiling tools for monitoring resource utilization (CPU, GPU, + memory) and identifying performance bottlenecks. + ************************** Components and Structure ************************** diff --git a/docs/user-guide/benchmarking.rst b/docs/user-guide/benchmarking.rst new file mode 100644 index 00000000..bdbb57d0 --- /dev/null +++ b/docs/user-guide/benchmarking.rst @@ -0,0 +1,746 @@ +############## + Benchmarking +############## + +*************************************** + High-level idea of the AnemoiProfiler +*************************************** + +Include a benchmark profiler that provides summary logs/statistics about +time, speed and hardware (memory, CPU/GPU usage) to profile training +runs executed with anemoi-training. Apart from those reports, it is also +possible to generate a model summary and a CUDA memory snapshot. + +- **Speed Report** - Report with metrics associated to the throughput + at training and validation time + +- **Time Report** - Report with metrics associated to the time it takes + to executes certain steps across the code + +- **Memory Report** - Report with metrics associated to GPU and CPU + memory allocation: focusing on listing those operations that are more + memory-intensive. + +- **System/hardware Report** - Report with aggregated metrics in terms + of GPU utilisation & memory usage, CPU usage (system), average disk + usage and total execution time + +- **Model Summary** - table summary with information regarding the + layers and parameters of the model. + +- **Memory (GPU) Snapshot** - memory snapshot that records the state of + allocated CUDA memory at any point in time, and optionally record the + history of allocation events that led up to that snapshot.​ + +.. figure:: ../images/profiler/anemoi_profiler_high_level.png + :alt: Schematic of the concept behind AnemoiProfiler + :align: center + +************** + How it works +************** + +Conceptual Diagram +================== + +As described in the high-level idea section the ``AnemoiProfiler`` +includes a series of features and report to help benchmark the model +training performance. Anemoi-training implementation uses PyTorch +Lightning as deep learning framework. We have designed the +AnemoiProfiler taking advantage of this functionality and building on +top of it. AnemoiProfiler then inherits from AnemoiTrainer and generate +the different reports via 3 main objects: + +- ``BenchmarkProfiler`` +- ``ProfilerProgressBar`` +- ``MemorySnapshotRecorder`` + +Each of these objects is described in more details in the sections +below. With the exception of the\ ``MemorySnapshotRecorder``, all the +above reports are defined as properties of the AnemoiProfiler. The +Memory snapshot is abstracted as an additional callback that can be +switched on/off through the config. + +- Details about the definition of AnemoiProfiler can be found in + ``src/anemoi/training/commands/profiler.py`` + +- Details about the definition of the different classes used by the + AnemoiProfiler can be found in: + ``src/anemoi/training/diagnostics/profilers.py`` + +- Details about the definition of the memory snapshot recorder: + ``src/anemoi/training/diagnostics/callbacks/__init__.py`` + +.. figure:: ../images/profiler/anemoi_profiler_architecture.png + :alt: Schematic of the AnemoiProfiler architecture + :align: center + +How to run it +============= + +The profiler has been built on top of the work already run in +anemoi-training. For that we have defined a new class ``AnemoiProfiler`` +that inherits from ``AnemoiTrainer`` where we just add new features and +methods relevant to the generation of the reports and activation of the +profiling mode. Similarly to how we do ``anemoi-trainining train`` to +submit a new training job, we had added an new command to execute a +profiler job, so we just need to do ``anemoi-training profiler``. + +Following the same concept as we have with the train command, the +profiler command is also controlled via the definition of a config. For +details about the config and the different fields required please refer +to the Config section. The full command to then execute the profiler is: + +.. code:: bash + + anemoi-training profiler --config-name=config.yaml + +The profiler requires certain new packages to be installed, and hence +has a specific section in the\ ``pyproject.toml`` +(``optional-dependencies.profile``). Hence the first time you'd like to +use you first need to make sure you have the dependencies installed by +doing: + +.. code:: bash + + pip install -e .[profile] + +Config +====== + +To control the execution of the anemoi benchmark profiler, we have to +define the following fields in the eval_rollout.yaml (inside the +diagnostics folder) file under benchmark_profiler key. + +As we mentioned the benchmark profiler can generate different reports. +For each report there is an entry in the config, that decide if we want +or not to generate the report ( if ``enabled:True`` the report is +generated, if enable:False, then the report is skipped). Some reports +have additional keys: + +- For the **time report**, we can also control the length/verbosity of + the report. If ``verbose: True``, the report will provide a more + concise set of actions while if False, the report will include the + full list of profiled actions. See Time Report section for more + information + +- In the case of the **memory report**, aside from the summary + statistics the MemoryProfiler can also provide some additional + insights that include memory traces and memory timeline, those can be + switched on by settings extra_plots entry. Additional config entries: + ``warmup``, ``steps`` and ``track_rank0_only`` provide more control + regarding the generation of the memory timeline and traces and are + explained in the memory profiler section. + +- For the **(memory) snapshot**, we can also control the + length/verbosity of the report. If ``verbose: True``, the report will + provide a more concise set of actions while if False, the report will + include the full list of profiled actions. See Time Report section + for more information + +.. figure:: ../images/profiler/anemoi_profiler_config.png + :alt: AnemoiProfiler Config Settings + :align: center + +**Note** - Anemoi Training also provides some functionality for quick +troubleshooting using just the PytorchProfiler. To know more about this +you can check the Troubleshooting section. This functionality is +activated by setting ``profiler:True`` in the diagnostics config. **When +using the benchmark profiler it's not necessary to set this flag**, +since the benchmark profiler will automatically activate the +PytorchProfiler when enabling the memory profiler. When running +``anemoi-training profiler`` it's then **recommended** to set +``profiler:False`` in the diagnostics config to avoid any conflicts. + +BenchmarkProfiler +================= + +The ``BenchmarkProfiler`` is the object in charge of generating the +memory report, time report, model summary and the system report. As the +diagram indicates, this class inherits from Pytorch Lightning Base +Profiler Class. Pytorch Lightning already provides built in +functionality that can be easily integrated with the Pytorch Lightning +Trainer to profile the code. In particular, it provides access to some +profilers +(https://pytorch-lightning.readthedocs.io/en/1.5.10/advanced/profiler.html) +that track performance across the training cycle in terms of execution +time (``Simple`` and ``Advanced`` Profilers) and in terms of CPU and GPU +usage (``Pytorch Profiler``). We have designed the Benchmark Profiler +taking advantage of that functionality and have extended it so it also +provides a system report and model summary. The diagram below +illustrates this. As can be seen the MemoryProfiler inherits from the +PytorchProfiler and generates the MemoryReport as main output, and the +TimeProfiler inherits from the SimlerProfiler and generates the Time +Report as output. + +.. figure:: ../images/profiler/anemoi_profiler_benchmark_profiler.png + :alt: AnemoiProfiler Config Settings + :align: center + +In the diagram, orange boxes mean output, dotted boxes refer to parent +classes. And ``get_memory_profiler_df``, ``get_time_profiler_df``, +``get_model_summary``, and ``get_system_profiler_df`` are the main +function interfaces of the BenchmarkProfiler. + +Time Report +----------- + +For the time report of our Benchmark Profiler we have decided to use the +``Simple Profiler``. This profiler provides support to profile both +callbacks, DataHooks and ModelHooks in the training and validation +loops. By default, the SimplerProfiler will record and output time +estimates for any of the callbacks, DataHooks and ModelHooks that +AnemoiTraining defines. To see this report, one just need to set in the +config ``verbose:True``. However, since this might quite extensive, +there is an option to generate a shorter and more concise version of the +time report with verbose:False, so that it focuses on the callbacks and +hooks coming from 3 main categories: + +- ``LightningDataModule (AnemoiDatasetDataModule)`` +- ``LightningModule (GraphForecaster)`` +- ``ParallelisationStrategy (DDPGroupStrategy)`` + +Aside from these 3 categories, the report also includes: + +- the execution time for the training_epoch (and training_batch) + - ``run_training_epoch/run_training_batch`` → Time it takes to + execute the 'training_step' per batch and per epoch ( check + https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/fit_loop.py + and + https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/loops/training_epoch_loop.py + for reference) + +- the time it takes the training dataloader and validation dataloader to fetch one batch: + - `[_TrainingEpochLoop].train_dataloader_next + `_ + - `[_EvaluationLoop].val_next + `_ + +- For the callbacks, the ``SimplerProfiler`` provides time estimates of + all the different steps defined for each class, so for simplicity the + report just aggregate all those times into a single quantity (see + below example of ``AnemoiCheckpoint`` Callback) + +Below you can find an example of the report the ``Time Profiler`` issues +after its execution. + +.. figure:: ../images/profiler/example_time_report.png + :alt: AnemoiProfiler Time Report + :align: center + +Note the above example corresponds to the time report generated when +verbose is set to False according to the config settings. If verbose is +set to True, then there is no filtering applied to the actions profiled, +and the time report will include many more entries. + +System Report +------------- + +This report provides a table with summary metrics in terms of GPU +utilisation & memory usage, CPU usage (system), average disk usage and +total execution time. For now the System profiler relies on the metrics +tracked by MlFlow which is the tool we use to track out ML-experiments. +If you run the profiler without MlFlow, it would still be possible to +generate all the other reports, but the code will indicate that the +system report can't be generated. + +When running anemoi-training with MlFlow activated, then this tool also +track a set of system metrics and log them into the UI. MlFlow does this +through the `SystemMetricsMonitor +`_. +For more information you can check their docs - +https://mlflow.org/docs/latest/system-metrics/index.html + +In this report we just simply take the average of those metrics, in the +case of those associated to the GPUS we also include metrics per GPU +device. + +Below you can find an example of the ``System Report`` + +.. figure:: ../images/profiler/example_system_report.png + :alt: AnemoiProfiler System Report + :align: center + :width: 300px + +Memory Profiler +--------------- + +As we mentioned above, PTL provides functionality to profile the code. +In particular one can use the PyTorch profiler to measure the time and +memory consumption of the model’s operators +(https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html). +The report includes including GPU/CPU utilisation, memory usage, and +execution time for different operations within the model. So far we have +configured it, so that report includes the top 20 operators with the +largest GPU utilisation (Note this can be adapted and we are keen to get +feedback). + +Below you can find an example of the report generated by the ``Memory +Profiler``: + +.. figure:: ../images/profiler/example_memory_report.png + :alt: AnemoiProfiler Memory Report + :align: center + +Note the difference between self cpu time and cpu time - operators can +call other operators, self cpu time excludes time spent in children +operator calls, while total cpu time includes it. Similarly the profiler +can also show the amount of memory (used by the model’s tensors) that +was allocated (or released - negative deallocation) during the execution +of the model’s operators. In the example, ‘self’ memory corresponds to +the memory allocated (released) by the operator, excluding the children +calls to the other operators. + +To use this functionality, one just needs to specify the following +entries in the config (Benchmark Profiler section): + +.. code:: yaml + + memory: + enabled: True + steps: 6 + warmup: 2 + extra_plots: False + trace_rank0_only: True + +The enabled flag will trigger the generation of the report showed above. +Tracing all of the execution can be slow and result in very large trace +files. To avoid this, we have some optional arguments that are passed to +the profiler scheduler. + +- warming up (``warmup=2`` steps), during this phase profiler starts + tracing, but the results are discarded; this phase is used to discard + the samples obtained by the profiler at the beginning of the trace + since they are usually skewed by an extra overhead; + +- active tracing (``active=6`` steps), during this phase profiler + traces and records data; + +**Note** if you use ``limit_batches`` in the dataloader, the number of +batches selected should be greater than the sum of warmup and steps. If +not, the profiler will not be able to generate the report. + +It's possible to also generate additional products/reports with the +memory profiler, the memory timeline and the memory traces. Those take +more time to generate and hence it is possible to choose if we want them +(extra_plots: True) or not (extra_plots: False). For details about those +exact plots please check the section below about **Memory Profiler +Extras**. If using multiple GPUs, the output of the memory traces will +be significantly larger. Since usually there are certain operations that +just happen on rank 0, it might be we are just interested in the outputs +coming from this device. It's possible then to generate traces and +results just from rank 0 by settings ``trace_rank0_only`` to True. Note +if we just have one device, then this flag doesn't make any difference, +it's just relevant in case we have more than 1. + +**Note Memory Profiler - Patch** + +We identified a bug in the PytorchProfiler and we awaiting for the fix +(see `PR `_) to be +included as part of the next Pytorch Release (so far it's just included +in the nightly version). To avoid hitting the error, in the current +AnemoiProfiler we have introduce a patch (see ``PatchedProfile`` class +in the ``profilers.py`` script). This patch will be removed from the +codebase as soon as we have a new Pytorch official release that include +the fix + +**Memory Profiler Extras - Memory Traces & Memory Timeline** + +**Memory Timeline** + +PytorchProfiler automatically generates categories based on the graph of +tensor operations recorded during profiling, it's possible to visualise +this categories and its evolution across the execution using the +``export_memory_timeline`` method. You can find an example of the memory +timeline plot below (this is an example from +https://pytorch.org/blog/understanding-gpu-memory-1/ ). The exported +timeline plot is in html format. + +.. figure:: ../images/profiler/example_memory_timeline.png + :alt: Example of PytorchProfiler's Memory Timeline + :align: center + +**Memory Traces** + +The PytorchProfiler enables recording of stack traces associated with +memory allocations, and results can be outputted as a .json trace file. +The PyTorch Profiler leverages the ``Kineto`` library to collect GPU +traces. . Kineto is the subsystem within Profiler that interfaces with +CUPTI. GPU kernels execute asynchronously, and GPU-side support is +needed to create the trace. NVIDIA provides this visibility via the +CUPTI library. + +The `Kineto `_ project enables: + +- Performance observability and diagnostics across common ML bottleneck + components. +- Actionable recommendations for common issues. +- Integration of external system-level profiling tools. +- Integration with popular visualization platforms and analysis + pipelines. + +Since these traces files are complex and challenging to interpret, it's +very useful to have other supporting packages to analyse them. Holistic +Trace Analysis (HTA), it's an open source performance analysis and +visualization Python library for PyTorch users. Holistic Trace Analysis +package, provides the following features: + +- **Temporal Breakdown** - Breakdown of time taken by the GPUs in terms + of time spent in computation, communication, memory events, and idle + time across all ranks. + +- **Kernel Breakdown** - Finds kernels with the longest duration on + each rank. + +- **Kernel Duration Distribution** - Distribution of average time taken + by longest kernels across different ranks. + +- **Idle Time Breakdown** - Breakdown of GPU idle time into waiting for + the host, waiting for another kernel or attribution to an unknown + cause. + +- **Communication Computation Overlap** - Calculate the percentage of + time when communication overlaps computation. + +- **Frequent CUDA Kernel Patterns** - Find the CUDA kernels most + frequently launched by any given PyTorch or user defined operator. + +- **CUDA Kernel Launch Statistics** - Distributions of GPU kernels with + very small duration, large duration, and excessive launch time. + +- **Augmented Counters (Queue length, Memory bandwidth)** - Augmented + trace files which provide insights into memory bandwidth utilized and + number of outstanding operations on each CUDA stream. + +- **Trace Comparison** - A trace comparison tool to identify and + visualize the differences between traces. + +- **CUPTI Counter Analysis** - An experimental API to get GPU + performance counters. By attributing performance measurements from + kernels to PyTorch operators roofline analysis can be performed and + kernels can be optimized. + +To be able to load the traces and explore them using HTA, one can set up +a jupyter notebook and run: + +.. code:: python + + from hta.trace_analysis import TraceAnalysis + from pathlib import Path + from hydra import initialize, compose + from omegaconf import OmegaConf + + base_path = Path.cwd().parent + with initialize(version_base=None, config_path="./"): + cfg = compose(config_name="config.yaml") + OmegaConf.resolve(cfg) + + + # Run anemoi-training profiler to generate the traces and get the run_id + run_id = "b0cc5f6fa6c0476aa1264ad7aacafb4d/" + tracepath = cfg.hardware.paths.profiler + run_id + analyzer = TraceAnalysis(trace_dir=tracepath) + + + # Temporal Breakdown + time_df = analyzer.get_temporal_breakdown() + +The function returns a dataframe containing the temporal breakdown for +each rank. See figure below. + +.. figure:: ../images/profiler/temporal_breakdown.png + :alt: Temporal Breakdown HTA Example + :align: center + +The idle time breakdown can be generated as follows: + +.. code:: python + + # Idle Time Breakdown + idle_time_df_r0 = analyzer.get_idle_time_breakdown() + +The function returns a dataframe containing the idle breakdown for each +rank. See figure below. + +.. figure:: ../images/profiler/idle_time_breakdown.png + :alt: Idle Time Breakdown HTA Example + :align: center + +Additionally, we can also look at kernel breakdown feature which breakds +down the time spent for each kernel type i.e. communication (COMM), +computation (COMP), and memory (MEM) across all ranks and presents the +proportion of time spent in each category. The percentage of time spent +in each category as a pie chart. + +.. code:: python + + # Kernel Breakdown + # NCCL changed their kernel naming convention so HTA v2.0 doesnt recognise communication kernels + # This can be fixed by editing one line of hta/utils/util.py, see https://github.com/facebookresearch/HolisticTraceAnalysis/pull/123 + + # see https://github.com/facebookresearch/HolisticTraceAnalysis/blob/main/examples/kernel_breakdown_demo.ipynb + kernel_type_metrics_df, kernel_metrics_df = analyzer.get_gpu_kernel_breakdown( + num_kernels=5, include_memory_kernels=True, visualize=True + ) + +The first dataframe returned by the function contains the raw values +used to generate the Pie chart. The second dataframe returned by +get_gpu_kernel_breakdown contains duration summary statistics for each +kernel. In particular, this includes the count, min, max, average, +standard deviation, sum and kernel type for each kernel on each rank. + +.. figure:: ../images/profiler/kernel_breakdown_dfs.png + :alt: Kernel Breakdown HTA - Dataframes Example + :align: center + +Using this data HTA creates many visualizations to identify performance +bottlenecks. + +- **Pie charts** of the top kernels for each kernel type for each rank. +- **Bar graphs** of the average duration across all ranks for each of + the top kernels and for each kernel type. + +.. figure:: ../images/profiler/kernel_breakdown_plots.png + :alt: Kernel Breakdown HTA - Plots Example + :align: center + +For more examples using HTA you can check +https://github.com/facebookresearch/HolisticTraceAnalysis/tree/main/examples +and the package docs https://hta.readthedocs.io/en/latest/. Additionally +we recommend this blog from Pytorch +https://pytorch.org/blog/trace-analysis-for-masses/ + +Model Summary +------------- + +While the ``ModelSummary`` does not fall within the category of any +report associated to computational performance, there is usually a +connection between the size of the model and it's demand for +computational resources. The ``ModelSummary`` provides a summary table +breaking down the model architecture and the number of trainable +parameters per layer. The functionality used to create this diagram +relies on https://github.com/TylerYep/torchinfo, and for the exact +details one can check the function ``get_model_summary`` defined as part +of the ``BenchmarkProfiler`` class. Below you can find an example of the +Model Summary produced. Note due to the size of the summary, the +screenshot below is truncated. + +.. figure:: ../images/profiler/example_model_summary.png + :alt: Example of AnemoiProfiler's Model Summary - Part I + :align: center + +.. figure:: ../images/profiler/example_model_summary_2.png + :alt: Example of AnemoiProfiler's Model Summary - Part II + :align: center + +ProfilerProgressBar +=================== + +**Speed Report** + +While time and speed are related, we wanted to have a separate ``Speed +Report`` that would just focus on the metrics associated to training and +validation loops throughput. To get those metrics we take advantage of +the iterations per second reported by the ``TQDMProgress`` bar, that can +be easily integrated when running a model with PTL. As indicated in the +diagram below, the ProfilerProgressBar inherits from (TQDMProgress) and +generates as main output the SpeedReport. + +The progress bar measures the iteration per second ``it/s`` by computing +the elapsed time at the start and end of each training and validation +iteration** (where iteration in this case refers to number of batches in +each epoch). The report provides an aggregated throughput by taking the +average across all epochs. Since this metric can be sensitive to the +number of samples per batch, the report includes a throughput_per_sample +where we simply just normalised the aggregated metrics taking into +account the batch size used for training and validation. Ib the case of +the dataloader(s) throughput this refers to the performance of +dataloader in terms of fetching and collating a batch, and again since +this metric can be influence by the selected batch size, we also +provided a normalised dataloader throughput. + +.. figure:: ../images/profiler/anemoi_profiler_speedreport_diagram.png + :alt: AnemoiProfiler's Speed Report Architecture + :align: center + :width: 200px + +Note, this is not just the ``training_step`` as we had recorded in the +'Time Profiler Report' but it also includes all the callbacks/hooks that +are executed during each training/validation iteration. Since most of +our callbacks are related to sanity and validation plots carried out +during the validation, we should expect lower throughputs compared to +training + +Below you can find an example of the report generated by the ``Speed +Profiler``: + +.. figure:: ../images/profiler/anemoi_profiler_speed_report.png + :alt: Example of AnemoiProfiler's Speed Report + :align: center + :width: 300px + +** CUDA and CPU total time as just time metrics (in seconds) computed by +the Memory Profiler. For now we have decided to ingrate and display them +as part of the Speed Report, but we can revisit that decision based on +user feedback + +MemorySnapshotRecorder +====================== + +With the latest pytorch versions (Pytorch equal or higher than 2.1), the +library introduces new features to analyse the GPU memory footprint. +https://pytorch.org/docs/stable/torch_cuda_memory.html#generating-a-snapshot +. The AnemoiProfiler integrates these new features through a custom +callback ``MemorySnapshotRecorder``. The memory snapshot generated is a +pickle file that records the state of allocated CUDA memory at any point +in time, and optionally record the history of allocation events that led +up to that snapshot. Captured memory snapshots will show memory events +including allocations, frees and OOMs, along with their stack traces. +The generated snapshots can then be drag and dropped onto the +interactive viewer hosted at pytorch.org/memory_viz which can be used to +explore the snapshot. To activate this callback, one just need to +specify the following entries in the config (Benchmark Profiler +section): + +.. code:: yaml + + snapshot: + enabled: True + steps: 6 + warmup: 2 + +If we don't want to generate a snapshot we simply set the ``enabled`` +flag to False. If we enable the snapshot recorder, then we need to +define the number of steps we want to record. Note a bigger number of +steps will generate a heavier file that then might take longer to render +in the website (pytorch.org/memory_viz). + +The Callback so far is defined to start tracking the CUDA memory at the +start of the training batch, when the global step matches the number of +warmup steps and end at the end of the training batch when the global +step matches the number of total steps (steps+warmup) defined. Note if +warmup is null then no warmup steps are considered, and the recording +will star as soon as the training starts. + +.. figure:: ../images/profiler/memory_snapshot_diagram.png + :alt: AnemoiProfiler's MemorySnapshotRecorder Architecture + :align: center + :width: 200px + +In the example below you can see how a ``memory snapshot`` for 6 steps +looks: + +.. figure:: ../images/profiler/memory_snapshot_output.png + :alt: Example of AnemoiProfiler's Memory Snapshot + :align: center + +******************** + Mlflow Integration +******************** + +If using MlFlow to track your run, then all the reports generated by the +profiler will also be logged into Mlflow. For now, speed, time, memory +and system reports are logged to mlflow both as json and csv files. We +hope to receive feedback about this, so in the future we can choose on +the two formats. The additional outputs generated by the memory profiler +(memory timeline are traces aren't tracked as part of mlflow due to +large size of those files). + +.. figure:: ../images/profiler/anemoi_profiler_mlflow_integration.png + :alt: AnemoiProfiler - Mlflow integration + :align: center + +One of the advantages of logging the reports as jsons, it's that those +files can be logged as ``table artifacts`` and then we can compared them +across different runs through the Evaluation tab. Below you can see an +example where we are comparing the system report metrics and speed +metrics for two different runs + +.. figure:: ../images/profiler/anemoi_profiler_mlflow_integration_2.png + :alt: AnemoiProfiler - Example Table Evaluation + :align: center + +Speed report - train/validation rates +===================================== + +When using MlFlow, there are two additional metrics that can be +explored, + +- ``training_rate`` - that's the iterations per second (it/s) recorded + by the `ProfilerProgressBar` across the training cycle. While the + SpeedReport provides the averaged throughput + `training_avg_throughput` the rate allows to see the evolution of the + throughput in time. + +- ``validation_rate`` - that's the iterations per second (it/s) + recorded by the `ProfilerProgressBar` across the validation cycle. + While the SpeedReport provides the averaged throughput + `validation_avg_throughput` the rate allows to see the evolution of + the throughput in time. + +Note - to get those metrics it's need to enable the ``SpeedProfiler``. +Below you can find an example of how the ``training_rate`` and +``validation_rate`` look like for two different runs. + +.. figure:: ../images/profiler/anemoi_profiler_training_rates.png + :alt: Example of AnemoiProfiler's Training Rates + :align: center + +.. figure:: ../images/profiler/anemoi_profiler_validation_rates.png + :alt: Example of AnemoiProfiler's Validation Rates + :align: center + +**************************** + Limitations & Improvements +**************************** + +Limitations​ +============ + +- General challenge for AI code benchmarking results → Noise coming + from hardware and AI stochastic behaviour​ + +- ``SpeedReport`` → Robustness of the metrics (val/train rates and + throughput) ​​ + +- ``TimeProfiler`` → Ability to profile just part of the code (so far + the SimplerProfiler just records 'pre-defined' hardcoded actions + according to the PROFILER_ACTIONS defined in the codebase. And as + mentioned above those actions need to be a DataHook, ModelHook or + Callback. ​ + +- ``TimeProfiler`` → Limitations to time asyncronous part of the code​ + +- ``MemoryProfiler`` → Report requires good understanding of pytorch + profiler model's operators + +- ``SpeedReport`` → Train/val rates categorisation + +Improvements​​ +============== + +- https://pytorch.org/tutorials/recipes/recipes/benchmark.html​ + +- Decorator style to do partial profiling - + https://github.com/pythonprofilers/memory_profiler or + https://github.com/pyutils/line_profiler + +- Defining a decorator o wrapper for the ``TimeProfiler`` could be + helpful to provide more control and access to time profiling other + parts of the codebase​ + +- Asynchronous code profiling -> https://github.com/sumerc/yappi​ + +- Performance benchmarking and integration with CI/CD - possibility to + run the profiler for different code releases as part of github + actions​ + +- Energy reports ​ + +- Better compatibility with other hardware ( AMD GPUs, IPUs, etc). - + System metrics monitor might not work out of the box with other + hardware different from Nvidia, since the library it uses to record + the gpu metrics it's pynvml. We could extend the functionality to be + able to profile other hardware like AMS GPUs or Graphcore IPUs + +- Support other components of Anemoi like ``anemoi-inference`` diff --git a/pyproject.toml b/pyproject.toml index 2baeea22..9725bb40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,13 @@ optional-dependencies.docs = [ "sphinx-argparse", "sphinx-rtd-theme", ] +optional-dependencies.profile = [ + "holistictraceanalysis>=0.2", + "pandas>=1.3.2", + "rich>=13.6", + "tabulate>=0.9", +] + optional-dependencies.tests = [ "hypothesis", "pytest", "pytest-mock" ] urls.Changelog = "https://github.com/ecmwf/anemoi-training/CHANGELOG.md" diff --git a/src/anemoi/training/commands/profiler.py b/src/anemoi/training/commands/profiler.py new file mode 100644 index 00000000..53523f41 --- /dev/null +++ b/src/anemoi/training/commands/profiler.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import sys +from typing import TYPE_CHECKING + +from anemoi.training.commands import Command + +if TYPE_CHECKING: + import argparse + +LOGGER = logging.getLogger(__name__) + + +class Profile(Command): + """Commands to profile Anemoi models.""" + + accept_unknown_args = True + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + return parser + + @staticmethod + def run(args: list[str], unknown_args: list[str] | None = None) -> None: + del args + + if unknown_args is not None: + sys.argv = [sys.argv[0], *unknown_args] + else: + sys.argv = [sys.argv[0]] + + LOGGER.info("Running anemoi profiling command with overrides: %s", sys.argv[1:]) + from anemoi.training.train.profiler import main as anemoi_profile + + anemoi_profile() + + +command = Profile diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index 50e9a647..8d67b50d 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -57,6 +57,28 @@ debug: # remember to also activate the tensorboard logger (below) profiler: False +# Use anemoi-profile to profile the training process +benchmark_profiler: + memory: + enabled: True + steps: 5 # wait warmup steps and then do steps (too many steps would lead to a big file) + warmup: 2 + extra_plots: False + trace_rank0_only: False #set to true and it will profile rank 0 only. Reads SLURM_PROC_ID so won't work when not running via Slurm + time: + enabled: True + verbose: False #If true, output every action the profiler caputres, otherwise output a subset defined in PROFILER_ACTIONS at the top of aifs/diagnostics/profiler.py + speed: + enabled: True + system: + enabled: True + model_summary: + enabled: True + snapshot: + enabled: True + steps: 4 # wait warmup steps and then do steps + warmup: 0 + checkpoint: every_n_minutes: save_frequency: 30 # Approximate, as this is checked at the end of training steps diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 870eeb7a..acd6654b 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -19,6 +19,8 @@ multistep_input: 2 # the effective batch size becomes num-devices * batch_size * k accum_grad_batches: 1 +num_sanity_val_steps: 6 + # clipp gradients, 0 : don't clip, default algorithm: norm, alternative: value gradient_clip: val: 32. diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..58fbebcd 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -37,6 +37,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.types import STEP_OUTPUT from anemoi.training.diagnostics.plots import init_plot_settings from anemoi.training.diagnostics.plots import plot_graph_features @@ -870,6 +871,71 @@ def on_load_checkpoint( pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] +class MemorySnapshotRecorder(Callback): + """Record memory snapshot using torch.cuda._record_memory_history().""" + + def __init__(self, config): + super().__init__() + self.config = config + self.dirpath = Path(self.config.hardware.paths.profiler) + + self.warmup = self.config.diagnostics.benchmark_profiler.snapshot.warmup + if not self.warmup: + self.warmup = 0 + self.num_steps = ( + self.config.diagnostics.benchmark_profiler.snapshot.steps + self.warmup + ) # be consistent with profiler scheduler + self.status = False + + assert ( + self.num_steps % self.config.dataloader.batch_size.training == 0 + ), "Snapshot steps is not a multiple of batch size" + assert ( + self.warmup % self.config.dataloader.batch_size.training == 0 + ), "Snapshot Warmup steps is not a multiple of batch size" + + @rank_zero_only + def _start_snapshot_recording(self): + LOGGER.info("Starting snapshot record_memory_history") + torch.cuda.memory._record_memory_history() + self.status = True + + @rank_zero_only + def _save_snapshot(self): + self.memory_snapshot_fname = self.dirpath / "memory_snapshot.pickle" + try: + LOGGER.info("Saving memory snapshot to %s", self.memory_snapshot_fname) + torch.cuda.memory._dump_snapshot(f"{self.memory_snapshot_fname}") + except Exception as e: + LOGGER.error(f"Failed to capture memory snapshot {e}") + + @rank_zero_only + def stop_record_memory_history(self) -> None: + LOGGER.info("Stopping snapshot record_memory_history") + torch.cuda.memory._record_memory_history(enabled=None) + + def on_train_batch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int + ) -> None: + if trainer.global_step == self.warmup: + self._start_snapshot_recording() + + def on_train_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if trainer.global_step == self.num_steps: + if self.status is True: + self._save_snapshot() + self.stop_record_memory_history() + else: + LOGGER.info("Snapshot recording was not started so no snapshot was saved") + + class AnemoiCheckpoint(ModelCheckpoint): """A checkpoint callback that saves the model after every validation epoch.""" diff --git a/src/anemoi/training/diagnostics/profilers.py b/src/anemoi/training/diagnostics/profilers.py new file mode 100644 index 00000000..4261fe83 --- /dev/null +++ b/src/anemoi/training/diagnostics/profilers.py @@ -0,0 +1,698 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +from __future__ import annotations + +import json +import logging +import re +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any + +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from pytorch_lightning.callbacks import TQDMProgressBar +from pytorch_lightning.profilers import Profiler +from pytorch_lightning.profilers import PyTorchProfiler +from pytorch_lightning.profilers import SimpleProfiler +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import importlib + + import pytorch_lightning as pl + from omegaconf import DictConfig + from pytorch_lightning.utilities.types import STEP_OUTPUT + + from anemoi.training.train.forecaster import GraphForecaster + + if importlib.util.find_spec("ipywidgets") is not None: + from tqdm.auto import tqdm as _tqdm + else: + from tqdm import tqdm as _tqdm + +from torch.profiler import profile + +from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger + +LOGGER = logging.getLogger(__name__) + + +def check_torch_version() -> bool: + torch_version = torch.__version__ + version_nums = torch_version.split(".") + major_version = int(version_nums[0]) + minor_version = int(version_nums[1]) + if major_version == 2 and minor_version >= 1: + return True + LOGGER.error("Memory snapshot is only supported for torch >= 2.1") + return False + + +def convert_to_seconds(time_str: str) -> float: + import re + + pattern = r"(\d+(\.\d+)?)\s*([a-zA-Z]+)" + # Use regex to find matches + match = re.match(pattern, time_str) + + # Check if match is found + if match: + # Extract numeric part and unit part + numeric_part = float(match.group(1)) + unit = match.group(3) + + # Convert the unit to seconds + if unit == "s": + return numeric_part + if unit == "ds": + return numeric_part / 10 # Convert decaseconds to seconds + if unit == "cs": + return numeric_part / 100 # Convert centiseconds to seconds + if unit == "ms": + return numeric_part / 1000 # Convert milliseconds to seconds + error_msg = ( + "Invalid unit. Supported units are: 's' (seconds)'" + "'ds' (decaseconds), 'cs' (centiseconds) and 'ms' (miliseconds) .", + ) + raise ValueError(error_msg) + error_msg = "Invalid time format. The time should be in the format: 'numeric_part unit'. For example: '10 ms'" + raise ValueError(error_msg) + + +PROFILER_ACTIONS = [ + r"\[Strategy]\w+\.batch_to_device", + r"\[Strategy]\w+\.backward", + r"\[Strategy]\w+\.training_step", + r"\[Strategy]\w+\.validation_step", + r"\[Strategy]\w+\.batch_to_device", + "run_training_epoch", + "run_training_batch", + r"\[_EvaluationLoop\]\.\w+", + r"\[_TrainingEpochLoop\]\.\w+", + r"\[LightningDataModule]\w+\.train_dataloader", + r"\[LightningDataModule]\w+\.val_dataloader", + r"\[LightningDataModule]\w+\.state_dict", + r"\[LightningDataModule]\w+\.setup", + r"\[LightningDataModule]\w+\.prepare_data", + r"\[LightningDataModule]\w+\.teardown", + r"\[LightningModule]\w+\.optimizer_step", + r"\[LightningModule]\w+\.configure_gradient_clipping", + r"\[LightningModule]\w+\.on_validation_model_eval", + r"\[LightningModule]\w+\.optimizer_zero_grad", + r"\[LightningModule]\w+\.transfer_batch_to_device", + r"\[LightningModule]\w+\.on_validation_model_train", + r"\[LightningModule]\w+\.configure_optimizers", + r"\[LightningModule]\w+\.lr_scheduler_step", + r"\[LightningModule]\w+\.configure_sharded_model", + r"\[LightningModule]\w+\.setup", + r"\[LightningModule]\w+\.prepare_data", + r"\[Callback\](.*Plot*)", + r"\[Callback\](.*Checkpoint*)", +] + +GPU_METRICS_DICT = { + "GPU device utilization (%)": "gpu", + "GPU memory use (%)": "memory", + "GPU memory allocated (%)": "memoryAllocated", + "GPU memory allocated (GB)": "memoryAllocatedBytes", +} + + +class WandBSystemSummarizer: + """Summarize System Metrics provided by W&B logger.""" + + def __init__(self, wandb_logger: pl.loggers.WandbLogger): + + run_dict = wandb_logger._wandb_init + self.run_id_path = f"{run_dict['entity']}/{run_dict['project']}/{run_dict['id']}" + + def get_wandb_metrics(self) -> (pd.DataFrame, dict): + """Fetches system metrics and metadata from a W&B run.""" + import wandb + + run = wandb.Api().run(self.run_id_path) + system_metrics = run.history(stream="events") + metadata_dict = run.metadata + system_metrics = system_metrics.dropna() + return system_metrics, metadata_dict + + def summarize_gpu_metrics(self, df: pd.DataFrame) -> dict[str, float]: + """Given the System Metrics DataFrame, summarized the GPU metrics. + + - gpu.{gpu_index}.memory - GPU memory utilization in percent for each GPU + - gpu.{gpu_index}.memoryAllocated - GPU memory allocated as % of the total available memory for each GPU + - gpu.{gpu_index}.memoryAllocatedBytes - GPU memory allocated in bytes for each GPU + - gpu.{gpu_index}.gpu - GPU utilization in percent for each GPU + """ + average_metric = {} + col_names = df.columns + for gpu_metric_name, gpu_metric in GPU_METRICS_DICT.items(): + pattern = rf"system.gpu.\d.{gpu_metric}$" + sub_gpu_cols = [string for string in col_names if re.match(pattern, string)] + metrics_per_gpu = df[sub_gpu_cols].mean(axis=0) + if gpu_metric == "memoryAllocatedBytes": + metrics_per_gpu = metrics_per_gpu * 1e-9 + average_metric[gpu_metric_name] = metrics_per_gpu.mean() + # Just add metrics per gpu to the report if we have more than 1 GPU + if metrics_per_gpu.shape[0] > 1: + metrics_per_gpu.index = [" " + index for index in metrics_per_gpu.index] + average_metric.update(dict(metrics_per_gpu)) + return average_metric + + def summarize_system_metrics(self) -> dict[str, float]: + r"""Summarizes the System metrics from a W&B run. + + Some of the metrics included are: + - cpu.{}.cpu_percent - CPU usage of the system on a per-core basis. + - system.memory - Represents the total system memory usage as a percentage of the total available memory. + - system.cpu - Percentage of CPU usage by the process, normalized by the number of available CPUs + - system.disk.\\.usageGB - (Represents the total system disk usage in gigabytes (GB)) + - system.proc.memory.percent - Indicates the memory usage of the process as a % of the total available memory + + More information about W&B system metrics can be found here: + https://docs.wandb.ai/guides/app/features/system-metrics + """ + system_metrics_df, metadata_dict = self.get_wandb_metrics(self.run_id_path) + + col_names = system_metrics_df.columns + system_metrics = {} + + n_cpus = metadata_dict["cpu_count"] + cpu_cols = list(filter(lambda k: "cpu." in k, col_names)) + system_metrics["avg CPU usage (%)"] = (system_metrics_df[cpu_cols].sum(axis=1) / n_cpus).mean() + + system_metrics_gpu = self.summarize_gpu_metrics(system_metrics_df) + system_metrics.update(system_metrics_gpu) + + system_metrics["avg Memory usage (%)"] = system_metrics_df["system.memory"].mean() + system_metrics["avg Disk usage (GB)"] = system_metrics_df["system.disk.\\.usageGB"].mean() + system_metrics["avg Disk usage (%)"] = system_metrics_df["system.disk.\\.usagePercent"].mean() + + system_metrics["execution time (sec)"] = system_metrics_df["_runtime"].iloc[-1] # in seconds + return system_metrics + + +class MLFlowSystemSummarizer: + """Summarize System Metrics provided by MlFlow logger.""" + + def __init__(self, mlflow_logger: pl.loggers.MLFlowLogger): + self.run_id = mlflow_logger.run_id + self.mlflow_client = mlflow_logger._mlflow_client + + @property + def system_metrics(self) -> list[str]: + run = self.mlflow_client.get_run(self.run_id) + return [metric for metric in run.data.metrics if "system" in metric] + + def _clean_metric_name(self, metric_name: str) -> str: + return ( + metric_name.replace("system.", "avg ") + .replace("_", " ") + .replace("megabytes", "MB") + .replace("percentage", "%") + ) + + def _get_mean(self, pattern: str, df: pd.DataFrame) -> float: + # Filter rows containing the pattern in the 'metric' column + filtered_rows = df[df["metric"].str.contains(pattern)] + return filtered_rows.loc[:, "value"].astype(np.float32).mean() + + def _extract_gpu_metrics(self, df: pd.DataFrame) -> pd.DataFrame: + # Define the pattern you want to search for + pattern = r"gpu\s\d+\s+utilization" + df.loc[len(df.index)] = ["avg GPU utilization (%)", self._get_mean(pattern, df)] + + pattern = r"gpu\s\d+\s+memory\s+usage\s+%" + df.loc[len(df.index)] = ["avg GPU memory usage %", self._get_mean(pattern, df)] + + pattern = r"gpu\s\d+\s+memory\s+usage\s+MB" + df.loc[len(df.index)] = ["avg GPU memory usage MB", self._get_mean(pattern, df)] + + return df + + def summarize_mlflow_system_metrics(self) -> pd.DataFrame: + rows = [] + for metric in self.system_metrics: + metric = self.mlflow_client.get_metric_history(self.run_id, metric) + avg_value = sum(m.value for m in metric) / len(metric) + metric_name = self._clean_metric_name(metric[0].key) + rows.append({"metric": metric_name, "value": f"{avg_value:.2f}"}) + return self._extract_gpu_metrics(pd.DataFrame(rows)) + + +class DummyProfiler(Profiler): + """Placeholder profiler.""" + + def __init__(self): + super().__init__() + + def start(self, *args, **kwargs) -> None: + pass + + def stop(self, *args, **kwargs) -> None: + pass + + +class PatchedProfile(profile): + + def _get_distributed_info(self) -> dict[str, str]: + dist_info = super()._get_distributed_info() + return json.dumps(dist_info, default=str) + + +class BenchmarkProfiler(Profiler): + """Custom PyTorch Lightning profiler for benchmarking.""" + + def __init__(self, config: DictConfig) -> None: + super().__init__(config) + + self.config = config + self.warmup = self.config.diagnostics.benchmark_profiler.memory.warmup + if not self.warmup: + self.warmup = 0 + self.num_steps = self.config.diagnostics.benchmark_profiler.memory.steps + + if self.config.diagnostics.benchmark_profiler.memory.extra_plots: + assert ( + self.num_steps <= self.config.training.num_sanity_val_steps + ), "Sanity steps should be less than snapshot steps, to avoid memory issues" + + self.dirpath = None + self.create_output_path() + # the profilers need to be initialised before the setup method because + # actions like configuring callbacks would trigger the profiler + self.memory_profiler = DummyProfiler # dummy profiler to be used as placeholder + self.time_profiler = DummyProfiler # dummy profiler to be used as placeholder + + @rank_zero_only + def create_output_path(self) -> None: + self.dirpath = Path(self.config.hardware.paths.profiler) + self.dirpath.mkdir(parents=True, exist_ok=True) + + def broadcast_profiler_path(self, string_var: str, src_rank: int) -> str: + from lightning_fabric.utilities.distributed import group as _group + + string_var = [string_var] + dist.broadcast_object_list(string_var, src_rank, group=_group.WORLD) + return string_var[0] + + def setup(self, stage: str, local_rank: int | None = None, log_dir: str | None = None) -> None: + del log_dir + # THE STRATEGY IS ALREADY INITIALISED AND TORCH DISTRIBUTED IS ACTIVE + # we need to broadcast the profiler path to all ranks to save the memory traces + self.dirpath = Path(self.broadcast_profiler_path(str(self.dirpath), 0)) + self._stage = stage + self._local_rank = local_rank + self._create_time_profilers() + self._create_memory_profilers() + + def _create_time_profilers(self) -> None: + """Creates profilers for time and memory measurements.""" + if self.config.diagnostics.benchmark_profiler.time.enabled: + self.time_profiler = SimpleProfiler( + dirpath=self.dirpath, + ) + + def _create_memory_profilers(self) -> None: + if self.config.diagnostics.benchmark_profiler.memory.enabled: + import os + + def trace_handler(dir_name: str, stage: str | None = None) -> callable: + + def handler_fn(prof: pl.profilers.Profiler) -> None: + import socket + import time + + worker_name = f"{socket.gethostname()}_{os.getpid()}" + file_name = str(dir_name / f"{worker_name}.{stage}.{time.time_ns()}.pt.trace.json") + prof.export_chrome_trace(file_name) + + return handler_fn + + global_rank = int(os.environ.get("SLURM_PROCID", "0")) # WON'T WORK WHEN RUNNING WITHOUT SLURM + if not (self.config.diagnostics.benchmark_profiler.memory.trace_rank0_only and global_rank != 0): + from pytorch_lightning.profilers.pytorch import _KINETO_AVAILABLE + + assert ( + _KINETO_AVAILABLE + ), "Kineto is not available. Please ensure Kineto is avaialble to be able to use the memory profiler" + + torch.profiler.profile = ( + PatchedProfile # patch the profile(KinetoProfile) object to serialise the distributed info + ) + self.memory_profiler = PyTorchProfiler( + with_stack=True, + emit_nvtx=False, + profile_memory=True, + export_to_chrome=True, + record_shapes=True, + group_by_input_shapes=True, + dirpath=self.dirpath, + on_trace_ready=trace_handler(self.dirpath), + schedule=torch.profiler.schedule( + wait=0, + warmup=self.warmup, + active=self.num_steps, + repeat=1, + skip_first=self.config.training.num_sanity_val_steps, + ), + ) + self.time_rows_dict = None # updated if we create a memory profile report + + def start(self, action_name: str) -> None: + """Starts recording for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.start(action_name) + self.memory_profiler.start(action_name) + + def stop(self, action_name: str) -> None: + """Stops recording for a specific action. + + Parameters + ---------- + action_name : str + Name of the action. + """ + self.time_profiler.stop(action_name) + self.memory_profiler.stop(action_name) + + def _trim_time_report(self, recorded_actions: dict) -> dict[str, float]: + all_actions_names = recorded_actions.keys() + df = pd.DataFrame({"Strings": all_actions_names}) + combined_pattern = "|".join(PROFILER_ACTIONS) + filtered_df = df[df["Strings"].str.contains(combined_pattern, regex=True, na=False)] + trimmed_actions_names = filtered_df["Strings"].tolist() + return {key: recorded_actions[key] for key in trimmed_actions_names} + + def get_time_profiler_df(self, precision: int = 5) -> pd.DataFrame: + """Retrieves a DataFrame with time profiling information. + + Parameters + ---------- + precision : int + Precision for rounding, by default 5 + + Returns + ------- + pd.DataFrame + DataFrame with time profiling information. + """ + if self.config.diagnostics.benchmark_profiler.time.verbose is False: + self.time_profiler.recorded_durations = self._trim_time_report( + recorded_actions=self.time_profiler.recorded_durations, + ) + time_df = pd.DataFrame(self.time_profiler.recorded_durations.items()) + time_df[2] = time_df[1].apply(len) + time_df[3] = time_df[1].apply(np.mean) + time_df[1] = time_df[1].apply(sum) + time_df.columns = ["name", "total_time", "n_calls", "avg_time"] + + def replace_function(value: str) -> str: + # Replace 'apple' with 'fruit' + return re.sub(r"\{.*?\}", "", value) # Remove anything between braces + + time_df["name"] = time_df["name"].apply(replace_function) + pattern = r"\[(.*?)\]|(.*)" + time_df["category"] = time_df["name"].str.extract(pattern, expand=False)[0].fillna(time_df["name"]) + + pattern = re.compile(r"\[Callback\](.*?)\.") + # Apply the regular expression to the column + callbacks_subcategories = "*Callback_" + time_df[time_df["category"] == "Callback"]["name"].str.extract(pattern) + indexer = time_df[time_df["category"] == "Callback"].index + time_df.loc[indexer, "category"] = callbacks_subcategories[0].tolist() + + # Check if 'Callback' is present in the 'category' column + time_df["is_callback"] = time_df["category"].str.contains("Callback", case=False) + + # Group by the 'is_callback' column and apply groupby operation only on rows with 'Callback' in 'category' + grouped_data = ( + time_df[time_df["is_callback"]] + .groupby("category") + .agg({"n_calls": "sum", "avg_time": "sum", "total_time": "sum"}) + .reset_index() + ) + grouped_data["name"] = grouped_data["category"] + + time_df = pd.concat([time_df[~time_df["is_callback"]], grouped_data]) + time_df = time_df.drop("is_callback", axis=1) + time_df = time_df.round(precision) + time_df = time_df.sort_values(by="category", ascending=False) + + self.time_report_fname = self.dirpath / "time_profiler.csv" + self._save_report(time_df, self.time_report_fname) + return time_df + + @staticmethod + def to_df(sample_dict: dict[str, float], precision: str = ".5") -> pd.DataFrame: + df = pd.DataFrame(sample_dict.items()) + df.columns = ["metric", "value"] + df.value = df.value.apply(lambda x: f"%{precision}f" % x) + return df + + @rank_zero_only + def get_system_profiler_df(self, logger_name: str, logger: pl.loggers.Logger) -> pd.DataFrame: + if logger_name == "wandb": + system_metrics_df = self.to_df(WandBSystemSummarizer(logger).summarize_system_metrics()) + elif logger_name == "mlflow": + system_metrics_df = MLFlowSystemSummarizer(logger).summarize_mlflow_system_metrics() + elif logger_name == "tensorboard": + LOGGER.info("No system profiler data available for Tensorboard") + system_metrics_df = None + + self.system_report_fname = self.dirpath / "system_profiler.csv" + self._save_report(system_metrics_df, self.system_report_fname) + return system_metrics_df + + def _save_report(self, df: pd.DataFrame, fname: Path) -> None: + df.to_csv(fname) + + def _save_model_summary(self, model_summary: str, fname: Path) -> None: + with fname.open("w") as f: + f.write(model_summary) + f.close() + + def get_model_summary(self, model: GraphForecaster, example_input_array: np.ndarray) -> str: + + from torchinfo import summary + + # when using flash attention model, we need to convert the input and model to float16 and cuda + # since FlashAttention only supports fp16 and bf16 data type + example_input_array = example_input_array.to(dtype=torch.float16) + example_input_array = example_input_array.to("cuda") + model.half() + model = model.to("cuda") + + summary_str = str( + summary( + model, + input_data=example_input_array, + depth=20, + col_width=16, + col_names=["trainable", "input_size", "output_size", "num_params", "params_percent", "mult_adds"], + row_settings=["var_names"], + verbose=0, + ), + ) + self.model_summary_fname = self.dirpath / "model_summary.txt" + self._save_model_summary(summary_str, self.model_summary_fname) + return summary_str + + @rank_zero_only + def get_speed_profiler_df(self, progressbar: _tqdm) -> pd.DataFrame: + """Computes the speed metrics based on training and validation rates.""" + speed_metrics = {} + + batch_size_tr = self.config.dataloader.batch_size.training + batch_size_val = self.config.dataloader.batch_size.validation + + training_rates_array = np.array(progressbar.training_rates) + speed_metrics["training_avg_throughput"] = training_rates_array.mean() + speed_metrics["training_avg_throughput_per_sample"] = training_rates_array.mean() / batch_size_tr + + validation_rates_array = np.array(progressbar.validation_rates) + speed_metrics["validation_avg_throughput"] = validation_rates_array.mean() + speed_metrics["validation_avg_throughput_per_sample"] = validation_rates_array.mean() / batch_size_val + + # Calculate per_sample metrics + speed_metrics["avg_training_dataloader_throughput"] = ( + 1 / np.array(self.time_profiler.recorded_durations["[_TrainingEpochLoop].train_dataloader_next"]).mean() + ) + speed_metrics["avg_training_dataloader_throughput_per_sample"] = ( + speed_metrics["avg_training_dataloader_throughput"] / batch_size_tr + ) + + speed_metrics["avg_validation_dataloader_throughput"] = ( + 1 / np.array(self.time_profiler.recorded_durations["[_EvaluationLoop].val_next"]).mean() + ) + speed_metrics["avg_validation_dataloader_throughput_per_sample"] = ( + speed_metrics["avg_validation_dataloader_throughput"] / batch_size_val + ) + + if self.time_rows_dict: + speed_metrics.update(self.time_rows_dict) + + speed_profile_df = self.to_df(speed_metrics) + + self.speed_report_fname = self.dirpath / "speed_profiler.csv" + self._save_report(speed_profile_df, self.speed_report_fname) + + return speed_profile_df + + def _save_extra_plots(self) -> None: + if check_torch_version(): + # !it's available for torch >= 2.1 + from torch.cuda._memory_viz import profile_plot + + self.memory_trace_fname = Path(self.dirpath, "memory_trace.html") + with self.memory_trace_fname.open("w") as f: + f.write(profile_plot(self.memory_profiler.profiler)) + + # !it's available for torch >= 2.1 + self.memory_timeline_fname = str(Path(self.dirpath, "memory_timelines.html")) + self.memory_profiler.profiler.export_memory_timeline(self.memory_timeline_fname) + + @rank_zero_only + def get_memory_profiler_df(self) -> pd.DataFrame: + """Retrieves the memory profiler data as a DataFrame. + + Aggregates the results coming from multiple nodes/processes. + + Returns + ------- + pd.DataFrame + Memory profiler data. + """ + if self.config.diagnostics.benchmark_profiler.memory.extra_plots: + self._save_extra_plots() + + self.memory_profiler._delete_profilers() + + if not self.memory_profiler.function_events: + return "" + + data = self.memory_profiler.function_events.key_averages( + group_by_input_shapes=self.memory_profiler._group_by_input_shapes, + ) + table = data.table( + sort_by=self.memory_profiler._sort_by_key, + row_limit=self.memory_profiler._row_limit, + **self.memory_profiler._table_kwargs, + ) # this is a string + + from io import StringIO + + table_main_body = table.split("\n")[:-3] # Remove the last rows + columns = [ + "Name", + "Self CPU %", + "Self CPU", + "CPU total %", + "CPU total", + "CPU time avg", + "Self CUDA", + "Self CUDA %", + "CUDA total", + "CUDA time avg", + "CPU Mem", + "Self CPU Mem", + "CUDA Mem", + "Self CUDA Mem", + "# of Calls", + "Input Shapes", + ] + table_main_body = "\n".join(table_main_body) + memory_df = pd.read_fwf(StringIO(table_main_body), names=columns, skiprows=2) + flag = ["--" not in row for row in memory_df["Name"]] + memory_df = memory_df[flag] + time_rows = [row for row in table.split("\n")[-3:] if row != ""] + if time_rows: + time_rows_dict = {} + for row in time_rows: + key, val = row.split(":") + val = convert_to_seconds(val.strip()) + time_rows_dict[key] = val + self.time_rows_dict = time_rows_dict + + memory_df = memory_df[~memory_df["Name"].isin(time_rows)] + + self.memory_report_fname = self.dirpath / "memory_profiler.csv" + self._save_report(memory_df, self.memory_report_fname) + return memory_df + + +class ProfilerProgressBar(TQDMProgressBar): + """Custom PyTorch Lightning progress bar with profiling functionality. + + Attributes + ---------- + validation_rates : list[float] + List to store validation rates (it/s). + training_rates : list[float] + List to store training rates (it/s). + """ + + def __init__(self): + super().__init__() + self.validation_rates = [] + self.training_rates = [] + + def _extract_rate(self, pbar: _tqdm) -> float: + """Extracts the iteration rate from the progress bar. + + Parameters + ---------- + pbar : tqdm + The progress bar. + + Returns + ------- + float + The iteration rate. + """ + return (pbar.format_dict["n"] - pbar.format_dict["initial"]) / pbar.format_dict["elapsed"] + + def on_train_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + """Appends the rate from the progress bar to the list of 'training_rates'.""" + batch_idx + 1 + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) + if self.train_progress_bar.format_dict["n"] != 0: + self.training_rates.append(self._extract_rate(self.train_progress_bar)) + for logger in self.trainer.loggers: + if isinstance(logger, AnemoiMLflowLogger): + logger.log_metrics({"training_rate": self.training_rates[-1]}, step=trainer.global_step) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Append rate from the progress bar to the list of 'validation_rates'.""" + super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + if self.val_progress_bar.format_dict["n"] != 0: + self.validation_rates.append(self._extract_rate(self.val_progress_bar)) + for logger in self.trainer.loggers: + if isinstance(logger, AnemoiMLflowLogger): + logger.log_metrics({"validation_rate": self.validation_rates[-1]}, step=trainer.global_step) diff --git a/src/anemoi/training/train/profiler.py b/src/anemoi/training/train/profiler.py new file mode 100644 index 00000000..43182e01 --- /dev/null +++ b/src/anemoi/training/train/profiler.py @@ -0,0 +1,346 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from __future__ import annotations + +import logging +import os +import warnings +from datetime import datetime +from datetime import timezone +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING + +import hydra +import pandas as pd +from pytorch_lightning.utilities import rank_zero_only +from rich.console import Console + +if TYPE_CHECKING: + from anemoi.training.data.datamodule import AnemoiDatasetsDataModule + from pytorch_lightning.loggers.logger import Logger + from omegaconf import DictConfig + import pytorch_lightning as pl + +from anemoi.training.diagnostics.profilers import BenchmarkProfiler +from anemoi.training.diagnostics.profilers import ProfilerProgressBar +from anemoi.training.train.train import AnemoiTrainer + +LOGGER = logging.getLogger(__name__) +console = Console(record=True, width=200) + + +class AnemoiProfiler(AnemoiTrainer): + """Profiling for Anemoi.""" + + def __init__(self, config: DictConfig) -> None: + super().__init__(config) + + def print_report(self, title: str, dataframe: pd.DataFrame, color: str = "white", emoji: str = "") -> None: + if title == "Model Summary": + console.print(f"[bold {color}]{title}[/bold {color}]", f":{emoji}:") + console.print(dataframe, end="\n\n") + else: + console.print(f"[bold {color}]{title}[/bold {color}]", f":{emoji}:") + console.print(dataframe.to_markdown(headers="keys", tablefmt="psql"), end="\n\n") + + @staticmethod + def print_title() -> None: + console.print("[bold magenta] Benchmark Profiler Summary [/bold magenta]!", ":book:") + + @staticmethod + def print_metadata() -> None: + console.print(f"[bold blue] SLURM NODE(s) {os.environ['HOST']} [/bold blue]!") + console.print(f"[bold blue] SLURM JOB ID {os.environ['SLURM_JOB_ID']} [/bold blue]!") + console.print(f"[bold blue] TIMESTAMP {datetime.now(timezone.utc).strftime('%d/%m/%Y %H:%M:%S')} [/bold blue]!") + + @rank_zero_only + def print_benchmark_profiler_report( + self, + speed_metrics_df: pd.DataFrame | None = None, + time_metrics_df: pd.DataFrame | None = None, + memory_metrics_df: pd.DataFrame | None = None, + system_metrics_df: pd.DataFrame | None = None, + model_summary: str | None = None, + ) -> None: + self.print_title() + self.print_metadata() + + if time_metrics_df is not None: + warnings.warn( + "INFO: Time Report metrics represent single-node metrics (not multi-node aggregated metrics)", + ) + warnings.warn("INFO: Metrics with a * symbol, represent the value after aggregating all steps") + self.print_report("Time Profiling", time_metrics_df, color="green", emoji="alarm_clock") + + if speed_metrics_df is not None: + warnings.warn( + "INFO: Speed Report metrics are single-node metrics (not multi-node aggregated metrics)", + ) + self.print_report("Speed Profiling", speed_metrics_df, color="yellow", emoji="racing_car") + + if memory_metrics_df is not None: + warnings.warn("INFO: Memory Report metrics represent metrics aggregated across all nodes") + self.print_report("Memory Profiling", memory_metrics_df, color="purple", emoji="floppy_disk") + + if system_metrics_df is not None: + self.print_report("System Profiling", system_metrics_df, color="Red", emoji="desktop_computer") + + if model_summary is not None: + self.print_report("Model Summary", model_summary, color="Orange", emoji="robot") + + @staticmethod + def write_benchmark_profiler_report() -> None: + console.save_html("report.html") + + @staticmethod + def to_df(sample_dict: dict[str, float], precision: str = ".5") -> pd.DataFrame: + df = pd.DataFrame(sample_dict.items()) + df.columns = ["metric", "value"] + df.value = df.value.apply(lambda x: f"%{precision}f" % x) + return df + + @cached_property + def speed_profile(self) -> None: + """Speed profiler Report. + + Get speed metrics from Progress Bar for training and validation. + """ + if self.config.diagnostics.benchmark_profiler.speed.enabled: + # Find the first ProfilerProgressBar callback. + for callback in self.callbacks: + if isinstance(callback, ProfilerProgressBar): + return self.profiler.get_speed_profiler_df(callback) + else: + error_msg = "No ProfilerProgressBar callback found." + raise ValueError(error_msg) + else: + return None + + def _get_logger(self) -> dict[str, Logger]: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + logger_info = {"logger_name": "wandb", "logger": self.wandb_logger} + elif self.config.diagnostics.log.tensorboard.enabled: + logger_info = {"logger_name": "tensorboard", "logger": self.tensorboard_logger} + elif self.config.diagnostics.log.mlflow.enabled: + logger_info = {"logger_name": "mlflow", "logger": self.mlflow_logger} + else: + LOGGER.warning("No logger enabled for system profiler") + logger_info = None + return logger_info + + @cached_property + def system_profile(self) -> None: + """System Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.system.enabled: + logger_info = self._get_logger() + if logger_info: + return self.profiler.get_system_profiler_df( + logger_name=logger_info["logger_name"], + logger=logger_info["logger"], + ) + LOGGER.warning("System Profiler Report is not available") + return None + return None + + @cached_property + def memory_profile(self) -> None: + """Memory Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.memory.enabled: + return self.profiler.get_memory_profiler_df() + return None + + @cached_property + def time_profile(self) -> None: + """Time Profiler Report.""" + if self.config.diagnostics.benchmark_profiler.time.enabled: + return self.profiler.get_time_profiler_df() + return None + + @cached_property + def model_summary(self) -> str: + if self.config.diagnostics.benchmark_profiler.model_summary.enabled: + if self.config.hardware.num_gpus_per_model > 1: + LOGGER.warning("Model Summary is not supported when using model sharding") + self.config.diagnostics.benchmark_profiler.model_summary.enabled = False + return None + model = self.model + example_input_array = self.example_input_array + return self.profiler.get_model_summary(model=model, example_input_array=example_input_array) + return None + + @rank_zero_only + def export_to_logger(self) -> None: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + self.to_wandb() + + elif self.config.diagnostics.log.mlflow.enabled: + self.to_mlflow() + + @rank_zero_only + def report(self) -> str: + """Print report to console.""" + LOGGER.info("Generating Profiler reports") + self.print_benchmark_profiler_report( + memory_metrics_df=self.memory_profile, + time_metrics_df=self.time_profile, + speed_metrics_df=self.speed_profile, # speed profile needs to be generated after time and memory reports + system_metrics_df=self.system_profile, + model_summary=self.model_summary, + ) + + def _get_extra_files(self) -> None: + extra_files = [] + extra_files.extend(self.profiler.dirpath.glob("*.pickle")) + # These trace files are too big to push to MLFlow so + # we won't push them as artifacts extra_files.extend(self.profiler.dirpath.glob("*.json")) + return extra_files + + def _log_reports_to_mlflow(self, run_id: str, data: pd.DataFrame, artifact_file: str, report_fname: str) -> None: + self.mlflow_logger.experiment.log_table( + run_id=run_id, + data=data, + artifact_file=artifact_file, + ) + + self.mlflow_logger.experiment.log_artifact(run_id, report_fname) + + @rank_zero_only + def to_mlflow(self) -> None: + """Log report into MLFlow.""" + LOGGER.info("logging to MLFlow Profiler report") + self.write_benchmark_profiler_report() + # check this https://stackoverflow.com/questions/71151054/how-to-log- d da-table-of-metrics-into-mlflow + + run_id = self.mlflow_logger.run_id + if self.config.diagnostics.benchmark_profiler.system.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.system_profile, + artifact_file="system_metrics_report.json", + report_fname=self.profiler.system_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.time.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.time_profile, + artifact_file="time_metrics_reports.json", + report_fname=self.profiler.time_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.speed.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.speed_profile, + artifact_file="speed_metrics_reports.json", + report_fname=self.profiler.speed_report_fname, + ) + + if self.config.diagnostics.benchmark_profiler.memory.enabled: + self._log_reports_to_mlflow( + run_id=run_id, + data=self.memory_profile, + artifact_file="memory_metrics_reports.json", + report_fname=self.profiler.memory_report_fname, + ) + + extra_files = self._get_extra_files() + for file in extra_files: + artifact_path = self.profiler.dirpath / file + if artifact_path.is_file(): + self.mlflow_logger.experiment.log_artifact(run_id, artifact_path) + + if self.config.diagnostics.benchmark_profiler.model_summary.enabled: + self.mlflow_logger.experiment.log_artifact(run_id, self.profiler.model_summary_fname) + + @rank_zero_only + def to_wandb(self) -> None: + """Log report into W&B.""" + LOGGER.info("logging to W&B Profiler report") + self.write_benchmark_profiler_report() + import wandb + from pytorch_lightning.loggers.wandb import WandbLogger + + logger = WandbLogger( + project=self.run_dict["project"], + entity=self.run_dict["entity"], + id=self.run_dict["id"], + offline=self.config.diagnostics.log.wandb.offline, + resume=self.run_dict["id"], + ) + + logger.experiment.log({"speed_metrics_report": wandb.Table(dataframe=self.speed_profile)}) + logger.experiment.log({"memory_metrics_report": wandb.Table(dataframe=self.system_profile)}) + logger.experiment.log({"time_metrics_report": wandb.Table(dataframe=self.time_profile)}) + logger.experiment.log({"memory_metrics_report": wandb.Table(dataframe=self.memory_profile)}) + logger.experiment.log({"model_summary_report": wandb.Table(dataframe=self.model_summary)}) + with Path("report.html").open("w") as f: + logger.experiment.log({"reports_benchmark_profiler": wandb.Html(f)}) + logger.experiment.finish() + + @cached_property + def callbacks(self) -> list[pl.callbacks.Callback]: + callbacks = super().callbacks + callbacks.append(ProfilerProgressBar()) + if self.config.diagnostics.benchmark_profiler.snapshot.enabled: + from anemoi.training.diagnostics.callbacks import MemorySnapshotRecorder + from anemoi.training.diagnostics.profilers import check_torch_version + + available = check_torch_version() + + if available: # if torch is below 2.1.0, the callback will not be added + callbacks.append(MemorySnapshotRecorder(self.config)) + return callbacks + + @cached_property + def datamodule(self) -> AnemoiDatasetsDataModule: + datamodule = super().datamodule + # to generate a model summary with shapes we need a sample input array + batch = next(iter(datamodule.train_dataloader())) + self.example_input_array = batch[ + :, + 0 : self.config.training.multistep_input, + ..., + self.data_indices.data.input.full, + ] + return datamodule + + @cached_property + def profiler(self) -> BenchmarkProfiler: + return BenchmarkProfiler(self.config) + + def _update_paths(self) -> None: + """Update the paths in the configuration.""" + super()._update_paths() + + if self.run_id: # when using mlflow only rank0 will have a run_id except when resuming runs + # Multi-gpu new runs or forked runs - only rank 0 + # Multi-gpu resumed runs - all ranks + self.config.hardware.paths.profiler = Path(self.config.hardware.paths.profiler, self.run_id) + elif self.config.training.fork_run_id: + parent_run = self.config.training.fork_run_id + self.config.hardware.paths.profiler = Path(self.config.hardware.paths.profiler, parent_run) + LOGGER.info("Profiler path: %s", self.config.hardware.paths.profiler) + + def _close_logger(self) -> None: + if (self.config.diagnostics.log.wandb.enabled) and (not self.config.diagnostics.log.wandb.offline): + # We need to close the W&B logger to be able to read the System Metrics + self.wandb_logger.experiment.finish() + + def profile(self) -> None: + """Profile the model.""" + self.train() + self.report() + self.export_to_logger() + + +@hydra.main(version_base=None, config_path="../config", config_name="config") +def main(config: DictConfig) -> None: + AnemoiProfiler(config).profile() diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index f48b9467..5d752783 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -328,7 +328,7 @@ def train(self) -> None: # run a fixed no of batches per epoch (helpful when debugging) limit_train_batches=self.config.dataloader.limit_batches.training, limit_val_batches=self.config.dataloader.limit_batches.validation, - num_sanity_val_steps=4, + num_sanity_val_steps=self.config.training.num_sanity_val_steps, accumulate_grad_batches=self.config.training.accum_grad_batches, gradient_clip_val=self.config.training.gradient_clip.val, gradient_clip_algorithm=self.config.training.gradient_clip.algorithm,