Skip to content

关于 train_beauty_sid_rec.py 代码中 padding 的疑问 #8

@AgentEvan

Description

@AgentEvan

请教下关于 CustomDataCollator 这个类,将 "<|im_start|>user" 之前的文本对应 label 置为-100,不参与训练,但 padding 部分的 label 并没有置为 -100,也就意味着 padding 需要计算 loss,请问是否有特殊的考虑?
相比之下,train_beauty_align.py 中使用的 DataCollatorForLanguageModeling ,在 torch_call 方法中有显式地将 label 中的 pad_token_id 置为 -100 的处理。

class CustomDataCollator:
    def __init__(self, tokenizer, mlm=False):
        self.tokenizer = tokenizer
        self.mlm = mlm
        
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        input_ids = [feature["input_ids"] for feature in features]
        attention_mask = [feature["attention_mask"] for feature in features]

        max_length = max(len(ids) for ids in input_ids)

        padded_input_ids = []
        padded_attention_mask = []
        labels = []

        for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
            padding_length = max_length - len(ids)
            padded_ids = ids + [self.tokenizer.pad_token_id] * padding_length
            padded_mask = mask + [0] * padding_length

            label = padded_ids.copy()

            text = self.tokenizer.decode(ids, skip_special_tokens=False)
            user_start_pos = text.find("<|im_start|>user")

            if user_start_pos != -1:
                user_start_tokens = self.tokenizer.encode("<|im_start|>user", add_special_tokens=False)

                for j in range(len(ids) - len(user_start_tokens) + 1):
                    if ids[j:j+len(user_start_tokens)] == user_start_tokens:
                        for k in range(j):
                            label[k] = -100
                        break
                else:
                    for k in range(len(label)):
                        label[k] = -100
            else:
                for k in range(len(label)):
                    label[k] = -100

            padded_input_ids.append(padded_ids)
            padded_attention_mask.append(padded_mask)
            labels.append(label)
        
        return {
            "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(padded_attention_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

跑了下源代码中的数据处理部分:
Image

DataCollatorForLanguageModeling 中对 pad_token_id 的处理:
Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions