diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index e1bdd0bebb..9c592c26d0 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -1914,3 +1914,14 @@ def argmax(a: TensorProxy, /, dim: int | None = None, keepdim: bool | None = Fal @clangop() def argmin(a: TensorProxy, /, dim: int | None = None, keepdim: bool | None = False): return _argmaxmin_helper(prims.argmin, a, dim, keepdim) + + +@clangop() +def topk( + a: TensorLike, /, k: int, dim: int | None = None, largest: bool = True, sorted: bool = True, *, out=None +) -> (TensorProxy, TensorProxy): + if dim is None: + dim = a.ndim - 1 if a.ndim > 0 else 0 + dim = utils.canonicalize_dim(a.ndim, dim) + + return prims.topk(a, k, dim, bool(largest), bool(sorted), out=out) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index aa90fac81c..90ec53bba6 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -231,6 +231,7 @@ class PrimIDs(Enum): VAR_MEAN = auto() ARGMAX = auto() ARGMIN = auto() + TOPK = auto() # Scatter and gather prims (Experimental!) INDEX_ADD = auto() INDEX_PUT = auto() @@ -3027,6 +3028,33 @@ def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, scatter_add = make_prim(PrimIDs.SCATTER_ADD, "scatter_add", meta=scatter_add_meta) +def topk_meta( + a: TensorProxy, /, k: int, dim: int, largest: Number, sorted: Number, *, out: None | TensorProxy +) -> (TensorProxy, TensorProxy): + utils.check( + out is None, + lambda: "Only `out` which is None is currently supported", + ) + + utils.check_type(a, TensorProxy) + utils.check_type(k, int) + utils.check_type(dim, int) + utils.check(pytype(largest) is bool, lambda: f"Expected {largest=} to be a boolean value") + utils.check(pytype(sorted) is bool, lambda: f"Expected {sorted=} to be a boolean value") + + utils.check(k >= 0 and k <= (a.shape[dim] if a.ndim > 0 else 1), lambda: f"selected index {k=} is out of range") + + new_shape = a.shape + if a.ndim > 0: + new_shape = list(new_shape) + new_shape[dim] = k + + return TensorProxy(like=a, shape=new_shape), TensorProxy(like=a, shape=new_shape, dtype=dtypes.int64) + + +topk = make_prim(PrimIDs.TOPK, "topk", meta=topk_meta, tags=(OpTags.REDUCTION_OP,)) + + def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorProxy: utils.check_type(a, TensorProxy) utils.check_type(permutation, tuple) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 535fddd943..58d3f1d808 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1009,6 +1009,7 @@ def _tril_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | N var_mean = _register_torch_operation("var_mean") argmax = _register_torch_operation("argmax") argmin = _register_torch_operation("argmin") +topk = _register_torch_operation("topk") # NOTE The following transforms are necessary because thunder uses the parameter name 'dims' while PyTorch @@ -1053,6 +1054,17 @@ def _argmin_transform(a: TensorProxy, /, dim: int): return argmin(a, dim) +# NOTE This transform translates number proxies to boolean values +# and handles dim = None +def _topk_transform( + a: TensorProxy, /, k: int, dim: int | None = None, largest: Number = 1, sorted: Number = 1, *, out=None +): + if dim is None: + dim = a.ndim - 1 if a.ndim > 0 else 0 + + return topk(a, k, dim, bool(largest), bool(sorted), out=out) + + _register_implementation(prims.amax, checker=_always_executable, execution_transform=_amax_prim_transform) _register_implementation(prims.amin, checker=_always_executable, execution_transform=_amin_prim_transform) _register_implementation(prims.prod, checker=_always_executable, execution_transform=_prod_prim_transform) @@ -1061,6 +1073,7 @@ def _argmin_transform(a: TensorProxy, /, dim: int): _register_implementation(prims.var_mean, checker=_always_executable, execution_transform=_var_mean_prim_transform) _register_implementation(prims.argmax, checker=_always_executable, execution_transform=_argmax_transform) _register_implementation(prims.argmin, checker=_always_executable, execution_transform=_argmin_transform) +_register_implementation(prims.topk, checker=_always_executable, execution_transform=_topk_transform) _register_implementation(ltorch.amax, amax, checker=_always_executable) _register_implementation(ltorch.amin, amin, checker=_always_executable) @@ -1072,6 +1085,7 @@ def _argmin_transform(a: TensorProxy, /, dim: int): _register_implementation(ltorch.var_mean, var_mean, checker=_always_executable) _register_implementation(ltorch.argmax, argmax, checker=_always_executable) _register_implementation(ltorch.argmin, argmin, checker=_always_executable) +_register_implementation(ltorch.topk, topk, checker=_always_executable, execution_transform=_topk_transform) # # Scatter and gather operations diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 0d5d2cb0ab..10e0f8e68e 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -4697,6 +4697,51 @@ def argmin_argmax_error_generator(op, device, **kwargs): ) reduction_ops.append(argmin_opinfo) + +def topk_sample_generator(op, device, dtype, requires_grad, **kwargs): + make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + # shape, k, dim + # NOTE: k = 0 is not consistent between the CPU and the CUDA PyTorch implementations, + # unless shape[dim] == 0 + cases = ( + ((), 1), + ((), 1, 0), + ((3, 0), 0), + ((4, 4), 2, 1), + ((4, 1, 6), 3, -1), + ((4, 1, 6), 3), + ((4, 7, 5, 1), 2, -3), + ((4, 2, 5, 1), 1), + ) + + for shape, *args in cases: + for largest, sorted in itertools.product((True, False), repeat=2): + yield SampleInput(make(shape), *args, largest=largest, sorted=sorted) + + +def topk_error_generator(op, device, **kwargs): + make = partial(make_tensor, device=device, dtype=torch.float32) + + err_msg = r"selected index .* is out of range" + yield (SampleInput(make(3, 2), 3), RuntimeError, err_msg) + yield (SampleInput(make(3, 0), 1), RuntimeError, err_msg) + + err_msg = "Dimension out of range" + yield (SampleInput(make(3, 3), 1, 3), IndexError, err_msg) + yield (SampleInput(make(3, 3), 1, -3), IndexError, err_msg) + + +topk_opinfo = OpInfo( + clang.topk, + sample_input_generator=topk_sample_generator, + error_input_generator=topk_error_generator, + torch_reference=torch.topk, + dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating), +) +reduction_ops.append(topk_opinfo) + + opinfos.extend(reduction_ops) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 8999d8a795..d1d55073d7 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -613,7 +613,7 @@ def test_nanogpt(): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)), + pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), ), ) @pytest.mark.parametrize( @@ -662,7 +662,7 @@ def test_litgpt_variants(name, device): "falcon-7b-like", "falcon-40b-like", "codellama2-like", - pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)), + pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)), ), ) @pytest.mark.parametrize( diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index c93bced99c..a61832343a 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1797,6 +1797,13 @@ def argmin(a: TensorLike, /, dim: int | None = None, keepdim: bool | None = Fals return clang.argmin(a, dim, keepdim) +@torchsymbol(torch.topk, is_method=True) +def topk( + a: TensorLike, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None +) -> (TensorLike, TensorLike): + return clang.topk(a, k, dim, largest, sorted, out=out) + + # # Scatter and gather-related operations #