Skip to content

Commit

Permalink
Added xglm transformers implementation
Browse files Browse the repository at this point in the history
Co-authored-by: Negar Foroutan <negar.foroutan@epfl.ch>
  • Loading branch information
AleHD and negar-foroutan committed Sep 5, 2024
1 parent 4759c55 commit 36ca804
Show file tree
Hide file tree
Showing 5 changed files with 1,275 additions and 23 deletions.
2 changes: 0 additions & 2 deletions examples/xglm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,3 @@ To save back to huggingface format use
```bash
torchrun examples/xglm/convert_ntmoe2hf.py --checkpoint-path=$SCRATCH/checkpoints/xglm-8x564M --save-path=$SCRATCH/checkpoints/huggingface/xglm-8x56fM
```

Make sure to have the [XGLM MOE implementation](https://github.com/negar-foroutan/Multilingual_MoE) installed (e.g. using `PYTHONPATH=/path/to/Multilingual_MoE`).
23 changes: 5 additions & 18 deletions examples/xglm/convert_ntmoe2hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@
from examples.xglm.convert_dense2moe import create_nt_moe_model
from examples.xglm.convert_nt2hf import convert_attention
from examples.xglm.convert_utils import convert_generic

from models.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP
from models.gating import BasicGate

# TODO: nanotron moe scales down the moe weights but hf doesn't
# TODO: nanotron does not use pdrop in moe.
from examples.xglm.transformers_impl.xglm_model import XGLMForCausalLM, XGLMDecoderLayer, XGLMmoeConfig, XGLMSparseMoeBlock, XGLMMLP
from examples.xglm.transformers_impl.gating import BasicGate


def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig:
#assert config.moe_num_experts > 1, f"Why are you using a 1-expert moe? lol"
if config.embd_pdrop != config.resid_pdrop:
warnings.warn(
f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with "
Expand Down Expand Up @@ -68,7 +63,6 @@ def convert_config(config: GPT3MoEConfig) -> XGLMmoeConfig:


def convert_mlp(mlp_hf: XGLMMLP, mlp_nt: SparseMLP):
# TODO: mlp_hf has non-zero bias.
convert_generic(mlp_hf.fc1, mlp_nt.w1.module)
convert_generic(mlp_hf.fc2, mlp_nt.w2.module)

Expand All @@ -88,7 +82,6 @@ def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
assert ff_nt.experts.mlp.w2.module.weight.shape == (int_size*len(ff_hf.experts), ff_nt.config.hidden_size)

for i, expert_hf in enumerate(ff_hf.experts):
# TODO: fc1, fc2 has bias
i0 = i*int_size
i1 = (i + 1)*int_size
with torch.no_grad():
Expand All @@ -98,24 +91,19 @@ def convert_ff(ff_hf: XGLMSparseMoeBlock, ff_nt: dMoE):
else:
expert_hf.fc1.weight.copy_(ff_nt.experts.mlp.w1.module.weight.T[i0:i1, :].clone())
expert_hf.fc2.weight.copy_(ff_nt.experts.mlp.w2.module.weight[i0:i1, :].T.clone())
expert_hf.fc1.bias.data.zero_()
expert_hf.fc2.bias.data.zero_()

def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPT3MoEBlock):
convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1)
convert_attention(block_hf.self_attn, block_nt.attn)
convert_generic(block_hf.final_layer_norm, block_nt.ln_2)
# TODO: hf has fc1, fc2 attributes but they are not used, probably should be removed.
#return block_nt.ff
convert_ff(block_hf.block_sparse_moe, block_nt.ff) # REMOVE
convert_ff(block_hf.block_sparse_moe, block_nt.ff)


def convert(model_hf: XGLMForCausalLM, model_nt: GPT3MoEForTraining):
convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding)
for layer_hf, layer_nt in tqdm(zip(model_hf.model.layers, model_nt.model.decoder), desc="Converting layers",
total=model_nt.config.num_hidden_layers):
#return convert_decoder(layer_hf, layer_nt.pp_block)
convert_decoder(layer_hf, layer_nt.pp_block) # REMOVE
convert_decoder(layer_hf, layer_nt.pp_block)
convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block)
convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block)

Expand All @@ -133,8 +121,7 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.save_pretrained(save_path)
states = torch.randn(4, 1, 1024)
#return convert(model_hf, model_nt), states.cuda().bfloat16()
convert(model_hf, model_nt), states.cuda().bfloat16() # REMOVE
convert(model_hf, model_nt), states.cuda().bfloat16()
print("Saving...")
model_hf.save_pretrained(save_path)
print(f"Model saved to {save_path}")
Expand Down
5 changes: 2 additions & 3 deletions examples/xglm/tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@

