Skip to content

Commit

Permalink
Merge branch 'master' into minor_fix_version2
Browse files Browse the repository at this point in the history
  • Loading branch information
tjruwase authored Feb 19, 2025
2 parents 7d66c86 + 33dd2e2 commit 1290a5b
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/nv-ds-chat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Install deepspeed
run: |
pip install transformers
pip install transformers==4.48.3
pip install .[dev]
ds_report
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,9 @@ def zero_quantized_gradients(self):
def zeropp_loco_param(self):
return self._config.zero_config.zeropp_loco_param

def zero_log_trace_cache_warnings(self):
return self._config.zero_config.log_trace_cache_warnings

def dump_state(self):
return self._config.dump_state

Expand Down Expand Up @@ -1692,6 +1695,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_quantized_weights=self.zero_quantized_weights(),
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)
else:
log_dist(
Expand Down Expand Up @@ -1740,6 +1744,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights(),
zero_module_granularity_threshold=self.zero_module_granularity_threshold(),
zeropp_loco_param=self.zeropp_loco_param(),
log_trace_cache_warnings=self.zero_log_trace_cache_warnings(),
)

else:
Expand Down
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"memory_efficient_linear": [true|false],
"override_module_apply": [true|false],
"zeropp_loco_param": {...},
"log_trace_cache_warnings" : [true|false],
}
}
"""
Expand Down Expand Up @@ -340,6 +341,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Override nn.Module apply function, for Stage 3.
"""

log_trace_cache_warnings: bool = False
"""
Whether to log warnings from trace cache, such as invalidation events.
"""

# Validators
@model_validator(mode="after")
def overlap_comm_valid(self):
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
log_trace_cache_warnings=False,
):

see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)
Expand All @@ -118,6 +119,7 @@ def __init__(
self.zero_param_parallel_group = zero_param_parallel_group
self.zero_quantized_weights = zero_quantized_weights
self.zero_quantized_nontrainable_weights = zero_quantized_nontrainable_weights
self.log_trace_cache_warnings = log_trace_cache_warnings

if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
self.offload_device = offload_param_config.device
Expand Down Expand Up @@ -165,7 +167,9 @@ def __init__(
timers=self.timers,
zero_quantized_weights=self.zero_quantized_weights,
zero_quantized_nontrainable_weights=self.zero_quantized_nontrainable_weights,
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module)
fast_sharding_for_leaf_module=self.fast_sharding_for_leaf_module,
log_trace_cache_warnings=self.log_trace_cache_warnings,
)

self.forward_hooks = []
self.backward_hooks = []
Expand Down
32 changes: 19 additions & 13 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,20 @@ class __ParamInTrace:
param: Parameter
step_id_last_used_at: int

def __init__(self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False) -> None:
def __init__(
self,
prefetch_bucket_sz: int,
max_reuse_distance_in_numel: int,
max_available_parameters_in_numel: int,
allgather_stream: get_accelerator().Stream,
inflight_param_registry: InflightParamRegistry,
prefetch_nvme: bool = False,
timers=None,
zero_quantized_weights=False,
zero_quantized_nontrainable_weights=False,
fast_sharding_for_leaf_module=False,
log_trace_cache_warnings=False,
) -> None:
# mapping of param -> handle for each param that is currently in flight
self.__inflight_param_registry = inflight_param_registry
# keeps track of the number of submodules invoked so far.
Expand Down Expand Up @@ -129,6 +132,9 @@ def __init__(self,
self.__max_ongoing_fetch_events: int = 2
self.__profiler = PartitionedParameterProfiler(timers if ENABLE_PROFILER else None)

# Whether to log trace cache warnings, e.g. invalidation events
self.__log_trace_cache_warnings = log_trace_cache_warnings

# whether to enable fast fetch for the z3 leaf module.
# this will improve fetch speed but will not break down leaf module parameters to alleviate memory pressure.
self.fast_sharding_for_leaf_module = fast_sharding_for_leaf_module
Expand Down Expand Up @@ -177,7 +183,7 @@ def trace_prologue(self, sub_module: Module) -> None:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id} and module {sub_module.ds_id}: "
f"cache has only {len(self.__submodule_order)} modules",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()
return

Expand All @@ -186,7 +192,7 @@ def trace_prologue(self, sub_module: Module) -> None:
print_rank_0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.ds_id}",
force=True)
force=self.__log_trace_cache_warnings)
self._invalidate_trace()

