Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Hello World Model Tutorial #180

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions website/docs/tutorials/models/hello-world-model.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
---
id: hello-world-model
sidebar_position: 2
title: Hello World Model
---

In this tutorial, you will create a "Hello World" model. The model will take a
string as input and return a string as output. You will also learn how to export
a model as TorchScript model that can be loaded with the PlayTorch SDK for
on-device inference.

# Create PyTorch Model

Let's begin by creating a PyTorch model. Here, we are going to create a simple
"Hello World" model using `torch.nn.Module` to represent a neural network (hence
the namespace `nn`).

The model defines a `forward` function with one argument `name`. The function
"performs" the computation, e.g., in later tutorials, it will perform inference
on an image.

The model constructor has one argument `prefix`, which will be used in the
`forward` function to prefix the `name` argument.

More details on PyTorch modules at https://pytorch.org/docs/stable/notes/modules.html

```python
import torch
from torch import nn

class Model(nn.Module):
def __init__(self, prefix: str):
super().__init__()
self.prefix = prefix

def forward(self, name: str) -> str:
return f"{self.prefix} {name}!"
```

## Create an instance of the model

Next, let's create a instance of the model and perform a computation.

```python
model = Model("Hello")
model("Roman")
```

```python title="Output"
Hello Roman!
```

# Export Model for Mobile

Now that we have a model, let's export the model to use on mobile. To do that,
we need to script the model (i.e., create a
[TorchScript](https://pytorch.org/docs/stable/jit.html) representation) as follows:

```python
scripted_model = torch.jit.script(model)
scripted_model("Lindsay")
```

```python title="Output"
Hello Lindsay!
```

:::note
The `torch.jit.script` is the recommended way to create a `TorchScript` model
because it can capture control flow,
but it might fail in some cases. If that happens, we recommend consulting the PyTorch
[TorchScript](https://pytorch.org/docs/stable/jit.html) documentation for solutions.
:::

PyTorch offers the `optimize_for_mobile` utility function to run a list of
optimizations on the model (e.g., Conv2D + BatchNorm fusion, dropout removal).
It's recommended to optimize the model with this utility before exporting it for
mobile.

More details on the `optimize_for_mobile` utility at: https://pytorch.org/docs/stable/mobile_optimizer.html

```python
from torch.utils.mobile_optimizer import optimize_for_mobile

optimized_model = optimize_for_mobile(scripted_model)
optimized_model("Kodo")
```

```python title="Output"
Hello Kodo!
```

Great! Now, let's export the model for mobile. This is done by saving the model
for the lite interpreter. The `_save_for_lite_interpreter` function will create
a `hello_world.ptl` file, which we will be able to load with the PlayTorch SDK.

```python
optimized_model._save_for_lite_interpreter("hello_world.ptl")
```

More details on the lite interpreter at:
https://pytorch.org/tutorials/prototype/lite_interpreter.html

# Create Mobile UI and Load Model on Mobile

Next, let's create a PlayTorch Snack by following the link
http://snack.playtorch.dev/. Then, drag and drop the `hello_world.ptl` file onto
the just created PlayTorch Snack--this will import the model into the Snack.

Replace the source code in the `App.js` with the React Native source code below.
The source code below will create a user interface with a text input, a button,
and a text element. When pressing the button, it will load the `hello_world.ptl`
model and call the model forward function with the text input value as argument.
The returned model output will then be displayed below the button.

```jsx
import * as React from 'react';
import { useState } from 'react';
import {
Button,
SafeAreaView,
StyleSheet,
Text,
TextInput,
View,
} from 'react-native';
import { torch, MobileModel } from 'react-native-pytorch-core';

export default function App() {
const [modelInput, setModelInput] = useState('');
const [modelOutput, setModelOutput] = useState('');

async function handleModelInput() {
const filePath = await MobileModel.download(require('./hello_world.ptl'));
const model = await torch.jit._loadForMobile(filePath);
const output = await model.forward(modelInput);
setModelOutput(output);
}

return (
<SafeAreaView style={StyleSheet.absoluteFill}>
<View style={styles.container}>
<TextInput
value={modelInput}
onChangeText={setModelInput}
placeholder="Write your name"
/>
<Button title="Let's Go" onPress={handleModelInput} />
<Text>{modelOutput}</Text>
</View>
</SafeAreaView>
);
}

const styles = StyleSheet.create({
container: {
flex: 1,
justifyContent: 'center',
backgroundColor: '#fff',
padding: 20,
},
});
````
76 changes: 76 additions & 0 deletions website/docs/tutorials/models/install-pytorch.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
---
id: install-pytorch
sidebar_position: 1
title: Install PyTorch
---

import Tabs from '@theme/Tabs';
import TabItem from '@theme/TabItem';

In this tutorial, we will learn a quick way for how to set up PyTorch.

# [OPTIONAL] Set up Python virtual environment

It is recommended to run the Python scripts in a virtual environment. Python offers a command to create a virtual environment with the following command.

```shell
python3 -m venv venv
source venv/bin/activate
```

# Install `torch` dependency

Last, let's install the PyTorch dependencies via the Python package manager.

<Tabs
defaultValue="macos"
values={[
{label: 'macOS', value: 'macos'},
{label: 'Linux', value: 'linux'},
{label: 'Windows', value: 'windows'},
]}>
<TabItem value="macos">

```shell
pip install torch==1.12.1
```

</TabItem>
<TabItem value="linux">

```shell
pip install torch==1.12.1+cpu
```

</TabItem>
<TabItem value="windows">

```shell
pip install torch==1.12.1+cpu
```

</TabItem>
</Tabs>

# Test Installation

Open Python interpreter in terminal

```python
python
```

Then execute the two lines of code, which will print the PyTorch version

```python
import torch
print(torch.__version__)
```

```python title="Output"
1.12.1
```

Exit the Python interpreter with `exit()`.

That's it! PyTorch is installed successfully
4 changes: 4 additions & 0 deletions website/fb/sdoc-cache.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"snippets": {},
"description": "@generated"
}
11 changes: 10 additions & 1 deletion website/sidebars.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@ module.exports = {
},
{
type: 'category',
label: 'Tutorials',
label: 'Model Tutorials',
collapsed: false,
items: [
'tutorials/models/install-pytorch',
'tutorials/models/hello-world-model',
],
},
{
type: 'category',
label: 'Demo Tutorials',
collapsed: false,
items: [
'tutorials/snacks/image-classification',
Expand Down