Skip to content

Commit

Permalink
Update to save classes as a json file
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Dec 15, 2023
1 parent 0a6a5c4 commit 2e0b06a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 21 deletions.
9 changes: 9 additions & 0 deletions abraia/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .multiple import Multiple, tempdir

import onnx
import json
import torch
import torchvision
from torchvision import models, transforms
Expand Down Expand Up @@ -107,6 +108,14 @@ def load_classes(path):
return [line.strip() for line in txt.splitlines()]


def save_json(path, values):
multiple.save_file(path, json.dumps(values))


def load_json(path):
return json.loads(multiple.load_file(path))


transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
Expand Down
19 changes: 3 additions & 16 deletions notebooks/torch_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,14 @@
"source": [
"import onnx\n",
"import torch\n",
"from abraia.torch import load_classes, load_model\n",
"from abraia.multiple import tempdir\n",
"from abraia.torch import load_classes, load_model, export_onnx\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"\n",
"class_names = load_classes(os.path.join(dataset, 'model_ft.txt'))\n",
"model = load_model(os.path.join(dataset, 'model_ft.pt'), class_names)\n",
"# model.eval()\n",
"\n",
"\n",
"def export_onnx(path, model):\n",
" dummy_input = torch.randn(1, 3, 224, 224)\n",
" src = os.path.join(tempdir, path)\n",
" os.makedirs(os.path.dirname(src), exist_ok=True)\n",
" torch.onnx.export(model, dummy_input, src, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'])\n",
" onnx_model = onnx.load(src)\n",
" onnx.checker.check_model(onnx_model)\n",
" return multiple.upload_file(src, path)\n",
"\n",
"\n",
"model_path = export_onnx(os.path.join(dataset, 'model_ft.onnx'), model)"
"model_path = export_onnx(os.path.join(dataset, 'model_ft.onnx'), model)\n",
"# model.eval()"
]
},
{
Expand Down
6 changes: 4 additions & 2 deletions notebooks/torch_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"source": [
"import torch\n",
"from torchvision import transforms\n",
"from abraia.torch import Dataset, visualize_data, create_model, train_model, visualize_model, save_model, save_classes\n",
"from abraia.torch import Dataset, visualize_data, create_model, train_model, visualize_model, save_model, save_classes, export_onnx, save_json\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"\n",
Expand Down Expand Up @@ -350,7 +350,9 @@
"outputs": [],
"source": [
"save_model(os.path.join(dataset, 'model_ft.pt'), model_ft)\n",
"save_classes(os.path.join(dataset, 'model_ft.txt'), class_names)"
"export_onnx(os.path.join(dataset, 'model_ft.onnx'), model_ft)\n",
"save_classes(os.path.join(dataset, 'model_ft.txt'), class_names)\n",
"save_json(os.path.join(dataset, 'model_ft.json'), class_names)"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions scripts/abraia
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def input_files(src):


@click.group('abraia')
@click.version_option('0.13.1')
@click.version_option('0.13.2')
def cli():
"""Abraia CLI tool"""
pass
Expand All @@ -64,7 +64,7 @@ def configure():
@cli.command()
def info():
"""Show user account information"""
click.echo('abraia, version 0.13.1\n')
click.echo('abraia, version 0.13.2\n')
click.echo('Go to [' + click.style('https://abraia.me/console/', fg='green') + '] to see your account information\n')


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='abraia',
version='0.13.1',
version='0.13.2',
description='Abraia Multiple SDK',
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 2e0b06a

Please sign in to comment.