Skip to content

Commit

Permalink
[Tensor Parallel] flatten new_out when updating swap_map (#577)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Jun 12, 2024
1 parent 6574922 commit c21533c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 2 deletions.
18 changes: 16 additions & 2 deletions thunder/distributed/tensor_parallel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,26 @@ def __call__(self, bsym: BoundSymbol) -> VISIT_TYPE:

new_out = new_bsym.sym(*new_bsym.args, **new_bsym.kwargs)

var_original_bsym_output = variableify(new_bsym.flat_proxy_outs[0])
if pre_post_process is not None:
from thunder.core import utils

# This is because current support coverage are only `Linear` and `Embedding` that return one tensor.
utils.check(
len(new_bsym.flat_proxy_outs) == 1,
lambda: f"{len(new_bsym.flat_proxy_outs)=} expected to be 1",
)
var_original_bsym_output = variableify(new_bsym.flat_proxy_outs[0])
processed_y = pre_post_process.postprocess(new_out, preprocess_artifacts)
self.swap_map[var_original_bsym_output] = processed_y
else:
self.swap_map[var_original_bsym_output] = new_out
from thunder.core.pytree import tree_flatten

for orig_o, new_o in zip(
new_bsym.flat_outs,
tree_flatten(new_out)[0],
):
if isinstance(orig_o, TensorProxy) and isinstance(new_o, TensorProxy) and orig_o.name != new_o.name:
self.swap_map[variableify(orig_o)] = new_o

return VISIT_TYPE.REPLACE

Expand Down
50 changes: 50 additions & 0 deletions thunder/tests/distributed/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,56 @@ def test_parallel_mlp(self, meta_init):
# - postprocessing of row-wise parallel linear
self.assertEqual(len(bsyms_of_tp_sync), 2, msg=msg)

@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="")
def test_litgpt_causal_self_attention(self):
from thunder.tests.litgpt_model import Config
from thunder.tests.litgpt_model import CausalSelfAttention
from thunder.tests.make_tensor import make_tensor
from thunder.distributed.prims import PrimIDs

device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16

batch_size: int = 4 # 4 is chosen arbitrarily.
config_name: str = "Llama-2-13b-hf"
config = Config.from_name(config_name)

x_shape = (batch_size, config.block_size, config.n_embd)
cos_shape = (config.block_size, config.rope_n_elem)
sin_shape = (config.block_size, config.rope_n_elem)
mask = None
input_pos = None

attention = CausalSelfAttention(config).to(device=device, dtype=dtype)
# Temporarily use only torchex due to https://github.com/NVIDIA/Fuser/issues/2390
tp_attention = thunder.jit(attention, executors=[thunder.executors.get_torch_executor()])
tp_attention = column_parallel(tp_attention, ["attn"])
tp_attention = row_parallel(tp_attention, ["proj"])

x = make_tensor(x_shape, device=device, dtype=dtype, requires_grad=True)
cos = make_tensor(cos_shape, device=device, dtype=dtype, requires_grad=True)
sin = make_tensor(sin_shape, device=device, dtype=dtype, requires_grad=True)

# TODO(crcrpar): add numeircal check
y = tp_attention(x, cos, sin, mask, input_pos)
tp_syncs = {PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_INPUT, PrimIDs.SYNCHRONIZE_TENSOR_PARALLEL_OUTPUT}
fwd_traces_with_tensor_parallel_syncs = list(
filter(
lambda trace: any(bsym.sym.id in tp_syncs for bsym in trace.bound_symbols),
thunder.last_traces(tp_attention),
)
)

last_fwd_trace_with_tp_sync = fwd_traces_with_tensor_parallel_syncs[-1]
bsyms_of_tp_sync = tuple(
filter(lambda bsym: bsym.sym.id in tp_syncs, last_fwd_trace_with_tp_sync.bound_symbols)
)
msg = f"{bsyms_of_tp_sync=}"
# TODO(crcrpar): Fix the comm optimization path. Ideally, 2.
# Though note this class' forward seems to depend on a hyperparam that could be affected by tensor parallel transform.
# ref: https://github.com/Lightning-AI/litgpt/blob/8ca46d2f/litgpt/model.py#L218
self.assertEqual(len(bsyms_of_tp_sync), 4, msg=msg)


common_utils.instantiate_parametrized_tests(TensorParallelTest)

Expand Down

0 comments on commit c21533c

Please sign in to comment.