Skip to content

Commit

Permalink
Remove torch_inference notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Aug 31, 2024
1 parent 09a2223 commit 27b62d4
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 158 deletions.
8 changes: 0 additions & 8 deletions abraia/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,6 @@ def export_onnx(path, model, device='cpu'):
multiple.upload_file(src, path)


def save_json(path, values):
multiple.save_json(path, values)


def load_json(path):
return multiple.load_json(path)


transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
Expand Down
94 changes: 0 additions & 94 deletions notebooks/torch_inference.ipynb

This file was deleted.

49 changes: 14 additions & 35 deletions notebooks/torch_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 27,
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "-cp253OYk0zk"
Expand All @@ -11,8 +11,8 @@
"source": [
"%%capture\n",
"#@markdown Start of notebook\n",
"!python -m pip install onnx onnxruntime\n",
"!python -m pip install abraia\n",
"!python -m pip install onnxruntime\n",
"\n",
"import os\n",
"if not os.getenv('ABRAIA_ID') and not os.getenv('ABRAIA_KEY'):\n",
Expand All @@ -26,26 +26,6 @@
"multiple = Multiple()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "dCkPRZpkJ3Rv"
},
"outputs": [],
"source": [
"# import torch\n",
"# import torchvision\n",
"\n",
"# dummy_input = torch.randn(1, 3, 224, 224)\n",
"# model = torchvision.models.mobilenet_v2(pretrained=True)\n",
"# model.eval()\n",
"\n",
"# torch.onnx.export(model, dummy_input, \"model.onnx\", verbose=True, input_names=['input'], output_names=['output'])\n",
"\n",
"# multiple.upload_file(\"model.onnx\", \"camera/model.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
Expand Down Expand Up @@ -76,16 +56,12 @@
}
],
"source": [
"import onnx\n",
"import torch\n",
"from abraia.torch import load_json, load_model, export_onnx\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"dataset = 'screws'\n",
"model_name = 'model_ft'\n",
"\n",
"class_names = load_json(os.path.join(dataset, 'model_ft.json'))\n",
"model = load_model(os.path.join(dataset, 'model_ft.pt'), class_names)\n",
"model_path = export_onnx(os.path.join(dataset, 'model_ft.onnx'), model)\n",
"# model.eval()"
"class_names = multiple.load_json(f\"{dataset}/{model_name}.json\")\n",
"model_src = multiple.cache_file(f\"{dataset}/{model_name}.onnx\")\n",
"print(model_src)"
]
},
{
Expand Down Expand Up @@ -119,7 +95,7 @@
"from PIL import Image\n",
"\n",
"\n",
"ort_session = ort.InferenceSession(f\"/tmp/{model_path}\", providers=['CPUExecutionProvider'])\n",
"ort_session = ort.InferenceSession(model_src, providers=['CPUExecutionProvider'])\n",
"\n",
"\n",
"def resize(img, size):\n",
Expand Down Expand Up @@ -147,7 +123,7 @@
"\n",
"\n",
"def predict(src):\n",
" img = Image.open(src)\n",
" img = Image.open(src).convert('RGB')\n",
" input = preprocess(img)\n",
" outputs = ort_session.run(None, {\"input\": input})\n",
" a = np.argsort(-outputs[0].flatten())\n",
Expand All @@ -157,8 +133,11 @@
" return results\n",
"\n",
"\n",
"filename = 'dog.jpg'\n",
"multiple.download_file(os.path.join(dataset, filename), filename)\n",
"files = multiple.list_files(f\"{dataset}/*.png\")[0]\n",
"filename = files[0]['name']\n",
"path = files[0]['path']\n",
"\n",
"multiple.download_file(path, filename)\n",
"predict(filename)"
]
}
Expand Down
52 changes: 31 additions & 21 deletions notebooks/torch_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
"source": [
"import torch\n",
"from torchvision import transforms\n",
"from abraia.torch import Dataset, visualize_data, create_model, train_model, visualize_model, save_model, export_onnx, save_json\n",
"from abraia.torch import Dataset, visualize_data, create_model, train_model, visualize_model, export_onnx\n",
"\n",
"dataset = 'hymenoptera_data'\n",
"dataset = 'screws'\n",
"\n",
"# Data augmentation and normalization for training\n",
"# Just normalization for validation\n",
Expand Down Expand Up @@ -259,6 +259,7 @@
],
"source": [
"class_names = image_datasets['train'].classes\n",
"model_name = 'model_ft'\n",
"\n",
"model_conv = create_model(class_names)\n",
"model_ft = train_model(model_conv, dataloaders, num_epochs=25)"
Expand Down Expand Up @@ -343,27 +344,28 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "VhYTUkQyrVGL"
},
"outputs": [],
"source": [
"save_model(os.path.join(dataset, 'model_ft.pt'), model_ft)\n",
"export_onnx(os.path.join(dataset, 'model_ft.onnx'), model_ft)\n",
"save_json(os.path.join(dataset, 'model_ft.json'), class_names)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "NameError",
"evalue": "name 'dataset' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 6\u001b[0m\n\u001b[1;32m 2\u001b[0m export_onnx(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.onnx\u001b[39m\u001b[38;5;124m\"\u001b[39m, model_ft)\n\u001b[1;32m 3\u001b[0m multiple\u001b[38;5;241m.\u001b[39msave_json(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdataset\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.json\u001b[39m\u001b[38;5;124m\"\u001b[39m, class_names)\n\u001b[0;32m----> 6\u001b[0m save_model(\u001b[43mdataset\u001b[49m, model_name, class_names)\n",
"\u001b[0;31mNameError\u001b[0m: name 'dataset' is not defined"
]
}
],
"source": [
"%%capture\n",
"#@markdown End of notebook\n",
"%notebook -e training.ipynb\n",
"multiple.upload_file('training.ipynb', f\"{dataset}/\")"
"def save_model(dataset, model_name, class_names):\n",
" export_onnx(f\"{dataset}/{model_name}.onnx\", model_ft)\n",
" multiple.save_json(f\"{dataset}/{model_name}.json\", class_names)\n",
"\n",
"\n",
"save_model(dataset, model_name, class_names)"
]
}
],
Expand All @@ -377,7 +379,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
Expand Down

0 comments on commit 27b62d4

Please sign in to comment.