Skip to content

Commit f259f29

Browse files
committed
refactor: remove repetitive imports in run.py
1 parent 6e4c82b commit f259f29

File tree

1 file changed

+32
-48
lines changed

1 file changed

+32
-48
lines changed

src/run.py

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,40 @@
11
import argparse
2+
import importlib
23

34

45
def run(args: argparse.Namespace) -> None:
5-
if args.scenario == 'grid':
6-
from grid_world.run import (
7-
run_bpr_okr,
8-
run_bpr_plus,
9-
run_bsi,
10-
run_bsi_pt,
11-
run_deep_bpr_plus,
12-
run_tom,
13-
)
14-
elif args.scenario == 'nav':
15-
from navigation_game.run import (
16-
run_bpr_okr,
17-
run_bpr_plus,
18-
run_bsi,
19-
run_bsi_pt,
20-
run_deep_bpr_plus,
21-
run_tom,
22-
)
23-
elif args.scenario == 'soccer':
24-
from soccer_game.run import (
25-
run_bpr_okr,
26-
run_bpr_plus,
27-
run_bsi,
28-
run_bsi_pt,
29-
run_deep_bpr_plus,
30-
run_tom,
31-
)
32-
elif args.scenario == 'baseball':
33-
from baseball_game.run import (
34-
run_bpr_okr,
35-
run_bpr_plus,
36-
run_bsi,
37-
run_bsi_pt,
38-
run_deep_bpr_plus,
39-
run_tom,
40-
)
6+
scenario_modules = {
7+
'grid': 'grid_world.run',
8+
'nav': 'navigation_game.run',
9+
'soccer': 'soccer_game.run',
10+
'baseball': 'baseball_game.run',
11+
}
4112

42-
if args.agent == 'bpr+':
43-
run_bpr_plus(args)
44-
elif args.agent == 'deep-bpr+':
45-
run_deep_bpr_plus(args)
46-
elif args.agent == 'tom':
47-
run_tom(args)
48-
elif args.agent == 'bpr-okr':
49-
run_bpr_okr(args)
50-
elif args.agent == 'bsi':
51-
run_bsi(args)
52-
elif args.agent == 'bsi-pt':
53-
run_bsi_pt(args)
13+
scenario = args.scenario
14+
if scenario in scenario_modules:
15+
run_module = importlib.import_module(scenario_modules[scenario])
16+
run_bpr_okr = run_module.run_bpr_okr
17+
run_bpr_plus = run_module.run_bpr_plus
18+
run_bsi = run_module.run_bsi
19+
run_bsi_pt = run_module.run_bsi_pt
20+
run_deep_bpr_plus = run_module.run_deep_bpr_plus
21+
run_tom = run_module.run_tom
22+
else:
23+
raise ValueError(f"Unsupported scenario: {scenario}")
24+
25+
agent_functions = {
26+
'bpr+': run_bpr_plus,
27+
'deep-bpr+': run_deep_bpr_plus,
28+
'tom': run_tom,
29+
'bpr-okr': run_bpr_okr,
30+
'bsi': run_bsi,
31+
'bsi-pt': run_bsi_pt,
32+
}
33+
34+
if agent in agent_functions:
35+
agent_functions[agent](args)
36+
else:
37+
raise ValueError(f"Unsupported agent type: {agent}")
5438

5539

5640
def positive_int(value: str) -> int:

0 commit comments

Comments
 (0)