diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index fb2bfb2993..f67cb9d0af 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,12 +16,6 @@ import torch._inductor.config import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.elastic.utils.distributed import get_free_port - -from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama - from torchchat.model import Model, ModelArgs, ModelType from torchchat.model_config.model_config import resolve_model_config @@ -464,77 +458,11 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: return model -def _maybe_init_distributed( - builder_args: BuilderArgs, -) -> Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - """ - Initialize distributed related setups if the user specified - using distributed inference. If not, this is a no-op. - - Args: - builder_args (:class:`BuilderArgs`): - Command args for model building. - Returns: - Tuple[Optional[DeviceMesh], Optional[ParallelDims]]: - - The first element is an optional DeviceMesh object, - which which describes the mesh topology of devices for the DTensor. - - The second element is an optional ParallelDims object, - which represents the parallel dimensions configuration. - """ - if not builder_args.use_distributed: - return None, None - dist_config = "llama3_8B.toml" # TODO - integrate with chat cmd line - - world_mesh, parallel_dims = launch_distributed(dist_config) - - assert ( - world_mesh is not None and parallel_dims is not None - ), f"failed to launch distributed using {dist_config}" - - return world_mesh, parallel_dims - - -def _maybe_parallelize_model( - model: nn.Module, - builder_args: BuilderArgs, - world_mesh: DeviceMesh, - parallel_dims: ParallelDims, -) -> nn.Module: - """ - We parallelize the module and load the distributed checkpoint to the model - if the user specifies using distributed inference. If not, this is a no-op. - - Args: - model (:class:`nn.Module`): - Module to be parallelized. - builder_args (:class:`BuilderArgs`): - Command args for model building. - world_mesh (:class:`DeviceMesh`): - Object which describes the mesh topology - of devices for the DTensor. - parallel_dims (:class:`ParallelDims`): - Object which represents the parallel dimensions configuration. - Returns: - A :class:`nn.Module` object which is parallelized and checkpoint loaded - if the user specifies using distributed inference. - """ - if world_mesh is None: - return model - assert parallel_dims is not None - print("Applying model parallel to model ...") - parallelize_llama(model, world_mesh, parallel_dims) - return load_checkpoints_to_model(model, builder_args, world_mesh) - - def _load_model(builder_args: BuilderArgs) -> Model: - # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) if builder_args.gguf_path: model = _load_model_gguf(builder_args) - # elif builder_args.use_distributed: - # model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) if builder_args.dso_path or builder_args.aoti_package_path: # AOTI-compoiled model will load its own weights. @@ -706,4 +634,4 @@ def tokenizer_setting_to_name(tiktoken: bool, tokenizers: bool) -> str: return "TikToken" if tokenizers: return "Tokenizers" - return "SentencePiece" \ No newline at end of file + return "SentencePiece" diff --git a/torchchat/generate.py b/torchchat/generate.py index dd423b58a1..fcae18d87a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -915,13 +915,6 @@ def chat( ] ) if generator_args.compile: - if ( - self.is_speculative and self.builder_args.use_distributed - ): # and ("cuda" in builder_args.device): - torch._inductor.config.triton.cudagraph_trees = ( - False # Bug with cudagraph trees in this case - ) - if self.builder_args.device == "cpu": if generator_args.max_autotune: kwargs = {"mode": "max-autotune"} @@ -1091,9 +1084,7 @@ def callback(x, *, done_generating=False): torch._inductor.config.profiler_mark_wrapper_call = True torch._inductor.config.cpp.enable_kernel_profile = True - if (i != generator_args.num_samples - 1 or not self.profile) or ( - self.builder_args.use_distributed and self.rank != 0 - ): + if i != generator_args.num_samples - 1 or not self.profile: import contextlib prof = contextlib.nullcontext() @@ -1136,10 +1127,7 @@ def callback(x, *, done_generating=False): print(prof.key_averages().table(sort_by="self_cpu_time_total")) else: print(prof.key_averages().table(sort_by="self_cuda_time_total")) - if self.builder_args.use_distributed: - prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json") - else: - prof.export_chrome_trace(f"{self.profile}.json") + prof.export_chrome_trace(f"{self.profile}.json") if start_pos >= max_seq_length: print(