Skip to content

Commit

Permalink
add true_buffer_state to the init of the BaseOptimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed Mar 7, 2024
1 parent 38fa295 commit cc407c7
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions mbpo/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,24 @@ def set_system(self, system: System):
self.system = system

@abstractmethod
def act(self, obs: chex.Array, opt_state: OptimizerState[RewardParams, DynamicsParams], evaluate: bool = True) -> \
Tuple[chex.Array, OptimizerState]:
def act(self,
obs: chex.Array,
opt_state: OptimizerState[RewardParams, DynamicsParams],
evaluate: bool = True) -> Tuple[chex.Array, OptimizerState]:
pass

def train(self, opt_state: OptimizerState[RewardParams, DynamicsParams]) \
-> OptimizerTrainingOutPut[RewardParams, DynamicsParams]:
def train(self,
opt_state: OptimizerState[RewardParams, DynamicsParams]) -> OptimizerTrainingOutPut[
RewardParams, DynamicsParams]:
return OptimizerTrainingOutPut(optimizer_state=opt_state)

def init(self,
key: chex.PRNGKey) -> OptimizerState:
key: chex.PRNGKey,
true_buffer_state: ReplayBufferState | None = None) -> OptimizerState:
pass

def dummy_true_buffer_state(self, key: chex.Array) -> ReplayBufferState:
def dummy_true_buffer_state(self,
key: chex.Array) -> ReplayBufferState:
assert self.system is not None, "Base optimizer requires system to be defined."
dummy_transition = Transition(
observation=jnp.zeros(self.system.x_dim),
Expand Down

0 comments on commit cc407c7

Please sign in to comment.