diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 618c9e4fc..11a97da21 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -119,16 +119,17 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.mappings = infer_mapping_from_model(state.model) self.norm_mappings = infer_norm_mapping_from_model(state.model) + head_dim = self._infer_head_dim(state.model) config_groups = {} if SpinquantRotation.R1 in self.rotations: config_groups["R1"] = self._create_r1_scheme() if SpinquantRotation.R2 in self.rotations: - config_groups["R2"] = self._create_r2_scheme(state.model) + config_groups["R2"] = self._create_r2_scheme(head_dim) if SpinquantRotation.R3 in self.rotations: - config_groups["R3"] = self._create_r3_scheme() + config_groups["R3"] = self._create_r3_scheme(head_dim) if SpinquantRotation.R4 in self.rotations: config_groups["R4"] = self._create_r4_scheme() @@ -209,16 +210,7 @@ def _create_r1_scheme(self) -> TransformScheme: ], ) - def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: - config = model.config - - if hasattr(config, "head_dim"): - head_dim = config.head_dim - elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): - head_dim = config.hidden_size // config.num_attention_heads - else: - raise NotImplementedError() - + def _create_r2_scheme(self, head_dim: int) -> TransformScheme: return TransformScheme( type=self.transform_type, randomize=self.randomize, @@ -235,9 +227,23 @@ def _create_r2_scheme(self, model: PreTrainedModel) -> TransformScheme: ], ) - def _create_r3_scheme(self) -> TransformScheme: - raise NotImplementedError( - "SpinQuant R3 rotations will be added in a future release" + def _create_r3_scheme(self, head_dim: int) -> TransformScheme: + return TransformScheme( + type=self.transform_type, + randomize=self.randomize, + requires_grad=self.learnable, + precision=self.precision, + head_dim=head_dim, + apply=[ + TransformArgs( + targets=[self.mappings.attn], + location="q_attn", + ), + TransformArgs( + targets=[self.mappings.attn], + location="k_cache", + ), + ], ) def _create_r4_scheme(self) -> TransformScheme: @@ -258,3 +264,13 @@ def _create_r4_scheme(self) -> TransformScheme: ), ], ) + + def _infer_head_dim(self, model: PreTrainedModel) -> int: + config = model.config + + if hasattr(config, "head_dim"): + return config.head_dim + elif hasattr(config, "hidden_size") and hasattr(config, "num_attention_heads"): + return config.hidden_size // config.num_attention_heads + else: + raise NotImplementedError() diff --git a/src/llmcompressor/modifiers/transform/spinquant/mappings.py b/src/llmcompressor/modifiers/transform/spinquant/mappings.py index 514d1f109..2d2bb3cba 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/mappings.py +++ b/src/llmcompressor/modifiers/transform/spinquant/mappings.py @@ -29,6 +29,7 @@ class SpinQuantMapping(BaseModel): embedding: str + attn: str attn_q: str attn_k: str attn_v: str @@ -50,6 +51,7 @@ def cast_to_list(cls, value): _default_mappings = SpinQuantMapping( embedding="re:.*embed_tokens$", + attn="re:.*self_attn$", attn_q="re:.*q_proj$", attn_k="re:.*k_proj$", attn_v="re:.*v_proj$",