Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mbridge/models/qwen3_vl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from megatron.core import parallel_state
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import get_tensor_model_parallel_group_if_none
try:
from megatron.core.utils import get_tensor_model_parallel_group_if_none
except:
from mbridge.utils.megatron import get_tensor_model_parallel_group_if_none
from torch import nn
from torch.nn import functional as F

Expand Down
5 changes: 4 additions & 1 deletion mbridge/models/qwen3_vl/vision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from megatron.core.transformer.enums import ModelType
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import get_tensor_model_parallel_group_if_none
try:
from megatron.core.utils import get_tensor_model_parallel_group_if_none
except:
from mbridge.utils.megatron import get_tensor_model_parallel_group_if_none
from torch import nn
from torch.nn import functional as F

Expand Down
31 changes: 31 additions & 0 deletions mbridge/utils/megatron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Utility functions used throughout Megatron core"""

import warnings
import torch
from megatron.core import parallel_state

def get_tensor_model_parallel_group_if_none(tp_group, is_expert=False, check_initialized=True):
"""Issue a deprecation warning if tp_group is None and return the default tp group."""
# TODO(zijiey): remove this function later.
if not torch.distributed.is_initialized():
return None

if tp_group is None:
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
warnings.warn(
"Warning: tp_group is None, using default tp group. "
"Passing tp_group will be mandatory soon",
DeprecationWarning,
stacklevel=2,
)
if is_expert:
tp_group = parallel_state.get_expert_tensor_parallel_group(
check_initialized=check_initialized
)
else:
tp_group = parallel_state.get_tensor_model_parallel_group(
check_initialized=check_initialized
)
return tp_group