Skip to content

Commit

Permalink
Merge pull request #1002 from anarkiwi/modelconf
Browse files Browse the repository at this point in the history
Add optional configuration for Torchserve model.
  • Loading branch information
anarkiwi authored Nov 29, 2023
2 parents 4226426 + 9017f1f commit 40d1c3d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/README-airt.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ From gamutRF's source directory, and having obtained mini2_snr.pt:
$ pip3 install torch-model-archiver
$ mkdir /tmp/model_store
$ wget https://raw.githubusercontent.com/pytorch/serve/master/examples/object_detector/yolo/yolov8/requirements.txt
$ torch-model-archiver --force --model-name mini2_snr --version 1.0 --serialized-file /PATH/TO/mini2_snr.pt --handler torchserve/custom_handler.py --export-path /tmp/model_store -r requirements.txt
$ torch-model-archiver --force --model-name mini2_snr --version 1.0 --serialized-file /PATH/TO/mini2_snr.pt --handler torchserve/custom_handler.py --extra-files torchserve/model_config.json --export-path /tmp/model_store -r requirements.txt
```

# start torchserve
Expand Down
1 change: 0 additions & 1 deletion gamutrf/grscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
gps_server="",
igain=0,
inference_min_confidence=0.5,
inference_nms_threshold=0.5,
inference_min_db=-200,
inference_model_server="",
inference_model_name="",
Expand Down
9 changes: 1 addition & 8 deletions gamutrf/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,9 @@ def argument_parser():
"--inference_min_confidence",
dest="inference_min_confidence",
type=float,
default=0.5,
default=0.25,
help="minimum confidence score to plot",
)
parser.add_argument(
"--inference_nms_confidence",
dest="inference_nms_threshold",
type=float,
default=0.5,
help="NMS threshold",
)
parser.add_argument(
"--inference_min_db",
dest="inference_min_db",
Expand Down
2 changes: 1 addition & 1 deletion orchestrator.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ services:
- --no-compass
- --use_external_gps
- --use_external_heading
- --inference_min_confidence=0.8
- --inference_min_confidence=0.25
- --inference_min_db=-80
- --inference_model_name=mini2_snr
# - --external_gps_server=1.2.3.4
Expand Down
22 changes: 16 additions & 6 deletions tests/test_torchserve.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@ set -e
TMPDIR=/tmp
sudo apt-get update && sudo apt-get install -y jq wget
sudo pip3 install torch-model-archiver
cp torchserve/custom_handler.py $TMPDIR/
cp torchserve/* $TMPDIR/
cd $TMPDIR
# TODO: use gamutRF weights here.
wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt
wget https://raw.githubusercontent.com/pytorch/serve/master/examples/object_detector/yolo/yolov8/requirements.txt
torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py -r requirements.txt
# TODO: use gamutRF test spectogram image
wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg
torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py --extra-files model_config.json -r requirements.txt
# -r requirements.txt
rm -rf model_store && mkdir model_store
mv yolov8n.mar model_store/
# TODO: --runtime nvidia is required for Orin
# TODO: --runtime nvidia is required for Orin, --gpus all for x86
docker run -v $(pwd)/model_store:/model_store -p 8080:8080 --rm --name testts --entrypoint timeout -d iqtlabs/torchserve 180s /torchserve/torchserve-entrypoint.sh --models yolov8n=yolov8n.mar
# TODO: use gamutRF test spectogram image
wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg
wget -q --retry-connrefused --retry-on-host-error --body-file=persons.jpg --method=PUT -O- --header='Content-Type: image/jpg' http://127.0.0.1:8080/predictions/yolov8n | jq
PRED=$(wget -q --retry-connrefused --retry-on-host-error --body-file=persons.jpg --method=PUT -O- --header='Content-Type: image/jpg' http://127.0.0.1:8080/predictions/yolov8n | jq)
echo $PRED
if [ "$PRED" = "" ] ; then
echo "error: no response from Torchserve"
exit 1
fi
if [ "$PRED" = "{}" ] ; then
echo "error: no predictions from Torchserve"
exit 1
fi
8 changes: 8 additions & 0 deletions torchserve/custom_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# based on pytorch's yolov8n example.

import json
from collections import defaultdict
import os

Expand Down Expand Up @@ -41,6 +42,9 @@ def initialize(self, context):
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
# https://docs.ultralytics.com/modes/predict/#inference-arguments
with open("model_config.json", "r") as f:
self.model_config = json.load(f)
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
Expand All @@ -63,6 +67,10 @@ def _load_torchscript_model(self, model_pt_path):
model.to(self.device)
return model

def inference(self, data, *args, **kwargs):
kwargs.update(self.model_config)
return super().inference(data, *args, **kwargs)

def postprocess(self, res):
output = []
for data in res:
Expand Down
5 changes: 5 additions & 0 deletions torchserve/model_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"conf": 0.25,
"iou": 0.7,
"half": false
}

0 comments on commit 40d1c3d

Please sign in to comment.