Skip to content
24 changes: 21 additions & 3 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
from torchvision import io, tv_tensors
from torchvision.transforms._functional_tensor import _max_value as get_max_value
from torchvision.transforms.v2.functional import to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional import cvcuda_to_tensor, to_cvcuda_tensor, to_image, to_pil_image
from torchvision.transforms.v2.functional._utils import _import_cvcuda, _is_cvcuda_available
from torchvision.utils import _Image_fromarray


IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CVCUDA_AVAILABLE = _is_cvcuda_available()
CUDA_NOT_AVAILABLE_MSG = "CUDA device not available"
MPS_NOT_AVAILABLE_MSG = "MPS device not available"
OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda."
Expand Down Expand Up @@ -275,6 +277,17 @@ def combinations_grid(**kwargs):
return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())]


def cvcuda_to_pil_compatible_tensor(tensor: "cvcuda.Tensor") -> torch.Tensor:
tensor = cvcuda_to_tensor(tensor)
if tensor.ndim != 4:
raise ValueError(f"CV-CUDA Tensor should be 4 dimensional. Got {tensor.ndim} dimensions.")
if tensor.shape[0] != 1:
raise ValueError(
f"CV-CUDA Tensor should have batch dimension 1 for comparison with PIL.Image.Image. Got {tensor.shape[0]}."
)
return tensor.squeeze(0).cpu()


class ImagePair(TensorLikePair):
def __init__(
self,
Expand All @@ -287,6 +300,11 @@ def __init__(
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])

# handle check for CV-CUDA Tensors
if CVCUDA_AVAILABLE and isinstance(actual, _import_cvcuda().Tensor):
# Use the PIL compatible tensor, so we can always compare with PIL.Image.Image
actual = cvcuda_to_pil_compatible_tensor(actual)

super().__init__(actual, expected, **other_parameters)
self.mae = mae

Expand Down Expand Up @@ -400,8 +418,8 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down
79 changes: 63 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import torchvision.transforms.v2 as transforms

from common_utils import (
assert_close,
assert_equal,
cache,
cpu_and_cuda,
cvcuda_to_pil_compatible_tensor,
freeze_rng_state,
ignore_jit_no_profile_information_warning,
make_bounding_boxes,
Expand All @@ -41,7 +43,6 @@
)

from torch import nn
from torch.testing import assert_close
from torch.utils._pytree import tree_flatten, tree_map
from torch.utils.data import DataLoader, default_collate
from torchvision import tv_tensors
Expand Down Expand Up @@ -5500,24 +5501,34 @@ def test_kernel_image(self, mean, std, device):

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_inplace(self, device):
input = make_image_tensor(dtype=torch.float32, device=device)
input_version = input._version
inpt = make_image_tensor(dtype=torch.float32, device=device)
input_version = inpt._version

output_out_of_place = F.normalize_image(input, mean=self.MEAN, std=self.STD)
assert output_out_of_place.data_ptr() != input.data_ptr()
assert output_out_of_place is not input
output_out_of_place = F.normalize_image(inpt, mean=self.MEAN, std=self.STD)
assert output_out_of_place.data_ptr() != inpt.data_ptr()
assert output_out_of_place is not inpt

output_inplace = F.normalize_image(input, mean=self.MEAN, std=self.STD, inplace=True)
assert output_inplace.data_ptr() == input.data_ptr()
output_inplace = F.normalize_image(inpt, mean=self.MEAN, std=self.STD, inplace=True)
assert output_inplace.data_ptr() == inpt.data_ptr()
assert output_inplace._version > input_version
assert output_inplace is input
assert output_inplace is inpt

assert_equal(output_inplace, output_out_of_place)

def test_kernel_video(self):
check_kernel(F.normalize_video, make_video(dtype=torch.float32), mean=self.MEAN, std=self.STD)

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_functional(self, make_input):
check_functional(F.normalize, make_input(dtype=torch.float32), mean=self.MEAN, std=self.STD)

