-
Notifications
You must be signed in to change notification settings - Fork 700
Torchtitan changes to integrate into Verl #2333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| ) -> int: | ||
| # Skip initialization if already initialized | ||
| if torch.distributed.is_initialized(): | ||
| return torch.distributed.get_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If going this path, it means a lot of the setting in this function / config won't take effect. Shall we add a warning to users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, will add a warning. But I think this also means user initialize distributed env somewhere else with their own settings.
| # linear warmup | ||
| # 0-indexed step, hence + 1 adjustments | ||
| current_step += 1 | ||
| # linear warmup (0-indexed to match FSDP/HuggingFace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we spent a lot of time converging to the current behavior.
this change will likely break unit test and user code (silent change)
let's be careful about this change.
cc @wwwjn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#1284 Here's more context why we converge to current behavior (red line) in this PR description. Can you explain in similar graph removing this current_step += 1 would affect the shape of learning rate scheduler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! turns out i don't need this change and can still achieve numeric parity. This was added during my debugging journey.
sorry I take that back, the change seems necessary, only with this change lr schedule exactly matches with verl's fsdp impl, and therefore loss exactly matches. Otherwise torchtitan lr is always one step ahead of FSDP lr(FSDP uses 0 index but titan use 1 index).
wondering what's the reason for this +1 adjustment in the beginning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sync with @wwwjn offline, likely moving from 1 indexed to 0 indexed will not change the shape of lr schedule, it only shift by 1 step. But I will do more thorough testing to confirm the new lr schedule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that lr schedule mismatch would cause loss curve mismatch, and aligning with verl fsdp would show stronger numerical alignment.
What I don't understand why you think verl is the golden standard. Do you think we could either
- not changing either side and bear with this difference
- change verl side LR scheduling
If not, could you give me an argument that it is us who needs to change? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this is a valid question, apologies for not doing enough research on this. I think we should aim for exact same lr schedule and loss(although loss difference is not large; still same decreasing trend).
the pink one is with titan's original lr schedule.
I asked agent to do some research https://docs.google.com/document/d/1YiFUvIa_JqTYpBd2Xj7ReH3Bw6wS07nKldycBX--uVE/edit?usp=sharing and most frameworks use 0 based index(current_step starts from 0). It would be great if titan also switch to 0 based if it doesn't cause other issues? But I am also fine if we want to preserve the difference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having learning rate 0 at the first update sounds ... pointless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to align with other frameworks if we don't understand their rationale. +1 that we don't want to have any step with lr = 0 (basically waste foward and backward computation). Would it be possible to make "offset 1" operation in verl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verl have this init_lr_ratio so it's lr schedule is 0-indexed with configurable init. If init_lr_ratio is not set(which is the default case), then first step lr will be 0.
Megatron, deepspeed also adopts a similar 0 index with configurable min_lr(default to 0)
We can also keep the lr schedule different, as offset by 1 step likely will not affect model quality? wdyt?
| # linear warmup | ||
| # 0-indexed step, hence + 1 adjustments | ||
| current_step += 1 | ||
| # linear warmup (0-indexed to match FSDP/HuggingFace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that lr schedule mismatch would cause loss curve mismatch, and aligning with verl fsdp would show stronger numerical alignment.
What I don't understand why you think verl is the golden standard. Do you think we could either
- not changing either side and bear with this difference
- change verl side LR scheduling
If not, could you give me an argument that it is us who needs to change? Thanks!
torchtitan/train.py
Outdated
| # The returned loss here is local SUM loss / global_valid_tokens | ||
| return loss | ||
|
|
||
| def forward_step( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might eventually need this as we build RL by ourselves
but for now, can we put this in verl engine?
torchtitan/models/attention.py
Outdated
| return (position_diff != 1).cumsum(-1) # [batch, seq] | ||
|
|
||
|
|
||
| def create_sdpa_document_causal_mask(positions: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes @fegin suggested supporting document mask only for flex and varlen attention, I will update the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
| ) -> AttentionMasksType: | ||
| match self.model_args.attn_type: | ||
| case "sdpa": | ||
| assert extra_inputs is not None and "positions" in extra_inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
positions won't be always available -- wouldn't this break torchtitan in general?
but anyways, we probably don't want to do sdpa with positions, see the other comment
torchtitan/models/attention.py
Outdated
| ) -> _mask_mod_signature: | ||
| """Creates a document mask from position_ids for flex attention. | ||
|
|
||
| Detects document boundaries where position_ids reset (diff != 1). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be careful here.
Previously, in torchtitan, position_ids is used to indexing rope cache (for CP to work with sharded sequence and replicate rope cache), in a sense that it's decoupled from attention block mask decision (which is determined by eos_id in inputs today).
According to what you are trying to do, it sounds like verl / HF is coupling position_ids with mask creation? IIUC it only works for block-causal assuming the block boundary is given by positions_ids. In particular, if masking is more complicated than block causal, e.g. in multimodal training, where the attention is bidirectional within an image, it can't be expressed using position_ids?
So, now there will be two ways to create block_causal masking, one by eos_id and the other by positions_id, is that correct?
cc @fegin do you think it's fine to create another model_args.attn_mask_type for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verl's fsdp engine uses transformers model, whose forward takes in position_ids and generate block causal mask based on that. Therefore, apart from get_document_mask_mod which generate block_mask from eos_id, I add another get_document_mask_mod_from_positions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we make this generic to <separator> and this separator can be eos_id or position_ids?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg position_ids is not a separator. You need a different function to translate it to a mask mod.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah fair one could also probably just have a base doc_mask(, segments) and it does the check for equal segments but I dont think it adds to much value since you still need the find_packed_sequence_indices
| 3.0 / 8.0, # Step 7: 3/8 of max LR | ||
| 2.0 / 8.0, # Step 8: 1/4 of max LR | ||
| 1.0 / 8.0, # Step 9: 1/8 of max LR | ||
| 0.0, # Step 0: 0% of max LR (warmup start) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, this step is not updating weights at all ,the forward and backward computation are wasted.
| # linear warmup | ||
| # 0-indexed step, hence + 1 adjustments | ||
| current_step += 1 | ||
| # linear warmup (0-indexed to match FSDP/HuggingFace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to align with other frameworks if we don't understand their rationale. +1 that we don't want to have any step with lr = 0 (basically waste foward and backward computation). Would it be possible to make "offset 1" operation in verl?
| input_ids=input_batch, eos_id=tokenizer.eos_id | ||
| ) | ||
| ) | ||
| case "position_block_causal": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in "position block causal", the EOS id is not used to create block mask, and the user might accidentally use both, and don't know which one is actually taking effect?
If they are both referring to "block_causal", one possible way to do this is let get_document_mask_mod takes both EOS id and positions, and specifically add warning when both is not None and let user know which field we are using to create the "block_causal" mask
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious about knowing more about the actual data format (the Trajectory) that the generator passed to trainer. For the following field:
- Prompt + Completion: Does this field has EOS in it? If so, can we use EOS id instead of position_id?
And for trajectory, do they do padding or packing? If padding, would position_id field also be padded? If yes, does the mask_mod algorithm handled padded field correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the user would have to specify using block_causal(with eos id) or position_block_causal(using positions); they can't use both at the same time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prompt + Completion: Does this field has EOS in it? If so, can we use EOS id instead of position_id?
I checked the input prompt + completion does has EOS in it, but it's not corresponding to positions. Looks like there is one more new line token after EOS for each sample.
torchtitan/models/attention.py
Outdated
| cumulative_mask = torch.cumsum(torch.where(eos_mask, 1, 0), dim=1) | ||
| sequence_indices = torch.zeros_like(cumulative_mask, dtype=torch.int32) | ||
| sequence_indices[:, 1:] = cumulative_mask[:, :-1] | ||
| if positions is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this algorithm work for inference time continuous batching, with mixed prefix and decode request? eg, the position could be "[4,5,6,0,1,2,3,7,8,9,10]"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And does it handled padding correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it detects sequence boundary through position_diff!=1 so it will give [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].
it will not handle padding tho(each padding marked as separate sequence), as we don't expect any padding, and use packing.
| tokenizer=self.tokenizer, | ||
| job_config=job_config, | ||
| ) | ||
| if self.train_spec.build_dataloader_fn is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
according to https://github.com/pytorch/torchtitan/blob/main/torchtitan/protocols/train_spec.py#L51, it can't be None
I'm OK with type change, but you'll need to assert not None in torchtitan before it's used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using verl's dataloader, i don't want to initialize titan's dataloader(otherwise I will encounter loading c4 dataset error). I did a hack here https://github.com/verl-project/verl/pull/5051/changes#diff-f658afe18d14b480f4067f7544fbdb0ef6962a20ef3b5f5d0c709ae31e91809dR101
Goal: This PR makes the changes so that we can integrate Torchtitan as a trainer to Verl: verl-project/verl#5051
Major changes:
Change LR schedule to be 0 indexed instead of 1 indexed; to align with Verl's fsdp util==> We decide not to change Titan's LR Scheduler behavior.See more analysis in https://docs.google.com/document/d/1YiFUvIa_JqTYpBd2Xj7ReH3Bw6wS07nKldycBX--uVE/edit?usp=sharing
add==> this is added in Verl's Torchtitan Engine code insteadposition_block_causalattn mask type, which creates block causal mask based onposition_idfor both varlen and flex attention: transformers referenceTodos:
pp_schedule.eval()does the microbatch split for us, as it takes in the whole batch. However, in verl we split batch into microbatches before pp, and we'd love to pass in a list of pre-split microbatches to pp schedule. (thanks for @H-Huang's help)