|
1 | 1 | import argparse
|
| 2 | +import importlib |
2 | 3 |
|
3 | 4 |
|
4 | 5 | def run(args: argparse.Namespace) -> None:
|
5 |
| - if args.scenario == 'grid': |
6 |
| - from grid_world.run import ( |
7 |
| - run_bpr_okr, |
8 |
| - run_bpr_plus, |
9 |
| - run_bsi, |
10 |
| - run_bsi_pt, |
11 |
| - run_deep_bpr_plus, |
12 |
| - run_tom, |
13 |
| - ) |
14 |
| - elif args.scenario == 'nav': |
15 |
| - from navigation_game.run import ( |
16 |
| - run_bpr_okr, |
17 |
| - run_bpr_plus, |
18 |
| - run_bsi, |
19 |
| - run_bsi_pt, |
20 |
| - run_deep_bpr_plus, |
21 |
| - run_tom, |
22 |
| - ) |
23 |
| - elif args.scenario == 'soccer': |
24 |
| - from soccer_game.run import ( |
25 |
| - run_bpr_okr, |
26 |
| - run_bpr_plus, |
27 |
| - run_bsi, |
28 |
| - run_bsi_pt, |
29 |
| - run_deep_bpr_plus, |
30 |
| - run_tom, |
31 |
| - ) |
32 |
| - elif args.scenario == 'baseball': |
33 |
| - from baseball_game.run import ( |
34 |
| - run_bpr_okr, |
35 |
| - run_bpr_plus, |
36 |
| - run_bsi, |
37 |
| - run_bsi_pt, |
38 |
| - run_deep_bpr_plus, |
39 |
| - run_tom, |
40 |
| - ) |
| 6 | + scenario_modules = { |
| 7 | + 'grid': 'grid_world.run', |
| 8 | + 'nav': 'navigation_game.run', |
| 9 | + 'soccer': 'soccer_game.run', |
| 10 | + 'baseball': 'baseball_game.run', |
| 11 | + } |
41 | 12 |
|
42 |
| - if args.agent == 'bpr+': |
43 |
| - run_bpr_plus(args) |
44 |
| - elif args.agent == 'deep-bpr+': |
45 |
| - run_deep_bpr_plus(args) |
46 |
| - elif args.agent == 'tom': |
47 |
| - run_tom(args) |
48 |
| - elif args.agent == 'bpr-okr': |
49 |
| - run_bpr_okr(args) |
50 |
| - elif args.agent == 'bsi': |
51 |
| - run_bsi(args) |
52 |
| - elif args.agent == 'bsi-pt': |
53 |
| - run_bsi_pt(args) |
| 13 | + scenario = args.scenario |
| 14 | + if scenario in scenario_modules: |
| 15 | + run_module = importlib.import_module(scenario_modules[scenario]) |
| 16 | + run_bpr_okr = run_module.run_bpr_okr |
| 17 | + run_bpr_plus = run_module.run_bpr_plus |
| 18 | + run_bsi = run_module.run_bsi |
| 19 | + run_bsi_pt = run_module.run_bsi_pt |
| 20 | + run_deep_bpr_plus = run_module.run_deep_bpr_plus |
| 21 | + run_tom = run_module.run_tom |
| 22 | + else: |
| 23 | + raise ValueError(f"Unsupported scenario: {scenario}") |
| 24 | + |
| 25 | + agent_functions = { |
| 26 | + 'bpr+': run_bpr_plus, |
| 27 | + 'deep-bpr+': run_deep_bpr_plus, |
| 28 | + 'tom': run_tom, |
| 29 | + 'bpr-okr': run_bpr_okr, |
| 30 | + 'bsi': run_bsi, |
| 31 | + 'bsi-pt': run_bsi_pt, |
| 32 | + } |
| 33 | + |
| 34 | + if agent in agent_functions: |
| 35 | + agent_functions[agent](args) |
| 36 | + else: |
| 37 | + raise ValueError(f"Unsupported agent type: {agent}") |
54 | 38 |
|
55 | 39 |
|
56 | 40 | def positive_int(value: str) -> int:
|
|
0 commit comments