Skip to content

Commit

Permalink
fix ema (#54)
Browse files Browse the repository at this point in the history
* fix ema

* fix ema
  • Loading branch information
guolinke authored Jun 23, 2024
1 parent 9486f7c commit 89fcb4b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 52 deletions.
40 changes: 5 additions & 35 deletions unicore/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,9 @@ def _infer_single_node_init(args):
args.distributed_init_method = "tcp://localhost:{port}".format(port=port)



def distributed_init(args):
if torch.distributed.is_available() and torch.distributed.is_initialized():
warnings.warn(
"Distributed is already initialized, cannot initialize twice!"
)
warnings.warn("Distributed is already initialized, cannot initialize twice!")
else:
logger.info(
"distributed init (rank {}): {}".format(
Expand Down Expand Up @@ -144,7 +141,6 @@ def distributed_init(args):
else:
logging.getLogger().setLevel(logging.WARNING)


return args.distributed_rank


Expand Down Expand Up @@ -187,31 +183,12 @@ def call_main(args, main, **kwargs):
join=True,
)
else:
distributed_main(int(os.environ['LOCAL_RANK']), main, args, kwargs)
distributed_main(int(os.environ["LOCAL_RANK"]), main, args, kwargs)
else:
# single GPU main
main(args, **kwargs)


def new_groups(grouped_ranks: List[List[int]]):
groups = [dist.new_group(g) for g in grouped_ranks]
my_group_idx = _find_my_group_index(grouped_ranks)
return groups[my_group_idx]


def _find_my_group_index(grouped_ranks):
my_rank = get_global_rank()
for i, group in enumerate(grouped_ranks):
if my_rank in group:
return i
raise RuntimeError


def _find_my_group(grouped_ranks):
index = _find_my_group_index(grouped_ranks)
return grouped_ranks[index]


def get_rank(group):
return dist.get_rank(group=group)

Expand All @@ -224,14 +201,7 @@ def get_world_size(group):


def get_global_group():
if torch.distributed.is_initialized():
if not hasattr(get_global_group, "_global_group"):
# ideally we could use torch.distributed.group.WORLD, but it seems
# to cause random NCCL hangs in some cases
get_global_group._global_group = dist.new_group()
return get_global_group._global_group
else:
return None
return None


def get_global_rank():
Expand Down Expand Up @@ -329,8 +299,8 @@ def all_gather_list(data, group=None, max_size=16384):
):
all_gather_list._buffer = torch.tensor(
data=[0] * buffer_size, # Initialize with zeros
dtype=torch.uint8, # Byte tensor corresponds to uint8
device='cuda' # Specify the device as CUDA
dtype=torch.uint8, # Byte tensor corresponds to uint8
device="cuda", # Specify the device as CUDA
)
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory()
buffer = all_gather_list._buffer
Expand Down
12 changes: 7 additions & 5 deletions unicore/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,10 @@ def __init__(self, args, task, model, loss):

if self.cuda and self.data_parallel_world_size > 1:
self._grad_norm_buf = torch.tensor(
data=[0.0] * self.data_parallel_world_size, # Initialize with zeros or appropriate values
data=[0.0]
* self.data_parallel_world_size, # Initialize with zeros or appropriate values
dtype=torch.double, # Set the desired data type
device='cuda'
device="cuda",
)
else:
self._grad_norm_buf = None
Expand All @@ -114,11 +115,12 @@ def __init__(self, args, task, model, loss):
if args.validate_with_ema:
assert args.ema_decay > 0, "valid with ema must with ema_decay > 0"

model = self.model
if args.ema_decay > 0 and (
self.data_parallel_rank == 0 or args.validate_with_ema
):

assert (self.args.fp16 or self.args.bf16), "ema must with fp16 or bf16"
assert self.args.fp16 or self.args.bf16, "ema must with fp16 or bf16"
self.ema = ExponentialMovingAverageModel(
model,
args.ema_decay,
Expand Down Expand Up @@ -518,8 +520,8 @@ def get_valid_iterator(
data_buffer_size=self.args.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
# Using training data for dummy batch. If the following line is enabled, the dummy batch will be from validation data,
# and cause OOM error for some corner case during training. So disable it.
# Using training data for dummy batch. If the following line is enabled, the dummy batch will be from validation data,
# and cause OOM error for some corner case during training. So disable it.
# self.reset_dummy_batch(batch_iterator.first_batch)
return batch_iterator

Expand Down
38 changes: 26 additions & 12 deletions unicore/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import sys
import warnings
from copy import deepcopy
from functools import partial
from typing import List, Callable, Any, Dict
import torch
Expand All @@ -18,13 +19,15 @@

try:
import unicore_fused_multi_tensor

HAS_MULTI_TENSOR = True
except:
print("fused_multi_tensor is not installed corrected")
HAS_MULTI_TENSOR = False

try:
import unicore_fused_rounding

HAS_FUSED_ROUNDING = True
except:
print("fused_rounding is not installed corrected")
Expand All @@ -36,6 +39,7 @@

logger = logging.getLogger(__name__)


def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
Expand Down Expand Up @@ -79,6 +83,7 @@ def _move_to_cpu(tensor):

return apply_to_sample(_move_to_cpu, sample)


def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
per_device_grads = {}
norms = []
Expand All @@ -94,15 +99,14 @@ def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor:
for dtype in per_device_grads[device].keys():
cur_grads = per_device_grads[device][dtype]
if HAS_MULTI_TENSOR and device.type == "cuda":
norm = unicore_fused_multi_tensor.l2norm(
chunk_size, [cur_grads]
)
norm = unicore_fused_multi_tensor.l2norm(chunk_size, [cur_grads])
norms.append(norm)
else:
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_grads]
total_norm = torch.norm(torch.stack(norms), p=2, dtype=torch.float32)
return total_norm


@torch.no_grad()
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor:
if isinstance(params, torch.Tensor):
Expand Down Expand Up @@ -135,7 +139,9 @@ def import_user_module(args):
module_path = getattr(args, "user_dir", None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
if not os.path.exists(module_path) and not os.path.isfile(os.path.dirname(module_path)):
if not os.path.exists(module_path) and not os.path.isfile(
os.path.dirname(module_path)
):
unicore_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir)
if os.path.exists(unicore_rel_path):
module_path = unicore_rel_path
Expand Down Expand Up @@ -164,8 +170,9 @@ def import_user_module(args):
"something unique and try again.".format(module_path, module_name)
)


def get_activation_fn(activation: str) -> Callable:
""" Returns the activation function corresponding to `activation` """
"""Returns the activation function corresponding to `activation`"""

if activation == "relu":
return F.relu
Expand Down Expand Up @@ -216,8 +223,10 @@ def torch_seed(seed, *addl_seeds):
if seed is None:
yield
return

def check_seed(s):
assert type(s) == int or type(s) == np.int32 or type(s) == np.int64

check_seed(seed)
if len(addl_seeds) > 0:
for s in addl_seeds:
Expand Down Expand Up @@ -366,9 +375,7 @@ def batched_gather(data, inds, dim=0, num_batch_dims=0):
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
ranges.append(r)

remaining_dims = [
slice(None) for _ in range(len(data.shape) - num_batch_dims)
]
remaining_dims = [slice(None) for _ in range(len(data.shape) - num_batch_dims)]
remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
Expand Down Expand Up @@ -408,7 +415,9 @@ def fp32_to_bf16_sr(t, o):
if HAS_FUSED_ROUNDING and t.device.type == "cuda":
unicore_fused_rounding.fp32_to_bf16_sr(t, o)
else:
r = (torch.rand(size=t.size(), device=t.device, dtype=torch.float32) - 0.5) / 256
r = (
torch.rand(size=t.size(), device=t.device, dtype=torch.float32) - 0.5
) / 256
m, e = torch.frexp(t)
t = t + torch.ldexp(r, e)
o.data.copy_(t.bfloat16())
Expand All @@ -428,11 +437,16 @@ def set_jit_fusion_options():
def validate_with_ema(trainer, ema=False):
if not ema:
yield
return
return
_wrapped_model = trainer._wrapped_model
trainer._wrapped_model = trainer.ema.model_ema
trainer._wrapped_model = deepcopy(trainer.ema.model_ema)
if trainer.args.fp16:
trainer._wrapped_model.half()
elif trainer.args.bf16:
trainer._wrapped_model.bfloat16()

try:
yield
finally:
del trainer._wrapped_model
trainer._wrapped_model = _wrapped_model

0 comments on commit 89fcb4b

Please sign in to comment.