Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,20 @@ def __repr__(self):
return f'Config({values_repr}\n)'

@contextmanager
def temp_flip_flag(self, var_name: str, var_value: bool):
def temp_flip_flag(self, var_name: str, var_value: bool, prefix='flax'):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new prefix parameter should be documented in the function's docstring. Please add a description for prefix in the Args section. Additionally, the description for var_name is now outdated and should be updated to refer to the generic prefix instead of hardcoding 'flax_'.

"""Context manager to temporarily flip feature flags for test functions.

Args:
var_name: the config variable name (without the 'flax_' prefix)
var_name: the config variable name (without its prefix like 'flax_')
var_value: the boolean value to set var_name to temporarily
prefix: the prefix of the config variable name (default: 'flax')
"""
old_value = getattr(self, f'flax_{var_name}')
old_value = getattr(self, prefix + '_' + var_name)
try:
self.update(f'flax_{var_name}', var_value)
self.update(prefix + '_' + var_name, var_value)
yield
finally:
self.update(f'flax_{var_name}', old_value)
self.update(prefix + '_' + var_name, old_value)
Comment on lines +86 to +91
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For consistency with the previous implementation and for better readability, it's preferable to use f-strings for constructing the flag name instead of string concatenation. This also makes the code less repetitive if you store the f-string in a variable.

Suggested change
old_value = getattr(self, prefix + '_' + var_name)
try:
self.update(f'flax_{var_name}', var_value)
self.update(prefix + '_' + var_name, var_value)
yield
finally:
self.update(f'flax_{var_name}', old_value)
self.update(prefix + '_' + var_name, old_value)
flag_name = f'{prefix}_{var_name}'
old_value = getattr(self, flag_name)
try:
self.update(flag_name, var_value)
yield
finally:
self.update(flag_name, old_value)



config = Config()
Expand Down Expand Up @@ -307,4 +308,4 @@ def static_int_env(varname: str, default: int | None) -> int | None:
name='nnx_graph_updates',
default=True,
help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.',
)
)
Loading