- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.6k
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
dd3181e
              bcae2f6
              439643b
              36edcb9
              b7862d9
              e8b51f7
              3383dec
              5824c66
              9cf5c7f
              996a154
              31994da
              8124b08
              3a4af33
              ec0009a
              50b9c85
              2483ba0
              c3be844
              9fcc4e7
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,260 @@ | ||
| from keras.src import layers | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import LayoutMap | ||
| from keras.src.distribution.tensor_parallel.tensor_layout import Split | ||
|  | ||
|  | ||
| def analyze_dense_layer(layer): | ||
| """Analyzes a Keras Dense layer to classify its sharding strategy. | ||
|  | ||
| This function inspects the input and output dimensions of a Dense layer | ||
| to determine if it functions as an expansion layer ("up-projection"), a | ||
| contraction layer ("down-projection"), or neither ("dense"). This | ||
| classification is a heuristic commonly used to apply tensor parallelism | ||
| in Transformer-based models, such as in an MLP block where an up-projection | ||
| is followed by a down-projection. | ||
|  | ||
| The classification is based on an `expansion_threshold` (set to 1.5). | ||
|  | ||
| Args: | ||
| layer: The Keras `layers.Dense` instance to analyze. | ||
|  | ||
| Returns: | ||
| str: A string classifying the layer as 'up_projection', | ||
| 'down_projection', or 'dense'. | ||
| """ | ||
|  | ||
| if not isinstance(layer, layers.Dense): | ||
| return "dense" | ||
|  | ||
| input_dim = None | ||
| output_dim = None | ||
|  | ||
| if hasattr(layer, "kernel") and layer.kernel is not None: | ||
| kernel_shape = layer.kernel.shape | ||
| if len(kernel_shape) == 2: | ||
| input_dim = kernel_shape[0] | ||
| output_dim = kernel_shape[1] | ||
|  | ||
| if input_dim is None or output_dim is None: | ||
| if hasattr(layer, "units"): | ||
| output_dim = layer.units | ||
| else: | ||
| return "dense" | ||
|  | ||
| if ( | ||
| hasattr(layer, "input_shape") | ||
| and layer.input_shape | ||
| and len(layer.input_shape) > 1 | ||
| ): | ||
| input_dim = layer.input_shape[-1] | ||
| else: | ||
| return "dense" | ||
|  | ||
| if not input_dim or not output_dim: | ||
| return "dense" | ||
|  | ||
| expansion_threshold = 1.5 | ||
| is_expansion = output_dim > input_dim * expansion_threshold | ||
| is_contraction = input_dim > output_dim * expansion_threshold | ||
|  | ||
| if is_expansion: | ||
| return "up_projection" | ||
| elif is_contraction: | ||
| return "down_projection" | ||
| else: | ||
| return "dense" | ||
|  | ||
|  | ||
| def _recursive_layer_traversal( | ||
| current_layer, | ||
| prefix, | ||
| device_count, | ||
| state_rules, | ||
| output_rules, | ||
| processed_layers, | ||
| ): | ||
| """Recursively traverses the model graph to apply sharding rules. | ||
|  | ||
| This function is necessary because Keras Model.layers property does not | ||
| recursively find all sub-layers in all architectures. It applies sharding | ||
| rules based on layer type and heuristic classification (e.g., up/down | ||
| projection). | ||
|  | ||
| - Split Logic: | ||
| - 'up_projection' (or general 'dense'): Column-wise sharding | ||
| (`Split(..., 1, "column")`) on kernel. Requires output to be | ||
| gathered (`gather`). | ||
| - 'down_projection' (or attention output): Row-wise sharding | ||
| (`Split(..., 0, "row")`) on kernel. Requires output to be | ||
| reduced (`allreduce`). | ||
| - Embedding: Column-wise sharding (`Split(..., 1, "column")`). | ||
|  | ||
| Args: | ||
| current_layer: The Keras layer instance currently being inspected. | ||
| prefix: The fully qualified name prefix for the current layer's scope. | ||
| device_count: The number of devices (replicas) in the parallelism group. | ||
| state_rules: A dictionary to accumulate variable sharding rules | ||
| (`LayoutMap.state_rules`). | ||
| output_rules: A dictionary to accumulate layer output communication | ||
| rules (`LayoutMap.output_rules`). | ||
| processed_layers: A set of layer IDs to prevent infinite recursion | ||
| in graph structures. | ||
| """ | ||
| if id(current_layer) in processed_layers: | ||
| return | ||
| processed_layers.add(id(current_layer)) | ||
|  | ||
| name = current_layer.name | ||
| full_name = f"{prefix}.{name}" if prefix else name | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because you will never really recurse, the prefix won't work. | ||
|  | ||
| if isinstance(current_layer, layers.Dense): | ||
| mlp_type = analyze_dense_layer(current_layer) | ||
|  | ||
| if mlp_type == "up_projection": | ||
| # Column-wise sharding for the first MLP layer | ||
| state_rules[f"{full_name}.kernel"] = Split( | ||
| device_count, 1, "column" | ||
| ) | ||
| if current_layer.use_bias: | ||
| state_rules[f"{full_name}.bias"] = Split( | ||
| device_count, 0, "column" | ||
| ) | ||
| # The result needs to be gathered back to a full tensor. | ||
| output_rules[f"{full_name}"] = {0: "gather"} | ||
|  | ||
| elif mlp_type == "down_projection": | ||
| # Row-wise sharding for the second MLP layer (down-projection) | ||
| state_rules[f"{full_name}.kernel"] = Split(device_count, 0, "row") | ||
| # Results from different devices needs to be summed (all-reduced). | ||
| output_rules[f"{full_name}"] = {0: "allreduce"} | ||
|  | ||
| else: | ||
| # Fallback for generic dense layers (treat as column-wise split) | ||
| state_rules[f"{full_name}.kernel"] = Split( | ||
| device_count, 1, "column" | ||
| ) | ||
| if current_layer.use_bias: | ||
| state_rules[f"{full_name}.bias"] = Split( | ||
| device_count, 0, "column" | ||
| ) | ||
| output_rules[f"{full_name}"] = {0: "gather -1"} | ||
|  | ||
| elif isinstance(current_layer, layers.EinsumDense): | ||
| if "attention_output" in full_name: | ||
| # Row-wise sharding for the attention output layer | ||
| state_rules[f"{full_name}.kernel"] = Split(device_count, 0, "row") | ||
| output_rules[f"{full_name}"] = {0: "allreduce"} | ||
| else: | ||
| # Column-wise sharding for key/query/value projections | ||
| state_rules[f"{full_name}.kernel"] = Split( | ||
| device_count, 1, "column" | ||
| ) | ||
| if ( | ||
| hasattr(current_layer, "bias") | ||
| and current_layer.bias is not None | ||
| ): | ||
| state_rules[f"{full_name}.bias"] = Split( | ||
| device_count, 0, "column" | ||
| ) | ||
| output_rules[f"{full_name}"] = {0: "gather -1"} | ||
|  | ||
| elif isinstance(current_layer, (layers.Embedding,)): | ||
| weight_name = None | ||
|  | ||
| if hasattr(current_layer, "embeddings"): | ||
| weight_name = "embeddings" | ||
| elif hasattr(current_layer, "position_embeddings"): | ||
| weight_name = "position_embeddings" | ||
|  | ||
| if weight_name: | ||
| # Column-wise sharding on the embedding dimension | ||
| state_rules[f"{full_name}.{weight_name}"] = Split( | ||
| device_count, 1, "column" | ||
| ) | ||
| # Output requires no communication | ||
| output_rules[f"{full_name}"] = {0: "no_comm"} | ||
|  | ||
| elif isinstance( | ||
| current_layer, | ||
| 
      Comment on lines
    
      +177
     to 
      +178
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What about other layer types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function is set up to only worry about the biggest layers in the model (Dense, Embedding, etc.). These are the only ones big enough to cause memory problems and need splitting (sharding). We skip the smaller layers for a few reasons: Normalization Layers (like LayerNormalization): Their weights are small. We leave them alone so we don't slow things down with extra communication. Layers with No Weights (like Dropout, Activation): They don't have anything to split. They just use the sharded data that comes from the layer before them. | ||
| ( | ||
| layers.LayerNormalization, | ||
| layers.BatchNormalization, | ||
| layers.GroupNormalization, | ||
| ), | ||
| ): | ||
| pass | ||
|  | ||
| if hasattr(current_layer, "layers") and current_layer.layers: | ||
| for sub_layer in current_layer.layers: | ||
|         
                  buildwithsuhana marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| _recursive_layer_traversal( | ||
| sub_layer, | ||
| full_name, | ||
| device_count, | ||
| state_rules, | ||
| output_rules, | ||
| processed_layers, | ||
| ) | ||
|  | ||
| for attr_name in dir(current_layer): | ||
| if attr_name.startswith("__") and attr_name.endswith("__"): | ||
| continue | ||
| if hasattr(current_layer, attr_name): | ||
| attr = getattr(current_layer, attr_name) | ||
|  | ||
| if isinstance(attr, layers.Layer) and attr is not current_layer: | ||
| _recursive_layer_traversal( | ||
| attr, | ||
| full_name, | ||
| device_count, | ||
| state_rules, | ||
| output_rules, | ||
| processed_layers, | ||
| ) | ||
| elif isinstance(attr, (list, tuple)): | ||
| for item in attr: | ||
| if isinstance(item, layers.Layer): | ||
| _recursive_layer_traversal( | ||
| item, | ||
| full_name, | ||
| device_count, | ||
| state_rules, | ||
| output_rules, | ||
| processed_layers, | ||
| ) | ||
|  | ||
|  | ||
| def get_default_config_keras(module, device_ids): | ||
| """Generates a default tensor parallelism sharding configuration. | ||
|  | ||
| This function leverages model-traversal and heuristic layer analysis to | ||
| automatically generate sharding rules (for state and layer outputs) | ||
| optimized for large-scale language models (Transformers). | ||
|  | ||
| Args: | ||
| module: The root Keras `Model` or `Layer` instance representing the | ||
| module to be sharded. | ||
| device_ids: A list of device identifiers (e.g., strings) that define | ||
| the parallelism group. The length of this list determines | ||
| the number of slices (`device_count`). | ||
|  | ||
| Returns: | ||
| LayoutMap: An object containing the generated `state_rules` (variable | ||
| sharding) and `output_rules` (layer communication). | ||
| """ | ||
|  | ||
| device_count = len(device_ids) | ||
| state_rules = {} | ||
| output_rules = {} | ||
|  | ||
| processed_layers = set() | ||
|  | ||
| _recursive_layer_traversal( | ||
| current_layer=module, | ||
| prefix="", | ||
| device_count=device_count, | ||
| state_rules=state_rules, | ||
| output_rules=output_rules, | ||
| processed_layers=processed_layers, | ||
| ) | ||
|  | ||
| return LayoutMap(state_rules=state_rules, output_rules=output_rules) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per my comment below about not needing a recursion, this is not needed