@compiler.disable
Expand Down
9 changes: 7 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
zero_quantized_nontrainable_weights=False,
zero_module_granularity_threshold=0,
zeropp_loco_param=None,
log_trace_cache_warnings=False,
):
see_memory_usage("Stage 3 initialize beginning", force=True)

Expand Down Expand Up @@ -231,7 +232,9 @@ def __init__(
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings,
)

self.persistent_parameters = self.parameter_offload.persistent_parameters
self._configure_offloading(offload_optimizer_config, offload_param_config)
Expand Down Expand Up @@ -465,6 +468,7 @@ def initialize_ds_offload(
zero_quantized_weights,
zero_quantized_nontrainable_weights,
zero_module_granularity_threshold,
log_trace_cache_warnings,
):
return DeepSpeedZeRoOffload(module=module,
timers=timers,
Expand All @@ -481,7 +485,8 @@ def initialize_ds_offload(
zero_param_parallel_group=zero_param_parallel_group,
zero_quantized_weights=zero_quantized_weights,
zero_quantized_nontrainable_weights=zero_quantized_nontrainable_weights,
zero_module_granularity_threshold=zero_module_granularity_threshold)
zero_module_granularity_threshold=zero_module_granularity_threshold,
log_trace_cache_warnings=log_trace_cache_warnings)

def _get_trainable_parameter_groups(self):
param_groups = []
Expand Down
17 changes: 12 additions & 5 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ Enabling and configuring ZeRO memory optimizations
"sub_group_size" : 1e12,
"elastic_checkpoint" : [true|false],
"stage3_gather_16bit_weights_on_model_save": [true|false],
"ignore_unused_parameters": [true|false]
"round_robin_gradients": [true|false]
"zero_hpz_partition_size": 1
"zero_quantized_weights": [true|false]
"zero_quantized_gradients": [true|false]
"ignore_unused_parameters": [true|false],
"round_robin_gradients": [true|false],
"zero_hpz_partition_size": 1,
"zero_quantized_weights": [true|false],
"zero_quantized_gradients": [true|false],
"log_trace_cache_warnings": [true|false],
}
```

Expand Down Expand Up @@ -512,6 +513,12 @@ Enabling and configuring ZeRO memory optimizations
| ----------------------------------------------------------------------------------------------------------------------------------- | ------- |
|Boolean indicating whether to enable communication efficient quantized gradients of ZeRO++. | `False` |

<i>**log_trace_cache_warnings**</i>: [boolean]

| Description | Default |
| ------------------------------------------------------------------------------------------------------------------- | ------- |
| Log warnings from trace cache optimization of parameter sharding, such as cache invalidation events. | `False` |

***cpu_offload***: [boolean]

**Deprecated:** **cpu_offload** is deprecated and will be removed in future, please use `offload_optimizer` instead.
Expand Down
17 changes: 12 additions & 5 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,20 @@ def installed_cuda_version(name=""):

def get_default_compute_capabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
# Update compute capability according to: https://en.wikipedia.org/wiki/CUDA#GPUs_supported
import torch.utils.cpp_extension
if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version()[0] >= 11:
if installed_cuda_version()[0] == 11 and installed_cuda_version()[1] == 0:
# Special treatment of CUDA 11.0 because compute_86 is not supported.
compute_caps += ";8.0"
else:
if torch.utils.cpp_extension.CUDA_HOME is not None:
if installed_cuda_version()[0] == 11:
if installed_cuda_version()[1] >= 0:
compute_caps += ";8.0"
if installed_cuda_version()[1] >= 1:
compute_caps += ";8.6"
if installed_cuda_version()[1] >= 8:
compute_caps += ";9.0"
elif installed_cuda_version()[0] == 12:
compute_caps += ";8.0;8.6;9.0"
if installed_cuda_version()[1] >= 8:
compute_caps += ";10.0;12.0"
return compute_caps


Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def get_env_if_set(key, default: typing.Any = ""):
cupy = None
if is_rocm_pytorch:
rocm_major, rocm_minor = rocm_version
# XXX cupy support for rocm 5 is not available yet.
if rocm_major <= 4:
# cupy support for rocm>5.0 is not available yet.
if (rocm_major == 5 and rocm_minor == 0) or rocm_major <= 4:
cupy = f"cupy-rocm-{rocm_major}-{rocm_minor}"
else:
cuda_major_ver, cuda_minor_ver = installed_cuda_version()
Expand Down

0 comments on commit 1290a5b

Please sign in to comment.