-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_model.py
More file actions
87 lines (59 loc) · 2.02 KB
/
evaluate_model.py
File metadata and controls
87 lines (59 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.optim import Adam
from c4 import Connect4NN, Connect4, MCTSNode, MCTSBatchSelfPlayer, DEVICE
import numpy as np
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
model = Connect4NN()
model.load_state_dict(torch.load('models/best_model.pth', weights_only=True))
model.to(DEVICE)
class SinglePlayer:
def __init__(self, model):
self.root = MCTSNode(None, Connect4(np.zeros((Connect4.BOARD_HEIGHT, Connect4.BOARD_WIDTH)), 1))
self.model = model
def iterate(self):
node_to_expand = self.root.traverse()
tensor_rep = node_to_expand.evaluation_tensor_representation()
if tensor_rep != None:
self.model.eval()
with torch.no_grad():
policy_logits, values = self.model(torch.stack( (tensor_rep,) ).to(DEVICE))
policy_logits = policy_logits.cpu()
values = values.cpu()
policy = torch.softmax(policy_logits[0],0).numpy()
node_to_expand.expand(policy, float(values[0]))
else:
node_to_expand.expand() # terminal states can be expanded without evaluation
def automove(self):
self.root.toggle_backprop(False)
self.root = self.root._children[self.root.pick_child_index( 0 )]
def move(self, i):
self.root = self.root._children[i]
def __str__(self):
return str(self.root) + self.root._game.__str__() + "|0|1|2|3|4|5|6|\n"
def is_terminal(self):
return self.root.is_terminal()
m = SinglePlayer(model)
while not m.is_terminal():
print(m)
for i in range(100):
m.iterate()
print("Automoving.")
m.automove()
print(m)
if m.is_terminal():
break
for i in range(100):
m.iterate()
i = int(input("Make your move! [0-6]:"))
# m.automove()
m.move(i)
print(m)
# current = MCTSNode(None, Connect4(np.zeros((6,7)),1))
# current.
# pred = model(c.tensor_representation().unsqueeze_(0))
# print(pred)
# print(next(model.parameters())[0])