diff --git a/skrl/utils/exporter/__init__.py b/skrl/utils/exporter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/skrl/utils/exporter/torch/__init__.py b/skrl/utils/exporter/torch/__init__.py new file mode 100644 index 00000000..f8d34f55 --- /dev/null +++ b/skrl/utils/exporter/torch/__init__.py @@ -0,0 +1 @@ +from skrl.utils.exporter.torch.exporter import export_policy_as_jit, export_policy_as_onnx diff --git a/skrl/utils/exporter/torch/exporter.py b/skrl/utils/exporter/torch/exporter.py new file mode 100644 index 00000000..09e04937 --- /dev/null +++ b/skrl/utils/exporter/torch/exporter.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import os + +import torch + + +if TYPE_CHECKING: + from skrl.models.torch import Model + + +def export_policy_as_jit( + policy: Model, + observation_preprocessor: torch.nn.Module | None, + state_preprocessor: torch.nn.Module | None, + path: str, + filename: str = "policy.pt", + example_inputs: dict[str, torch.Tensor] | None = None, + optimize: bool = True, + device: str | torch.device = "cpu", +) -> None: + """Export a policy to a TorchScript (JIT) file. + + The exporter wraps the given policy together with optional observation and + state preprocessors into a single module exposing + ``forward(observations, states) -> actions`` for inference. + + :param policy: Policy model to be exported. + :param observation_preprocessor: Module to preprocess observations, applied before the policy. + :param state_preprocessor: Module to preprocess states, applied before the policy. + :param path: Directory where the exported file will be saved. + :param filename: Output file name. Defaults to ``"policy.pt"``. + :param example_inputs: Example inputs for tracing. If ``None``, dummy inputs with batch size 1 are used. + :param optimize: Whether to optimize the traced model for inference. Defaults to ``True``. + :param device: Device used for export. Defaults to ``"cpu"``. + """ + + exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) + os.makedirs(path, exist_ok=True) + full_path = os.path.join(path, filename) + + exporter.to(device) + exporter.eval() + + # Use tracing for broader TorchScript compatibility with dict-based models + if example_inputs is not None: + example_inputs = {k: v.to(device) for k, v in example_inputs.items()} + else: + example_inputs = { + "observations": torch.zeros(1, exporter._num_observations, device=device), + "states": torch.zeros(1, exporter._num_states, device=device), + } + + traced = torch.jit.trace(exporter, tuple(example_inputs.values())) + if optimize: + traced = torch.jit.optimize_for_inference(traced) + torch.jit.save(traced, full_path) + + +def export_policy_as_onnx( + policy: Model, + observation_preprocessor: torch.nn.Module | None, + state_preprocessor: torch.nn.Module | None, + path: str, + filename: str = "policy.onnx", + example_inputs: dict[str, torch.Tensor] | None = None, + optimize: bool = True, + dynamo: bool = True, + opset_version: int = 18, + verbose: bool = False, + device: str | torch.device = "cpu", +) -> None: + """Export a policy to an ONNX file. + + The exporter wraps the given policy together with optional observation and + state preprocessors into a single module exposing + ``forward(observations, states) -> actions`` for inference. + + :param policy: Policy model to be exported. + :param observation_preprocessor: Module to preprocess observations, applied before the policy. + :param state_preprocessor: Module to preprocess states, applied before the policy. + :param path: Directory where the exported file will be saved. + :param filename: Output file name. Defaults to ``"policy.onnx"``. + :param example_inputs: Example inputs for export. If ``None``, dummy inputs with batch size 1 are used. + :param optimize: Whether to optimize the exported model for inference. Defaults to ``True``. + :param dynamo: Whether to use Torch Dynamo for export. Defaults to ``True``. + :param opset_version: ONNX opset version to use. Defaults to ``18``. + :param verbose: Whether to print the export graph summary. + :param device: Device used for export. Defaults to ``"cpu"``. + """ + + exporter = _TorchPolicyExporter(policy, observation_preprocessor, state_preprocessor) + os.makedirs(path, exist_ok=True) + full_path = os.path.join(path, filename) + + exporter.to(device) + exporter.eval() + + if example_inputs is not None: + example_inputs = {k: v.to(device) for k, v in example_inputs.items()} + else: + example_inputs = { + "observations": torch.zeros(1, exporter._num_observations, device=device), + "states": torch.zeros(1, exporter._num_states, device=device), + } + + torch.onnx.export( + exporter, + tuple(example_inputs.values()), + full_path, + artifacts_dir=path, + opset_version=opset_version, + verbose=verbose, + report=verbose, + input_names=["observations", "states"], + output_names=["actions"], + optimize=optimize, + verify=True, + dynamo=dynamo, + ) + + +class _TorchPolicyExporter(torch.nn.Module): + """Wrapper that prepares a policy model for export. + + This module exposes a minimal ``forward(observations, states)`` that returns + actions, handling the internal policy call and input dictionary construction + expected by policy's ``act()`` method. + + :param policy: A policy model to be wrapped. + :param observation_preprocessor: Optional preprocessor applied to observations. + :param state_preprocessor: Optional preprocessor applied to states. + """ + + def __init__( + self, + policy: Model, + observation_preprocessor: torch.nn.Module | None = None, + state_preprocessor: torch.nn.Module | None = None, + ) -> None: + super().__init__() + + self.policy = policy + self._observation_preprocessor = ( + observation_preprocessor if observation_preprocessor is not None else torch.nn.Identity() + ) + self._state_preprocessor = state_preprocessor if state_preprocessor is not None else torch.nn.Identity() + + self._num_observations = getattr(self.policy, "num_observations", 0) + self._num_states = getattr(self.policy, "num_states", 0) + + @torch.no_grad() + def forward(self, observations: torch.Tensor, states: torch.Tensor) -> torch.Tensor: + """Compute actions from observations and states. + + :param observations: Batch of environment observations. + :param states: Batch of agent states (or zeros if unused). + :returns: Batch of actions produced by the policy. + """ + actions, _ = self.policy.act( + { + "observations": self._observation_preprocessor(observations), + "states": self._state_preprocessor(states), + }, + role="policy", + ) + return actions