From 11407f332012d8cbfc90b5cee4426f03f60beca7 Mon Sep 17 00:00:00 2001 From: Jadeiin <92222981+Jadeiin@users.noreply.github.com> Date: Sat, 24 Jan 2026 21:46:01 +0800 Subject: [PATCH 1/2] Add initial PyTorch policy export functionality with JIT and ONNX support --- skrl/utils/exporter/__init__.py | 0 skrl/utils/exporter/torch/__init__.py | 1 + skrl/utils/exporter/torch/exporter.py | 168 ++++++++++++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 skrl/utils/exporter/__init__.py create mode 100644 skrl/utils/exporter/torch/__init__.py create mode 100644 skrl/utils/exporter/torch/exporter.py diff --git a/skrl/utils/exporter/__init__.py b/skrl/utils/exporter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/utils/exporter/torch/__init__.py b/skrl/utils/exporter/torch/__init__.py new file mode 100644 index 00000000..f8d34f55 --- /dev/null +++ b/skrl/utils/exporter/torch/__init__.py @@ -0,0 +1 @@ +from skrl.utils.exporter.torch.exporter import export_policy_as_jit, export_policy_as_onnx diff --git a/skrl/utils/exporter/torch/exporter.py b/skrl/utils/exporter/torch/exporter.py new file mode 100644 index 00000000..26e04da8 --- /dev/null +++ b/skrl/utils/exporter/torch/exporter.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import os + +import torch + + +if TYPE_CHECKING: + from skrl.models.torch import Model + + +def export_policy_as_jit( + policy: Model, + observation_preprocessor: torch.nn.Module | None, + state_preprocessor: torch.nn.Module | None, + path: str, + filename: str = "policy.pt", + example_inputs: dict[str, torch.Tensor] | None = None, + optimize: bool = True, + device: str | torch.device = "cpu", +) -> None: + """Export a policy to a Torch JIT file. + + This exporter is designed for skrl models during evaluation. It wraps the given + policy together with an optional observation preprocessor and produces a single + module with a simple `forward(obs)` -> `actions` interface. + + Limitations: + - Torch-only base + - Non-recurrent policies (RNN/LSTM/GRU export is out of scope here) + + Args: + policy: A skrl policy model (torch.nn.Module) implementing `act(inputs)`. + observation_preprocessor: Optional module to preprocess observations. + path: Directory to save the file to. + filename: Output file name, defaults to "policy.pt". + example_inputs: Example inputs for tracing. If None, dummy inputs with batch size 1 will be used. + optimize: Whether to optimize the traced model for inference, defaults to True. + device: Device to use for export, defaults to "cpu". + """ + + exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) + os.makedirs(path, exist_ok=True) + full_path = os.path.join(path, filename) + + exporter.to(device) + exporter.eval() + + # Use tracing for broader TorchScript compatibility with dict-based models + if example_inputs is not None: + example_inputs = {k: v.to(device) for k, v in example_inputs.items()} + else: + example_inputs = { + "observations": torch.zeros(1, exporter._num_observations, device=device), + "states": torch.zeros(1, exporter._num_states, device=device), + } + + traced = torch.jit.trace(exporter, tuple(example_inputs.values())) + if optimize: + traced = torch.jit.optimize_for_inference(traced) + torch.jit.save(traced, full_path) + + +def export_policy_as_onnx( + policy: Model, + observation_preprocessor: torch.nn.Module | None, + state_preprocessor: torch.nn.Module | None, + path: str, + filename: str = "policy.onnx", + example_inputs: dict[str, torch.Tensor] | None = None, + optimize: bool = True, + dynamo: bool = True, + opset_version: int = 18, + verbose: bool = False, + device: str | torch.device = "cpu", +) -> None: + """Export a policy to an ONNX file. + + This exporter is designed for skrl models during evaluation. It wraps the given + policy together with an optional observation preprocessor and produces a single + ONNX graph with `obs` input and `actions` output. + + Limitations: + - Torch-only base + - Non-recurrent policies (RNN/LSTM/GRU export is out of scope here) + + Args: + policy: A skrl policy model (torch.nn.Module) implementing `act(inputs)`. + observation_preprocessor: Optional module to preprocess observations. + path: Directory to save the file to. + filename: Output file name, defaults to "policy.onnx". + example_inputs: Example inputs for tracing. If None, dummy inputs with batch size 1 will be used. + optimize: Whether to optimize the model for inference, defaults to True. + dynamo: Whether to use Torch Dynamo for export, defaults to True. + opset_version: ONNX opset version to use, defaults to 18. + verbose: Whether to print the model export graph summary. + device: Device to use for export, defaults to "cpu". + """ + + exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) + os.makedirs(path, exist_ok=True) + full_path = os.path.join(path, filename) + + exporter.to(device) + exporter.eval() + + if example_inputs is not None: + example_inputs = {k: v.to(device) for k, v in example_inputs.items()} + else: + example_inputs = { + "observations": torch.zeros(1, exporter._num_observations, device=device), + "states": torch.zeros(1, exporter._num_states, device=device), + } + + torch.onnx.export( + exporter, + tuple(example_inputs.values()), + full_path, + artifacts_dir=path, + opset_version=opset_version, + verbose=verbose, + report=verbose, + input_names=["observations", "states"], + output_names=["actions"], + optimize=optimize, + verify=True, + dynamo=dynamo, + ) + + +class _TorchPolicyExporter(torch.nn.Module): + """Wrap a skrl policy and optional observation preprocessor for export. + + The wrapper exposes a minimal `forward(obs)` that returns actions, handling the + internal policy call and dict construction expected by skrl models. + """ + + def __init__( + self, + policy: Model, + observation_preprocessor: torch.nn.Module | None = None, + state_preprocessor: torch.nn.Module | None = None, + ) -> None: + super().__init__() + # keep given instances to preserve any registered buffers/state; move to CPU on export + self.policy = policy + self._observation_preprocessor = ( + observation_preprocessor if observation_preprocessor is not None else torch.nn.Identity() + ) + self._state_preprocessor = state_preprocessor if state_preprocessor is not None else torch.nn.Identity() + + # skrl `Model` exposes `num_observations` (0 if `observation_space` is None) + # fall back to attempting to infer input size from first linear layer if necessary + self._num_observations = getattr(self.policy, "num_observations", 0) + self._num_states = getattr(self.policy, "num_states", 0) + + @torch.no_grad() + def forward(self, observations: torch.Tensor, states: torch.Tensor) -> torch.Tensor: + actions, _ = self.policy.act( + { + "observations": self._observation_preprocessor(observations), + "states": self._state_preprocessor(states), + }, + role="policy", + ) + return actions From 2e6a0c02480e40fb79da8103c3105d7287266c1a Mon Sep 17 00:00:00 2001 From: Jadeiin <92222981+Jadeiin@users.noreply.github.com> Date: Sat, 24 Jan 2026 21:51:35 +0800 Subject: [PATCH 2/2] Clarify export function documentation to specify state preprocessors in JIT and ONNX exports --- skrl/utils/exporter/torch/exporter.py | 87 ++++++++++++++------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/skrl/utils/exporter/torch/exporter.py b/skrl/utils/exporter/torch/exporter.py index 26e04da8..09e04937 100644 --- a/skrl/utils/exporter/torch/exporter.py +++ b/skrl/utils/exporter/torch/exporter.py @@ -21,24 +21,20 @@ def export_policy_as_jit( optimize: bool = True, device: str | torch.device = "cpu", ) -> None: - """Export a policy to a Torch JIT file. - - This exporter is designed for skrl models during evaluation. It wraps the given - policy together with an optional observation preprocessor and produces a single - module with a simple `forward(obs)` -> `actions` interface. - - Limitations: - - Torch-only base - - Non-recurrent policies (RNN/LSTM/GRU export is out of scope here) - - Args: - policy: A skrl policy model (torch.nn.Module) implementing `act(inputs)`. - observation_preprocessor: Optional module to preprocess observations. - path: Directory to save the file to. - filename: Output file name, defaults to "policy.pt". - example_inputs: Example inputs for tracing. If None, dummy inputs with batch size 1 will be used. - optimize: Whether to optimize the traced model for inference, defaults to True. - device: Device to use for export, defaults to "cpu". + """Export a policy to a TorchScript (JIT) file. + + The exporter wraps the given policy together with optional observation and + state preprocessors into a single module exposing + ``forward(observations, states) -> actions`` for inference. + + :param policy: Policy model to be exported. + :param observation_preprocessor: Module to preprocess observations, applied before the policy. + :param state_preprocessor: Module to preprocess states, applied before the policy. + :param path: Directory where the exported file will be saved. + :param filename: Output file name. Defaults to ``"policy.pt"``. + :param example_inputs: Example inputs for tracing. If ``None``, dummy inputs with batch size 1 are used. + :param optimize: Whether to optimize the traced model for inference. Defaults to ``True``. + :param device: Device used for export. Defaults to ``"cpu"``. """ exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) @@ -78,25 +74,21 @@ def export_policy_as_onnx( ) -> None: """Export a policy to an ONNX file. - This exporter is designed for skrl models during evaluation. It wraps the given - policy together with an optional observation preprocessor and produces a single - ONNX graph with `obs` input and `actions` output. - - Limitations: - - Torch-only base - - Non-recurrent policies (RNN/LSTM/GRU export is out of scope here) - - Args: - policy: A skrl policy model (torch.nn.Module) implementing `act(inputs)`. - observation_preprocessor: Optional module to preprocess observations. - path: Directory to save the file to. - filename: Output file name, defaults to "policy.onnx". - example_inputs: Example inputs for tracing. If None, dummy inputs with batch size 1 will be used. - optimize: Whether to optimize the model for inference, defaults to True. - dynamo: Whether to use Torch Dynamo for export, defaults to True. - opset_version: ONNX opset version to use, defaults to 18. - verbose: Whether to print the model export graph summary. - device: Device to use for export, defaults to "cpu". + The exporter wraps the given policy together with optional observation and + state preprocessors into a single module exposing + ``forward(observations, states) -> actions`` for inference. + + :param policy: Policy model to be exported. + :param observation_preprocessor: Module to preprocess observations, applied before the policy. + :param state_preprocessor: Module to preprocess states, applied before the policy. + :param path: Directory where the exported file will be saved. + :param filename: Output file name. Defaults to ``"policy.onnx"``. + :param example_inputs: Example inputs for export. If ``None``, dummy inputs with batch size 1 are used. + :param optimize: Whether to optimize the exported model for inference. Defaults to ``True``. + :param dynamo: Whether to use Torch Dynamo for export. Defaults to ``True``. + :param opset_version: ONNX opset version to use. Defaults to ``18``. + :param verbose: Whether to print the export graph summary. + :param device: Device used for export. Defaults to ``"cpu"``. """ exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) @@ -131,10 +123,15 @@ def export_policy_as_onnx( class _TorchPolicyExporter(torch.nn.Module): - """Wrap a skrl policy and optional observation preprocessor for export. + """Wrapper that prepares a policy model for export. - The wrapper exposes a minimal `forward(obs)` that returns actions, handling the - internal policy call and dict construction expected by skrl models. + This module exposes a minimal ``forward(observations, states)`` that returns + actions, handling the internal policy call and input dictionary construction + expected by policy's ``act()`` method. + + :param policy: A policy model to be wrapped. + :param observation_preprocessor: Optional preprocessor applied to observations. + :param state_preprocessor: Optional preprocessor applied to states. """ def __init__( @@ -144,20 +141,24 @@ def __init__( state_preprocessor: torch.nn.Module | None = None, ) -> None: super().__init__() - # keep given instances to preserve any registered buffers/state; move to CPU on export + self.policy = policy self._observation_preprocessor = ( observation_preprocessor if observation_preprocessor is not None else torch.nn.Identity() ) self._state_preprocessor = state_preprocessor if state_preprocessor is not None else torch.nn.Identity() - # skrl `Model` exposes `num_observations` (0 if `observation_space` is None) - # fall back to attempting to infer input size from first linear layer if necessary self._num_observations = getattr(self.policy, "num_observations", 0) self._num_states = getattr(self.policy, "num_states", 0) @torch.no_grad() def forward(self, observations: torch.Tensor, states: torch.Tensor) -> torch.Tensor: + """Compute actions from observations and states. + + :param observations: Batch of environment observations. + :param states: Batch of agent states (or zeros if unused). + :returns: Batch of actions produced by the policy. + """ actions, _ = self.policy.act( { "observations": self._observation_preprocessor(observations),