Skip to content

train_beauty_sid_rec.py代码中的疑问 #6

@mymymynameisisislry

Description

@mymymynameisisislry

作者你好,非常好的工作!
我看到代码中在处理数据时,为了进行label预测,会mask掉一部分数据:

`class CustomDataCollator:
def init(self, tokenizer, mlm=False):
self.tokenizer = tokenizer
self.mlm = mlm

def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
    input_ids = [feature["input_ids"] for feature in features]
    attention_mask = [feature["attention_mask"] for feature in features]

    max_length = max(len(ids) for ids in input_ids)

    padded_input_ids = []
    padded_attention_mask = []
    labels = []

    for i, (ids, mask) in enumerate(zip(input_ids, attention_mask)):
        padding_length = max_length - len(ids)
        padded_ids = ids + [self.tokenizer.pad_token_id] * padding_length
        padded_mask = mask + [0] * padding_length

        label = padded_ids.copy()

        text = self.tokenizer.decode(ids, skip_special_tokens=False)
        user_start_pos = text.find("<|im_start|>user")

        if user_start_pos != -1:
            user_start_tokens = self.tokenizer.encode("<|im_start|>user", add_special_tokens=False)

            for j in range(len(ids) - len(user_start_tokens) + 1):
                if ids[j:j+len(user_start_tokens)] == user_start_tokens:
                    for k in range(j):
                        label[k] = -100
                    break
            else:
                for k in range(len(label)):
                    label[k] = -100
        else:
            for k in range(len(label)):
                label[k] = -100

        padded_input_ids.append(padded_ids)
        padded_attention_mask.append(padded_mask)
        labels.append(label)
    
    return {
        "input_ids": torch.tensor(padded_input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(padded_attention_mask, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
    }`

比如user_start_pos = text.find("<|im_start|>user")
可以理解为mask掉了"<|im_start|>user"之前的system prompt部分吗。如果这样的话,为什么不同样mask掉user prompt部分呢/为什么要让user和assistant部分同时用作label呢?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions