-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathAIcandy_LSTM_model_hdbhkibl.py
39 lines (31 loc) · 1.22 KB
/
AIcandy_LSTM_model_hdbhkibl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
@author: AIcandy
@website: aicandy.vn
"""
import torch
import torch.nn as nn
import numpy as np
class PricePredictionModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate=0.2):
super(PricePredictionModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate)
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device).to(x.dtype)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device).to(x.dtype)
out, _ = self.lstm(x, (h0, c0))
out = self.dropout(out[:, -1, :])
out = self.fc(out)
return out
def create_sequences(data, seq_length):
sequences = []
targets = []
for i in range(len(data) - seq_length):
seq = data[i:i+seq_length]
target = data[i+seq_length]
sequences.append(seq)
targets.append(target)
return np.array(sequences), np.array(targets)