Expand All @@ -5527,9 +5538,16 @@ def test_functional(self, make_input):
(F.normalize_image, torch.Tensor),
(F.normalize_image, tv_tensors.Image),
(F.normalize_video, tv_tensors.Video),
pytest.param(
F._misc._normalize_cvcuda,
"cvcuda.Tensor",
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA"),
),
],
)
def test_functional_signature(self, kernel, input_type):
if input_type == "cvcuda.Tensor":
input_type = _import_cvcuda().Tensor
check_functional_kernel_signature_match(F.normalize, kernel=kernel, input_type=input_type)

def test_functional_error(self):
Expand All @@ -5543,9 +5561,9 @@ def test_functional_error(self):
with pytest.raises(ValueError, match="std evaluated to zero, leading to division by zero"):
F.normalize_image(make_image(dtype=torch.float32), mean=self.MEAN, std=std)

def _sample_input_adapter(self, transform, input, device):
def _sample_input_adapter(self, transform, inpt, device):
adapted_input = {}
for key, value in input.items():
for key, value in inpt.items():
if isinstance(value, PIL.Image.Image):
# normalize doesn't support PIL images
continue
Expand All @@ -5555,7 +5573,17 @@ def _sample_input_adapter(self, transform, input, device):
adapted_input[key] = value
return adapted_input

@pytest.mark.parametrize("make_input", [make_image_tensor, make_image, make_video])
@pytest.mark.parametrize(
"make_input",
[
make_image_tensor,
make_image,
make_video,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
def test_transform(self, make_input):
check_transform(
transforms.Normalize(mean=self.MEAN, std=self.STD),
Expand All @@ -5570,14 +5598,33 @@ def _reference_normalize_image(self, image, *, mean, std):

@pytest.mark.parametrize(("mean", "std"), MEANS_STDS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda, marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
),
],
)
@pytest.mark.parametrize("fn", [F.normalize, transform_cls_to_functional(transforms.Normalize)])
def test_correctness_image(self, mean, std, dtype, fn):
image = make_image(dtype=dtype)
def test_correctness_image(self, mean, std, dtype, make_input, fn):
if make_input == make_image_cvcuda and dtype != torch.float32:
pytest.skip("CVCUDA only supports float32 for normalize")

image = make_input(dtype=dtype)

actual = fn(image, mean=mean, std=std)

if make_input == make_image_cvcuda:
image = cvcuda_to_pil_compatible_tensor(image)

expected = self._reference_normalize_image(image, mean=mean, std=std)

assert_equal(actual, expected)
if make_input == make_image_cvcuda:
assert_close(actual, expected, rtol=0, atol=1e-6)
else:
assert_equal(actual, expected)


class TestClampBoundingBoxes:
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_bounding_boxes,
get_keypoints,
has_any,
is_cvcuda_tensor,
is_pure_tensor,
)

Expand Down Expand Up @@ -160,6 +161,8 @@ class Normalize(Transform):

_v1_transform_cls = _transforms.Normalize

_transformed_types = Transform._transformed_types + (is_cvcuda_tensor,)

def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False):
super().__init__()
self.mean = list(mean)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torch import nn
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import tv_tensors
from torchvision.transforms.v2._utils import check_type, has_any, is_pure_tensor
from torchvision.transforms.v2._utils import check_type, has_any, is_cvcuda_tensor, is_pure_tensor
from torchvision.utils import _log_api_usage_once

from .functional._utils import _get_kernel
Expand All @@ -23,7 +23,7 @@ class Transform(nn.Module):

# Class attribute defining transformed types. Other types are passed-through without any transformation
# We support both Types and callables that are able to do further checks on the type of the input.
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image)
_transformed_types: tuple[type | Callable[[Any], bool], ...] = (torch.Tensor, PIL.Image.Image, is_cvcuda_tensor)

def __init__(self) -> None:
super().__init__()
Expand Down
5 changes: 3 additions & 2 deletions torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision._utils import sequence_to_str

