diff --git a/docs/README-airt.md b/docs/README-airt.md index 655a0301..a8f93e5e 100644 --- a/docs/README-airt.md +++ b/docs/README-airt.md @@ -140,7 +140,7 @@ From gamutRF's source directory, and having obtained mini2_snr.pt: $ pip3 install torch-model-archiver $ mkdir /tmp/model_store $ wget https://raw.githubusercontent.com/pytorch/serve/master/examples/object_detector/yolo/yolov8/requirements.txt -$ torch-model-archiver --force --model-name mini2_snr --version 1.0 --serialized-file /PATH/TO/mini2_snr.pt --handler torchserve/custom_handler.py --export-path /tmp/model_store -r requirements.txt +$ torch-model-archiver --force --model-name mini2_snr --version 1.0 --serialized-file /PATH/TO/mini2_snr.pt --handler torchserve/custom_handler.py --extra-files torchserve/model_config.json --export-path /tmp/model_store -r requirements.txt ``` # start torchserve diff --git a/gamutrf/grscan.py b/gamutrf/grscan.py index 2fcfc3b9..23b283f9 100644 --- a/gamutrf/grscan.py +++ b/gamutrf/grscan.py @@ -41,7 +41,6 @@ def __init__( gps_server="", igain=0, inference_min_confidence=0.5, - inference_nms_threshold=0.5, inference_min_db=-200, inference_model_server="", inference_model_name="", diff --git a/gamutrf/scan.py b/gamutrf/scan.py index 610f1f3a..d86cf016 100644 --- a/gamutrf/scan.py +++ b/gamutrf/scan.py @@ -221,16 +221,9 @@ def argument_parser(): "--inference_min_confidence", dest="inference_min_confidence", type=float, - default=0.5, + default=0.25, help="minimum confidence score to plot", ) - parser.add_argument( - "--inference_nms_confidence", - dest="inference_nms_threshold", - type=float, - default=0.5, - help="NMS threshold", - ) parser.add_argument( "--inference_min_db", dest="inference_min_db", diff --git a/orchestrator.yml b/orchestrator.yml index 67d83584..1c0881e9 100644 --- a/orchestrator.yml +++ b/orchestrator.yml @@ -68,7 +68,7 @@ services: - --no-compass - --use_external_gps - --use_external_heading - - --inference_min_confidence=0.8 + - --inference_min_confidence=0.25 - --inference_min_db=-80 - --inference_model_name=mini2_snr # - --external_gps_server=1.2.3.4 diff --git a/tests/test_torchserve.sh b/tests/test_torchserve.sh index ee531bff..f84500ac 100755 --- a/tests/test_torchserve.sh +++ b/tests/test_torchserve.sh @@ -4,16 +4,26 @@ set -e TMPDIR=/tmp sudo apt-get update && sudo apt-get install -y jq wget sudo pip3 install torch-model-archiver -cp torchserve/custom_handler.py $TMPDIR/ +cp torchserve/* $TMPDIR/ cd $TMPDIR # TODO: use gamutRF weights here. wget https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt wget https://raw.githubusercontent.com/pytorch/serve/master/examples/object_detector/yolo/yolov8/requirements.txt -torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py -r requirements.txt +# TODO: use gamutRF test spectogram image +wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg +torch-model-archiver --force --model-name yolov8n --version 1.0 --serialized-file yolov8n.pt --handler custom_handler.py --extra-files model_config.json -r requirements.txt +# -r requirements.txt rm -rf model_store && mkdir model_store mv yolov8n.mar model_store/ -# TODO: --runtime nvidia is required for Orin +# TODO: --runtime nvidia is required for Orin, --gpus all for x86 docker run -v $(pwd)/model_store:/model_store -p 8080:8080 --rm --name testts --entrypoint timeout -d iqtlabs/torchserve 180s /torchserve/torchserve-entrypoint.sh --models yolov8n=yolov8n.mar -# TODO: use gamutRF test spectogram image -wget https://github.com/pytorch/serve/raw/master/examples/object_detector/yolo/yolov8/persons.jpg -wget -q --retry-connrefused --retry-on-host-error --body-file=persons.jpg --method=PUT -O- --header='Content-Type: image/jpg' http://127.0.0.1:8080/predictions/yolov8n | jq +PRED=$(wget -q --retry-connrefused --retry-on-host-error --body-file=persons.jpg --method=PUT -O- --header='Content-Type: image/jpg' http://127.0.0.1:8080/predictions/yolov8n | jq) +echo $PRED +if [ "$PRED" = "" ] ; then + echo "error: no response from Torchserve" + exit 1 +fi +if [ "$PRED" = "{}" ] ; then + echo "error: no predictions from Torchserve" + exit 1 +fi diff --git a/torchserve/custom_handler.py b/torchserve/custom_handler.py index ab40dbd7..7cde61d9 100644 --- a/torchserve/custom_handler.py +++ b/torchserve/custom_handler.py @@ -1,5 +1,6 @@ # based on pytorch's yolov8n example. +import json from collections import defaultdict import os @@ -41,6 +42,9 @@ def initialize(self, context): properties = context.system_properties self.manifest = context.manifest model_dir = properties.get("model_dir") + # https://docs.ultralytics.com/modes/predict/#inference-arguments + with open("model_config.json", "r") as f: + self.model_config = json.load(f) self.model_pt_path = None if "serializedFile" in self.manifest["model"]: serialized_file = self.manifest["model"]["serializedFile"] @@ -63,6 +67,10 @@ def _load_torchscript_model(self, model_pt_path): model.to(self.device) return model + def inference(self, data, *args, **kwargs): + kwargs.update(self.model_config) + return super().inference(data, *args, **kwargs) + def postprocess(self, res): output = [] for data in res: diff --git a/torchserve/model_config.json b/torchserve/model_config.json new file mode 100644 index 00000000..dabb2b05 --- /dev/null +++ b/torchserve/model_config.json @@ -0,0 +1,5 @@ +{ + "conf": 0.25, + "iou": 0.7, + "half": false +}