-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
96 lines (72 loc) · 2.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils
from torch.autograd import Variable
import math
from blackhc.mdp import dsl
from blackhc import mdp
import time
from blackhc.mdp import lp
import functools
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
from numpy import random
from operator import itemgetter
from collections import defaultdict
import numpy as np
class ReplayMemory:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.position = 0
def flush_all(self):
self.buffer = []
self.position = 0
return
def push(self, state, action, reward, next_state,policy):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state,policy)
self.position = (self.position + 1) % self.capacity
def push_batch(self, batch):
if len(self.buffer) < self.capacity:
append_len = min(self.capacity - len(self.buffer), len(batch))
self.buffer.extend([None] * append_len)
if self.position + len(batch) < self.capacity:
self.buffer[self.position : self.position + len(batch)] = batch
self.position += len(batch)
else:
self.buffer[self.position : len(self.buffer)] = batch[:len(self.buffer) - self.position]
self.buffer[:len(batch) - len(self.buffer) + self.position] = batch[len(self.buffer) - self.position:]
self.position = len(batch) - len(self.buffer) + self.position
def sample(self, batch_size):
if batch_size > len(self.buffer):
batch_size = len(self.buffer)
batch = random.sample(self.buffer, int(batch_size))
state, action, reward, next_state,policy = map(np.stack, zip(*batch))
return state, action, reward, next_state,policy
def sample_all_batch(self, batch_size):
idxes = np.random.randint(0, len(self.buffer), batch_size)
batch = list(itemgetter(*idxes)(self.buffer))
state, action, reward, next_state,policy = map(np.stack, zip(*batch))
return state, action, reward, next_state,policy
def return_all(self):
return self.buffer
def __len__(self):
return len(self.buffer)
def init_weights(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
class Network(nn.Module):
def __init__(self, input_layer,output_layer):
super(Network, self).__init__()
self.fc1 = nn.Linear(input_layer, output_layer,bias=False)
self.fc2=nn.Softmax(dim=1)
def forward(self, input_):
x=self.fc1(input_)
y=self.fc2(x)
return y