Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete Type Safety: Eliminating All Pyright Errors #245

Merged
merged 74 commits into from
Mar 3, 2025

Conversation

saleml
Copy link
Collaborator

@saleml saleml commented Feb 21, 2025

🎯 Complete Type Safety: Eliminating All Pyright Errors

🌟 Major Achievement

  • ZERO pyright errors across the entire codebase! 🎉
  • Stricter type checking with reportOptionalMemberAccess and reportArgumentType set to "error"
  • Critical step towards a production-ready, enterprise-grade codebase

🏗️ Key Architectural Improvements

1. New Type-Safe Containers 📦

  • Introduced StatePairs[DiscreteStates] for robust state pair handling
  • Generic ReplayBuffer[ContainerType] implementation
  • Type-safe container operations with proper generics

2. Enhanced Type Safety in Core Components ⚡

  • Removed ALL pyright: ignore comments by fixing underlying issues
  • Proper type casting for DiscreteStates in training examples
  • Improved null handling in log probabilities and state operations
  • Better type hints in preprocessors and estimators

3. Configuration & Quality Assurance 🛠️

  • Extended pyright coverage to include tutorials/examples/ and testing/
  • Upgraded pre-commit hooks for stricter type checking
  • More comprehensive type validation across the project

💫 Impact & Importance

Why This Matters

  1. Code Reliability: Eliminates an entire class of runtime errors
  2. Developer Experience: Better IDE support and code navigation
  3. Maintainability: Easier to refactor and extend code with confidence
  4. Documentation: Types serve as living documentation

Strategic Timing ⏰

  • Critical to merge before the next release
  • Should precede the graph PR to ensure type safety foundation
  • Sets the standard for future contributions

🔄 Next Steps

While this PR achieves complete pyright compliance, future improvements could include:

  • Further generic type constraints
  • Additional custom type guards
  • More specific type narrowing
  • Enhanced type documentation

🎓 Technical Details

+ Added reportOptionalMemberAccess = "error"
+ Added reportArgumentType = "error"
- Removed all pyright: ignore comments
+ Extended type checking coverage

🚀 Call to Action

This PR represents a crucial milestone in code quality. Merging it now will:

  1. Set a strong foundation for future development
  2. Prevent type-related technical debt
  3. Ensure the upcoming graph PR maintains type safety standards

Ready for immediate review and high-priority merge 🔥

Salem Lahlou added 30 commits February 21, 2025 14:22
- Improve type hints in stack_states method
- Add robust handling of log rewards in stacking
- Remove pyright ignore comments
- Use cast for type safety in DiscreteStates
- Improve type conversion and error handling
- Implement type-safe method to generate batch of initial states
- Ensure return type is DiscreteStates with an assertion
- Extends base class method with discrete environment specifics
…tioning

- Implement a generic container for storing and manipulating pairs of states
- Support optional conditioning tensors for intermediary and terminating states
- Provide methods for extending, indexing, and accessing state pairs
- Designed to support flow matching and other algorithms requiring state pair processing
- Implement a helper method to compute loss directly from trajectories
- Support different GFlowNet types by handling training sample conversion
- Provide a flexible way to compute loss with optional recalculation of log probabilities
- Enhance loss computation workflow for various GFlowNet implementations
- Add import for StatePairs from state_pairs module
- Extend container module to include the new StatePairs class
- Modify EnumPreprocessor and OneHotPreprocessor to use DiscreteStates type hint
- Update type annotations for get_states_indices and preprocess methods
- Improve type safety for discrete state preprocessing
- Update `to_non_initial_intermediary_and_terminating_states` method to return a StatePairs instance
- Improve type safety by asserting DiscreteStates type for intermediary and terminating states
- Enhance method documentation to clarify its purpose and usage
- Simplify state pair generation with direct StatePairs constructor
- Update FMGFlowNet to use StatePairs instead of tuple for state handling
- Modify loss method to work with StatePairs container
- Simplify type annotations and state processing logic
- Improve type safety by asserting DiscreteStates types
- Update to_training_samples method to return StatePairs
- Update expected_output_dim methods to use @Property decorator
- Remove pyright ignore comments in various modules
- Improve type safety and code clarity in samplers, modules, and utility functions
- Simplify state and action processing in sampling and training methods
- Update type hints in discrete environment and estimator classes
- Implement generic ReplayBuffer with type-safe container handling
- Remove objects_type parameter and use type inference
- Simplify initialization and sampling methods
- Add support for dynamic buffer type detection
- Improve type hints and remove pyright ignore comments
- Update test cases to work with new generic buffer implementation
Comment on lines +90 to +95
states = env._step(states, actions)

