Skip to content

Commit 64f51a5

Browse files
committed
Merge remote-tracking branch 'origin/main' into upstream_merge_24_10_28
2 parents ab3f100 + 5974cc3 commit 64f51a5

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

vllm/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
get_hf_image_processor_config,
1818
get_hf_text_config)
1919
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
20-
print_warning_once)
20+
is_mi250, print_warning_once)
2121

2222
if TYPE_CHECKING:
2323
from ray.util.placement_group import PlacementGroup
@@ -953,6 +953,12 @@ def __init__(
953953
self._verify_args()
954954
self.rank: int = 0
955955

956+
if is_mi250() and self.tensor_parallel_size > 1:
957+
self.disable_custom_all_reduce = True
958+
logger.info(
959+
"Disabled the custom all-reduce kernel because it is not "
960+
"working correctly on multi AMD MI250.")
961+
956962
@property
957963
def use_ray(self) -> bool:
958964
return self.distributed_executor_backend == "ray" or (

vllm/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,15 @@ def reset(self):
421421
self._index = 0
422422

423423

424+
@lru_cache(maxsize=None)
425+
def is_mi250() -> bool:
426+
if not is_hip() or not torch.cuda.is_available():
427+
return False
428+
archName = torch.cuda.get_device_properties('cuda').gcnArchName
429+
return (archName is not None) and \
430+
("gfx90a" in archName)
431+
432+
424433
@lru_cache(maxsize=None)
425434
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
426435
"""Returns the maximum shared memory per thread block in bytes."""

0 commit comments

Comments
 (0)