Skip to content

Commit

Permalink
misc uncomitted changes
Browse files Browse the repository at this point in the history
  • Loading branch information
giorgi-o committed May 3, 2024
1 parent d90486f commit a6c7fec
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 43 deletions.
94 changes: 56 additions & 38 deletions src/python/src/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ class EnvOpts:
render: bool


VARIANT = 1
AGENT_COUNT = 3
VARIANT = 3
AGENT_COUNT = 5
PASSENGER_COUNT = 30
SPAWN_MORE_PASSENGERS = False
VERBOSE_AND_RENDER = False
DETERMINISTIC = True
SPAWN_MORE_PASSENGERS = True
TOPLEFT2BOTRIGHT = None
VERBOSE_AND_RENDER = True
DETERMINISTIC = False


# None = random
Expand All @@ -47,6 +48,11 @@ def generate_grid_opts(
spawn_more_passengers = SPAWN_MORE_PASSENGERS
if spawn_more_passengers is None:
spawn_more_passengers = random.random() > 0.7

if TOPLEFT2BOTRIGHT is not None:
topleft_2_botright = TOPLEFT2BOTRIGHT
else:
topleft_2_botright = random.random() > 0.7

if spawn_more_passengers:
passenger_spawn_rate = 0.005
Expand All @@ -55,39 +61,42 @@ def generate_grid_opts(

charging_stations_pos = [
rust.CarPosition(
direction=rust.Direction.Up,
road_index=1,
section_index=1,
position_in_section=3,
direction=rust.Direction.Right,
road_index=4,
section_index=5,
position_in_section=2,
),
rust.CarPosition(
direction=rust.Direction.Down,
direction=rust.Direction.Right,
road_index=5,
section_index=3,
position_in_section=3,
section_index=7,
position_in_section=2,
),
rust.CarPosition(
direction=rust.Direction.Left,
road_index=7,
section_index=5,
position_in_section=3,
direction=rust.Direction.Down,
road_index=8,
section_index=6,
position_in_section=2,
),
rust.CarPosition(
direction=rust.Direction.Right,
road_index=9,
section_index=7,
position_in_section=3,
road_index=2,
section_index=9,
position_in_section=2,
),
]

passenger_events = [
rust.PassengerEvent(
start_area=(0, 0, 2, 2),
destination_area=(-2, -2, -0.0, -0.0),
spawn_rate=0.5,
between_ticks=(None, None),
)
]
if topleft_2_botright == True:
passenger_events = [
rust.PassengerEvent(
start_area=(0, 0, 4, 4),
destination_area=(-4, -4, -0.0, -0.0),
spawn_rate=0.0,
between_ticks=(None, None),
)
]
elif topleft_2_botright == False:
passenger_events = []

grid_opts_args = {}
match variant:
Expand All @@ -106,15 +115,14 @@ def generate_grid_opts(
passenger_spawn_rate=passenger_spawn_rate,
max_passengers=30,
agent_car_count=AGENT_COUNT,
npc_car_count=0,
npc_car_count=200,
# passengers_per_car=1,
charging_stations=charging_stations_pos,
charging_station_capacity=1,
# discharge_rate=0.002,
car_radius=3,
passenger_radius=5,
# passenger_events=passenger_events,
passenger_events=[],
passenger_events=passenger_events,
deterministic_mode=DETERMINISTIC,
verbose=VERBOSE_AND_RENDER,
**grid_opts_args,
Expand All @@ -140,7 +148,8 @@ def __init__(

self.car_passenger_slots = 4
self.num_envs = self.grid_opts.agent_car_count
self.width, self.height = rust.grid_dimensions()
# self.width, self.height = rust.grid_dimensions()
self.width, self.height = (15, 10)

self.TICKS_PER_EPISODE = 10000
self.MAX_DISTANCE = 100
Expand Down Expand Up @@ -177,14 +186,17 @@ def car_radius(self) -> int:
@property
def passengers_per_car(self) -> int:
return self.grid_opts.passengers_per_car
# return 4

@property
def charging_station_count(self) -> int:
return len(self.grid_opts.charging_stations)
# return len(self.grid_opts.charging_stations)
return 4

@property
def charging_station_capacity(self) -> int:
return self.grid_opts.charging_station_capacity
# return self.grid_opts.charging_station_capacity
return 1

def register_worker(self, worker: "GridEnvWorker"):
assert len(self.workers) < self.num_envs
Expand Down Expand Up @@ -602,7 +614,10 @@ def _parse_cars(self, cars) -> list[int | float]:

def _parse_charging_station(self, charging_station, pov_car) -> list[int | float]:
slots = [0, 0] * self.charging_station_capacity
for i, car in enumerate(charging_station.cars):
# for i, car in enumerate(charging_station.cars):
if len(charging_station.cars) == 4:
i = 0
car = charging_station.cars[0]
i = i * 2
slots[i] = 1
slots[i + 1] = max(car.battery, 0)
Expand All @@ -615,10 +630,13 @@ def _parse_charging_station(self, charging_station, pov_car) -> list[int | float
def _parse_charging_stations(self, state) -> list[int | float]:
# return [*self._parse_charging_station(cs, state.pov_car) for cs in state.charging_stations]
charging_stations = []
for cs in state.charging_stations:
for i, cs in enumerate(state.charging_stations):
parsed_cs = self._parse_charging_station(cs, state.pov_car)
charging_stations.extend(parsed_cs)

if i + 1 == self.charging_station_count:
break

return charging_stations

def parse_observation(self, state: GridState) -> np.ndarray:
Expand Down Expand Up @@ -726,16 +744,16 @@ def action_mask(self, state: GridState) -> np.ndarray:
offset += 4

if is_in_charging_station:
if current_battery_level < 1.0:
if current_battery_level < 0.98:
# force charge till completion
valid_actions[:] = False
valid_actions[charging_station_offset] = True
else:
if current_battery_level < 1.0:
if current_battery_level < 0.98:
# can go to charging station if it wants
valid_actions[charging_station_offset] = True

if current_battery_level < 1.0:
if current_battery_level < 0.98:
valid_actions[charging_station_offset] = True

assert offset == self.action_count
Expand Down
7 changes: 2 additions & 5 deletions src/python/src/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@
from env import GridVecEnv
from util import LogStep

# note: "on" = overnight, 25 = date 25/04
# LOAD_POLICY = "rainbow_on_25"
# LOAD_POLICY = "rainbow_on_26"
LOAD_POLICY = "rainbow_on_27/checkpoint_139.pth"
# LOAD_POLICY = None

LOAD_POLICY = None
ZERO_EPSILON = True


Expand Down

0 comments on commit a6c7fec

Please sign in to comment.