# Step 4 fails due an invalid input action.
actions = env.actions_from_tensor(format_tensor(failing_actions_list))
with pytest.raises(NonValidActionsError):
states = env._step(states, actions) # pyright: ignore
states = env._step(states, actions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about exchanging the name of env.step and env._step? Currently, env._step is the one that needs to be called externally (as seen in samplers.py), which seems unusual for something with a private-style naming convention.

The same goes for env._backward_step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last year, we used to have maskless_step and step.
We then changed them to step and _step respectively.

A user that defines their environment needs to define step only, and _step handles the masking for them.

When I don't work on the codebase for 1-2 months and go back to it, I agree that _step is confusing. What do you think of we change it to safe_step? (obviously, in a new enviornment, the user would still need to write step only).

@josephdviviano , your opinion here would be appreciated too. Thanks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t really have a strong preference on this —safe_step seems to be fine also.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the name safe_step, which implies the existence of unsafe_step.

I understand the use of env._step to be correct in this case (how it is called by the Sampler) - of course this is all subjective but I'm comfortable with the current naming --

Let me know what you think:

https://claude.ai/share/8e7a4b6a-7347-4b8e-b064-2f510c2a6d3e

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option might be to call this method env._base_step -- but I think we should keep the _ which denotes to the user of the library "you shouldn't call this method unless you really know what you're doing".

hyeok9855 and others added 5 commits February 26, 2025 02:55
Change `args.replay_capacity` to `args.replay_buffer_size` to align with parameter naming convention
Improve documentation for the __getitem__ method in StatePairs to clarify batch dimension indexing and note potential differences in intermediary and terminating states batch shapes
@saleml
Copy link
Collaborator Author

saleml commented Feb 26, 2025

Thank you for your commit @hyeok9855. I have addressed all your points, and left a question.

Salem Lahlou added 6 commits February 26, 2025 16:17
Extend test coverage for hypergrid training by introducing a new parametrized test that checks different loss functions and replay buffer sizes. Also add new configuration options to HypergridArgs and CommonArgs classes to support these variations.
…hods

Refactor the States class to remove the _log_rewards attribute and associated methods. Update StatePairs and related classes to handle log rewards more explicitly, including modifications to initialization, concatenation, and indexing methods.
Modify the `initialize` method to simplify type checking and initialization of training objects. Move the initialization logic outside of the condition and ensure the buffer is only initialized when no training objects exist.
Copy link
Collaborator

@hyeok9855 hyeok9855 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I approved with one concern; see below.

I really appreciate this PR, @saleml !!

@@ -60,7 +60,7 @@ def __init__(
self.terminating_conditioning = terminating_conditioning

def __len__(self) -> int:
return len(self.intermediary_states) + len(self.terminating_states)
return min(len(self.intermediary_states), len(self.terminating_states))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach feels a bit like a workaround. Using the sum of the two lengths seems more reasonable to me.

It looks like we might need to refactor this a bit more. For example, instead of storing intermediary_states and terminating_states in separate variables, we could combine them into one and use two sets of indices to track whether a state is intermediate or terminating.

If you agree, I'll go ahead and create an issue for this. Let me know if you have any better suggestions!"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's a bit hacky. This is only because we have to have a len function in

def sample(self, n_samples: int) -> Container:
        """Samples a subset of the container."""
        return self[torch.randperm(len(self))[:n_samples]]

in containers/base.py.
Please go ahead and raise the issue. Thanks for your review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saleml saleml removed the request for review from younik February 27, 2025 07:53
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial review of the core features. I didn't check everything. Thanks very much to @hyeok9855 for his thorough review. I do however have some important comments to address.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure I agree with removing this file completely. What i agree with is skipping the tests (but leave in black / etc).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a duplicate of ci.py. Everything is still tested for on github (e.g., this PR)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this is fine, this was duplicated effort more or less, but I worry that pytorch geometric might require us to be much more bound to conda, so I wonder if this is premature.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

README.md Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome

pyproject.toml Outdated
@@ -25,11 +25,12 @@ classifiers = [
einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
torch = ">=1.9.0"
torch = "==2.6.0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>= I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

pyproject.toml Outdated

# dev dependencies.
black = { version = "24.3", optional = true }
flake8 = { version = "*", optional = true }
pyright = {version = "*", optional = true}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pin this to a version, in case the rules change, and it leads to changes being requested across the library.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Comment on lines +5 to 6
from typing import ClassVar, List, Sequence

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we rely too much on nonstandard types throughout the code for us to bother imho, but consistency is important.

Comment on lines +94 to +95
# We know this is safe because PFBasedGFlowNet's loss accepts these arguments
return self.loss(env, training_samples, recalculate_all_logprobs=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can recalculation be automatically configured IIF we're doing on policy training? We once had a flag to this effect. If a warning is thrown here, I'm not sure what the user can actually do about it.

@@ -34,6 +34,7 @@ def __init__(
self._output_dim = output_dim

if trunk is None:
hidden_dim = hidden_dim or 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this magic number. Init has a default value for hidden_dim, so why do we need this raw 256?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes true

Comment on lines +90 to +95
states = env._step(states, actions)

# Step 4 fails due an invalid input action.
actions = env.actions_from_tensor(format_tensor(failing_actions_list))
with pytest.raises(NonValidActionsError):
states = env._step(states, actions) # pyright: ignore
states = env._step(states, actions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the name safe_step, which implies the existence of unsafe_step.

I understand the use of env._step to be correct in this case (how it is called by the Sampler) - of course this is all subjective but I'm comfortable with the current naming --

Let me know what you think:

https://claude.ai/share/8e7a4b6a-7347-4b8e-b064-2f510c2a6d3e

Comment on lines +90 to +95
states = env._step(states, actions)

# Step 4 fails due an invalid input action.
actions = env.actions_from_tensor(format_tensor(failing_actions_list))
with pytest.raises(NonValidActionsError):
states = env._step(states, actions) # pyright: ignore
states = env._step(states, actions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One option might be to call this method env._base_step -- but I think we should keep the _ which denotes to the user of the library "you shouldn't call this method unless you really know what you're doing".

- Relax torch version constraint to >=2.6.0
- Pin pyright version to 1.1.395
- Modify warning message in base.py to provide clearer guidance on log probability recalculation
- Enforce hidden_dim requirement in MLP initialization
@saleml
Copy link
Collaborator Author

saleml commented Feb 28, 2025

One option might be to call this method env._base_step -- but I think we should keep the _ which denotes to the user of the library "you shouldn't call this method unless you really know what you're doing"

No strong opinion here. We can have this discussion outside this PR.

Thanks for your review. I have addressed your comments.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional comment

Comment on lines +94 to +95
# We know this is safe because PFBasedGFlowNet's loss accepts these arguments
return self.loss(env, training_samples, recalculate_all_logprobs=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saleml I don't understand what the purpose of this warning is. If we're throwing a warning, this implies the user can actually do something to improve efficiency. Right now the only thing the user can do is subclass this gflownew base class and override this method completely.

I'm going to file an issue about this.

@josephdviviano josephdviviano merged commit 2ef0824 into master Mar 3, 2025
2 checks passed
@josephdviviano josephdviviano deleted the fixpyright branch March 4, 2025 05:16
@saleml saleml mentioned this pull request Mar 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants