Skip to content

Commit

Permalink
feat: RND-114: Add SAM2 integration for Video Object Tracking (#596)
Browse files Browse the repository at this point in the history
Co-authored-by: nik <nik@heartex.net>
Co-authored-by: Micaela Kaplan <kaplan.micaela@gmail.com>
  • Loading branch information
3 people committed Aug 15, 2024
1 parent a9cf8e7 commit d803e87
Show file tree
Hide file tree
Showing 11 changed files with 677 additions and 0 deletions.
59 changes: 59 additions & 0 deletions label_studio_ml/examples/segment_anything_2_video/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
ARG DEBIAN_FRONTEND=noninteractive
ARG TEST_ENV

WORKDIR /app

RUN conda update conda -y

RUN --mount=type=cache,target="/var/cache/apt",sharing=locked \
--mount=type=cache,target="/var/lib/apt/lists",sharing=locked \
apt-get -y update \
&& apt-get install -y git \
&& apt-get install -y wget \
&& apt-get install -y g++ freeglut3-dev build-essential libx11-dev \
libxmu-dev libxi-dev libglu1-mesa libglu1-mesa-dev libfreeimage-dev \
&& apt-get -y install ffmpeg libsm6 libxext6 libffi-dev python3-dev python3-pip gcc

ENV PYTHONUNBUFFERED=1 \
PYTHONDONTWRITEBYTECODE=1 \
PIP_CACHE_DIR=/.cache \
PORT=9090 \
WORKERS=2 \
THREADS=4 \
CUDA_HOME=/usr/local/cuda \
SEGMENT_ANYTHING_2_REPO_PATH=/segment-anything-2

RUN conda install -c "nvidia/label/cuda-12.1.1" cuda -y
ENV CUDA_HOME=/opt/conda \
TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0;7.5;8.0;8.6+PTX;8.9;9.0"

# install base requirements
COPY requirements-base.txt .
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
pip install -r requirements-base.txt

COPY requirements.txt .
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
pip3 install -r requirements.txt

# install segment-anything-2
RUN cd / && git clone --depth 1 --branch main --single-branch https://github.com/facebookresearch/segment-anything-2.git
WORKDIR /segment-anything-2
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
pip3 install -e .
RUN cd checkpoints && ./download_ckpts.sh

WORKDIR /app

# install test requirements if needed
COPY requirements-test.txt .
# build only when TEST_ENV="true"
RUN --mount=type=cache,target=${PIP_CACHE_DIR},sharing=locked \
if [ "$TEST_ENV" = "true" ]; then \
pip3 install -r requirements-test.txt; \
fi

COPY . ./

CMD ["/app/start.sh"]
63 changes: 63 additions & 0 deletions label_studio_ml/examples/segment_anything_2_video/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
This guide describes the simplest way to start using **SegmentAnything 2** with Label Studio.

This repository is specifically for working with object tracking in videos. For working with images,
see the [segment_anything_2_image repository](https://github.com/HumanSignal/label-studio-ml-backend/tree/master/label_studio_ml/examples/segment_anything_2_image)

![sam2](./Sam2Video.gif)

## Running from source

1. To run the ML backend without Docker, you have to clone the repository and install all dependencies using pip:

```bash
git clone https://github.com/HumanSignal/label-studio-ml-backend.git
cd label-studio-ml-backend
pip install -e .
cd label_studio_ml/examples/segment_anything_2_video
pip install -r requirements.txt
```

2. Download [`segment-anything-2` repo](https://github.com/facebookresearch/segment-anything-2) into the root directory. Install SegmentAnything model and download checkpoints using [the official Meta documentation](https://github.com/facebookresearch/segment-anything-2?tab=readme-ov-file#installation). Make sure that you complete the steps for downloadingn the checkpoint files!

3. Export the following environment variables (fill them in with your credentials!):
- LABEL_STUDIO_URL: the http:// or https:// link to your label studio instance (include the prefix!)
- LABEL_STUDIO_API_KEY: your api key for label studio, available in your profile.

4. Then you can start the ML backend on the default port `9090`:

```bash
cd ../
label-studio-ml start ./segment_anything_2_video
```
Note that if you're running in a cloud server, you'll need to run on an exposed port. To change the port, add `-p <port number>` to the end of the start command above.
5. Connect running ML backend server to Label Studio: go to your project `Settings -> Machine Learning -> Add Model` and specify `http://localhost:9090` as a URL. Read more in the official [Label Studio documentation](https://labelstud.io/guide/ml#Connect-the-model-to-Label-Studio).
Again, if you're running in the cloud, you'll need to replace this localhost location with whatever the external ip address is of your container, along with the exposed port.

# Labeling Config
For your project, you can use any labeling config with video properties. Here's a basic one to get you started!

<View>
<Labels name="videoLabels" toName="video" allowEmpty="true">



<Label value="Player" background="#11A39E"/><Label value="Ball" background="#D4380D"/></Labels>

<!-- Please specify FPS carefully, it will be used for all project videos -->
<Video name="video" value="$video" framerate="25.0"/>
<VideoRectangle name="box" toName="video" smart="true"/>
</View><!--{
"video": "/static/samples/opossum_snow.mp4"
}-->


# Known limitiations
- As of 8/11/2024, SAM2 only runs on GPU servers.
- Currently, we only support the tracking of one object in video, although SAM2 can support multiple.
- Currently, we do not support video segmentation.
- No Docker support

If you want to contribute to this repository to help with some of these limitations, you can submit a PR.
# Customization

The ML backend can be customized by adding your own models and logic inside the `./segment_anything_2_video` directory.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
121 changes: 121 additions & 0 deletions label_studio_ml/examples/segment_anything_2_video/_wsgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import argparse
import json
import logging
import logging.config

logging.config.dictConfig({
"version": 1,
"formatters": {
"standard": {
"format": "[%(asctime)s] [%(levelname)s] [%(name)s::%(funcName)s::%(lineno)d] %(message)s"
}
},
"handlers": {
"console": {
"class": "logging.StreamHandler",
"level": os.getenv('LOG_LEVEL'),
"stream": "ext://sys.stdout",
"formatter": "standard"
}
},
"root": {
"level": os.getenv('LOG_LEVEL'),
"handlers": [
"console"
],
"propagate": True
}
})

from label_studio_ml.api import init_app
from model import NewModel


_DEFAULT_CONFIG_PATH = os.path.join(os.path.dirname(__file__), 'config.json')


def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
if not os.path.exists(config_path):
return dict()
with open(config_path) as f:
config = json.load(f)
assert isinstance(config, dict)
return config


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Label studio')
parser.add_argument(
'-p', '--port', dest='port', type=int, default=9090,
help='Server port')
parser.add_argument(
'--host', dest='host', type=str, default='0.0.0.0',
help='Server host')
parser.add_argument(
'--kwargs', '--with', dest='kwargs', metavar='KEY=VAL', nargs='+', type=lambda kv: kv.split('='),
help='Additional LabelStudioMLBase model initialization kwargs')
parser.add_argument(
'-d', '--debug', dest='debug', action='store_true',
help='Switch debug mode')
parser.add_argument(
'--log-level', dest='log_level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'], default=None,
help='Logging level')
parser.add_argument(
'--model-dir', dest='model_dir', default=os.path.dirname(__file__),
help='Directory where models are stored (relative to the project directory)')
parser.add_argument(
'--check', dest='check', action='store_true',
help='Validate model instance before launching server')
parser.add_argument('--basic-auth-user',
default=os.environ.get('ML_SERVER_BASIC_AUTH_USER', None),
help='Basic auth user')

parser.add_argument('--basic-auth-pass',
default=os.environ.get('ML_SERVER_BASIC_AUTH_PASS', None),
help='Basic auth pass')

args = parser.parse_args()

# setup logging level
if args.log_level:
logging.root.setLevel(args.log_level)

def isfloat(value):
try:
float(value)
return True
except ValueError:
return False

def parse_kwargs():
param = dict()
for k, v in args.kwargs:
if v.isdigit():
param[k] = int(v)
elif v == 'True' or v == 'true':
param[k] = True
elif v == 'False' or v == 'false':
param[k] = False
elif isfloat(v):
param[k] = float(v)
else:
param[k] = v
return param

kwargs = get_kwargs_from_config()

if args.kwargs:
kwargs.update(parse_kwargs())

if args.check:
print('Check "' + NewModel.__name__ + '" instance creation..')
model = NewModel(**kwargs)

app = init_app(model_class=NewModel, basic_auth_user=args.basic_auth_user, basic_auth_pass=args.basic_auth_pass)

app.run(host=args.host, port=args.port, debug=args.debug)

else:
# for uWSGI use
app = init_app(model_class=NewModel)
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
version: "3.8"

services:
segment_anything_2_video:
container_name: segment_anything_2_video
image: humansignal/segment_anything_2_video:v0
build:
context: .
args:
TEST_ENV: ${TEST_ENV}
environment:
# specify these parameters if you want to use basic auth for the model server
- BASIC_AUTH_USER=
- BASIC_AUTH_PASS=
# set the log level for the model server
- LOG_LEVEL=DEBUG
# any other parameters that you want to pass to the model server
- ANY=PARAMETER
# specify the number of workers and threads for the model server
- WORKERS=1
- THREADS=8
# specify the model directory (likely you don't need to change this)
- MODEL_DIR=/data/models
# specify device
- DEVICE=cuda # or 'cpu' (coming soon)
# SAM2 model config
- MODEL_CONFIG=sam2_hiera_l.yaml
# SAM2 checkpoint
- MODEL_CHECKPOINT=sam2_hiera_large.pt

# Specify the Label Studio URL and API key to access
# uploaded, local storage and cloud storage files.
# Do not use 'localhost' as it does not work within Docker containers.
# Use prefix 'http://' or 'https://' for the URL always.
# Determine the actual IP using 'ifconfig' (Linux/Mac) or 'ipconfig' (Windows).
- LABEL_STUDIO_URL=
- LABEL_STUDIO_API_KEY=
ports:
- "9090:9090"
volumes:
- "./data/server:/data"
Loading

0 comments on commit d803e87

Please sign in to comment.