Skip to content

Latest commit

 

History

History
 
 

1.13

PyTorch with DirectML Samples

For detailed instructions on getting started with PyTorch with DirectML, see GPU accelerated ML training.

Setup

Follow the steps below to get set up with PyTorch on DirectML.

  1. Download and install Python 3.8 to 3.10.

  2. Clone this repo.

  3. Install torch-directml

⚠️ Since torch-directml 0.1.13.1.*, torch and torchvision will be installed as dependencies

pip install torch-directml
  1. Create a DML Device and Test
import torch
import torch_directml
dml = torch_directml.device()

⚠️ Note that device creation has changed in torch-directml 0.1.13 from previous versions. The torch-directml backend is currently mapped to “PrivateUse1." The new torch_directml.device() API is a convenient wrapper for creating your tenors on the correct device.

Samples

The following sample models are included in this repo to help you get started. The sample includes both inference and training scripts, and you can either train the models from scratch or use the supplied pre-trained weights.

External Links