from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor
from torchvision.transforms.v2.functional import get_dimensions, get_size, is_cvcuda_tensor, is_pure_tensor
from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT


Expand Down Expand Up @@ -182,7 +182,7 @@ def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:
chws = {
tuple(get_dimensions(inpt))
for inpt in flat_inputs
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video))
if check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video, is_cvcuda_tensor))
}
if not chws:
raise TypeError("No image or video was found in the sample")
Expand All @@ -207,6 +207,7 @@ def query_size(flat_inputs: list[Any]) -> tuple[int, int]:
tv_tensors.Mask,
tv_tensors.BoundingBoxes,
tv_tensors.KeyPoints,
is_cvcuda_tensor,
),
)
}
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torchvision.transforms import InterpolationMode # usort: skip

from ._utils import is_pure_tensor, register_kernel # usort: skip
from ._utils import is_pure_tensor, register_kernel, is_cvcuda_tensor # usort: skip

from ._meta import (
clamp_bounding_boxes,
Expand Down
49 changes: 47 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Optional
from typing import Optional, TYPE_CHECKING

import PIL.Image
import torch
Expand All @@ -13,7 +13,14 @@

from ._meta import _convert_bounding_box_format

from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
from ._utils import _get_kernel, _import_cvcuda, _is_cvcuda_available, _register_kernel_internal, is_pure_tensor

CVCUDA_AVAILABLE = _is_cvcuda_available()

if TYPE_CHECKING:
import cvcuda # type: ignore[import-not-found]
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda() # noqa: F811


def normalize(
Expand Down Expand Up @@ -72,6 +79,44 @@ def normalize_video(video: torch.Tensor, mean: list[float], std: list[float], in
return normalize_image(video, mean, std, inplace=inplace)


def _normalize_cvcuda(
image: "cvcuda.Tensor",
mean: list[float],
std: list[float],
inplace: bool = False,
) -> "cvcuda.Tensor":
cvcuda = _import_cvcuda()
if inplace:
raise ValueError("Inplace normalization is not supported for CVCUDA.")

# CV-CUDA supports signed int and float tensors
# torchvision only supports uint and float, right now CV-CUDA doesnt expose float16, so only check 32
# in the future add float16 once exposed in CV-CUDA
if not (image.dtype == cvcuda.Type.F32):
raise ValueError(f"Input tensor should be a float tensor. Got {image.dtype}.")

channels = image.shape[3]
if isinstance(mean, float | int):
mean = [mean] * channels
elif len(mean) != channels:
raise ValueError(f"Mean should have {channels} elements. Got {len(mean)}.")
if isinstance(std, float | int):
std = [std] * channels
elif len(std) != channels:
raise ValueError(f"Std should have {channels} elements. Got {len(std)}.")

mt = torch.as_tensor(mean, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
st = torch.as_tensor(std, dtype=torch.float32).reshape(1, 1, 1, channels).cuda()
mean_cv = cvcuda.as_tensor(mt, cvcuda.TensorLayout.NHWC)
std_cv = cvcuda.as_tensor(st, cvcuda.TensorLayout.NHWC)

return cvcuda.normalize(image, base=mean_cv, scale=std_cv, flags=cvcuda.NormalizeFlags.SCALE_IS_STDDEV)


if CVCUDA_AVAILABLE:
_register_kernel_internal(normalize, _import_cvcuda().Tensor)(_normalize_cvcuda)


def gaussian_blur(inpt: torch.Tensor, kernel_size: list[int], sigma: Optional[list[float]] = None) -> torch.Tensor:
"""See :class:`~torchvision.transforms.v2.GaussianBlur` for details."""
if torch.jit.is_scripting():
Expand Down
7 changes: 7 additions & 0 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,10 @@ def _is_cvcuda_available():
return True
except ImportError:
return False


def is_cvcuda_tensor(inpt: Any) -> bool:
if _is_cvcuda_available():
cvcuda = _import_cvcuda()
return isinstance(inpt, cvcuda.Tensor)
return False