Skip to content

Commit 7a7bc8c

Browse files
committed
Merge branch 'pl' of https://github.com/epfl-dlab/PauseToken into pl
2 parents 5511be6 + e30c14f commit 7a7bc8c

File tree

8 files changed

+184
-95
lines changed

8 files changed

+184
-95
lines changed

configs/experiment/train/star_on_policy_pause.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ data:
2323

2424

2525
trainer:
26-
inner_loop_timesteps: 10
26+
inner_loop_timesteps: 3
2727
n_outer_loops: 5
2828
progress_bar: false
2929
num_val_samples: 10

lm_stable_baselines/buffers/lm_rollout_buffer.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
filler_token = -100,
2323
**kwargs
2424
):
25+
self.filler_token = filler_token
2526
super().__init__(*args, **kwargs)
2627
self.set_filler_token(filler_token)
2728
self.tokenizer = tokenizer
@@ -34,8 +35,8 @@ def set_tokenizer(self, tokenizer):
3435

3536
def reset(self) -> None:
3637
super().reset()
37-
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long)
38-
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long)
38+
self.observations = np.zeros((self.buffer_size, self.n_envs, *self.obs_shape), dtype=np.long) + self.filler_token
39+
self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.long) +self.filler_token
3940
self.above_threshold_indices = None
4041
self.data_size = 0
4142