from examples.xglm.convert_ntmoe2hf import convert_config, convert_gate, convert_ff, convert
from examples.xglm.tests.test_implementation import almost_close

from models.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM
from models.gating import BasicGate
from examples.xglm.transformers_impl.xglm_model import XGLMSparseMoeBlock, XGLMForCausalLM
from examples.xglm.transformers_impl.gating import BasicGate


MAX_SEQUENCE_LENGTH = 2048
Expand Down
149 changes: 149 additions & 0 deletions examples/xglm/transformers_impl/gating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

from abc import ABC, abstractmethod


class Gate(ABC):
def __init__(self, device):
super(Gate, self).__init__()
self.device = device

@abstractmethod
def compute(self, x):
"""
Compute the output of the gate.
This method should be implemented by all subclasses.
"""
pass


def init_x_embeddings(Xs, x_embedding_dim):
x2embeddings = nn.ParameterDict(dict())
for x in Xs:
x_embedding = torch.empty(x_embedding_dim)
nn.init.normal_(x_embedding)
x2embeddings[str(x)] = nn.Parameter(x_embedding)
return x2embeddings


class BasicGate(nn.Module):
"""One or two layer feedforward network as the Gate."""

def __init__(self, config) -> None:
super().__init__()

self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.ffn_dim = config.ffn_dim
self.activation = nn.ReLU(self.ffn_dim)

if config.gate_depth == 1:
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
elif config.gate_depth == 2:
self.gate = nn.Sequential(
nn.Linear(self.hidden_dim, self.ffn_dim),
self.activation,
nn.Linear(self.ffn_dim, self.num_experts, bias=False),
)
else:
raise ValueError("Invalid gate_depth!")

def forward(self, x, lang_name):
return self.gate(x)


class LanguageAwareGate(nn.Module):
"""One or two layer feedforward network as the Gate."""

def __init__(self, config) -> None:
super().__init__()

self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.ffn_dim = config.ffn_dim
self.activation = nn.ReLU(self.ffn_dim)
self.language_embedding_dim = (
config.language_embedding_dim
if config.language_embedding_dim is not None
else config.hidden_size
)
self.lang_embeddings = init_x_embeddings(
config.languages, self.language_embedding_dim
)

if config.gate_depth == 1:
self.gate = nn.Linear(
self.hidden_dim + self.language_embedding_dim,
self.num_experts,
bias=False,
)
elif config.gate_depth == 2:
self.gate = nn.Sequential(
nn.Linear(self.hidden_dim, self.ffn_dim),
self.activation,
nn.Linear(self.ffn_dim, self.num_experts, bias=False),
)
else:
raise ValueError("Invalid gate_depth!")

def forward(self, x, lang_name):
# TODO x needs to be added to the language embedding (we need to pass the language as well)
lang_embedding = self.lang_embeddings[str(lang_name)]
lang_embedding.squeeze(0)
lang_embedding = lang_embedding.expand(x.shape[0], -1)
x = torch.cat((x, lang_embedding), dim=-1)
return self.gate(x)


class TopKGate(Gate):
def __init__(self, device, straight_through, k=1):
super(TopKGate, self).__init__(device)
self.k = k
self.device = device
self.straight_through = straight_through

def compute(self, x):
if self.k > 1:
topk_gate_scores, indices = torch.topk(x, self.k)
topk_gate_scores = F.softmax(
topk_gate_scores,
dim=1,
dtype=torch.float,
).type_as(x)
mask = F.one_hot(indices, x.shape[-1]).float()
mask_flat = mask.sum(dim=-1)
combine_tensor = (
topk_gate_scores[..., None, None, None]
* mask_flat[..., None, None, None]
* F.one_hot(indices, x.shape[-1])[..., None, None]
)
combine_tensor = combine_tensor.sum(1)
return combine_tensor, indices, topk_gate_scores
elif self.k == 1:
x = F.softmax(x, dim=-1)
topk_gate_scores, index = x.topk(
k=self.k, dim=-1
) # torch.nn.functional.softmax(x , dim=-1).topk(k=self.k, dim=-1)
if self.straight_through:
index_soft = F.softmax(x, dim=-1)
index = (index - index_soft).detach() + index_soft
index = index[:, 0]
topk_gate_scores, index = map(
lambda x: x.squeeze(dim=-1), (topk_gate_scores, index)
)
else:
topk_gate_scores, index = map(
lambda x: x.squeeze(dim=-1), (topk_gate_scores, index)
)

mask = F.one_hot(index, x.shape[-1]).float()
mask_flat = mask.sum(dim=-1)
combine_tensor = (
topk_gate_scores[..., None, None, None]
* mask_flat[..., None, None, None]
* F.one_hot(index, x.shape[-1])[..., None, None]
)
return combine_tensor, index, topk_gate_scores
Loading

0 comments on commit 36ca804

Please sign in to comment.