diff --git a/.gitignore b/.gitignore index fdc0e71e..e4655b73 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +checkpoints + # IDE .idea/ .vscode/ diff --git a/groundingdino/models/GroundingDINO/backbone/swin_transformer.py b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py index 1c66194d..e335ae06 100644 --- a/groundingdino/models/GroundingDINO/backbone/swin_transformer.py +++ b/groundingdino/models/GroundingDINO/backbone/swin_transformer.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.layers import DropPath, to_2tuple, trunc_normal_ from groundingdino.util.misc import NestedTensor diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h index c7408eba..20402614 100644 --- a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn.h @@ -27,7 +27,7 @@ ms_deform_attn_forward( const at::Tensor &attn_weight, const int im2col_step) { - if (value.type().is_cuda()) + if (value.is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_forward( @@ -49,7 +49,7 @@ ms_deform_attn_backward( const at::Tensor &grad_output, const int im2col_step) { - if (value.type().is_cuda()) + if (value.is_cuda()) { #ifdef WITH_CUDA return ms_deform_attn_cuda_backward( diff --git a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu index d04fae8a..9a59f907 100644 --- a/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu +++ b/groundingdino/models/GroundingDINO/csrc/MsDeformAttn/ms_deform_attn_cuda.cu @@ -32,7 +32,7 @@ at::Tensor ms_deform_attn_cuda_forward( AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); @@ -62,15 +62,15 @@ at::Tensor ms_deform_attn_cuda_forward( for (int n = 0; n < batch/im2col_step_; ++n) { auto columns = output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] { ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - columns.data()); + columns.data_ptr()); })); } @@ -98,7 +98,7 @@ std::vector ms_deform_attn_cuda_backward( AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); - AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); + AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor"); AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); @@ -132,18 +132,18 @@ std::vector ms_deform_attn_cuda_backward( for (int n = 0; n < batch/im2col_step_; ++n) { auto grad_output_g = grad_output_n.select(0, n); - AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] { ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), - grad_output_g.data(), - value.data() + n * im2col_step_ * per_value_size, - spatial_shapes.data(), - level_start_index.data(), - sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - attn_weight.data() + n * im2col_step_ * per_attn_weight_size, + grad_output_g.data_ptr(), + value.data_ptr() + n * im2col_step_ * per_value_size, + spatial_shapes.data_ptr(), + level_start_index.data_ptr(), + sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size, batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, - grad_value.data() + n * im2col_step_ * per_value_size, - grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, - grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); + grad_value.data_ptr() + n * im2col_step_ * per_value_size, + grad_sampling_loc.data_ptr() + n * im2col_step_ * per_sample_loc_size, + grad_attn_weight.data_ptr() + n * im2col_step_ * per_attn_weight_size); })); } diff --git a/groundingdino/models/GroundingDINO/fuse_modules.py b/groundingdino/models/GroundingDINO/fuse_modules.py index 2753b3dd..a5d428ca 100644 --- a/groundingdino/models/GroundingDINO/fuse_modules.py +++ b/groundingdino/models/GroundingDINO/fuse_modules.py @@ -8,7 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from timm.models.layers import DropPath +from timm.layers import DropPath class FeatureResizer(nn.Module): diff --git a/groundingdino/version.py b/groundingdino/version.py new file mode 100644 index 00000000..b794fd40 --- /dev/null +++ b/groundingdino/version.py @@ -0,0 +1 @@ +__version__ = '0.1.0' diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b652decd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "wheel", "torch", "torchvision", "ninja"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 275b6fc3..b09c896d 100644 --- a/setup.py +++ b/setup.py @@ -31,14 +31,32 @@ def install_torch(): try: import torch except ImportError: - subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"]) + # Try uv first, then pip + try: + import subprocess + subprocess.check_call([sys.executable, "-m", "uv", "pip", "install", "torch"]) + except (subprocess.CalledProcessError, ImportError): + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"]) + except Exception: + # If we can't install, torch should be available via build dependencies + pass # Call the function to ensure torch is installed install_torch() -import torch +try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension +except ImportError: + # During build isolation, torch might not be available yet + # This will be installed via setup_requires or build dependencies + torch = None + CUDA_HOME = None + CppExtension = None + CUDAExtension = None + from setuptools import find_packages, setup -from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension # groundingdino version info version = "0.1.0" @@ -200,6 +218,13 @@ def gen_packages_items(): license = f.read() write_version_file() + + # Set compiler to g++ to match PyTorch's build (PyTorch was built with g++) + import os + if "CXX" not in os.environ: + os.environ["CXX"] = "g++" + if "CC" not in os.environ: + os.environ["CC"] = "gcc" setup( name="groundingdino", diff --git a/setup.py.bak b/setup.py.bak new file mode 100644 index 00000000..275b6fc3 --- /dev/null +++ b/setup.py.bak @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2022 The IDEA Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ------------------------------------------------------------------------------------------------ +# Modified from +# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/setup.py +# https://github.com/facebookresearch/detectron2/blob/main/setup.py +# https://github.com/open-mmlab/mmdetection/blob/master/setup.py +# https://github.com/Oneflow-Inc/libai/blob/main/setup.py +# ------------------------------------------------------------------------------------------------ + +import glob +import os +import subprocess + +import subprocess +import sys + +def install_torch(): + try: + import torch + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"]) + +# Call the function to ensure torch is installed +install_torch() + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +# groundingdino version info +version = "0.1.0" +package_name = "groundingdino" +cwd = os.path.dirname(os.path.abspath(__file__)) + + +sha = "Unknown" +try: + sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip() +except Exception: + pass + + +def write_version_file(): + version_path = os.path.join(cwd, "groundingdino", "version.py") + with open(version_path, "w") as f: + f.write(f"__version__ = '{version}'\n") + # f.write(f"git_version = {repr(sha)}\n") + + +requirements = ["torch", "torchvision"] + +torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "groundingdino", "models", "GroundingDINO", "csrc") + + main_source = os.path.join(extensions_dir, "vision.cpp") + sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( + os.path.join(extensions_dir, "*.cu") + ) + + sources = [main_source] + sources + + extension = CppExtension + + extra_compile_args = {"cxx": []} + define_macros = [] + + if CUDA_HOME is not None and (torch.cuda.is_available() or "TORCH_CUDA_ARCH_LIST" in os.environ): + print("Compiling with CUDA") + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + print("Compiling without CUDA") + define_macros += [("WITH_HIP", None)] + extra_compile_args["nvcc"] = [] + return None + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + + ext_modules = [ + extension( + "groundingdino._C", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +def parse_requirements(fname="requirements.txt", with_version=True): + """Parse the package dependencies listed in a requirements file but strips + specific versioning information. + + Args: + fname (str): path to requirements file + with_version (bool, default=False): if True include version specs + + Returns: + List[str]: list of requirements items + + CommandLine: + python -c "import setup; print(setup.parse_requirements())" + """ + import re + import sys + from os.path import exists + + require_fpath = fname + + def parse_line(line): + """Parse information from a line in a requirements text file.""" + if line.startswith("-r "): + # Allow specifying requirements in other files + target = line.split(" ")[1] + for info in parse_require_file(target): + yield info + else: + info = {"line": line} + if line.startswith("-e "): + info["package"] = line.split("#egg=")[1] + elif "@git+" in line: + info["package"] = line + else: + # Remove versioning from the package + pat = "(" + "|".join([">=", "==", ">"]) + ")" + parts = re.split(pat, line, maxsplit=1) + parts = [p.strip() for p in parts] + + info["package"] = parts[0] + if len(parts) > 1: + op, rest = parts[1:] + if ";" in rest: + # Handle platform specific dependencies + # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies + version, platform_deps = map(str.strip, rest.split(";")) + info["platform_deps"] = platform_deps + else: + version = rest # NOQA + info["version"] = (op, version) + yield info + + def parse_require_file(fpath): + with open(fpath, "r") as f: + for line in f.readlines(): + line = line.strip() + if line and not line.startswith("#"): + for info in parse_line(line): + yield info + + def gen_packages_items(): + if exists(require_fpath): + for info in parse_require_file(require_fpath): + parts = [info["package"]] + if with_version and "version" in info: + parts.extend(info["version"]) + if not sys.version.startswith("3.4"): + # apparently package_deps are broken in 3.4 + platform_deps = info.get("platform_deps") + if platform_deps is not None: + parts.append(";" + platform_deps) + item = "".join(parts) + yield item + + packages = list(gen_packages_items()) + return packages + + +if __name__ == "__main__": + print(f"Building wheel {package_name}-{version}") + + with open("LICENSE", "r", encoding="utf-8") as f: + license = f.read() + + write_version_file() + + setup( + name="groundingdino", + version="0.1.0", + author="International Digital Economy Academy, Shilong Liu", + url="https://github.com/IDEA-Research/GroundingDINO", + description="open-set object detector", + license=license, + install_requires=parse_requirements("requirements.txt"), + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, + ) diff --git a/test.ipynb b/test.ipynb index 9138092a..38cbea30 100644 --- a/test.ipynb +++ b/test.ipynb @@ -91,7 +91,7 @@ ], "metadata": { "kernelspec": { - "display_name": "base", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -105,7 +105,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.16" }, "orig_nbformat": 4 },