Skip to content

RMSNORM Component error due to d_model parameter conflict while Testing Encoder #1

@Ritwika-Das-Gupta

Description

@Ritwika-Das-Gupta

RMS NORM

The RMS normalization (Root Mean Square normalization) is a normalization technique used in neural networks. It normalizes the input by its root mean square. Here's an example of how you can implement RMS normalization in PyTorch:

RMS NORM CODE

import torch
from torch import nn

from pragna.components.module import Module, set_config_and_validate
from pragna.logger.logging import logger

class RMSNorm(Module):

@set_config_and_validate
def __init__(self, d_model, eps=1e-6, weight_bias_scalar=0, dtype=None, device=None, **kwargs):
    super().__init__()

    self.eps = eps
    self.weight = nn.Parameter(
        torch.ones(d_model, dtype=dtype, device=device)
    )

def forward(self, x):

    input_dtype = x.dtype
    variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
    x = x * torch.rsqrt(variance + self.eps)

    return ((self.weight.float() + self.config.weight_bias_scalar) * x).to(input_dtype)

def _config_validation(self) -> bool:
    return True

def extra_repr(self) -> str:
    return f'd_model={self.config.d_model}'

def _init_params_fn(self):
    self.weight.data = torch.ones(self.config.d_model, dtype=self.config.dtype, device=self.weight.device)

Explanation

  • Initialization (__init__ method):

    • dim: The dimension of the input tensor.
    • eps: A small epsilon value to avoid division by zero.
    • scale: A learnable parameter that scales the normalized input.
  • Forward pass (forward method):

    • Compute the root mean square (rms) of the input tensor along the last dimension.
    • Normalize the input tensor by dividing it by rms.
    • Scale the normalized input by the learnable parameter scale.

Tensor Passed Through the following components

Running main.py...
Log-Mel spectrogram shape: torch.Size([80, 279])
Log-Mel spectrogram:
tensor([[-0.6515, -0.6505, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
...,
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515],
[-0.6515, -0.6515, -0.6515, ..., -0.6515, -0.6515, -0.6515]])
Original mel spectrogram shape: torch.Size([80, 279])
Conv1d(80, 256, kernel_size=(3,), stride=(1,))
Conv1d(256, 256, kernel_size=(3,), stride=(1,))
Sinusoids()
LayerNorm()

Error:

