-
Notifications
You must be signed in to change notification settings - Fork 0
Description
RMS NORM
The RMS normalization (Root Mean Square normalization) is a normalization technique used in neural networks. It normalizes the input by its root mean square. Here's an example of how you can implement RMS normalization in PyTorch:
RMS NORM CODE
import torch
from torch import nn
from pragna.components.module import Module, set_config_and_validate
from pragna.logger.logging import logger
class RMSNorm(Module):
@set_config_and_validate
def __init__(self, d_model, eps=1e-6, weight_bias_scalar=0, dtype=None, device=None, **kwargs):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(
torch.ones(d_model, dtype=dtype, device=device)
)
def forward(self, x):
input_dtype = x.dtype
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return ((self.weight.float() + self.config.weight_bias_scalar) * x).to(input_dtype)
def _config_validation(self) -> bool:
return True
def extra_repr(self) -> str:
return f'd_model={self.config.d_model}'
def _init_params_fn(self):
self.weight.data = torch.ones(self.config.d_model, dtype=self.config.dtype, device=self.weight.device)
Explanation
-
Initialization (
__init__method):dim: The dimension of the input tensor.eps: A small epsilon value to avoid division by zero.scale: A learnable parameter that scales the normalized input.
-
Forward pass (
forwardmethod):- Compute the root mean square (
rms) of the input tensor along the last dimension. - Normalize the input tensor by dividing it by
rms. - Scale the normalized input by the learnable parameter
scale.
- Compute the root mean square (
Tensor Passed Through the following components
Running main.py...
Log-Mel spectrogram shape: torch.Size([80, 279])
Log-Mel spectrogram:
tensor([[-0.6515, -0.6505, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
...,
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515]])
Original mel spectrogram shape: torch.Size([80, 279])
Conv1d(80, 256, kernel_size=(3,), stride=(1,))
Conv1d(256, 256, kernel_size=(3,), stride=(1,))
Sinusoids()
LayerNorm()
Error:
Traceback (most recent call last):
File "/home/rit/pragna/pragna/components/module.py", line 28, in getitem
return self.c[item]
~~~~~~^^^^^^
KeyError: 'd_model'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/rit/pragna/scripts/factory.py", line 4, in
from pragna.factory.model.model_factory import ModelFactory
File "/home/rit/pragna/pragna/factory/model/model_factory.py", line 5, in
from pragna.components.module import Module
File "/home/rit/pragna/pragna/components/init.py", line 5, in
import pragna.components.layer.register
File "/home/rit/pragna/pragna/components/layer/register.py", line 3, in
from pragna.components.layer.encoder import Encoder
File "/home/rit/pragna/pragna/components/layer/encoder.py", line 116, in
encoder_list = [EncoderDecoderBlock(input_layernorm=RMSNorm(256, eps=1e-6, weight_bias_scalar=0),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 148, in init
assert self.validate_config(), 'config validation failed'
^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 69, in validate_config
return self._config_validation()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/layer/encoder_decoder_block.py", line 53, in _config_validation
return self.config.d_model % self.config.n_heads == 0
^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 24, in getattr
return self.getitem(item)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 33, in getitem
raise Exception(f'{item} config not found')
Exception: d_model config not found
Files To refer:
encoder.py
`
import torch
import torch.nn as nn
import torch.nn.functional as F
from pragna.components.module import Module, set_config_and_validate
from pragna.logger.logging import logger
class Encoder(Module):
@set_config_and_validate
# def init(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def init(self, conv1, conv2, encoder_list, positional_embd, post_layernorm, device=None, **kwargs):
super().init()
self.conv1 = conv1
self.conv2 = conv2
self.encoder_list = encoder_list
# self.positional_embd = positional_embd
# self.activation = activation
self.post_layernorm = post_layernorm
self.positional_embd = positional_embd
# self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
# self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
# self.register_buffer("positional_embedding", sinusoids(n_ctx // 2, n_state))
# self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
# [ResidualAttentionBlock(n_state=n_state, n_head=n_head, d_model=n_state) for _ in range(n_layer)]
# )
# self.ln_post = LayerNorm(n_state)
print("Executing encoder.py...")
def forward(self, x: torch.Tensor) -> torch.Tensor:
print("Executing AudioEncoder forward pass... Input shape:", x.shape)
x = F.gelu(self.conv1(x))
print("After conv1 shape:", x.shape)
x = F.gelu(self.conv2(x))
print("After conv2 shape:", x.shape)
# x = x.permute(0, 2, 1) # Shape: (batch_size, n_ctx // 2, n_state)
# print("After permutation shape:", x.shape)
# # Ensure positional embedding matches the shape
# x = self.positional_embd(x).to(x.dtype)
# print("After adding positional embedding shape:", x.shape)
for encoder in self.encoder_list:
x = encoder(x)
print("After block output shape:", x.shape)
x = self.post_layernorm(x)
print("Final output shape from AudioEncoder:", x.shape)
return x
def _config_validation(self) -> bool:
return True
`
from pragna.components.common.AudioProcessor import AudioProcessor
from pragna.components.convolution.conv1d import Conv1d
from pragna.components.layer.encoder_decoder_block import EncoderDecoderBlock
from pragna.components.common.sinusoids import Sinusoids
from pragna.components.layernorm.layernorm import LayerNorm
from pragna.components.layernorm.rmsnorm import RMSNorm
from pragna.components.mlp.ffn import ExpansionFFN
from pragna.components.attention.multi_head_attention import MultiHeadAttention
Example usage
print("Running main.py...")
Initialize AudioProcessor and load/process audio
processor = AudioProcessor()
audio = processor.load_audio("/home/rit/pragna/pragna/components/common/h1.wav")
audio = processor.pad_or_trim(audio)
Compute mel spectrogram
mel_spectrogram = processor.log_mel_spectrogram(audio, n_mels=80, device=torch.device('cpu'))
print("Log-Mel spectrogram shape:", mel_spectrogram.shape)
print("Log-Mel spectrogram:")
print(mel_spectrogram)
Verify and process mel spectrogram
print("Original mel spectrogram shape:", mel_spectrogram.shape)
if len(mel_spectrogram.shape) != 2:
raise ValueError("mel_spectrogram must have the shape (n_mels, length)")
Add a batch dimension
mel_spectrogram = mel_spectrogram.unsqueeze(0) # Shape: (1, n_mels, length)
#Hyper parameters
n_mels, n_state, n_heads, n_layer = 80, 256, 8, 6
n_kv_heads = 4
head_dim = 256 // n_heads
softmax_scale = None
positional_emb = None # You can optionally pass a positional embedding module here
impl = 'torch'
bias = False
causal = True
dtype = torch.float32
device = 'cpu' # Change to 'cuda' if using GPU
n_ctx = mel_spectrogram.shape[2] # Use the length of mel_spectrogram for n_ctx
Instantiate components
conv1 = Conv1d(in_channels=n_mels, out_channels=n_state, kernel_size=3, padding=1)
conv2 = Conv1d(in_channels=n_state, out_channels=n_state, kernel_size=3, stride=2, padding=1)
print(conv1)
print(conv2)
positional_embd = Sinusoids(n_ctx, n_state,10000)
print(positional_embd)
post_layernorm = LayerNorm(normalized_shape=n_state, eps=1e-5, elementwise_affine=True) # Example of using nn.LayerNorm for post_layernorm
Initialize encoder blocks
print(post_layernorm)
encoder_list = [EncoderDecoderBlock(input_layernorm=RMSNorm(256, eps=1e-6, weight_bias_scalar=0),
attention=MultiHeadAttention(256,n_heads=n_heads,n_kv_heads=n_kv_heads,head_dim=head_dim,softmax_scale=softmax_scale,positional_emb=positional_emb,impl=impl,bias=bias,causal=causal,dtype=dtype,device=device),
post_attn_layernorm=LayerNorm(n_state),
cross_attention=MultiHeadAttention(256,n_heads=n_heads,n_kv_heads=n_kv_heads,head_dim=head_dim,softmax_scale=softmax_scale,positional_emb=positional_emb,impl=impl,bias=bias,causal=causal,dtype=dtype,device=device),
ffn=ExpansionFFN(256))]
Instantiate Encoder and process mel spectrogram
encoder = Encoder(conv1, conv2, encoder_list, positional_embd, post_layernorm)
output = encoder(mel_spectrogram)
print("Output shape from Encoder:", output.shape)
Save output if needed
torch.save(output, 'encoder_output.pt')
print("Saved Encoder output to 'encoder_output.pt'")
`
module.py
`
import torch
import torch.nn as nn
from typing import Union
from functools import reduce
class ComponentConfig:
keys_to_ignore_at_inference = []
def __init__(self):
self.c = {
'model_type': 'pragna',
'quantization_config': None,
'_name_or_path': 'pragna',
'tie_weights': None,
}
def append(self, **kwargs):
self.c.update(kwargs)
def __getattr__(self, item):
if item == 'keys':
return self.c.keys
return self.__getitem__(item)
def __getitem__(self, item):
try:
return self.c[item]
except KeyError:
# To support PEFT
if item == 'pretraining_tp':
return False
raise Exception(f'{item} config not found')
def __repr__(self) -> str:
return str(self.c)
def to_dict(self):
return self.c
class Module(nn.Module):
def __init__(self, **kwargs):
super().__init__()
if getattr(self, 'config', None) is None:
self.config = ComponentConfig()
def get_num_trainable_params(self) -> int:
trainable_params = 0
all_param = 0
for param in self.parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
return trainable_params
# return sum(p.numel() for p in self.parameters())
def set_config(self, **kwargs):
self.config.append(**kwargs)
def get_config(self):
return self.config
def _config_validation(self) -> bool:
raise NotImplementedError()
def validate_config(self) -> bool:
return self._config_validation()
def init_params(self):
# DFS to init children parameters
for m in self.children():
if hasattr(m, 'init_params'):
m.init_params()
# Init own parameters
self._init_params_fn()
# print('init for ' + str(self.__class__))
return self
@torch.no_grad()
def _init_params_fn(self):
pass
def tie_weights(self, root_module: nn.Module):
# DFS to init children parameters
for m in self.children():
if hasattr(m, 'tie_weights'):
m.tie_weights(root_module)
# Init own parameters
self._tie_weights(root_module)
# print('init for ' + str(self.__class__))
return self
@torch.no_grad()
def _tie_weights(self, root_module: nn.Module):
if self.config.tie_weights:
names = self.config.tie_weights.split('.')
source_mod = reduce(getattr, names, root_module)
# Tie the weights
self.weight = source_mod.weight
# TODO: Tie the bias as well
# TODO: Clone the weights for scripts
def set_config_and_validate(func):
# def decorator(func):
import inspect, functools
argspec = inspect.signature(func)
# argnames = argspec.args[1:]
params = list(argspec.parameters.values())
@functools.wraps(func)
def __init__(self, *args, **kwargs):
sig = inspect.signature(super(self.__class__, self).__init__)
init_args = {k:kwargs[k] for k in sig.parameters.keys() if k in kwargs}
super(self.__class__, self).__init__(**init_args)
for i, v in enumerate(params[1:]):
if v.name == 'self':
continue
if v.name not in kwargs:
if i < len(args):
kwargs[v.name] = args[i]
else:
if v.default is inspect._empty:
continue
kwargs[v.name] = v.default
if hasattr(self, 'set_config'):
self.set_config(**kwargs)
func(self, **kwargs)
if hasattr(self, 'validate_config'):
assert self.validate_config(), 'config validation failed'
return __init__
# return decorator
`