-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdiscriminator_net.py
108 lines (88 loc) · 3.67 KB
/
discriminator_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch.types
from models.mlp_network import MLPNetwork
class DiscrNetBase:
def __init__(self):
self.device = torch.device('cpu')
def parameters(self, recursive):
pass
def run_logits(self, state, action):
pass
def to(self, device):
pass
def save(self, path, version):
pass
def load(self, path, version, **kwargs):
pass
def state_dict(self):
pass
def load_state_dict(self, state_dict):
pass
class GAILDiscrNet(DiscrNetBase):
def __init__(self, input_size, hidden_size,
activations, name_suffix=''):
super(GAILDiscrNet, self).__init__()
self.name_suffix = name_suffix
self.model = MLPNetwork(input_size, 1, hidden_size, activations, name_suffix)
def parameters(self, recursive=True):
return [*self.model.parameters(True)]
def run_logits(self, x):
return self.model.meta_forward(x)
def to(self, device):
if not self.device == device:
self.device = device
self.model.to(device)
def save(self, path, version):
self.model.save(path, version)
def load(self, path, version, **kwargs):
self.model.load(path, version, **kwargs)
def state_dict(self):
return dict(
discriminator=self.model.state_dict()
)
def load_state_dict(self, state_dict):
if 'discriminator' in state_dict:
self.model.load_state_dict(state_dict['discriminator'])
else:
print(f'discriminator is not founded in the state dict')
class AIRLDiscrNet(DiscrNetBase):
def __init__(self, input_size_r, input_size_V, hidden_size_r, hidden_size_V, activations_r, activations_V,
gamma, name_suffix=''):
super(AIRLDiscrNet, self).__init__()
self.name_suffix_r = name_suffix + '_r'
self.name_suffix_V = name_suffix + '_V'
self.model_r = MLPNetwork(input_size_r, 1, hidden_size_r, activations_r, self.name_suffix_r)
self.model_V = MLPNetwork(input_size_V, 1, hidden_size_V, activations_V, self.name_suffix_V)
self.gamma = gamma
def parameters(self, recursive=True):
return [*self.model_r.parameters(True)] + [*self.model_V.parameters(True)]
def run_logits(self, embed_for_r, embed_for_V, next_embed_for_V, dones, log_pi_s):
r_s = self.model_r.meta_forward(embed_for_r)
V_s = self.model_V.meta_forward(embed_for_V)
V_s_next = self.model_V.meta_forward(next_embed_for_V)
f = r_s + self.gamma * (1 - dones) * V_s_next - V_s
return f - log_pi_s
def to(self, device):
if not self.device == device:
self.device = device
self.model_r.to(device)
self.model_V.to(device)
def save(self, path, version):
self.model_r.save(path, version)
self.model_V.save(path, version)
def load(self, path, version, **kwargs):
self.model_r.load(path, version, **kwargs)
self.model_V.load(path, version, **kwargs)
def state_dict(self):
return dict(
discriminator_r=self.model_r.state_dict(),
discriminator_V=self.model_V.state_dict(),
)
def load_state_dict(self, state_dict):
if 'discriminator_r' in state_dict:
self.model_r.load_state_dict(state_dict['discriminator_r'])
else:
print(f'discriminator_r is not founded in the state dict')
if 'discriminator_V' in state_dict:
self.model_V.load_state_dict(state_dict['discriminator_V'])
else:
print(f'discriminator_V is not founded in the state dict')