-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Distribute config typings closer to where they're used
- Loading branch information
1 parent
4ecf0af
commit 3919fea
Showing
13 changed files
with
398 additions
and
373 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Common config typings for agents.""" | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from .utils.replay_buffer import PriorityConfig | ||
|
||
|
||
@dataclass | ||
class ExperienceConfig: | ||
"""Config for experience collection.""" | ||
|
||
n_steps: int | ||
""" | ||
Number of lookahead steps for n-step returns, or zero to lookahead to the | ||
end of the episode (i.e. Monte Carlo returns). | ||
""" | ||
|
||
discount_factor: float | ||
"""Discount factor for future rewards.""" | ||
|
||
buffer_size: int | ||
"""Size of the replay buffer for storing experience.""" | ||
|
||
priority: Optional[PriorityConfig] = None | ||
"""Config for priority replay.""" | ||
|
||
@classmethod | ||
def from_dict(cls, config: dict): | ||
"""Creates an ExperienceConfig from a JSON dictionary.""" | ||
if config.get("priority", None) is not None: | ||
config["priority"] = PriorityConfig.from_dict(config["priority"]) | ||
return cls(**config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Common config typings for agent utils.""" | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class AnnealConfig: | ||
"""Config for annealing a hyperparameter during training.""" | ||
|
||
start: float | ||
"""Starting value.""" | ||
|
||
end: float | ||
"""End value.""" | ||
|
||
steps: int | ||
"""Number of steps to linearly anneal from `start` to `end`.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.