-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_utils.py
114 lines (93 loc) · 3.89 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright 2023 The Distilling-step-by-step authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# https://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import pandas as pd
import torch
from torch import nn
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer
class TaskPrefixDataCollator(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
features_df = pd.DataFrame(features)
pred_features = features_df.loc[:, ~features_df.columns.isin(
[
'aux_labels',
'expl_input_ids',
'expl_attention_mask'
]
)].to_dict('records')
expl_features = (
features_df.loc[:,
~features_df
.columns
.isin(['labels', 'input_ids', 'attention_mask'])
]
.rename(columns={
'aux_labels': 'labels',
'expl_input_ids': 'input_ids',
'expl_attention_mask': 'attention_mask'
}
)
.to_dict('records')
)
pred_features = super().__call__(pred_features, return_tensors)
expl_features = super().__call__(expl_features, return_tensors)
return {
'pred': pred_features,
'expl': expl_features,
}
class TaskPrefixTrainer(Seq2SeqTrainer):
def __init__(self, alpha, output_rationale, **kwargs):
super().__init__(**kwargs)
self.alpha = alpha
self.output_rationale = output_rationale
def compute_loss(self, model, inputs, return_outputs=False):
pred_outputs = model(**inputs['pred'])
expl_outputs = model(**inputs['expl'])
loss = self.alpha * pred_outputs.loss + \
(1. - self.alpha) * expl_outputs.loss
return (loss, {
'pred': pred_outputs,
'expl': expl_outputs
}) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None
) -> Tuple[
Optional[float],
Optional[torch.Tensor],
Optional[torch.Tensor]
]:
gen_kwargs = {"max_new_tokens": 512}
pred_outputs = super().prediction_step(model,
inputs['pred'],
prediction_loss_only=False,
ignore_keys=ignore_keys,
**gen_kwargs
)
if self.output_rationale:
expl_outputs = super().prediction_step(model,
inputs['expl'],
prediction_loss_only=False,
ignore_keys=ignore_keys,
**gen_kwargs
)
else:
expl_outputs = pred_outputs # placeholder only
loss = self.alpha * pred_outputs[0] + \
(1 - self.alpha) * expl_outputs[0]
return (
loss,
[pred_outputs[1], expl_outputs[1]],
[pred_outputs[2], expl_outputs[2]],
)