@@ -55,42 +56,45 @@ def to_torch(self, array: Union[np.ndarray, torch.Tensor, transformers.BatchEnco
5556

5657

5758
def find_where_advantage_exceeds_threshold(self, advantage: np.ndarray) -> None:
59+
if self.advantage_threshold is None:
60+
self.advantage_threshold = - np.inf
5861
self.above_threshold_indices = np.where(advantage > self.advantage_threshold)
62+
self.remaining_indices = None
5963
self.data_size = len(self.above_threshold_indices[0])
60-
61-
64+
6265
def sample_batch(self, batch_size, env: Optional[VecNormalize] = None) -> RolloutBufferSamples:
63-
# Get the positions of the allowed indices (where the matrix is 1)
64-
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size)
65-
66-
# Sample randomly from the allowed indices
67-
idx = np.random.choice(len(allowed_indices), size=batch_size, replace=True)
68-
sampled_positions = (allowed_indices[0][idx], allowed_indices[1][idx])
69-
70-
obs = self.observations[sampled_positions]
66+
# Initialize remaining indices if it's the first pass or if we've exhausted the dataset
67+
if self.remaining_indices is None or len(self.remaining_indices[0]) == 0:
68+
allowed_indices = self.above_threshold_indices if self.above_threshold_indices is not None else np.arange(self.buffer_size)
69+
# Shuffle the allowed indices
70+
shuffled_indices = np.random.permutation(np.arange(len(allowed_indices[0])))
71+
# Store shuffled indices for further sampling
72+
self.remaining_indices = (allowed_indices[0][shuffled_indices], allowed_indices[1][shuffled_indices])
7173

72-
obs = self.tokenizer(
73-
self.tokenizer.batch_decode(
74-
remove_filler_tokens(obs[..., 1:].long(), self.filler_token) # remove the first token (the bos token, tokenizer will re-add it)
75-
),
76-
return_tensors="pt", padding=True, truncation=True
77-
)
74+
# Sample from the remaining indices without replacement
75+
num_remaining = len(self.remaining_indices[0])
76+
num_to_sample = min(batch_size, num_remaining)
7877

79-
actions = self.tokenizer(
80-
self.tokenizer.batch_decode(
81-
remove_filler_tokens(self.actions[sampled_positions], self.filler_token) # don't remove the first token (since it's an action, it didn't start with a bos token)
82-
),
83-
return_tensors="pt", padding=True, truncation=True
84-
)["input_ids"][..., 1:] # remove the first token (the bos token, actions should not have it)
78+
idx = np.arange(num_remaining)[:num_to_sample]
79+
sampled_positions = (self.remaining_indices[0][idx], self.remaining_indices[1][idx])
8580

81+
# Remove the sampled positions from remaining indices
82+
self.remaining_indices = (
83+
np.delete(self.remaining_indices[0], idx),
84+
np.delete(self.remaining_indices[1], idx)
85+
)
86+
87+
return self.sample_indices(sampled_positions)
88+
89+
def sample_indices(self, idx, padding='right') -> RolloutBufferSamples:
90+
assert idx[0].shape == idx[1].shape, "The indices must have the same shape"
8691
data = (
87-
self.observations[sampled_positions],
88-
self.actions[sampled_positions],
89-
self.values[sampled_positions].flatten(),
90-
self.log_probs[sampled_positions].flatten(),
91-
self.advantages[sampled_positions].flatten(),
92-
self.returns[sampled_positions].flatten(),
92+
self.observations[idx],
93+
self.actions[idx],
94+
self.values[idx].flatten(),
95+
self.log_probs[idx].flatten(),
96+
self.advantages[idx].flatten(),
97+
self.returns[idx].flatten(),
9398
)
9499

95100
return RolloutBufferSamples(*tuple(map(self.to_torch, data)))
96-

lm_stable_baselines/environments/language_model_env.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
from lm_stable_baselines.utils import remove_filler_tokens
99
import warnings
10+
import torch
11+
from torch import LongTensor
1012
class LanguageModelEnv(Env):
1113
""" Environment for language models. This class is a subclass of gym.Env and is used to handle language model environments.
1214
This environment allows to sample from a dataset and compute rewards based on the model output and the ground truth.
@@ -75,12 +77,22 @@ def reprermute_dataset_id_list(cls):
7577
# TODO: check if this is necessary
7678
# NICKY:
7779
# I don't think we nee this. We want dataset_id_list to be a static variable that is shared across all instances of the class
78-
#   We know which sample to take thanks to LanguageModelEnv.next_idx
80+
# We know which sample to take thanks to LanguageModelEnv.next_idx
7981
# if LanguageModelEnv.n_envs != -1:
8082
# self.dataset_id_list = LanguageModelEnv.dataset_id_list[self.env_idx::LanguageModelEnv.n_envs]
8183
# else:
8284
# self.dataset_id_list = LanguageModelEnv.dataset_id_list
8385

86+
def _step(self, curr_obs, action):
87+
if isinstance(curr_obs, list):
88+
curr_obs.extend(action)
89+
elif isinstance(curr_obs, torch.Tensor):
90+
curr_obs = torch.cat([curr_obs, action], dim = 0)
91+
else:
92+
raise ValueError("curr_obs should be a list or a tensor")
93+
return curr_obs
94+
95+
8496
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
8597
""" Apply an action to the environment. For a language model it's simply adding the action to the current state
8698
@@ -91,11 +103,21 @@ def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[
91103
"""
92104

93105
clean_action = remove_filler_tokens(action, self.filler_token).squeeze(-1).tolist()
94-
self.current_state.extend(clean_action)
106+
self.current_state = self._step(self.current_state, clean_action)
95107
observation , reward, terminated, truncated, info = self._get_obs()
96108

97109
return observation, reward, terminated, truncated, info
110+
111+
def next_observation_from_observation_and_action(self, obs: LongTensor, actions: LongTensor) -> List[List[int]]:
112+
#assumption: filler tokens have been removed
113+
unpadded_obs = remove_filler_tokens(obs, self.filler_token)
114+
unpadded_acts = remove_filler_tokens(actions, self.filler_token)
115+
116+
new_observations = [self._step(observation,action) for observation, action in zip(unpadded_obs,unpadded_acts)]
98117

118+
return new_observations
119+
120+
99121
def is_terminated(self, state: List[int]):
100122
""" Check if the state is terminated
101123
@@ -216,6 +238,8 @@ def _get_obs(self):
216238
info = {}
217239
return np.array(self.current_state) , reward, is_terminated, is_truncated, info
218240

241+
242+
219243
def render(self):
220244
""" Render the current state
221245
@@ -230,4 +254,5 @@ def close(self):
230254
This is critical for closing rendering windows, database or HTTP connections.
231255
Calling ``close`` on an already closed environment has no effect and won't raise an error.
232256
"""
233-
pass
257+
pass
258+

lm_stable_baselines/environments/vectorized_environments/lm_dummy_vec_enc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,5 @@ def set_stage(self, stage: str, **kwargs):
4949
"""
5050
for env in self.envs:
5151
env.set_stage(stage, **kwargs)
52+
53+

lm_stable_baselines/policies/llm_base_policy.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,29 +186,30 @@ def forward(self, obs: PyTorchObs, labels = None) -> torch.Tensor:
186186
# feature = self.extract_features(obs)
187187
# feature = {k: v.to(self.device) for k, v in feature.items()}
188188
# feature["labels"] = labels.to(self.device) if labels is not None else None
189-
actions = self._predict(obs)
190-
padded_actions = actions.clone()
191-
padded_actions[actions == self.filler_token] = self.tokenizer.pad_token_id
192-
logprobs = torch.log_softmax(self.lm(padded_actions).logits, dim = -1)
193-
mask = (actions != self.filler_token).float()
194-
logprob_actions = torch.gather(logprobs, 2, padded_actions.unsqueeze(-1)).squeeze(-1)
189+
190+
next_obs, actions, unpadded_actions = self._predict(obs).values()
191+
192+
logprobs = torch.log_softmax(self.lm(next_obs).logits, dim = -1)[:, (-unpadded_actions.shape[1]-1):-1, ...]
193+
mask = (unpadded_actions != self.tokenizer.pad_token_id).float()
194+
logprob_actions = torch.gather(logprobs, 2, unpadded_actions.unsqueeze(-1)).squeeze(-1)
195195
logprobs = (logprob_actions * mask).sum(dim = 1)
196-
values = self.predict_values(obs)
196+
values = self.predict_values(next_obs)
197197
return actions, values, logprobs
198198

199199
def predict_values(self, obs: PyTorchObs) -> torch.Tensor:
200200
raise NotImplementedError
201201

202202
def post_predict(self, inputs: torch.Tensor, outputs: torch.Tensor) -> torch.Tensor:
203203
#remove the input tokens from the output
204-
actions = outputs[:, inputs.shape[-1]:]
204+
actions = outputs[:, inputs.shape[-1]:].clone()
205+
padded_actions = actions.clone()
205206
#replace all pad tokens with filler tokens
206-
actions[actions == self.tokenizer.pad_token_id] = self.filler_token
207+
padded_actions[actions == self.tokenizer.pad_token_id] = self.filler_token
207208

208209
action_space_dim = self.action_space.shape[0]
209-
actions = add_filler_tokens(actions, action_space_dim, self.filler_token)
210+
padded_actions = add_filler_tokens(padded_actions, action_space_dim, self.filler_token)
210211

211-
return actions
212+
return {'next_observation':outputs, 'actions': padded_actions, 'unpadded_actions': actions}
212213

213214
def pre_predict(self, feature: PyTorchObs) -> PyTorchObs:
214215
pass

lm_stable_baselines/training_algorithms/star_on_policy.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ def predict_values(obs: PyTorchObs) -> torch.Tensor:
5252
# return -1 for all values
5353
return torch.ones(obs.shape[0]) * 0
5454

55+
def process_rollouts(self, data):
56+
next_obs = self.env.envs[0].next_observation_from_observation_and_action(data.observations[:,1:], data.actions)
57+
#create the next observation by interacting with the environment and then tokenizing to get input_ids + attention mask
58+
next_observation = self.policy.tokenizer.pad(
59+
{'input_ids': next_obs},
60+
return_tensors="pt",
61+
padding=True,
62+
)
63+
return next_observation
5564

5665
def train(self) -> None:
5766
self.policy.train()
@@ -61,27 +70,32 @@ def train(self) -> None:
6170

6271
self.rollout_buffer.find_where_advantage_exceeds_threshold(self.rollout_buffer.advantages)
6372
n_batches = self.rollout_buffer.data_size // self.batch_size
64-
73+
74+
self.policy.tokenizer.padding_side = "right"
6575
for _ in range(n_batches):
6676

6777
self._n_updates += 1
68-
6978
data = self.rollout_buffer.sample_batch(self.batch_size, env=self._vec_normalize_env)
70-
79+
next_observation = self.process_rollouts(data)
7180
if self.loss_computed_in_forward_pass:
72-
labels = data.next_observations["input_ids"]
81+
labels = next_observation["input_ids"]
7382
labels_list = list(labels.cpu())
7483
collated_labels = self.data_collator(labels_list)
7584
labels = collated_labels["labels"] # check with self.policy.tokenizer.decode(labels[0][labels[0]>0])
7685
else:
7786
labels = None
7887

79-
output = self.policy(data.next_observations, labels=labels)
88+
output = self.policy.lm(input_ids=next_observation['input_ids'].to(self.device),
89+
attention_mask=next_observation['attention_mask'].to(self.device),
90+
labels=labels.to(self.device))
8091

8192
if self.loss_computed_in_forward_pass:
8293
nll_loss = output.loss
94+
#if control token model you can also get these losses:
95+
#control_token_loss = output.ctrl_tok_loss
96+
#lm_loss = output.lm_loss
8397
else:
84-
nll_loss = self.policy.compute_nll_loss(output.logits, data.next_observations)
98+
nll_loss = self.policy.compute_nll_loss(output.logits, labels)
8599

86100
nll_losses.append(nll_loss.item())
87101

0 commit comments

Comments
 (0)