Skip to content

Commit

Permalink
topk: add as a primitive (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved authored Mar 27, 2024
1 parent 483c352 commit b433f6d
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 2 deletions.
11 changes: 11 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,3 +1914,14 @@ def argmax(a: TensorProxy, /, dim: int | None = None, keepdim: bool | None = Fal
@clangop()
def argmin(a: TensorProxy, /, dim: int | None = None, keepdim: bool | None = False):
return _argmaxmin_helper(prims.argmin, a, dim, keepdim)


@clangop()
def topk(
a: TensorLike, /, k: int, dim: int | None = None, largest: bool = True, sorted: bool = True, *, out=None
) -> (TensorProxy, TensorProxy):
if dim is None:
dim = a.ndim - 1 if a.ndim > 0 else 0
dim = utils.canonicalize_dim(a.ndim, dim)

return prims.topk(a, k, dim, bool(largest), bool(sorted), out=out)
28 changes: 28 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class PrimIDs(Enum):
VAR_MEAN = auto()
ARGMAX = auto()
ARGMIN = auto()
TOPK = auto()
# Scatter and gather prims (Experimental!)
INDEX_ADD = auto()
INDEX_PUT = auto()
Expand Down Expand Up @@ -3027,6 +3028,33 @@ def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy,
scatter_add = make_prim(PrimIDs.SCATTER_ADD, "scatter_add", meta=scatter_add_meta)


def topk_meta(
a: TensorProxy, /, k: int, dim: int, largest: Number, sorted: Number, *, out: None | TensorProxy
) -> (TensorProxy, TensorProxy):
utils.check(
out is None,
lambda: "Only `out` which is None is currently supported",
)

utils.check_type(a, TensorProxy)
utils.check_type(k, int)
utils.check_type(dim, int)
utils.check(pytype(largest) is bool, lambda: f"Expected {largest=} to be a boolean value")
utils.check(pytype(sorted) is bool, lambda: f"Expected {sorted=} to be a boolean value")

utils.check(k >= 0 and k <= (a.shape[dim] if a.ndim > 0 else 1), lambda: f"selected index {k=} is out of range")

new_shape = a.shape
if a.ndim > 0:
new_shape = list(new_shape)
new_shape[dim] = k

return TensorProxy(like=a, shape=new_shape), TensorProxy(like=a, shape=new_shape, dtype=dtypes.int64)


topk = make_prim(PrimIDs.TOPK, "topk", meta=topk_meta, tags=(OpTags.REDUCTION_OP,))


def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(permutation, tuple)
Expand Down
14 changes: 14 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,7 @@ def _tril_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | N
var_mean = _register_torch_operation("var_mean")
argmax = _register_torch_operation("argmax")
argmin = _register_torch_operation("argmin")
topk = _register_torch_operation("topk")


# NOTE The following transforms are necessary because thunder uses the parameter name 'dims' while PyTorch
Expand Down Expand Up @@ -1053,6 +1054,17 @@ def _argmin_transform(a: TensorProxy, /, dim: int):
return argmin(a, dim)


# NOTE This transform translates number proxies to boolean values
# and handles dim = None
def _topk_transform(
a: TensorProxy, /, k: int, dim: int | None = None, largest: Number = 1, sorted: Number = 1, *, out=None
):
if dim is None:
dim = a.ndim - 1 if a.ndim > 0 else 0

return topk(a, k, dim, bool(largest), bool(sorted), out=out)


_register_implementation(prims.amax, checker=_always_executable, execution_transform=_amax_prim_transform)
_register_implementation(prims.amin, checker=_always_executable, execution_transform=_amin_prim_transform)
_register_implementation(prims.prod, checker=_always_executable, execution_transform=_prod_prim_transform)
Expand All @@ -1061,6 +1073,7 @@ def _argmin_transform(a: TensorProxy, /, dim: int):
_register_implementation(prims.var_mean, checker=_always_executable, execution_transform=_var_mean_prim_transform)
_register_implementation(prims.argmax, checker=_always_executable, execution_transform=_argmax_transform)
_register_implementation(prims.argmin, checker=_always_executable, execution_transform=_argmin_transform)
_register_implementation(prims.topk, checker=_always_executable, execution_transform=_topk_transform)

_register_implementation(ltorch.amax, amax, checker=_always_executable)
_register_implementation(ltorch.amin, amin, checker=_always_executable)
Expand All @@ -1072,6 +1085,7 @@ def _argmin_transform(a: TensorProxy, /, dim: int):
_register_implementation(ltorch.var_mean, var_mean, checker=_always_executable)
_register_implementation(ltorch.argmax, argmax, checker=_always_executable)
_register_implementation(ltorch.argmin, argmin, checker=_always_executable)
_register_implementation(ltorch.topk, topk, checker=_always_executable, execution_transform=_topk_transform)

#
# Scatter and gather operations
Expand Down
45 changes: 45 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4697,6 +4697,51 @@ def argmin_argmax_error_generator(op, device, **kwargs):
)
reduction_ops.append(argmin_opinfo)


def topk_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

# shape, k, dim
# NOTE: k = 0 is not consistent between the CPU and the CUDA PyTorch implementations,
# unless shape[dim] == 0
cases = (
((), 1),
((), 1, 0),
((3, 0), 0),
((4, 4), 2, 1),
((4, 1, 6), 3, -1),
((4, 1, 6), 3),
((4, 7, 5, 1), 2, -3),
((4, 2, 5, 1), 1),
)

for shape, *args in cases:
for largest, sorted in itertools.product((True, False), repeat=2):
yield SampleInput(make(shape), *args, largest=largest, sorted=sorted)


def topk_error_generator(op, device, **kwargs):
make = partial(make_tensor, device=device, dtype=torch.float32)

err_msg = r"selected index .* is out of range"
yield (SampleInput(make(3, 2), 3), RuntimeError, err_msg)
yield (SampleInput(make(3, 0), 1), RuntimeError, err_msg)

err_msg = "Dimension out of range"
yield (SampleInput(make(3, 3), 1, 3), IndexError, err_msg)
yield (SampleInput(make(3, 3), 1, -3), IndexError, err_msg)


topk_opinfo = OpInfo(
clang.topk,
sample_input_generator=topk_sample_generator,
error_input_generator=topk_error_generator,
torch_reference=torch.topk,
dtypes=(datatypes.signedinteger, datatypes.unsignedinteger, datatypes.floating),
)
reduction_ops.append(topk_opinfo)


opinfos.extend(reduction_ops)


Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def test_nanogpt():
"falcon-7b-like",
"falcon-40b-like",
"codellama2-like",
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)),
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
),
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -662,7 +662,7 @@ def test_litgpt_variants(name, device):
"falcon-7b-like",
"falcon-40b-like",
"codellama2-like",
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=NotImplementedError, reason="topk", strict=True)),
pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
),
)
@pytest.mark.parametrize(
Expand Down
7 changes: 7 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,13 @@ def argmin(a: TensorLike, /, dim: int | None = None, keepdim: bool | None = Fals
return clang.argmin(a, dim, keepdim)


@torchsymbol(torch.topk, is_method=True)
def topk(
a: TensorLike, /, k: int, dim: None | int = None, largest: bool = True, sorted: bool = True, *, out=None
) -> (TensorLike, TensorLike):
return clang.topk(a, k, dim, largest, sorted, out=out)


#
# Scatter and gather-related operations
#
Expand Down

0 comments on commit b433f6d

Please sign in to comment.