Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix/randomizer]Fix Randomizer error on CPU #5265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions colossalai/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,32 @@ def get_current_device() -> torch.device:


def _dispatch_device_func(fn_name: str, *args, **kwargs):
if torch.cuda.is_available():
if "device" in kwargs: # if device is specified, try to use the provided one
device = kwargs["device"]
del kwargs["device"]
if 'cuda' in device and torch.cuda.is_available():
device = "cuda"
elif 'npu' in device and IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"
else: # if device is not specified, device will be automatically detected
if torch.cuda.is_available():
device = "cuda"
elif IS_NPU_AVAILABLE:
device = "npu"
else:
device = "cpu"

if device == "cuda":
return getattr(torch.cuda, fn_name)(*args, **kwargs)
elif IS_NPU_AVAILABLE:
elif device == "npu":
return getattr(torch.npu, fn_name)(*args, **kwargs)
else:
raise RuntimeError("No device available")
else:
try:
return getattr(torch, fn_name)(*args, **kwargs)
except AttributeError:
raise RuntimeError(f"Current device does not support the function: {fn_name}")


# device semantics
Expand Down Expand Up @@ -114,15 +134,25 @@ def utilization(device=None) -> int:


def get_rng_state(device="cuda") -> torch.Tensor:
return _dispatch_device_func("get_rng_state", device)
if torch.cuda.is_available() and device=="cuda":
return _dispatch_device_func("get_rng_state", device="cuda")
elif IS_NPU_AVAILABLE and device=="npu":
return _dispatch_device_func("get_rng_state", device="npu")
else:
return _dispatch_device_func("get_rng_state", device="cpu")


def get_rng_state_all() -> List[torch.Tensor]:
return _dispatch_device_func("get_rng_state_all")


def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None:
return _dispatch_device_func("set_rng_state", new_state, device)
if torch.cuda.is_available() and device=="cuda":
return _dispatch_device_func("set_rng_state", new_state, device="cuda")
elif IS_NPU_AVAILABLE and device=="npu":
return _dispatch_device_func("set_rng_state", new_state, device="npu")
else:
return _dispatch_device_func("set_rng_state", new_state, device="cpu")


def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None:
Expand Down
Loading