-
Notifications
You must be signed in to change notification settings - Fork 24
[MAMBA-2, GDN] fix initialization and make a separate class for decay gate logic #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
86db0ea
add GDN efficient init
mayank31398 db411cf
fix mamba2 init
mayank31398 06ce95f
pass args
mayank31398 3a21404
hidden_states -> x
mayank31398 b6dd81b
hidden_states -> x
mayank31398 8dc2c79
hidden_states -> x
mayank31398 3bec1f0
hidden_states -> x
mayank31398 52b14a3
use gate for GDN
mayank31398 b062d86
use gate for mamba2
mayank31398 87e7a44
use gate for mamba2
mayank31398 d1050e5
fix tests
mayank31398 750f9c0
fix tests
mayank31398 8e5903e
fix tests
mayank31398 cc76e24
fix tests
mayank31398 60e092a
fix tests
mayank31398 0685373
merge
mayank31398 740a636
merge
mayank31398 43e099c
merge
mayank31398 d1fd4fb
merge
mayank31398 913e08d
merge
mayank31398 037185f
merge
mayank31398 f031286
merge
mayank31398 8eb840a
drop post_init()
mayank31398 6271d83
drop buffer
mayank31398 84286a0
fix init check
mayank31398 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # ************************************************** | ||
| # Copyright (c) 2025, Mayank Mishra | ||
| # ************************************************** | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import math | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.distributed._tensor.api import DTensor | ||
| from torch.distributed._tensor.placement_types import Replicate | ||
|
|
||
| from ...dtensors import tensor_to_dtensor | ||
| from ..parameter import ( | ||
| mark_parameter_as_initialized, | ||
| mark_parameter_as_mup_learning_rate, | ||
| mark_parameter_as_no_weight_decay, | ||
| ) | ||
| from .linear import ParameterizedLinear | ||
|
|
||
|
|
||
| class SoftplusDecayGate(nn.Module): | ||
| def __init__( | ||
| self, | ||
| hidden_size: int | None, | ||
| output_size: int, | ||
| std: float | None, | ||
| has_projection: bool = False, | ||
| A_init_min: float = 0, | ||
| A_init_max: float = 16, | ||
| dt_init_min: float = 1e-3, | ||
| dt_init_max: float = 0.1, | ||
| dt_init_floor: float = 1e-4, | ||
| ) -> SoftplusDecayGate: | ||
| super().__init__() | ||
|
|
||
| self.output_size = output_size | ||
| self.has_projection = has_projection | ||
|
|
||
| if has_projection: | ||
| self.proj = ParameterizedLinear(hidden_size, self.output_size, std=std) | ||
| mark_parameter_as_mup_learning_rate(self.proj.weight) | ||
| else: | ||
| assert hidden_size is None | ||
|
|
||
| self.A_log = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) | ||
| mark_parameter_as_no_weight_decay(self.A_log) | ||
|
|
||
| self.dt_bias = nn.Parameter(torch.empty(self.output_size, dtype=torch.float32)) | ||
| mark_parameter_as_no_weight_decay(self.dt_bias) | ||
|
|
||
| assert A_init_min >= 0 | ||
| assert A_init_max >= A_init_min | ||
|
|
||
| self.A_init_min = A_init_min | ||
| self.A_init_max = A_init_max | ||
|
|
||
| assert dt_init_min > 0 | ||
| assert dt_init_max >= dt_init_min | ||
|
|
||
| self.dt_init_min = dt_init_min | ||
| self.dt_init_max = dt_init_max | ||
| self.dt_init_floor = dt_init_floor | ||
|
|
||
| self.reset_parameters() | ||
|
|
||
| def forward( | ||
| self, x: torch.Tensor, final_exponential: bool, output_dtype: torch.dtype = torch.float32 | ||
| ) -> torch.Tensor: | ||
| if self.has_projection: | ||
| x = self.proj(x) | ||
|
|
||
| x = x.float() | ||
| x = x + self.dt_bias | ||
| x = F.softplus(x) | ||
| x = -self.A_log.float().exp() * x | ||
|
|
||
| if final_exponential: | ||
| x = torch.exp(x) | ||
|
|
||
| x = x.to(output_dtype) | ||
|
|
||
| return x | ||
|
|
||
| @torch.no_grad() | ||
| def reset_parameters(self) -> None: | ||
| A = torch.empty(self.output_size, dtype=torch.float32).uniform_(self.A_init_min, self.A_init_max) | ||
|
|
||
| if isinstance(self.A_log, DTensor): | ||
| A = tensor_to_dtensor( | ||
| tensor=A, | ||
| device_mesh=self.A_log.device_mesh, | ||
| current_placement=[Replicate()] * len(self.A_log.placements), | ||
| desired_placement=self.A_log.placements, | ||
| ) | ||
|
|
||
| self.A_log.copy_(torch.log(A)) | ||
|
|
||
| dt = torch.exp( | ||
| torch.rand(self.output_size) * (math.log(self.dt_init_max) - math.log(self.dt_init_min)) | ||
| + math.log(self.dt_init_min) | ||
| ) | ||
| dt = torch.clamp(dt, min=self.dt_init_floor) | ||
| # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 | ||
| inv_dt = dt + torch.log(-torch.expm1(-dt)) | ||
|
|
||
| if isinstance(self.dt_bias, DTensor): | ||
| inv_dt = tensor_to_dtensor( | ||
| tensor=inv_dt, | ||
| device_mesh=self.dt_bias.device_mesh, | ||
| current_placement=[Replicate()] * len(self.dt_bias.placements), | ||
| desired_placement=self.dt_bias.placements, | ||
| ) | ||
|
|
||
| self.dt_bias.copy_(inv_dt) | ||
|
|
||
| mark_parameter_as_initialized(self.A_log) | ||
| mark_parameter_as_initialized(self.dt_bias) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.