Skip to content

Commit

Permalink
halfway point
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackmin801 committed Oct 3, 2024
1 parent 996ab21 commit be63b4f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 4 deletions.
43 changes: 39 additions & 4 deletions src/zeroband/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ def ring_allreduce(
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
transfer_dtype: Optional[torch.dtype] = None,
quantization_func: Optional[Callable] = None,
) -> None:
"""
Perform all-reduce on a tensor using ring algorithm.
The accumulation will be done in-place on the input tensor.
The transfers will be done using the specified transfer_dtype.
"""
if quantization_func is not None:
if transfer_dtype is not None:
raise ValueError("Quantization and transfer_dtype cannot be used together")
transfer_dtype = tensor.dtype
if transfer_dtype is None:
transfer_dtype = tensor.dtype
if group is None:
Expand All @@ -64,8 +69,16 @@ def ring_allreduce(

# Temporary buffers for transferring data
num_buffers = BUFFER_COUNT * world_size
send_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
recv_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
if quantization_func is not None:
recv_buffer = [torch.empty_like(chunks[0], dtype=torch.uint8) for _ in range(BUFFER_COUNT)]
send_buffer = [None for _ in range(BUFFER_COUNT)]
send_lookup_buffer = [None for _ in range(BUFFER_COUNT)]
recv_lookup_buffer = [torch.empty(256, dtype=chunks[0].dtype) for _ in range(BUFFER_COUNT)]
send_lookup_work = [None for _ in range(BUFFER_COUNT)]
recv_lookup_work = [None for _ in range(BUFFER_COUNT)]
else:
recv_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_work = [None] * BUFFER_COUNT
recv_work = [None] * BUFFER_COUNT

Expand All @@ -77,11 +90,30 @@ def ring_allreduce(
if send_work[step % BUFFER_COUNT] is not None:
send_work[step % BUFFER_COUNT].wait()
recv_work[step % BUFFER_COUNT].wait()
chunks[send_chunk].add_(recv_buffer[step % BUFFER_COUNT])
if quantization_func is not None:
send_lookup_work[step % BUFFER_COUNT].wait()
recv_lookup_work[step % BUFFER_COUNT].wait()
# print(recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()])
chunks[send_chunk].add_(
recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()]
)
else:
chunks[send_chunk].add_(recv_buffer[step % BUFFER_COUNT])

if step <= (world_size - 1) * BUFFER_COUNT:
# Send and receive
send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk])
if quantization_func is not None:
send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT] = quantization_func(
chunks[send_chunk]
)
send_lookup_work[step % BUFFER_COUNT] = dist.isend(
send_lookup_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step + 1000
)
recv_lookup_work[step % BUFFER_COUNT] = dist.irecv(
recv_lookup_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step + 1000
)
else:
send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk])
send_work[step % BUFFER_COUNT] = dist.isend(
send_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step
)
Expand All @@ -93,6 +125,9 @@ def ring_allreduce(
for i in range(BUFFER_COUNT):
chunks[i + rank * BUFFER_COUNT].divide_(world_size)

if quantization_func is not None:
send_lookup_work = [None for _ in range(BUFFER_COUNT)]
recv_lookup_work = [None for _ in range(BUFFER_COUNT)]
send_work = [None] * BUFFER_COUNT
recv_work = [None] * BUFFER_COUNT
for step in range(1, world_size * BUFFER_COUNT + 1):
Expand Down
68 changes: 68 additions & 0 deletions src/zeroband/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import numpy as np
from typing import Tuple
import math
from concurrent.futures import ThreadPoolExecutor
import os

RANGE_IN_SIGMAS: int = 6
EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))


def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
"""Return the average value in each bucket"""
bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
lookup = bin_sums / bin_counts
return lookup


def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
"""Adjust chunk_size to minimize imbalance between chunk sizes"""
if min_chunk_size >= num_elements:
return min_chunk_size
leftover_elements = num_elements % min_chunk_size
num_chunks = num_elements // min_chunk_size
return min_chunk_size + (leftover_elements - 1) // num_chunks + 1


def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
"""Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
if not array.data.c_contiguous and array.data.f_contiguous:
array = array.T
array = np.ascontiguousarray(array.reshape(-1))
quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
chunk_size = get_chunk_size(len(array), min_chunk_size)
num_chunks = (len(array) - 1) // chunk_size + 1
partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)

jobs = []
for i in range(num_chunks):
chunk = slice(chunk_size * i, chunk_size * (i + 1))
jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))

for job in jobs:
job.result()
return np.quantile(partition_quantiles, quantiles)


n_bins = 2**8


def uniform_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
offset = n_bins // 2
# shift = tensor.mean()
# centered_tensor = tensor.sub_(shift) if inplace else tensor - shift
centered_tensor = tensor
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = RANGE_IN_SIGMAS * std_unbiased / n_bins
quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
lookup = average_buckets(tensor, quantized, n_bins)
return quantized, lookup


def quantile_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
lookup = average_buckets(tensor, quantized, n_bins)
return quantized, lookup

0 comments on commit be63b4f

Please sign in to comment.