diff --git a/agent/SARL/encoder/encoder.py b/agent/SARL/encoder/encoder.py index fba2eb53..6b17df6c 100644 --- a/agent/SARL/encoder/encoder.py +++ b/agent/SARL/encoder/encoder.py @@ -186,9 +186,9 @@ def __init__(self, args): self.test_label_list, self.test_df_list = prepart_m_lstm_data( self.test_data, self.num_day, self.technical_indicator) self.train_dataset = m_lstm_dataset(self.train_df_list, - self.train_label_list) + self.train_label_list,self.num_day) self.valid_dataset = m_lstm_dataset(self.valid_df_list, - self.valid_label_list) + self.valid_label_list,self.num_day) train_dataloader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True) @@ -219,4 +219,4 @@ def set_seed(self): if __name__ == "__main__": args = parser.parse_args() - a = encoder(args) \ No newline at end of file + a = encoder(args)