Skip to content

Commit

Permalink
Merge pull request #4350 from IvyZX:logical-axis
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694704983
  • Loading branch information
Flax Authors committed Nov 9, 2024
2 parents 803aca8 + 4d3006a commit d31f290
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 148 deletions.
80 changes: 80 additions & 0 deletions flax/core/spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import dataclasses
import threading

from flax.typing import (
LogicalRules,
Sharding,
)

# Dynamic Axis Mapping Context
# ------------------------------------------------------------------------------


@dataclasses.dataclass
class _AxisRules(threading.local):
"""Dynamic logical axis to mesh axis binding context."""

rules: LogicalRules = ()


# Global axis binding context.
_axis_rules = _AxisRules()


def set_logical_axis_rules(rules: LogicalRules):
"""Sets the global logical axis to mesh axis binding."""
_axis_rules.rules = rules


def get_logical_axis_rules() -> LogicalRules:
"""Returns the global logical axis to mesh axis binding."""
return _axis_rules.rules


@contextlib.contextmanager
def logical_axis_rules(rules: LogicalRules):
"""Context manager for setting the logical to mesh axis bindings."""
old_rules = _axis_rules.rules
try:
_axis_rules.rules = rules
yield
finally:
_axis_rules.rules = old_rules


def composite_rules(rule1, rule2):
if not rule1 and not rule2:
return ()
rules = {alias: value for alias, value in rule1}
for alias, value in rule2:
if alias in rules and rules[alias] != value:
raise ValueError(
f'Inconsistent logical axis annotations for {alias}: '
f'{rules[alias]} vs {value}'
)
rules[alias] = value
return tuple(rules.items())


def from_sharding_rules(
sharding: Sharding, sharding_rules: LogicalRules
) -> Sharding:
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return tuple(
rules[str(s)] if (s and str(s) in rules) else s for s in sharding
)
8 changes: 5 additions & 3 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
unbox as unbox,
with_partitioning as with_partitioning,
)
from flax.core.spmd import (
get_logical_axis_rules as get_logical_axis_rules,
logical_axis_rules as logical_axis_rules,
set_logical_axis_rules as set_logical_axis_rules,
)
from .activation import (
PReLU as PReLU,
celu as celu,
Expand Down Expand Up @@ -130,12 +135,9 @@
)
from .spmd import (
LogicallyPartitioned as LogicallyPartitioned,
get_logical_axis_rules as get_logical_axis_rules,
logical_axis_rules as logical_axis_rules,
logical_to_mesh,
logical_to_mesh_axes,
logical_to_mesh_sharding,
set_logical_axis_rules as set_logical_axis_rules,
with_logical_constraint,
with_logical_partitioning as with_logical_partitioning,
)
Expand Down
8 changes: 3 additions & 5 deletions flax/linen/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@
CollectionFilter as CollectionFilter,
PRNGSequenceFilter as PRNGSequenceFilter,
)
from flax.linen.spmd import _axis_rules # pylint: disable=unused-import
from flax.linen.spmd import _AxisRules # pylint: disable=unused-import
from flax.core.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import
from flax.core.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import
from flax.core.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import
from flax.linen.spmd import _is_logical_spec
from flax.linen.spmd import _with_sharding_constraint # pylint: disable=unused-import
from flax.linen.spmd import get_logical_axis_rules as get_axis_rules # pylint: disable=unused-import
from flax.linen.spmd import logical_axis_rules as axis_rules # pylint: disable=unused-import
from flax.linen.spmd import logical_to_mesh # pylint: disable=unused-import
from flax.linen.spmd import logical_to_mesh_axes # pylint: disable=unused-import
from flax.linen.spmd import RulesFallback
from flax.linen.spmd import set_logical_axis_rules as set_axis_rules # pylint: disable=unused-import
from flax.linen.spmd import with_logical_constraint as with_sharding_constraint
from flax.traverse_util import flatten_dict
from flax.traverse_util import unflatten_dict
Expand Down
45 changes: 5 additions & 40 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
"""

import collections
import contextlib
import dataclasses
import enum
import functools
import threading
from typing import Any
from collections.abc import Callable, Sequence

Expand All @@ -39,6 +37,9 @@

from flax import struct
from flax.core import meta
from flax.core.spmd import (
get_logical_axis_rules,
)
from flax.typing import (
Array,
LogicalNames,
Expand All @@ -49,42 +50,6 @@
)


# Dynamic Axis Mapping Context
# ------------------------------------------------------------------------------


@dataclasses.dataclass
class _AxisRules(threading.local):
"""Dynamic logical axis to mesh axis binding context."""

rules: LogicalRules = ()


# Global axis binding context.
_axis_rules = _AxisRules()


def set_logical_axis_rules(rules: LogicalRules):
"""Sets the global logical axis to mesh axis binding."""
_axis_rules.rules = rules


def get_logical_axis_rules() -> LogicalRules:
"""Returns the global logical axis to mesh axis binding."""
return _axis_rules.rules


@contextlib.contextmanager
def logical_axis_rules(rules: LogicalRules):
"""Context manager for setting the logical to mesh axis bindings."""
old_rules = _axis_rules.rules
try:
_axis_rules.rules = rules
yield
finally:
_axis_rules.rules = old_rules


class _UnassignedAxis:
"""Sentinel class for unassigned logical axis name."""

Expand Down Expand Up @@ -115,7 +80,7 @@ def _logical_to_mesh_axes(
if array_dim_names is None:
return None
if rules is None:
rules = _axis_rules.rules
rules = get_logical_axis_rules()
axis_name_counts = collections.Counter(array_dim_names)
dups = tuple(
k for k, v in axis_name_counts.items() if v > 1 and k is not None
Expand Down Expand Up @@ -292,7 +257,7 @@ def with_logical_constraint(
"""Version of jit's with_sharding_constraint that uses logical axis names."""
# If no axis binding is set, this is a no-op.
if rules is None:
rules = _axis_rules.rules
rules = get_logical_axis_rules()
if not rules or logical_axis_resources is None:
return x
# Translate logical names to mesh assignments.
Expand Down
Loading

0 comments on commit d31f290

Please sign in to comment.