Skip to content

Commit

Permalink
Integrate Triton up to [68aa962e67baa191cec5aac173255abdba80db1a](htt…
Browse files Browse the repository at this point in the history
  • Loading branch information
Aliia Khasanova authored and Google-ML-Automation committed Oct 16, 2024
1 parent c7b8cd5 commit faa314a
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,20 +370,20 @@ def get_or_create_triton_kernel(
# We replace array arguments with mock Torch tensors, to allow us to use
# `JITFunction._get_config` to get the specialization_attr.
mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16)
args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes)
args_for_specialization_attr = [mock_torch_tensor] * len(fn.params)
backend = backend_init_func(device, compute_capability)
for i, _, v in scalar_args:
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

specialization_attr = backend.get_attrs_descriptor(fn.params, args_for_specialization_attr) # pylint: disable=protected-access
constants = {k: v for k, v in metaparams.items()}
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
fn,
tuple(signature.items()),
tuple(vars(specialization_attr).values()),
tuple(specialization_attr.arg_properties),
tuple(constants.items()),
num_warps,
num_stages,
Expand All @@ -403,7 +403,6 @@ def get_or_create_triton_kernel(
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

kernel_hash = abs(hash(cache_key))
Expand Down Expand Up @@ -643,7 +642,7 @@ def prune_configs(configs, named_args, **kwargs):
kernel_params.append(
triton_kernel_call_lib.create_array_parameter(
zeroed_params_with_sizes.get(i, 0),
16 if (i in specialization_attr.divisible_by_16) else 0,
16 if (i in specialization_attr.divisibility_16) else 0,
)
)
elif i not in specialization_attr.equal_to_1:
Expand Down

0 comments on commit faa314a

Please sign in to comment.