-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complete Type Safety: Eliminating All Pyright Errors #245
Conversation
- Improve type hints in stack_states method - Add robust handling of log rewards in stacking - Remove pyright ignore comments - Use cast for type safety in DiscreteStates - Improve type conversion and error handling
- Implement type-safe method to generate batch of initial states - Ensure return type is DiscreteStates with an assertion - Extends base class method with discrete environment specifics
…tioning - Implement a generic container for storing and manipulating pairs of states - Support optional conditioning tensors for intermediary and terminating states - Provide methods for extending, indexing, and accessing state pairs - Designed to support flow matching and other algorithms requiring state pair processing
- Implement a helper method to compute loss directly from trajectories - Support different GFlowNet types by handling training sample conversion - Provide a flexible way to compute loss with optional recalculation of log probabilities - Enhance loss computation workflow for various GFlowNet implementations
- Add import for StatePairs from state_pairs module - Extend container module to include the new StatePairs class
- Modify EnumPreprocessor and OneHotPreprocessor to use DiscreteStates type hint - Update type annotations for get_states_indices and preprocess methods - Improve type safety for discrete state preprocessing
- Update `to_non_initial_intermediary_and_terminating_states` method to return a StatePairs instance - Improve type safety by asserting DiscreteStates type for intermediary and terminating states - Enhance method documentation to clarify its purpose and usage - Simplify state pair generation with direct StatePairs constructor
- Update FMGFlowNet to use StatePairs instead of tuple for state handling - Modify loss method to work with StatePairs container - Simplify type annotations and state processing logic - Improve type safety by asserting DiscreteStates types - Update to_training_samples method to return StatePairs
- Update expected_output_dim methods to use @Property decorator - Remove pyright ignore comments in various modules - Improve type safety and code clarity in samplers, modules, and utility functions - Simplify state and action processing in sampling and training methods - Update type hints in discrete environment and estimator classes
- Implement generic ReplayBuffer with type-safe container handling - Remove objects_type parameter and use type inference - Simplify initialization and sampling methods - Add support for dynamic buffer type detection - Improve type hints and remove pyright ignore comments - Update test cases to work with new generic buffer implementation
states = env._step(states, actions) | ||
|
||
# Step 4 fails due an invalid input action. | ||
actions = env.actions_from_tensor(format_tensor(failing_actions_list)) | ||
with pytest.raises(NonValidActionsError): | ||
states = env._step(states, actions) # pyright: ignore | ||
states = env._step(states, actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about exchanging the name of env.step
and env._step
? Currently, env._step
is the one that needs to be called externally (as seen in samplers.py), which seems unusual for something with a private-style naming convention.
The same goes for env._backward_step
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last year, we used to have maskless_step
and step
.
We then changed them to step
and _step
respectively.
A user that defines their environment needs to define step
only, and _step
handles the masking for them.
When I don't work on the codebase for 1-2 months and go back to it, I agree that _step
is confusing. What do you think of we change it to safe_step
? (obviously, in a new enviornment, the user would still need to write step
only).
@josephdviviano , your opinion here would be appreciated too. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t really have a strong preference on this —safe_step
seems to be fine also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love the name safe_step
, which implies the existence of unsafe_step
.
I understand the use of env._step
to be correct in this case (how it is called by the Sampler
) - of course this is all subjective but I'm comfortable with the current naming --
Let me know what you think:
https://claude.ai/share/8e7a4b6a-7347-4b8e-b064-2f510c2a6d3e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One option might be to call this method env._base_step
-- but I think we should keep the _
which denotes to the user of the library "you shouldn't call this method unless you really know what you're doing".
Change `args.replay_capacity` to `args.replay_buffer_size` to align with parameter naming convention
Improve documentation for the __getitem__ method in StatePairs to clarify batch dimension indexing and note potential differences in intermediary and terminating states batch shapes
Thank you for your commit @hyeok9855. I have addressed all your points, and left a question. |
…which to calulcate PB. this fixes that
Extend test coverage for hypergrid training by introducing a new parametrized test that checks different loss functions and replay buffer sizes. Also add new configuration options to HypergridArgs and CommonArgs classes to support these variations.
…hods Refactor the States class to remove the _log_rewards attribute and associated methods. Update StatePairs and related classes to handle log rewards more explicitly, including modifications to initialization, concatenation, and indexing methods.
Modify the `initialize` method to simplify type checking and initialization of training objects. Move the initialization logic outside of the condition and ensure the buffer is only initialized when no training objects exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I approved with one concern; see below.
I really appreciate this PR, @saleml !!
@@ -60,7 +60,7 @@ def __init__( | |||
self.terminating_conditioning = terminating_conditioning | |||
|
|||
def __len__(self) -> int: | |||
return len(self.intermediary_states) + len(self.terminating_states) | |||
return min(len(self.intermediary_states), len(self.terminating_states)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach feels a bit like a workaround. Using the sum of the two lengths seems more reasonable to me.
It looks like we might need to refactor this a bit more. For example, instead of storing intermediary_states
and terminating_states
in separate variables, we could combine them into one and use two sets of indices to track whether a state is intermediate or terminating.
If you agree, I'll go ahead and create an issue for this. Let me know if you have any better suggestions!"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it's a bit hacky. This is only because we have to have a len
function in
def sample(self, n_samples: int) -> Container:
"""Samples a subset of the container."""
return self[torch.randperm(len(self))[:n_samples]]
in containers/base.py
.
Please go ahead and raise the issue. Thanks for your review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partial review of the core features. I didn't check everything. Thanks very much to @hyeok9855 for his thorough review. I do however have some important comments to address.
.github/workflows/pre-commit.yml
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I agree with removing this file completely. What i agree with is skipping the tests (but leave in black / etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was a duplicate of ci.py
. Everything is still tested for on github (e.g., this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this is fine, this was duplicated effort more or less, but I worry that pytorch geometric might require us to be much more bound to conda, so I wonder if this is premature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
README.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome
pyproject.toml
Outdated
@@ -25,11 +25,12 @@ classifiers = [ | |||
einops = ">=0.6.1" | |||
numpy = ">=1.21.2" | |||
python = "^3.10" | |||
torch = ">=1.9.0" | |||
torch = "==2.6.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>=
I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
pyproject.toml
Outdated
|
||
# dev dependencies. | ||
black = { version = "24.3", optional = true } | ||
flake8 = { version = "*", optional = true } | ||
pyright = {version = "*", optional = true} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should pin this to a version, in case the rules change, and it leads to changes being requested across the library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
from typing import ClassVar, List, Sequence | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we rely too much on nonstandard types throughout the code for us to bother imho, but consistency is important.
# We know this is safe because PFBasedGFlowNet's loss accepts these arguments | ||
return self.loss(env, training_samples, recalculate_all_logprobs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can recalculation be automatically configured IIF we're doing on policy training? We once had a flag to this effect. If a warning is thrown here, I'm not sure what the user can actually do about it.
src/gfn/utils/modules.py
Outdated
@@ -34,6 +34,7 @@ def __init__( | |||
self._output_dim = output_dim | |||
|
|||
if trunk is None: | |||
hidden_dim = hidden_dim or 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this magic number. Init has a default value for hidden_dim
, so why do we need this raw 256
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes true
states = env._step(states, actions) | ||
|
||
# Step 4 fails due an invalid input action. | ||
actions = env.actions_from_tensor(format_tensor(failing_actions_list)) | ||
with pytest.raises(NonValidActionsError): | ||
states = env._step(states, actions) # pyright: ignore | ||
states = env._step(states, actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love the name safe_step
, which implies the existence of unsafe_step
.
I understand the use of env._step
to be correct in this case (how it is called by the Sampler
) - of course this is all subjective but I'm comfortable with the current naming --
Let me know what you think:
https://claude.ai/share/8e7a4b6a-7347-4b8e-b064-2f510c2a6d3e
states = env._step(states, actions) | ||
|
||
# Step 4 fails due an invalid input action. | ||
actions = env.actions_from_tensor(format_tensor(failing_actions_list)) | ||
with pytest.raises(NonValidActionsError): | ||
states = env._step(states, actions) # pyright: ignore | ||
states = env._step(states, actions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One option might be to call this method env._base_step
-- but I think we should keep the _
which denotes to the user of the library "you shouldn't call this method unless you really know what you're doing".
- Relax torch version constraint to >=2.6.0 - Pin pyright version to 1.1.395 - Modify warning message in base.py to provide clearer guidance on log probability recalculation - Enforce hidden_dim requirement in MLP initialization
No strong opinion here. We can have this discussion outside this PR. Thanks for your review. I have addressed your comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional comment
# We know this is safe because PFBasedGFlowNet's loss accepts these arguments | ||
return self.loss(env, training_samples, recalculate_all_logprobs=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@saleml I don't understand what the purpose of this warning is. If we're throwing a warning, this implies the user can actually do something to improve efficiency. Right now the only thing the user can do is subclass this gflownew base class and override this method completely.
I'm going to file an issue about this.
🎯 Complete Type Safety: Eliminating All Pyright Errors
🌟 Major Achievement
reportOptionalMemberAccess
andreportArgumentType
set to "error"🏗️ Key Architectural Improvements
1. New Type-Safe Containers 📦
StatePairs[DiscreteStates]
for robust state pair handlingReplayBuffer[ContainerType]
implementation2. Enhanced Type Safety in Core Components ⚡
pyright: ignore
comments by fixing underlying issuesDiscreteStates
in training examples3. Configuration & Quality Assurance 🛠️
tutorials/examples/
andtesting/
💫 Impact & Importance
Why This Matters
Strategic Timing ⏰
🔄 Next Steps
While this PR achieves complete pyright compliance, future improvements could include:
🎓 Technical Details
🚀 Call to Action
This PR represents a crucial milestone in code quality. Merging it now will:
Ready for immediate review and high-priority merge 🔥