Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cute-kernels
352 changes: 352 additions & 0 deletions examples/diffusion/diffusion-1b-24l.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
datasets:
# class_name - data_name & data_sampling_ratio are not used but need to be passed to avoid errors
- class_name: MegatronDataset
data_name: Megatron
data_sampling_ratio: 1
class_args:
eval_steps: 2
data_cache_path: /proj/checkpoints/shawntan/diffusion/release/data-cache
data_path:
- 1 # mixture ratio
- /proj/checkpoints/shawntan/diffusion/release/data/dclm-dedup-gpt2-tokenized/dclm_00_text # path prefix
split: 100,0,0
sequence_length: 4096 # context length


tokenizer_args:
tokenizer_name: /proj/checkpoints/shawntan/diffusion/release/data/tokenizer

kernel_args:
kernels:
- swiglu_packed_cute
- rmsnorm_cute
- scattermoe
- flash_attention_2

model_args:
model_class: AutoModelForCausalLM
pretrained_config:
initializer_range: 0.1
layer_norm_epsilon: 1e-05
model_type: diffusion
normalization_function: rmsnorm
position_embedding_type: rope
hidden_size: 2048
m_width: 8
m_emb: 12
m_residual: 0.28577380332470415
num_layers: 24
init_method: mup
tie_word_embeddings: true
router_aux_loss_coef: 0.01
bos_token_id: 50256 # ensure these are same in the tokenizer
eos_token_id: 50256
pad_token_id: 50258
vocab_size: 50259
max_position_embeddings: 4096
sequence_mixer_blocks:
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
- sequence_mixer_type: softmax_attention
causal: false
num_attention_heads: 16
num_key_value_heads: 16
add_bias: false
attention_multiplier: 0.0078125
mlp_blocks:
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false
- mlp_type: MLP
activation_function: swiglu
intermediate_size: 4096
add_bias: false


use_padding_free_transformer: true
# efficient_initialization: true
reset_attention_mask: true
reset_position_ids: true

tuning_args:
tuning_method: pretraining_diffusion

save_args:
save_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b
save_interval: 5000

# TODO restoring from last checkpoint
# load_args:
# load_path: /proj/checkpoints/shawntan/diffusion/release/data/diffusion-24l-1b

logging_args:
log_interval: 10
# experiments_tracker_name: wandb
# wandb_args:
# project: diffusion-release
# name: diffusion-1b-24l


training_parameters:
num_training_steps: 75000
eval_interval: 1000000000
micro_batch_size: 2
gradient_accumulation_steps: 4
eval_during_training: false

optimizer_args:
params_group_method: mup
class_name: TorchAdamW
class_args:
lr: 0.01
weight_decay: 0.1
betas:
- 0.9
- 0.95
eps: 1e-10

lr_scheduler_args:
lr_decay_style: power
num_warmup_steps: 5000
num_constant_steps: 0
num_decay_steps: 70000
extra_lr_scheduler_args:
# 4 * global_batch_size
a: 4096
# constant
b: -0.51
# global_batch_size in number of tokens
c: 4194304

mixed_precision_args:
dtype: bf16

distributed_args:
fsdp_algorithm: 2
torch_compile: true
stage: 0
25 changes: 25 additions & 0 deletions examples/diffusion/diffusion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
set -x
DATASET="Zyphra/dclm-dedup"
BASE_TOKENIZER="openai-community/gpt2"
DATA_PATH="../data/"
mkdir -p $DATA_PATH
TRAIN_PATH="$DATA_PATH/dclm-dedup-gpt2-tokenized"
mkdir -p $TRAIN_PATH
TOKENIZER_PATH="$DATA_PATH/tokenizer"
mkdir -p $TOKENIZER_PATH

python -u examples/diffusion/modify_tokenizer.py --tokenizer $BASE_TOKENIZER --output-path $TOKENIZER_PATH

CHUNK=0
CHUNK_SIZE=20000000
START_IDX=$(($CHUNK * $CHUNK_SIZE))
END_IDX=$(($START_IDX + $CHUNK_SIZE))
SPLIT="train[$START_IDX:$END_IDX]"

OUTPUT_FILE="$TRAIN_PATH/dclm_`printf '%02d' $CHUNK`"
python -u examples/diffusion/preprocess_data.py \
--input Zyphra/dclm-dedup --split $SPLIT \
--tokenizer $TOKENIZER_PATH \
--output-prefix $OUTPUT_FILE \
--workers 128 --chunk-size 8192 --append-eod
Loading
Loading