-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
40 lines (36 loc) · 1.12 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pettingzoo.mpe._mpe_utils.simple_env import make_env
import glob
import os
import time
import numpy as np
import time
from eval import evaluate
import supersuit as ss
from stable_baselines3 import PPO
from stable_baselines3.ppo import CnnPolicy, MlpPolicy
from train import train
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import DQN
from stable_baselines3.dqn import CnnPolicy, MlpPolicy
from pettingzoo.utils.conversions import (
aec_to_parallel_wrapper,
parallel_to_aec_wrapper,
turn_based_aec_to_parallel_wrapper,
)
from pettingzoo.utils.wrappers import BaseWrapper
import warnings
from pettingzoo.test.api_test import missing_attr_warning
def main():
env_kwargs = dict(
max_cycles=120,
continuous_actions=False,
num_intruders=1,
num_patrollers=4,
num_obstacles=1
)
env_fn = "patrolEnv"
train(env_fn, steps=1e6, seed=16, render_mode=None, hyperparams=None ,**env_kwargs)
evaluate(env_fn, num_games=20, render_mode="human", **env_kwargs)
if __name__ == "__main__":
main()