Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

模型训练时每轮训练结束测试速度慢的改进方案 #137

Open
bigdata0 opened this issue Oct 17, 2024 · 1 comment
Open

模型训练时每轮训练结束测试速度慢的改进方案 #137

bigdata0 opened this issue Oct 17, 2024 · 1 comment

Comments

@bigdata0
Copy link

作者您好,模型训练时设置加载测试集batchsize=1会导致训练速度严重受影响。
当样本总数N可以被batch_size整除时,分批次计算的损失均值与直接计算所有样本的损失均值(等价于batch_size=1)的结果相同。每个批次的损失均值是对各自批次内样本的平均体现,而整体均值是对所有样本的平均,若样本能完全整除,即可无损地传递平均损失。
然而,如果样本总数 N不能被batch_size整除时,若设置drop_last=False,最后一个批次的均值可能会导致批次均值与整体均值之间存在差异。这时,整个数据集的损失均值和分批的损失均值之间不再一致。
在您的代码中,设置了batch_size=1,并设置了drop_last=True,按理说这两个参数没有必要并存,drop_last丢弃了不满批次大小的样本,batch_size=1保证的是计算所有已有测试样本的损失。在您的项目中丢弃了部分样本算是前人工作的遗留问题。
因此,为了保证所有测试样本参与模型评估而设置drop_last=False时,也可以将batch_size调大从而加快训练速度,修改exp_long_term_forecasting.py文件中函数vali的代码,使损失计算与函数test中 对preds与trues进行reshape并计算所有样本损失均值(等价于batch_size=1) 的流程一致,即可加快测试速度,并保证所有测试样本参与模型的评估。
若设置drop_last=True,则可以直接调大batch_size,大大加快训练速度并且不影响最终结果。

@bigdata0
Copy link
Author

def vali(self, vali_data, vali_loader, criterion):
    preds = []
    trues=[]
    total_loss = []
    self.model.eval()
    with torch.no_grad():
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
            batch_x = batch_x.float().to(self.device)
            batch_y = batch_y.float()

            batch_x_mark = batch_x_mark.float().to(self.device)
            batch_y_mark = batch_y_mark.float().to(self.device)

            # decoder input
            dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
            dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
            # encoder - decoder
            if self.args.use_amp:
                with torch.cuda.amp.autocast():
                    if self.args.output_attention:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                    else:
                        outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            else:
                if self.args.output_attention:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
                else:
                    outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            f_dim = -1 if self.args.features == 'MS' else 0
            outputs = outputs[:, -self.args.pred_len:, f_dim:]
            batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)

            pred = outputs.detach().cpu().numpy()
            true = batch_y.detach().cpu().numpy()
            preds.append(pred)
            trues.append(true)

        last_pred=np.empty((0, preds[0].shape[-2], preds[0].shape[-1]))
        last_true=np.empty((0, preds[0].shape[-2], preds[0].shape[-1]))
        if preds[-1].shape!=preds[0].shape:
            last_pred=preds[-1]
            last_true=trues[-1]
            preds=preds[:-1]
            trues=trues[:-1]
        preds = np.array(preds)
        trues = np.array(trues)
        preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
        trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])       
        preds=np.concatenate((preds, np.array(last_pred)), axis=0) 
        trues=np.concatenate((trues, np.array(last_true)), axis=0) 
        total_loss=criterion(torch.from_numpy(preds), torch.from_numpy(trues)).item()
    self.model.train()
    return total_loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant