-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from zeroae/f/sagemaker
Add SageMaker Inference Toolkit support
- Loading branch information
Showing
18 changed files
with
411 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
anaconda_upload: false | ||
channels: | ||
- zeroae | ||
- defaults | ||
- conda-forge | ||
show_channel_urls: true | ||
show_channel_urls: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
# MacOS | ||
.DS_Store | ||
|
||
# Vagrant | ||
.vagrant | ||
|
||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- mode: ruby -*- | ||
# vi: set ft=ruby : | ||
|
||
# All Vagrant configuration is done below. The "2" in Vagrant.configure | ||
# configures the configuration version (we support older styles for | ||
# backwards compatibility). Please don't change it unless you know what | ||
# you're doing. | ||
Vagrant.configure("2") do |config| | ||
# The most common configuration options are documented and commented below. | ||
# For a complete reference, please see the online documentation at | ||
# https://docs.vagrantup.com. | ||
|
||
# Every Vagrant development environment requires a box. You can search for | ||
# boxes at https://vagrantcloud.com/search. | ||
config.vm.box = "hashicorp/bionic64" | ||
|
||
# Disable automatic box update checking. If you disable this, then | ||
# boxes will only be checked for updates when the user runs | ||
# `vagrant box outdated`. This is not recommended. | ||
# config.vm.box_check_update = false | ||
|
||
# Create a forwarded port mapping which allows access to a specific port | ||
# within the machine from a port on the host machine. In the example below, | ||
# accessing "localhost:8080" will access port 80 on the guest machine. | ||
# NOTE: This will enable public access to the opened port | ||
config.vm.network "forwarded_port", guest: 8888, host: 8888 | ||
|
||
# Create a forwarded port mapping which allows access to a specific port | ||
# within the machine from a port on the host machine and only allow access | ||
# via 127.0.0.1 to disable public access | ||
# config.vm.network "forwarded_port", guest: 80, host: 8080, host_ip: "127.0.0.1" | ||
|
||
# Create a private network, which allows host-only access to the machine | ||
# using a specific IP. | ||
# config.vm.network "private_network", ip: "192.168.33.10" | ||
|
||
# Create a public network, which generally matched to bridged network. | ||
# Bridged networks make the machine appear as another physical device on | ||
# your network. | ||
# config.vm.network "public_network" | ||
|
||
# Share an additional folder to the guest VM. The first argument is | ||
# the path on the host to the actual folder. The second argument is | ||
# the path on the guest to mount the folder. And the optional third | ||
# argument is a set of non-required options. | ||
# config.vm.synced_folder "../data", "/vagrant_data" | ||
|
||
# Provider-specific configuration so you can fine-tune various | ||
# backing providers for Vagrant. These expose provider-specific options. | ||
# Example for VirtualBox: | ||
# | ||
# config.vm.provider "virtualbox" do |vb| | ||
# # Display the VirtualBox GUI when booting the machine | ||
# vb.gui = true | ||
# | ||
# # Customize the amount of memory on the VM: | ||
# vb.memory = "1024" | ||
# end | ||
# | ||
# View the documentation for the provider you are using for more | ||
# information on available options. | ||
config.vm.provider "vmware_desktop" do |v| | ||
v.vmx["memsize"] = "2048" | ||
end | ||
|
||
# Enable provisioning with a shell script. Additional provisioners such as | ||
# Ansible, Chef, Docker, Puppet and Salt are also available. Please see the | ||
# documentation for more information about their specific syntax and use. | ||
config.vm.provision "shell", name: "Install Mambaforge", privileged: false, reset: true, inline: <<-SHELL | ||
MAMBA_FORGE_FILE=Mambaforge-$(uname)-$(uname -m).sh | ||
if ! [ -d ~/mambaforge ]; then | ||
if ! [ -f $MAMBA_FORGE_FILE ]; then | ||
wget -q https://github.com/conda-forge/miniforge/releases/latest/download/$MAMBA_FORGE_FILE | ||
fi | ||
bash $MAMBA_FORGE_FILE -b -u | ||
rm -f $MAMBA_FORGE_FILE | ||
mambaforge/bin/conda init --all | ||
fi | ||
SHELL | ||
|
||
config.vm.provision "shell", name: "Install OS Packages", inline: <<-SHELL | ||
### Add OpenJDK 8 | ||
apt-get update | ||
apt-get --yes install openjdk-8-jre-headless | ||
SHELL | ||
|
||
config.vm.provision "shell", name: "Create Development Environment", privileged: false, inline: <<-SHELL | ||
### Create the DarkNet Environment | ||
source mambaforge/etc/profile.d/conda.sh | ||
mamba env update --name darknet-cpu -f /vagrant/environment.yml | ||
echo 'conda activate darknet-cpu' >> ~/.bashrc | ||
SHELL | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .default_inference_handler import DefaultDarknetInferenceHandler, Network | ||
from ..py.util import image_to_3darray | ||
|
||
__all__ = ["DefaultDarknetInferenceHandler", "Network", "image_to_3darray"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from retrying import retry | ||
from subprocess import CalledProcessError | ||
from sagemaker_inference import model_server | ||
|
||
# TODO: from .classifier import handler_service as classifier_service | ||
from .detector import handler_service as detector_service | ||
|
||
|
||
def _retry_if_error(exception): | ||
return isinstance(exception, CalledProcessError or OSError) | ||
|
||
|
||
@retry(stop_max_delay=1000 * 50, retry_on_exception=_retry_if_error) | ||
def _start_mms(): | ||
# by default the number of workers per model is 1, but we can configure it through the | ||
# environment variable below if desired. | ||
# os.environ['SAGEMAKER_MODEL_SERVER_WORKERS'] = '2' | ||
# TODO: Start Classifier *or* Detector Service | ||
model_server.start_model_server(handler_service=detector_service.__name__) | ||
|
||
|
||
def main(): | ||
_start_mms() | ||
|
||
|
||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .handler_service import HandlerService | ||
|
||
__all__ = ["HandlerService"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from retrying import retry | ||
from subprocess import CalledProcessError | ||
from sagemaker_inference import model_server | ||
|
||
from . import handler_service as classifier_service | ||
|
||
|
||
def _retry_if_error(exception): | ||
return isinstance(exception, CalledProcessError or OSError) | ||
|
||
|
||
@retry(stop_max_delay=1000 * 50, retry_on_exception=_retry_if_error) | ||
def _start_mms(): | ||
# by default the number of workers per model is 1, but we can configure it through the | ||
# environment variable below if desired. | ||
# os.environ['SAGEMAKER_MODEL_SERVER_WORKERS'] = '2' | ||
# TODO: Start Classifier *or* Detector Service | ||
model_server.start_model_server(handler_service=classifier_service.__name__) | ||
|
||
|
||
def main(): | ||
_start_mms() | ||
|
||
|
||
main() |
36 changes: 36 additions & 0 deletions
36
src/darknet/sagemaker/classifier/default_inference_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Tuple, List | ||
|
||
from sagemaker_inference.errors import UnsupportedFormatError | ||
|
||
from .. import DefaultDarknetInferenceHandler, image_to_3darray, Network | ||
|
||
|
||
class DefaultDarknetClassifierInferenceHandler(DefaultDarknetInferenceHandler): | ||
def default_predict_fn(self, data, model: Tuple[Network, List[str]]): | ||
"""A default predict_fn for DarkNet. Calls a model on data deserialized in input_fn. | ||
Args: | ||
data: input data (PIL.Image) for prediction deserialized by input_fn | ||
model: Darknet model loaded in memory by model_fn | ||
Returns: a prediction | ||
""" | ||
network, labels = model | ||
max_labels = data.get("MaxLabels", 5) | ||
# TODO: min_confidence = data.get("MinConfidence", 55) | ||
|
||
if "NDArray" in data: | ||
probabilities = network.predict(data["NDArray"]) | ||
elif "Image" in data: | ||
image, _ = image_to_3darray(data["Image"], network.shape) | ||
probabilities = network.predict_image(image) | ||
else: | ||
raise UnsupportedFormatError("Expected an NDArray or an Image") | ||
|
||
rv = [ | ||
{ | ||
"Name": label, | ||
"Confidence": prob * 100, | ||
} | ||
for label, prob in sorted(zip(labels, probabilities), key=lambda x: x[1], reverse=True) | ||
] | ||
return {"Labels": rv[0:max_labels] if max_labels else rv} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from sagemaker_inference.default_handler_service import DefaultHandlerService | ||
from sagemaker_inference.transformer import Transformer | ||
|
||
from .default_inference_handler import DefaultDarknetClassifierInferenceHandler | ||
|
||
|
||
class HandlerService(DefaultHandlerService): | ||
"""Handler service that is executed by the model server. | ||
Determines specific default inference handlers to use based on the type MXNet model being used. | ||
This class extends ``DefaultHandlerService``, which define the following: | ||
- The ``handle`` method is invoked for all incoming inference requests to the model server. | ||
- The ``initialize`` method is invoked at model server start up. | ||
Based on: https://github.com/awslabs/mxnet-model-server/blob/master/docs/custom_service.md | ||
""" | ||
|
||
def __init__(self): | ||
self._initialized = False | ||
|
||
transformer = Transformer( | ||
default_inference_handler=DefaultDarknetClassifierInferenceHandler() | ||
) | ||
super(HandlerService, self).__init__(transformer=transformer) |
Oops, something went wrong.