-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
75 lines (64 loc) · 2.77 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import muzero
import os
import time
import gcloud
from glob import glob
class Utils:
def load_model_menu(self, game_name):
# Configure running options
options = ["Specify paths manually"] + sorted(glob(f"results/{game_name}/*/"))
options.reverse()
print()
for i in range(len(options)):
print(f"{i}. {options[i]}")
choice = input("Enter a number to choose a model to load: ")
valid_inputs = [str(i) for i in range(len(options))]
while choice not in valid_inputs:
choice = input("Invalid input, enter a number listed above: ")
choice = int(choice)
if choice == (len(options) - 1):
# manual path option
checkpoint_path = input(
"Enter a path to the model.checkpoint, or ENTER if none: "
)
while checkpoint_path and not os.path.isfile(checkpoint_path):
checkpoint_path = input("Invalid checkpoint path. Try again: ")
replay_buffer_path = input(
"Enter a path to the replay_buffer.pkl, or ENTER if none: "
)
while replay_buffer_path and not os.path.isfile(replay_buffer_path):
replay_buffer_path = input("Invalid replay buffer path. Try again: ")
else:
checkpoint_path = f"{options[choice]}model.checkpoint"
replay_buffer_path = f"{options[choice]}replay_buffer.pkl"
self.load_model(checkpoint_path=checkpoint_path, replay_buffer_path=replay_buffer_path)
def cloud_load_model_menu(muzero, game_name):
# Configure running options
options = Utils.get_run_ids(gcloud.get_blobs_list(game_name))
if options == 0:
print("No cloud models found")
time.sleep(3)
return None
options.reverse()
print()
for i in range(len(options)):
print(f"{i}. {options[i]}")
choice = input("Enter a number to choose a model to load: ")
valid_inputs = [str(i) for i in range(len(options))]
while choice not in valid_inputs:
choice = input("Invalid input, enter a number listed above: ")
choice = int(choice)
destination_folder_name = gcloud.download_blob(game_name, options[choice])
checkpoint_path = os.path.join(destination_folder_name, "model.checkpoint")
replay_buffer_path = os.path.join(destination_folder_name, "replay_buffer.pkl")
muzero.load_model(
checkpoint_path=checkpoint_path, replay_buffer_path=replay_buffer_path
)
return destination_folder_name
def get_run_ids(blobs):
ids = []
for blob in blobs:
id = blob.name.split("/")[2]
if not id in ids:
ids.append(id)
return ids