Traceback (most recent call last):
File "/home/rit/pragna/pragna/components/module.py", line 28, in getitem
return self.c[item]
~~~~~~^^^^^^
KeyError: 'd_model'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/home/rit/pragna/scripts/factory.py", line 4, in
from pragna.factory.model.model_factory import ModelFactory
File "/home/rit/pragna/pragna/factory/model/model_factory.py", line 5, in
from pragna.components.module import Module
File "/home/rit/pragna/pragna/components/init.py", line 5, in
import pragna.components.layer.register
File "/home/rit/pragna/pragna/components/layer/register.py", line 3, in
from pragna.components.layer.encoder import Encoder
File "/home/rit/pragna/pragna/components/layer/encoder.py", line 116, in
encoder_list = [EncoderDecoderBlock(input_layernorm=RMSNorm(256, eps=1e-6, weight_bias_scalar=0),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 148, in init
assert self.validate_config(), 'config validation failed'
^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 69, in validate_config
return self._config_validation()
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/layer/encoder_decoder_block.py", line 53, in _config_validation
return self.config.d_model % self.config.n_heads == 0
^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 24, in getattr
return self.getitem(item)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/rit/pragna/pragna/components/module.py", line 33, in getitem
raise Exception(f'{item} config not found')
Exception: d_model config not found

Files To refer:

encoder.py

`
import torch
import torch.nn as nn
import torch.nn.functional as F
from pragna.components.module import Module, set_config_and_validate
from pragna.logger.logging import logger

class Encoder(Module):
@set_config_and_validate
# def init(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
def init(self, conv1, conv2, encoder_list, positional_embd, post_layernorm, device=None, **kwargs):
super().init()

    self.conv1 = conv1
    self.conv2 = conv2
    self.encoder_list = encoder_list
    # self.positional_embd = positional_embd
    # self.activation = activation
    self.post_layernorm = post_layernorm
    self.positional_embd = positional_embd
    
    
    # self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
    # self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
    # self.register_buffer("positional_embedding", sinusoids(n_ctx // 2, n_state))
    # self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
    #     [ResidualAttentionBlock(n_state=n_state, n_head=n_head, d_model=n_state) for _ in range(n_layer)]
    # )
    # self.ln_post = LayerNorm(n_state)
    print("Executing encoder.py...")

def forward(self, x: torch.Tensor) -> torch.Tensor:
    print("Executing AudioEncoder forward pass... Input shape:", x.shape)

    x = F.gelu(self.conv1(x))
    print("After conv1 shape:", x.shape)

    x = F.gelu(self.conv2(x))
    print("After conv2 shape:", x.shape)

    # x = x.permute(0, 2, 1)  # Shape: (batch_size, n_ctx // 2, n_state)
    # print("After permutation shape:", x.shape)

    # # Ensure positional embedding matches the shape
    
    # x = self.positional_embd(x).to(x.dtype)
    # print("After adding positional embedding shape:", x.shape)

    for encoder in self.encoder_list:
        x = encoder(x)
        print("After block output shape:", x.shape)

    x = self.post_layernorm(x)
    print("Final output shape from AudioEncoder:", x.shape)
    return x
    
def _config_validation(self) -> bool:
    return True

`

from pragna.components.common.AudioProcessor import AudioProcessor
from pragna.components.convolution.conv1d import Conv1d
from pragna.components.layer.encoder_decoder_block import EncoderDecoderBlock
from pragna.components.common.sinusoids import Sinusoids
from pragna.components.layernorm.layernorm import LayerNorm
from pragna.components.layernorm.rmsnorm import RMSNorm
from pragna.components.mlp.ffn import ExpansionFFN
from pragna.components.attention.multi_head_attention import MultiHeadAttention

Example usage

print("Running main.py...")

Initialize AudioProcessor and load/process audio

processor = AudioProcessor()
audio = processor.load_audio("/home/rit/pragna/pragna/components/common/h1.wav")
audio = processor.pad_or_trim(audio)

Compute mel spectrogram

mel_spectrogram = processor.log_mel_spectrogram(audio, n_mels=80, device=torch.device('cpu'))
print("Log-Mel spectrogram shape:", mel_spectrogram.shape)
print("Log-Mel spectrogram:")
print(mel_spectrogram)

Verify and process mel spectrogram

print("Original mel spectrogram shape:", mel_spectrogram.shape)
if len(mel_spectrogram.shape) != 2:
raise ValueError("mel_spectrogram must have the shape (n_mels, length)")

Add a batch dimension

mel_spectrogram = mel_spectrogram.unsqueeze(0) # Shape: (1, n_mels, length)

#Hyper parameters
n_mels, n_state, n_heads, n_layer = 80, 256, 8, 6
n_kv_heads = 4
head_dim = 256 // n_heads
softmax_scale = None
positional_emb = None # You can optionally pass a positional embedding module here
impl = 'torch'
bias = False
causal = True
dtype = torch.float32
device = 'cpu' # Change to 'cuda' if using GPU
n_ctx = mel_spectrogram.shape[2] # Use the length of mel_spectrogram for n_ctx

Instantiate components

conv1 = Conv1d(in_channels=n_mels, out_channels=n_state, kernel_size=3, padding=1)
conv2 = Conv1d(in_channels=n_state, out_channels=n_state, kernel_size=3, stride=2, padding=1)
print(conv1)
print(conv2)
positional_embd = Sinusoids(n_ctx, n_state,10000)
print(positional_embd)
post_layernorm = LayerNorm(normalized_shape=n_state, eps=1e-5, elementwise_affine=True) # Example of using nn.LayerNorm for post_layernorm

Initialize encoder blocks

print(post_layernorm)
encoder_list = [EncoderDecoderBlock(input_layernorm=RMSNorm(256, eps=1e-6, weight_bias_scalar=0),
attention=MultiHeadAttention(256,n_heads=n_heads,n_kv_heads=n_kv_heads,head_dim=head_dim,softmax_scale=softmax_scale,positional_emb=positional_emb,impl=impl,bias=bias,causal=causal,dtype=dtype,device=device),
post_attn_layernorm=LayerNorm(n_state),
cross_attention=MultiHeadAttention(256,n_heads=n_heads,n_kv_heads=n_kv_heads,head_dim=head_dim,softmax_scale=softmax_scale,positional_emb=positional_emb,impl=impl,bias=bias,causal=causal,dtype=dtype,device=device),
ffn=ExpansionFFN(256))]

Instantiate Encoder and process mel spectrogram

encoder = Encoder(conv1, conv2, encoder_list, positional_embd, post_layernorm)
output = encoder(mel_spectrogram)

print("Output shape from Encoder:", output.shape)

Save output if needed

torch.save(output, 'encoder_output.pt')
print("Saved Encoder output to 'encoder_output.pt'")

`

module.py

`
import torch
import torch.nn as nn
from typing import Union
from functools import reduce

class ComponentConfig:
keys_to_ignore_at_inference = []

def __init__(self):
    self.c  = {
        'model_type': 'pragna',
        'quantization_config': None,
        '_name_or_path': 'pragna',
        'tie_weights': None,
    }

def append(self, **kwargs):
    self.c.update(kwargs)

def __getattr__(self, item):
    if item == 'keys':
        return self.c.keys
    return self.__getitem__(item)

def __getitem__(self, item):
    try:
        return self.c[item]
    except KeyError:
        # To support PEFT
        if item == 'pretraining_tp':
            return False
        raise Exception(f'{item} config not found')

def __repr__(self) -> str:
    return str(self.c)

def to_dict(self):
    return self.c

class Module(nn.Module):

def __init__(self, **kwargs):
    super().__init__()

    if getattr(self, 'config', None) is None:
        self.config = ComponentConfig()

def get_num_trainable_params(self) -> int:
    trainable_params = 0
    all_param = 0
    for param in self.parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return trainable_params
    # return sum(p.numel() for p in self.parameters())

def set_config(self, **kwargs):
    self.config.append(**kwargs)

def get_config(self):
    return self.config

def _config_validation(self) -> bool:
    raise NotImplementedError()

def validate_config(self) -> bool:
    return self._config_validation()

def init_params(self):
    

    # DFS to init children parameters
    for m in self.children():
        if hasattr(m, 'init_params'):
            m.init_params()

    # Init own parameters
    self._init_params_fn()

    # print('init for ' + str(self.__class__))

    return self

@torch.no_grad()
def _init_params_fn(self):
    pass

def tie_weights(self, root_module: nn.Module):
    
    # DFS to init children parameters
    for m in self.children():
        if hasattr(m, 'tie_weights'):
            m.tie_weights(root_module)

    # Init own parameters
    self._tie_weights(root_module)

    # print('init for ' + str(self.__class__))

    return self

@torch.no_grad()
def _tie_weights(self, root_module: nn.Module):
    if self.config.tie_weights:
        names = self.config.tie_weights.split('.')
        source_mod = reduce(getattr, names, root_module)

        # Tie the weights
        self.weight = source_mod.weight

        # TODO: Tie the bias as well
        # TODO: Clone the weights for scripts

def set_config_and_validate(func):
# def decorator(func):
import inspect, functools
argspec = inspect.signature(func)
# argnames = argspec.args[1:]
params = list(argspec.parameters.values())

@functools.wraps(func)
def __init__(self, *args, **kwargs):
    
    sig = inspect.signature(super(self.__class__, self).__init__)
    init_args = {k:kwargs[k] for k in sig.parameters.keys() if k in kwargs}
    super(self.__class__, self).__init__(**init_args)

    for i, v in enumerate(params[1:]):
        
        if v.name == 'self':
            continue

        if v.name not in kwargs:
            if i < len(args):
                kwargs[v.name] = args[i]
            else:
                if v.default is inspect._empty:
                    continue
                kwargs[v.name] = v.default
    
    if hasattr(self, 'set_config'):
        self.set_config(**kwargs)
    func(self, **kwargs)

    if hasattr(self, 'validate_config'):
        assert self.validate_config(), 'config validation failed'
        


return __init__
# return decorator

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions