Skip to content

Commit

Permalink
fix mx4 illegal memory access (#3509)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3509

X-link: facebookresearch/FBGEMM#593

when calaculting num_thread and group_per_thread to distribute work, rounding gets accumulated and effectively expand the input space.

for example (the new UT), when input tensor is (1, 2^31 - 8),
```
a.numel: 2147483640
num_threads: 46341
groups_per_thread: 1449
num_groups: 67108864
num_threads * groups_per_threads= 67148109 > num_groups
```

in kernel, when we try to access memory, input_start = num_threads * groups_per_threads * pid, so when pid is large, we end up visiting data outside the input

Reviewed By: jwfromm

Differential Revision: D67369392

fbshipit-source-id: 62c28fe3a94911a10921e233ff5ae42097e9dbb4
  • Loading branch information
Jingyuan Fan authored and facebook-github-bot committed Dec 19, 2024
1 parent 804a499 commit cc1bad1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
4 changes: 3 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _kernel_quantize_mx4(
# When theres no padding we can simplify indexing.
else:
padded_input_offset = input_offset

# Load a block of values.
a = tl.load(
A + padded_input_offset,
Expand Down Expand Up @@ -434,7 +435,8 @@ def triton_quantize_mx4(
rand_bits = None

# Check if we need to use int64 for indexing.
use_int64 = a.numel() > 2**31 - 1
use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1

# Invoke triton quantization kernel over rows.
grid = (num_threads,)
_kernel_quantize_mx4[grid](
Expand Down
15 changes: 15 additions & 0 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,21 @@ def test_mx4_index_overflow(self) -> None:
# We just need to check that everything ran without an illegal memory access.
assert mx_dequantized[0] == 0

# pyre-fixme[56]:
@unittest.skipIf(
not (
torch.cuda.is_available() and torch.cuda.mem_get_info()[0] / (1024**3) >= 32
),
"Test requires a gpu with at least 32GB of memory.",
)
def test_mx4_index_overflow_large_input(self) -> None:
"""Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices."""
large_input = torch.zeros((1, 2**31 - 2**3), dtype=torch.float32).to("cuda")
mx_quantized = fp32_to_mx4(large_input, 32)
mx_dequantized = mx4_to_fp32(mx_quantized, 32)
# We just need to check that everything ran without an illegal memory access.
assert mx_dequantized[0][0] == 0


if __name__ == "__main__":
unittest.main()

0 comments on commit cc1bad1

Please sign in to comment.