Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
checkpoints

# IDE
.idea/
.vscode/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.layers import DropPath, to_2tuple, trunc_normal_

from groundingdino.util.misc import NestedTensor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ms_deform_attn_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
if (value.is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
Expand All @@ -49,7 +49,7 @@ ms_deform_attn_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
if (value.is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ at::Tensor ms_deform_attn_cuda_forward(
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");

AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
Expand Down Expand Up @@ -62,15 +62,15 @@ at::Tensor ms_deform_attn_cuda_forward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());
columns.data_ptr<scalar_t>());

}));
}
Expand Down Expand Up @@ -98,7 +98,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");

AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
Expand Down Expand Up @@ -132,18 +132,18 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
grad_output_g.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);

}));
}
Expand Down
2 changes: 1 addition & 1 deletion groundingdino/models/GroundingDINO/fuse_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath
from timm.layers import DropPath


class FeatureResizer(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions groundingdino/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '0.1.0'
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools", "wheel", "torch", "torchvision", "ninja"]
build-backend = "setuptools.build_meta"
31 changes: 28 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,32 @@ def install_torch():
try:
import torch
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
# Try uv first, then pip
try:
import subprocess
subprocess.check_call([sys.executable, "-m", "uv", "pip", "install", "torch"])
except (subprocess.CalledProcessError, ImportError):
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
except Exception:
# If we can't install, torch should be available via build dependencies
pass

# Call the function to ensure torch is installed
install_torch()

import torch
try:
import torch
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
except ImportError:
# During build isolation, torch might not be available yet
# This will be installed via setup_requires or build dependencies
torch = None
CUDA_HOME = None
CppExtension = None
CUDAExtension = None

from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension

# groundingdino version info
version = "0.1.0"
Expand Down Expand Up @@ -200,6 +218,13 @@ def gen_packages_items():
license = f.read()

write_version_file()

# Set compiler to g++ to match PyTorch's build (PyTorch was built with g++)
import os
if "CXX" not in os.environ:
os.environ["CXX"] = "g++"
if "CC" not in os.environ:
os.environ["CC"] = "gcc"

setup(
name="groundingdino",
Expand Down
220 changes: 220 additions & 0 deletions setup.py.bak
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Modified from
# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py
# https://github.com/facebookresearch/detectron2/blob/main/setup.py
# https://github.com/open-mmlab/mmdetection/blob/master/setup.py
# https://github.com/Oneflow-Inc/libai/blob/main/setup.py
# ------------------------------------------------------------------------------------------------

import glob
import os
import subprocess

import subprocess
import sys

def install_torch():
try:
import torch
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])

# Call the function to ensure torch is installed
install_torch()

import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension

# groundingdino version info
version = "0.1.0"
package_name = "groundingdino"
cwd = os.path.dirname(os.path.abspath(__file__))


sha = "Unknown"
try:
sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip()
except Exception:
pass


def write_version_file():
version_path = os.path.join(cwd, "groundingdino", "version.py")
with open(version_path, "w") as f:
f.write(f"__version__ = '{version}'\n")
# f.write(f"git_version = {repr(sha)}\n")


requirements = ["torch", "torchvision"]

torch_ver = [int(x) for x in torch.__version__.split(".")[:2]]


def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc")

main_source = os.path.join(extensions_dir, "vision.cpp")
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"))
source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob(
os.path.join(extensions_dir, "*.cu")
)

sources = [main_source] + sources

extension = CppExtension

extra_compile_args = {"cxx": []}
define_macros = []

if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ):
print("Compiling with CUDA")
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args["nvcc"] = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]
else:
print("Compiling without CUDA")
define_macros += [("WITH_HIP", None)]
extra_compile_args["nvcc"] = []
return None

sources = [os.path.join(extensions_dir, s) for s in sources]
include_dirs = [extensions_dir]

ext_modules = [
extension(
"groundingdino._C",
sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
)
]

return ext_modules


def parse_requirements(fname="requirements.txt", with_version=True):
"""Parse the package dependencies listed in a requirements file but strips
specific versioning information.

Args:
fname (str): path to requirements file
with_version (bool, default=False): if True include version specs

Returns:
List[str]: list of requirements items

CommandLine:
python -c "import setup; print(setup.parse_requirements())"
"""
import re
import sys
from os.path import exists

require_fpath = fname

def parse_line(line):
"""Parse information from a line in a requirements text file."""
if line.startswith("-r "):
# Allow specifying requirements in other files
target = line.split(" ")[1]
for info in parse_require_file(target):
yield info
else:
info = {"line": line}
if line.startswith("-e "):
info["package"] = line.split("#egg=")[1]
elif "@git+" in line:
info["package"] = line
else:
# Remove versioning from the package
pat = "(" + "|".join([">=", "==", ">"]) + ")"
parts = re.split(pat, line, maxsplit=1)
parts = [p.strip() for p in parts]

info["package"] = parts[0]
if len(parts) > 1:
op, rest = parts[1:]
if ";" in rest:
# Handle platform specific dependencies
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
version, platform_deps = map(str.strip, rest.split(";"))
info["platform_deps"] = platform_deps
else:
version = rest # NOQA
info["version"] = (op, version)
yield info

def parse_require_file(fpath):
with open(fpath, "r") as f:
for line in f.readlines():
line = line.strip()
if line and not line.startswith("#"):
for info in parse_line(line):
yield info

def gen_packages_items():
if exists(require_fpath):
for info in parse_require_file(require_fpath):
parts = [info["package"]]
if with_version and "version" in info:
parts.extend(info["version"])
if not sys.version.startswith("3.4"):
# apparently package_deps are broken in 3.4
platform_deps = info.get("platform_deps")
if platform_deps is not None:
parts.append(";" + platform_deps)
item = "".join(parts)
yield item

packages = list(gen_packages_items())
return packages


if __name__ == "__main__":
print(f"Building wheel {package_name}-{version}")

with open("LICENSE", "r", encoding="utf-8") as f:
license = f.read()

write_version_file()

setup(
name="groundingdino",
version="0.1.0",
author="International Digital Economy Academy, Shilong Liu",
url="https://github.com/IDEA-Research/GroundingDINO",
description="open-set object detector",
license=license,
install_requires=parse_requirements("requirements.txt"),
packages=find_packages(
exclude=(
"configs",
"tests",
)
),
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
Loading