Skip to content

Commit

Permalink
fixes for reward learning
Browse files Browse the repository at this point in the history
  • Loading branch information
Shivam Singhal authored and Shivam Singhal committed Feb 27, 2024
1 parent 60d895e commit 8c3775b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
6 changes: 6 additions & 0 deletions occupancy_measures/envs/learned_reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, config: LearnedRewardWrapperConfig):
base_env_name = config["env"]
base_env_config = config["env_config"]
env_creator = _global_registry.get(ENV_CREATOR, base_env_name)
self.base_env_config = config["env_config"]
base_env = env_creator(base_env_config)
if isinstance(base_env, MultiAgentEnv):
self.base_env = base_env
Expand Down Expand Up @@ -135,6 +136,11 @@ def step(self, action_dict):
SampleBatch.ACTIONS: action_tensor,
}
reward = self.reward_fn(input_dict)
if (
"reward_scale" in self.base_env_config
and self.base_env_config["reward_scale"] is not None
):
reward *= self.base_env_config["reward_scale"]
base_reward = {id: reward.item()}
for info_key in base_infos[id].keys():
if "proxy" in info_key:
Expand Down
6 changes: 5 additions & 1 deletion occupancy_measures/experiments/orpo_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..agents.learned_reward_algorithm import LearnedRewardAlgorithm
from ..agents.orpo import ORPO, ORPOPolicy
from ..envs.learned_reward_wrapper import LearnedRewardWrapperConfig
from ..models.glucose_models import normalize_obs
from ..models.reward_model import RewardModelConfig
from ..utils.os_utils import available_cpu_count
from ..utils.training_utils import ( # convert_to_msgpack_checkpoint,
Expand Down Expand Up @@ -447,18 +448,21 @@ def restore_default_params(config=config, env_to_run=env_to_run):
noise_prob = 0.0
action_info_key = []
rew_clip = 50

obs_normalization_func = None
if env_to_run == "tomato":
config.env = "tomato_env_multiagent"
elif env_to_run == "glucose":
config.env = "glucose_env_multiagent"
obs_normalization_func = normalize_obs

max_seq_len = 20
reward_model_width = 32
reward_model_depth = 2

custom_model_config: RewardModelConfig = {
"reward_model_depth": reward_model_depth,
"reward_model_width": reward_model_width,
"normalize_obs": obs_normalization_func,
}
model_config = {
"max_seq_len": max_seq_len,
Expand Down
16 changes: 9 additions & 7 deletions occupancy_measures/models/glucose_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ class GlucoseModelConfig(TypedDict, total=False):
use_subcutaneous_glucose_obs: bool


def normalize_obs(obs):
obs[..., 0] = (obs[..., 0] - 100) / 100
obs[..., 1] = obs[..., 1] * 10
return obs


class GlucoseModel(ModelWithDiscriminator):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
ModelWithDiscriminator.__init__(
Expand Down Expand Up @@ -91,14 +97,10 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
self.discriminator_submodules.append("lstm_discriminator")
self.discriminator_submodules.append("fc_discriminator")

def normalize_obs(self, obs):
obs[..., 0] = (obs[..., 0] - 100) / 100
obs[..., 1] = obs[..., 1] * 10

def forward(self, input_dict, state, seq_lens):
obs = input_dict[SampleBatch.OBS].permute(0, 2, 1).clone()
# Normalize observations
self.normalize_obs(obs)
normalize_obs(obs)
lstm_out, _ = self.lstm(obs)
self._backbone_out = lstm_out[:, -1, :]
# self._backbone_out = self.backbone(obs[:, :, :].mean(axis=1))
Expand Down Expand Up @@ -142,7 +144,7 @@ def discriminator(
normalized_input_dict[SampleBatch.OBS] = (
input_dict[SampleBatch.OBS].clone().permute(0, 2, 1)
)
self.normalize_obs(normalized_input_dict[SampleBatch.OBS])
normalize_obs(normalized_input_dict[SampleBatch.OBS])
normalized_input_dict[SampleBatch.OBS] = normalized_input_dict[
SampleBatch.OBS
].permute(0, 2, 1)
Expand All @@ -152,7 +154,7 @@ def discriminator(
obs = input_dict[SampleBatch.OBS].clone()
indices = torch.tensor(tuple(range(*self.history_range)))
obs = obs[:, :, indices].permute(0, 2, 1)
self.normalize_obs(obs)
normalize_obs(obs)
if self.use_cgm:
obs = obs[:, :, :1]
obs_embed, _ = self.lstm_discriminator(obs)
Expand Down
10 changes: 9 additions & 1 deletion occupancy_measures/models/reward_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, TypedDict
from typing import Callable, List, Optional, TypedDict

import torch
import torch.nn.functional as F
Expand All @@ -15,6 +15,7 @@
class RewardModelConfig(TypedDict, total=False):
reward_model_width: int
reward_model_depth: int
normalize_obs: Optional[Callable]


class RewardModel(TorchModelV2, nn.Module):
Expand All @@ -29,6 +30,11 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name):
)
self.rew_model_width = custom_model_config.get("reward_model_width", 256)
self.rew_model_depth = custom_model_config.get("reward_model_depth", 2)
self.normalize_obs = custom_model_config.get("normalize_obs", None)
if self.normalize_obs is not None:
assert callable(
self.normalize_obs
), "Must specify a function for normalizing the observations in-place"

rew_in_dim = utils.flatdim(action_space) + utils.flatdim(obs_space)

Expand All @@ -54,6 +60,8 @@ def value_function(self):

def learned_reward(self, input_dict):
obs = input_dict[SampleBatch.OBS].flatten(1)
if self.normalize_obs is not None:
self.normalize_obs(obs)
actions = input_dict[SampleBatch.ACTIONS]
net_input = self._get_concatenated_obs_action(obs, actions)
predicted_rew = self.fc_reward_net(net_input)
Expand Down

0 comments on commit 8c3775b

Please sign in to comment.