diff --git a/examples/01_resnet-50/benchmark_ait.py b/examples/01_resnet-50/benchmark_ait.py index e9eba1d77..11aefac62 100644 --- a/examples/01_resnet-50/benchmark_ait.py +++ b/examples/01_resnet-50/benchmark_ait.py @@ -45,7 +45,6 @@ def mark_output(y): def compile_module(model_name, batch_size, **kwargs): - if model_name != "resnet50": raise NotImplementedError diff --git a/examples/01_resnet-50/weight_utils.py b/examples/01_resnet-50/weight_utils.py index 40dc53e74..4839b88f5 100644 --- a/examples/01_resnet-50/weight_utils.py +++ b/examples/01_resnet-50/weight_utils.py @@ -17,7 +17,6 @@ Only tested on resnet50 """ - import pickle import re diff --git a/examples/02_detectron2/demo.py b/examples/02_detectron2/demo.py index 749a1eab8..88e1c63d0 100644 --- a/examples/02_detectron2/demo.py +++ b/examples/02_detectron2/demo.py @@ -15,6 +15,7 @@ """ A main inference script for rcnn models """ + import glob import os diff --git a/examples/02_detectron2/modeling/roi_heads/roi_heads.py b/examples/02_detectron2/modeling/roi_heads/roi_heads.py index 587d9601b..cdc5d1685 100644 --- a/examples/02_detectron2/modeling/roi_heads/roi_heads.py +++ b/examples/02_detectron2/modeling/roi_heads/roi_heads.py @@ -59,7 +59,6 @@ def get_shape(self, x): return shape def forward(self, features: Dict[str, Tensor], rois: Tensor, proposals: Tensor): - box_features = [features[f] for f in self.in_features] roi_feat = self.box_head(box_features, rois) detections = self.box_predictor(roi_feat, proposals) diff --git a/examples/02_detectron2/predictor/builtin_meta.py b/examples/02_detectron2/predictor/builtin_meta.py index c09e5a5ba..d50d03e97 100644 --- a/examples/02_detectron2/predictor/builtin_meta.py +++ b/examples/02_detectron2/predictor/builtin_meta.py @@ -25,7 +25,6 @@ COCO model (with correct class names and colors). """ - # All coco categories, together with their nice-looking visualization colors # It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json COCO_CATEGORIES = [ diff --git a/examples/04_vit/weight_utils.py b/examples/04_vit/weight_utils.py index 49d3c9eed..3fd5e737b 100644 --- a/examples/04_vit/weight_utils.py +++ b/examples/04_vit/weight_utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""script for converting vit model from timm to ait -""" +"""script for converting vit model from timm to ait""" + import pickle import click diff --git a/examples/05_stable_diffusion/scripts/download_pipeline.py b/examples/05_stable_diffusion/scripts/download_pipeline.py index 006e16531..34c1d8e4f 100644 --- a/examples/05_stable_diffusion/scripts/download_pipeline.py +++ b/examples/05_stable_diffusion/scripts/download_pipeline.py @@ -36,7 +36,6 @@ help="Pipeline files local directory.", ) def download_pipeline_files(model_name, token, save_directory) -> None: - StableDiffusionPipeline.from_pretrained( model_name, revision="fp16", diff --git a/fx2ait/fx2ait/acc_tracer/acc_normalizer.py b/fx2ait/fx2ait/acc_tracer/acc_normalizer.py index 9295dffc6..c394e7570 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_normalizer.py +++ b/fx2ait/fx2ait/acc_tracer/acc_normalizer.py @@ -283,7 +283,10 @@ def move_kwargs_to_acc_out_ty( for kwarg_replacement_tuple in normalization_info.kwargs_to_move_to_acc_out_ty: if len(kwarg_replacement_tuple) == 2: - orig_kwarg_name, tmd_field_name, move_to_qparams = *kwarg_replacement_tuple, False # type: ignore[misc] + orig_kwarg_name, tmd_field_name, move_to_qparams = ( + *kwarg_replacement_tuple, + False, + ) # type: ignore[misc] else: assert len(kwarg_replacement_tuple) == 3 orig_kwarg_name, tmd_field_name, move_to_qparams = kwarg_replacement_tuple # type: ignore[misc] @@ -331,9 +334,7 @@ def get_normalized_kwargs( new_kwargs[new_kwarg_name] = node.args[i] else: # Verify the arg we're trying to normalize was optional. - assert ( - is_optional - ), f"Cannot normalize {orig_kwargs_names} to {new_kwarg_name} for {node.name}" + assert is_optional, f"Cannot normalize {orig_kwargs_names} to {new_kwarg_name} for {node.name}" else: new_kwargs[new_kwarg_name] = node.kwargs[orig_kwargs_name] diff --git a/fx2ait/fx2ait/acc_tracer/acc_tracer.py b/fx2ait/fx2ait/acc_tracer/acc_tracer.py index b5899f89b..f9f6d3881 100644 --- a/fx2ait/fx2ait/acc_tracer/acc_tracer.py +++ b/fx2ait/fx2ait/acc_tracer/acc_tracer.py @@ -462,7 +462,10 @@ def __init__(self, orig): for k, v in orig.__dict__.items(): if k == "_modules": for mod_k, mod_v in v.items(): - if getattr(mod_v, "_base_class_origin", type(mod_v)) in leaf_module_list: # type: ignore[operator] + if ( + getattr(mod_v, "_base_class_origin", type(mod_v)) + in leaf_module_list + ): # type: ignore[operator] _LOGGER.info( f"Skip rewriting leaf module {type(mod_v)}" ) diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 3c0463ad8..8149872ea 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -920,8 +920,8 @@ def acc_ops_conv_transpose2d( # Grouped conv doesn't currently work on AIT CUDA, manually map groups = kwargs["groups"] assert ( - w_last_dim * groups - ) % 8 == 0, f"cutlass needs weight output channel={w_last_dim*groups} is not divisble by 8! This restriction may be not valid in newer version" + (w_last_dim * groups) % 8 == 0 + ), f"cutlass needs weight output channel={w_last_dim*groups} is not divisble by 8! This restriction may be not valid in newer version" group_size = input_val.shape()[3]._attrs["values"][0] // groups w_group_size = weight.shape()[0]._attrs["values"][0] // groups @@ -1767,7 +1767,7 @@ def acc_ops_to_dtype( input_val = kwargs["input"] def _get_cast_to_dtype_from_kwargs( - kwargs: Dict[str, Argument] + kwargs: Dict[str, Argument], ) -> Optional[torch.dtype]: torch_dtype_to_ait_dtype_str = { torch.float: "float32", diff --git a/fx2ait/fx2ait/extension.py b/fx2ait/fx2ait/extension.py index dc1783067..7db8f975a 100644 --- a/fx2ait/fx2ait/extension.py +++ b/fx2ait/fx2ait/extension.py @@ -27,7 +27,6 @@ def is_oss_ait_model(): def _get_extension_path(lib_name): - lib_dir = os.path.dirname(__file__) loader_details = ( diff --git a/fx2ait/fx2ait/tools/ait_subgraph_rewriter.py b/fx2ait/fx2ait/tools/ait_subgraph_rewriter.py index e3fa593e2..46098b753 100644 --- a/fx2ait/fx2ait/tools/ait_subgraph_rewriter.py +++ b/fx2ait/fx2ait/tools/ait_subgraph_rewriter.py @@ -348,7 +348,6 @@ def _replace_pattern( replacement: Union[Callable, GraphModule], match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined] ) -> List[ReplacedPatterns]: - if match_filters is None: match_filters = [] @@ -392,7 +391,6 @@ def _replace_pattern( match_and_replacements = [] for match in _matches: - # Build connecting between replacement graph's input and original graph input producer node # Initialize `val_map` with mappings from placeholder nodes in @@ -458,7 +456,6 @@ def get_replacement_nodes(curr_node: Node): and node.op != "output" and node.target != torch.ops.aten.sym_size ): - gn = match.nodes_map[node] gm.graph.erase_node(gn) match_and_replacements.append( diff --git a/python/aitemplate/backend/__init__.py b/python/aitemplate/backend/__init__.py index df7240114..b4755a131 100644 --- a/python/aitemplate/backend/__init__.py +++ b/python/aitemplate/backend/__init__.py @@ -15,6 +15,7 @@ """ Backend for AITemplate. """ + from aitemplate.backend import ( # noqa backend_spec, builder, diff --git a/python/aitemplate/backend/codegen.py b/python/aitemplate/backend/codegen.py index cc34dde70..58669e257 100644 --- a/python/aitemplate/backend/codegen.py +++ b/python/aitemplate/backend/codegen.py @@ -651,9 +651,7 @@ def _codegen_output_tensor(self, tensor: Tensor) -> None: elif external_tensor is not None: # Special view cases for outputs; we can hit this case if the output # is a view of a constant, input, or another output. - assert ( - is_view - ), f"orig_tensor is not None, but node {name} is not marked as a view! Node: {tensor}" + assert is_view, f"orig_tensor is not None, but node {name} is not marked as a view! Node: {tensor}" self.set_inputs.append( check_not_null(tensor, output_idx, skip_if_lower_bound_is_zero=True) ) diff --git a/python/aitemplate/backend/common/concatenate_common.py b/python/aitemplate/backend/common/concatenate_common.py index 0e5235a43..e6b086e19 100644 --- a/python/aitemplate/backend/common/concatenate_common.py +++ b/python/aitemplate/backend/common/concatenate_common.py @@ -15,6 +15,7 @@ """ backend concatenate function common templates. """ + from copy import deepcopy from typing import List diff --git a/python/aitemplate/backend/common/split_common.py b/python/aitemplate/backend/common/split_common.py index ded02558f..2ff2c8e2e 100644 --- a/python/aitemplate/backend/common/split_common.py +++ b/python/aitemplate/backend/common/split_common.py @@ -15,6 +15,7 @@ """ Backend-agnostic function templates for split. """ + import jinja2 FUNC_DECL_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/common/tensor/permute0213_common.py b/python/aitemplate/backend/common/tensor/permute0213_common.py index 1cb6357f5..865a20e52 100644 --- a/python/aitemplate/backend/common/tensor/permute0213_common.py +++ b/python/aitemplate/backend/common/tensor/permute0213_common.py @@ -23,6 +23,7 @@ in the blockIdx.z for the direct kernel launch. The input and output pointers are shifted accordingly in the kernel code. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/common/tensor/permute021_common.py b/python/aitemplate/backend/common/tensor/permute021_common.py index 426693714..817e55847 100644 --- a/python/aitemplate/backend/common/tensor/permute021_common.py +++ b/python/aitemplate/backend/common/tensor/permute021_common.py @@ -21,6 +21,7 @@ i.e. Output[d0, ..., dn-3, dn-1, dn-2] = Input[d0, ..., dn-3, dn-2, dn-1] """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/common/tensor/permute102_common.py b/python/aitemplate/backend/common/tensor/permute102_common.py index 8cdabb466..6006d40fa 100644 --- a/python/aitemplate/backend/common/tensor/permute102_common.py +++ b/python/aitemplate/backend/common/tensor/permute102_common.py @@ -36,6 +36,7 @@ starting from 17 items, the approach #1 corresponds to the same data movement, just through the SMEM and with more index computation. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/common/tensor/permute210_common.py b/python/aitemplate/backend/common/tensor/permute210_common.py index 83ba4d61a..87af608f3 100644 --- a/python/aitemplate/backend/common/tensor/permute210_common.py +++ b/python/aitemplate/backend/common/tensor/permute210_common.py @@ -26,6 +26,7 @@ The 4 for thread blocks indicates each thread is responsible of 4 elements. We use TILE_SIZE = 32 for the time being. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/common/tensor/slice_common.py b/python/aitemplate/backend/common/tensor/slice_common.py index 49961916a..e0a40e868 100644 --- a/python/aitemplate/backend/common/tensor/slice_common.py +++ b/python/aitemplate/backend/common/tensor/slice_common.py @@ -15,6 +15,7 @@ """ Slice backend common implementation. """ + import jinja2 diff --git a/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py b/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py index 49d574548..d6306a26c 100644 --- a/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py +++ b/python/aitemplate/backend/common/tensor/slice_reshape_scatter_common.py @@ -15,6 +15,7 @@ """ Slice reshape backend common implementation. """ + import functools import jinja2 diff --git a/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py index c40b01e7c..873adaf0a 100644 --- a/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py +++ b/python/aitemplate/backend/common/vision_ops/efficient_nms_kernel.py @@ -15,6 +15,7 @@ """ efficient_nms function gpu kernel. """ + import jinja2 kernel = jinja2.Template( diff --git a/python/aitemplate/backend/common/vision_ops/nms_kernel.py b/python/aitemplate/backend/common/vision_ops/nms_kernel.py index fba468e28..3affce69d 100644 --- a/python/aitemplate/backend/common/vision_ops/nms_kernel.py +++ b/python/aitemplate/backend/common/vision_ops/nms_kernel.py @@ -15,6 +15,7 @@ """ nms kernel template. """ + import jinja2 KERNEL_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/cuda/__init__.py b/python/aitemplate/backend/cuda/__init__.py index ac88e3679..fb3849a86 100644 --- a/python/aitemplate/backend/cuda/__init__.py +++ b/python/aitemplate/backend/cuda/__init__.py @@ -16,6 +16,7 @@ """ CUDA backend codegen functions. """ + from aitemplate.backend.cuda import ( builder_cmake, cuda_common, diff --git a/python/aitemplate/backend/cuda/attention/__init__.py b/python/aitemplate/backend/cuda/attention/__init__.py index c57effeee..a2e6ee840 100644 --- a/python/aitemplate/backend/cuda/attention/__init__.py +++ b/python/aitemplate/backend/cuda/attention/__init__.py @@ -15,6 +15,7 @@ """ cuda flash_attention module init """ + from aitemplate.backend.cuda.attention import flash_attention, mem_eff_attention __all__ = ["flash_attention", "mem_eff_attention"] diff --git a/python/aitemplate/backend/cuda/attention/flash_attention.py b/python/aitemplate/backend/cuda/attention/flash_attention.py index b53eb419e..539e928c4 100644 --- a/python/aitemplate/backend/cuda/attention/flash_attention.py +++ b/python/aitemplate/backend/cuda/attention/flash_attention.py @@ -15,6 +15,7 @@ """ attention kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/attention/mem_eff_attention.py b/python/aitemplate/backend/cuda/attention/mem_eff_attention.py index 1d8b8c3e3..192652536 100644 --- a/python/aitemplate/backend/cuda/attention/mem_eff_attention.py +++ b/python/aitemplate/backend/cuda/attention/mem_eff_attention.py @@ -15,6 +15,7 @@ """ Attention kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/b2b_bmm/classic_b2b_bmm.py b/python/aitemplate/backend/cuda/b2b_bmm/classic_b2b_bmm.py index 21dd56a5c..b2c29cfc5 100644 --- a/python/aitemplate/backend/cuda/b2b_bmm/classic_b2b_bmm.py +++ b/python/aitemplate/backend/cuda/b2b_bmm/classic_b2b_bmm.py @@ -15,6 +15,7 @@ """ classic_b2b_bmm kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/b2b_bmm/fmha_style_b2b_bmm.py b/python/aitemplate/backend/cuda/b2b_bmm/fmha_style_b2b_bmm.py index 308ddafa9..5729129fc 100644 --- a/python/aitemplate/backend/cuda/b2b_bmm/fmha_style_b2b_bmm.py +++ b/python/aitemplate/backend/cuda/b2b_bmm/fmha_style_b2b_bmm.py @@ -15,6 +15,7 @@ """ fmha_style_b2b_bmm kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/b2b_bmm/grouped_classic_b2b_bmm.py b/python/aitemplate/backend/cuda/b2b_bmm/grouped_classic_b2b_bmm.py index 36efecf7d..dd8da1df9 100644 --- a/python/aitemplate/backend/cuda/b2b_bmm/grouped_classic_b2b_bmm.py +++ b/python/aitemplate/backend/cuda/b2b_bmm/grouped_classic_b2b_bmm.py @@ -15,6 +15,7 @@ """ classic_b2b_bmm kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/b2b_bmm/grouped_fmha_style_b2b_bmm.py b/python/aitemplate/backend/cuda/b2b_bmm/grouped_fmha_style_b2b_bmm.py index 2b988c5ab..ee562a67c 100644 --- a/python/aitemplate/backend/cuda/b2b_bmm/grouped_fmha_style_b2b_bmm.py +++ b/python/aitemplate/backend/cuda/b2b_bmm/grouped_fmha_style_b2b_bmm.py @@ -15,6 +15,7 @@ """ grouped_fmha_style_b2b_bmm kernel codegen for CUDA. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/common/__init__.py b/python/aitemplate/backend/cuda/common/__init__.py index 4971d840f..d5491adba 100644 --- a/python/aitemplate/backend/cuda/common/__init__.py +++ b/python/aitemplate/backend/cuda/common/__init__.py @@ -16,4 +16,5 @@ """ CUDA Common module init """ + from aitemplate.backend.cuda.common.dummy_op import * diff --git a/python/aitemplate/backend/cuda/conv2d/__init__.py b/python/aitemplate/backend/cuda/conv2d/__init__.py index e18c91cdf..1a143930c 100644 --- a/python/aitemplate/backend/cuda/conv2d/__init__.py +++ b/python/aitemplate/backend/cuda/conv2d/__init__.py @@ -16,6 +16,7 @@ """ cuda conv2d module init """ + from aitemplate.backend.cuda.conv2d import ( conv2d, conv2d_bias, diff --git a/python/aitemplate/backend/cuda/conv2d/common.py b/python/aitemplate/backend/cuda/conv2d/common.py index c6bfbdb8b..d2945cd12 100644 --- a/python/aitemplate/backend/cuda/conv2d/common.py +++ b/python/aitemplate/backend/cuda/conv2d/common.py @@ -15,6 +15,7 @@ """ common template for conv2d """ + import re from collections import OrderedDict diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d.py b/python/aitemplate/backend/cuda/conv2d/conv2d.py index 01bb30105..610d772db 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d.py @@ -15,6 +15,7 @@ """ Codegen for conv2d. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py index adc7b0253..585e82e61 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias.py @@ -15,6 +15,7 @@ """ conv2d bias codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common, common_conv2d_bias_activation as cba diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py index c9762f6ec..e0349bd57 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add.py @@ -15,6 +15,7 @@ """ conv2d bias add codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import ( common, diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py index defcae4a4..102de9ae4 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_hardswish.py @@ -15,6 +15,7 @@ """ conv2d bias add hardswish codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import ( common, diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py index cbbe02038..379632c22 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_add_relu.py @@ -15,6 +15,7 @@ """ conv2d bias add relu codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import ( common, diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py index 9ab085cfd..3896f379c 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_hardswish.py @@ -15,6 +15,7 @@ """ conv2d bias hardswish codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common, common_conv2d_bias_activation as cba diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py index 1b3726f66..5adc2a9a8 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ conv2d bias relu codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common, common_conv2d_bias_activation as cba diff --git a/python/aitemplate/backend/cuda/conv2d/conv2d_depthwise.py b/python/aitemplate/backend/cuda/conv2d/conv2d_depthwise.py index 3546cb823..70c047843 100644 --- a/python/aitemplate/backend/cuda/conv2d/conv2d_depthwise.py +++ b/python/aitemplate/backend/cuda/conv2d/conv2d_depthwise.py @@ -15,6 +15,7 @@ """ Codegen for conv2d_depthwise. """ + from collections import OrderedDict from aitemplate.backend import registry @@ -85,7 +86,6 @@ def f_proc_op_special(op): and op.accumulator_type() == acc_type and op.group_mode == cutlass_lib.library.GroupMode.NoneGroup ): - op = copy.deepcopy(op) # set epilogue epilogue_name = func_attrs["epilogue"] diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py index e7186ebce..9e524400f 100644 --- a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d.py @@ -15,6 +15,7 @@ """ transposed conv2d op codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common, common_transposed_conv2d as ctc diff --git a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py index 54f298cdc..430e48ebf 100644 --- a/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py +++ b/python/aitemplate/backend/cuda/conv2d/transposed_conv2d_bias.py @@ -15,6 +15,7 @@ """ transposed conv2d + bias + (relu) codegen """ + from aitemplate.backend import registry from aitemplate.backend.cuda.conv2d import common, common_transposed_conv2d as ctc diff --git a/python/aitemplate/backend/cuda/conv3d/__init__.py b/python/aitemplate/backend/cuda/conv3d/__init__.py index dadb06e64..f5dc5177d 100644 --- a/python/aitemplate/backend/cuda/conv3d/__init__.py +++ b/python/aitemplate/backend/cuda/conv3d/__init__.py @@ -15,6 +15,7 @@ """ CUDA conv3d module init """ + from aitemplate.backend.cuda.conv3d import ( conv3d, conv3d_bias, diff --git a/python/aitemplate/backend/cuda/conv3d/common.py b/python/aitemplate/backend/cuda/conv3d/common.py index 2dd9f9f66..158d08542 100644 --- a/python/aitemplate/backend/cuda/conv3d/common.py +++ b/python/aitemplate/backend/cuda/conv3d/common.py @@ -15,6 +15,7 @@ """ CUDA conv3d common functions """ + import re from hashlib import sha1 from typing import List diff --git a/python/aitemplate/backend/cuda/conv3d/common_bias.py b/python/aitemplate/backend/cuda/conv3d/common_bias.py index 9ecda801b..d3ad47911 100644 --- a/python/aitemplate/backend/cuda/conv3d/common_bias.py +++ b/python/aitemplate/backend/cuda/conv3d/common_bias.py @@ -15,6 +15,7 @@ """ CUDA conv3d common functions """ + import re from hashlib import sha1 from typing import List diff --git a/python/aitemplate/backend/cuda/conv3d/conv3d.py b/python/aitemplate/backend/cuda/conv3d/conv3d.py index 602cc8ef0..cc5fb20a7 100644 --- a/python/aitemplate/backend/cuda/conv3d/conv3d.py +++ b/python/aitemplate/backend/cuda/conv3d/conv3d.py @@ -16,6 +16,7 @@ """ Codegen for conv3d. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/conv3d/conv3d_bias.py b/python/aitemplate/backend/cuda/conv3d/conv3d_bias.py index a442b472b..3887fe793 100644 --- a/python/aitemplate/backend/cuda/conv3d/conv3d_bias.py +++ b/python/aitemplate/backend/cuda/conv3d/conv3d_bias.py @@ -16,6 +16,7 @@ """ Codegen for conv3d. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py index 399c88d79..9cfbd954d 100644 --- a/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py +++ b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d.py @@ -15,6 +15,7 @@ """ Codegen functions for depthwise_conv3d. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d_bias.py b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d_bias.py index 70f46eff8..54722ccc9 100644 --- a/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d_bias.py +++ b/python/aitemplate/backend/cuda/conv3d/depthwise_conv3d_bias.py @@ -15,6 +15,7 @@ """ Codegen functions for depthwise_conv3d_bias. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/cuda_common.py b/python/aitemplate/backend/cuda/cuda_common.py index adb09af10..ce3cefd0b 100644 --- a/python/aitemplate/backend/cuda/cuda_common.py +++ b/python/aitemplate/backend/cuda/cuda_common.py @@ -15,6 +15,7 @@ """ CUDA common functions for codegen. """ + from typing import Dict DTYPE_TO_CUDATYPE: Dict[str, str] = { diff --git a/python/aitemplate/backend/cuda/elementwise/__init__.py b/python/aitemplate/backend/cuda/elementwise/__init__.py index 545d10a86..f21be8edc 100644 --- a/python/aitemplate/backend/cuda/elementwise/__init__.py +++ b/python/aitemplate/backend/cuda/elementwise/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.cuda.elementwise import fused_elementwise, int_elementwise __all__ = ["fused_elementwise", "int_elementwise"] diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py index 570986ce6..94201ac69 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/bmm_rcr_softmax.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When use for `linear`, need set A->Data, B->Weight """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py index 0b3db0496..4cbaedb73 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/common_softmax.py @@ -15,6 +15,7 @@ """ Common template for softmax. """ + import os import re from hashlib import sha1 diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_bmm_rrr_div.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_bmm_rrr_div.py index 61b11bd5a..abf7618b2 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_bmm_rrr_div.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_bmm_rrr_div.py @@ -17,6 +17,7 @@ C = BMM_RRR(A, B0) / BMM_RRR(A, B1) where A[RowMajor][M, K], B[RowMajor][K, N], B1[RowMajor][K, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py index 283aaab72..9d1e96e0d 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py @@ -17,6 +17,7 @@ C = FAST_GELU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) where A[RowMajor][M, K], B[ColMajor][N, K], B1[ColMajor][N, K] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py index a2cd67f60..d65d73320 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/dual_gemm_rcr_silu.py @@ -17,6 +17,7 @@ C = SILU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) where A[RowMajor][M, K], B[ColMajor][N, K], B1[ColMajor][N, K] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py index bf3a4d0c0..47beea105 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When use for `linear`, need set A->Data, B->Weight """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py index 35ef2e467..9337e6896 100644 --- a/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py +++ b/python/aitemplate/backend/cuda/gemm_epilogue_vistor/gemm_rcr_softmax.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When use for `linear`, need set A->Data, B->Weight """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_special/__init__.py b/python/aitemplate/backend/cuda/gemm_special/__init__.py index 00f99c3c7..011cf4123 100644 --- a/python/aitemplate/backend/cuda/gemm_special/__init__.py +++ b/python/aitemplate/backend/cuda/gemm_special/__init__.py @@ -15,6 +15,7 @@ """ special gemm ops """ + from aitemplate.backend.cuda.gemm_special import ( batched_dense_vec_jagged_2d_mul, bmm_rcr_n1, diff --git a/python/aitemplate/backend/cuda/gemm_special/batched_dense_vec_jagged_2d_mul.py b/python/aitemplate/backend/cuda/gemm_special/batched_dense_vec_jagged_2d_mul.py index aa4267dd5..bc614e395 100644 --- a/python/aitemplate/backend/cuda/gemm_special/batched_dense_vec_jagged_2d_mul.py +++ b/python/aitemplate/backend/cuda/gemm_special/batched_dense_vec_jagged_2d_mul.py @@ -15,6 +15,7 @@ """ Define batched_dense_vec_jagged_2d_mul codegen and CUDA kernel """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py index 797028e69..30e8d11b4 100644 --- a/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_special/bmm_rrr_k1_tanh.py @@ -20,6 +20,7 @@ B[RowMajor]: [B, 1, N] C[RowMajor]: [B, M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py index bd3affc47..860186c01 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_common.py @@ -15,6 +15,7 @@ """ Common functions and templates for bmm-family ops """ + import dataclasses import jinja2 diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py index 7fdfa98b4..01e4d0e82 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_permute_common.py @@ -15,6 +15,7 @@ """ Common functions and templates for bmm_permute-family ops """ + from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.common import gemm_common diff --git a/python/aitemplate/backend/cuda/gemm_universal/bmm_xxx_add.py b/python/aitemplate/backend/cuda/gemm_universal/bmm_xxx_add.py index a95edbc8f..f3370029d 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/bmm_xxx_add.py +++ b/python/aitemplate/backend/cuda/gemm_universal/bmm_xxx_add.py @@ -23,7 +23,6 @@ "cuda.bmm_rcr_add.func_call". """ - from aitemplate.backend import registry from aitemplate.backend.common import gemm_common from aitemplate.backend.cuda.gemm_universal import bmm_common, common diff --git a/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py index e8c7af5b6..e7e71974a 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py +++ b/python/aitemplate/backend/cuda/gemm_universal/common_bias_activation.py @@ -16,6 +16,7 @@ """ Common codegen functions for gemm_bias_activation. """ + import jinja2 from aitemplate.backend.backend_spec import CUDASpec diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py index 3f1d6aaa1..08177aca4 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr.py @@ -17,6 +17,7 @@ C = GeMM(A, B) where A[RowMajor][M, K], B[ColMajor][N, K] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py index 1464383bc..b2db15adc 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias.py @@ -17,6 +17,7 @@ C = GeMM(A, B) + bias where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_elementwise.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_elementwise.py index 01d62de36..839ac6ab1 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_elementwise.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_elementwise.py @@ -16,6 +16,7 @@ GEMM Specialization for C = UnaryOp2(BinaryOp2(BinaryOp1(UnaryOp1(GeMM(A, B) + bias), D1), D2)), """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import common, common_bias_broadcast from aitemplate.backend.cuda.gemm_universal.layout import RCR diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py index e88acc73d..b5e325388 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_fast_gelu.py @@ -16,6 +16,7 @@ GEMM Specialization for C = fast_gelu(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py index 4d577b5e9..727d3019d 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_gelu.py @@ -16,6 +16,7 @@ GEMM Specialization for C = gelu(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py index 3524f0c81..73f52d54e 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_hardswish.py @@ -16,6 +16,7 @@ GEMM Specialization for C = hard_swish(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py index 1d4c20f23..602e9b999 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_sigmoid.py @@ -17,6 +17,7 @@ C = Sigmoid(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py index 91e17d474..2585802ad 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_swish.py @@ -17,6 +17,7 @@ C = swish(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py index afd0b09b4..752cb91b4 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_bias_tanh.py @@ -17,6 +17,7 @@ C = tanh(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py index a0619f56a..01afdb9eb 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_fast_gelu.py @@ -16,6 +16,7 @@ GEMM Specialization for C = fast_gelu(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][K], C[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py index 30740cdf7..45086c94f 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute.py @@ -17,6 +17,7 @@ C = permute(GeMM(A, B) + bias) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute_elup1.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute_elup1.py index f90882e31..83189fdba 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute_elup1.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rcr_permute_elup1.py @@ -17,6 +17,7 @@ C = permute(elu(GeMM(A, B) + bias) + 1.0) where A[RowMajor][M, K], B[ColMajor][N, K], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py index f4b0a0d07..3f4701b6c 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr.py @@ -17,6 +17,7 @@ C = GeMM(A, B) where A[RowMajor][M, K], B[RowMajor][K, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py index 42c675ff0..773c1c0ca 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_bias.py @@ -17,6 +17,7 @@ C = GeMM(A, B) + bias where A[RowMajor][M, K], B[ColMajor][K, N], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py index 2e4d1ed0a..4025cad84 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/gemm_rrr_permute.py @@ -17,6 +17,7 @@ C = permute(GeMM(A, B) + bias) where A[RowMajor][M, K], B[RowMajor][K, N], bias[RowMajor][N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common.py b/python/aitemplate/backend/cuda/gemm_universal/group_common.py index c06445e74..60078597d 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_common.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common.py @@ -15,6 +15,7 @@ """ Common functions and templates for group-gemm-family kernels """ + import re from hashlib import sha1 from typing import Any, Dict, List diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py index 9e57686b4..196fc09d1 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_common_bias.py @@ -15,6 +15,7 @@ """ Common codegen functions for group_gemm_bias-family kernels. """ + import jinja2 from aitemplate.backend.cuda.gemm_universal import group_common diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py index 03acac5df..4a9521788 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr.py @@ -15,6 +15,7 @@ """ Codegen functions for group_gemm_rcr. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py index e2a9589bd..7e1374b2f 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias.py @@ -15,6 +15,7 @@ """ Codegen functions for group_gemm_rcr_bias. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( common, diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py index df4fccb31..5bebf3972 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_relu.py @@ -15,6 +15,7 @@ """ Codegen functions for group_gemm_rcr_bias_relu. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( common, diff --git a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py index 4e6a6a15f..0756c91ab 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/backend/cuda/gemm_universal/group_gemm_rcr_bias_sigmoid.py @@ -15,6 +15,7 @@ """ Codegen functions for group_gemm_rcr_bias_sigmoid. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( common, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py index e6b51647d..daaf26028 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr.py @@ -16,6 +16,7 @@ Codegen functions for perm021fc_ccr, which computes [b, m, n] = bmm([b, k, m], [1, n, k]). """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import bmm_common, common diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py index d3946f532..f29d9acfe 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias.py @@ -16,6 +16,7 @@ Codegen functions for perm021fc_ccr_bias, which computes [b, m, n] = bmm([b, k, m], [1, n, k]) + bias[n]. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( bmm_common, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py index 4bbc994ed..1d5690756 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_ccr_bias_permute.py @@ -16,6 +16,7 @@ Common functions and templates for perm021_ccr_bias_permute, which computes (A.permute(0, 2, 1)[col] @ B[col] + Bias).permute(0, 2, 1) """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py index db645f28f..b5b49f5c0 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc.py @@ -16,6 +16,7 @@ Codegen functions for perm021fc_crc, which computes [b, n, m](col) = bmm([1, k, n](col), [b, k, m](row)). """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import bmm_common, common diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py index 3546cea7c..bc57b5b03 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm021fc_crc_bias.py @@ -16,6 +16,7 @@ Codegen functions for perm021fc_crc_bias, which computes [b, n, m](col) = bmm([1, k, n](col), [b, k, m](row)) + bias[n]. """ + from aitemplate.backend import registry from aitemplate.backend.cuda.gemm_universal import ( bmm_common, diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py index 1e0273c00..e09d92ba8 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr.py @@ -16,6 +16,7 @@ Codegen functions for perm102_bmm_rcr, which computes C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) """ + from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.cuda.gemm_universal import bmm_common, common diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py index 6634ff80f..067e82560 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rcr_bias.py @@ -16,6 +16,7 @@ Codegen functions for perm102_bmm_rcr_bias, which computes C[m, b, n](row) = bmm(A[m, b, k](row), B[b, n, k](col)) + bias[n]. """ + from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.cuda.gemm_universal import ( diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py index 354a4392b..edcf633a0 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr.py @@ -16,6 +16,7 @@ Codegen functions for perm102_bmm_rrr, which computes C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) """ + from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.cuda.gemm_universal import bmm_common, common diff --git a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py index de73d4880..e83c8ae60 100644 --- a/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py +++ b/python/aitemplate/backend/cuda/gemm_universal/perm102_bmm_rrr_bias.py @@ -16,6 +16,7 @@ Codegen functions for perm102_bmm_rrr_bias, which computes C[m, b, n](row) = bmm(A[m, b, k](row), B[b, k, n](row)) + bias[n] """ + from aitemplate.backend import registry from aitemplate.backend.backend_spec import CUDASpec from aitemplate.backend.cuda.gemm_universal import ( diff --git a/python/aitemplate/backend/cuda/jagged/__init__.py b/python/aitemplate/backend/cuda/jagged/__init__.py index 550a59a2b..8d2159675 100644 --- a/python/aitemplate/backend/cuda/jagged/__init__.py +++ b/python/aitemplate/backend/cuda/jagged/__init__.py @@ -15,6 +15,7 @@ """ CUDA jagged tensor-specific ops module init """ + from aitemplate.backend.cuda.jagged import ( jagged_lengths_to_offsets, jagged_lengths_to_presences, diff --git a/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_offsets.py b/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_offsets.py index d47e8925e..bc11bfb62 100644 --- a/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_offsets.py +++ b/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_offsets.py @@ -15,6 +15,7 @@ """ Codegen functions for the jagged_lengths_to_offsets op. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_presences.py b/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_presences.py index 91e5f528f..6723133b4 100644 --- a/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_presences.py +++ b/python/aitemplate/backend/cuda/jagged/jagged_lengths_to_presences.py @@ -15,6 +15,7 @@ """ Codegen functions for the jagged_lengths_to_presences op. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py index c8fd30caf..4d4c6adc2 100644 --- a/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py +++ b/python/aitemplate/backend/cuda/layernorm_sigmoid_mul/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.cuda.layernorm_sigmoid_mul import ( batch_layernorm_sigmoid_mul, group_layernorm_sigmoid_mul, diff --git a/python/aitemplate/backend/cuda/lib_template.py b/python/aitemplate/backend/cuda/lib_template.py index 56cd310bf..57c8f8194 100644 --- a/python/aitemplate/backend/cuda/lib_template.py +++ b/python/aitemplate/backend/cuda/lib_template.py @@ -15,6 +15,7 @@ """ Common function templates for CUDA codegen. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/padding/__init__.py b/python/aitemplate/backend/cuda/padding/__init__.py index 807b81bc4..51c512609 100644 --- a/python/aitemplate/backend/cuda/padding/__init__.py +++ b/python/aitemplate/backend/cuda/padding/__init__.py @@ -15,6 +15,7 @@ """ CUDA padding init """ + from aitemplate.backend.cuda.padding import ndhwc3to8, nhwc3to4, nhwc3to8, pad_last_dim __all__ = ["ndhwc3to8", "nhwc3to8", "pad_last_dim", "nhwc3to4"] diff --git a/python/aitemplate/backend/cuda/padding/ndhwc3to8.py b/python/aitemplate/backend/cuda/padding/ndhwc3to8.py index bb03c0b16..25f1dd7ef 100644 --- a/python/aitemplate/backend/cuda/padding/ndhwc3to8.py +++ b/python/aitemplate/backend/cuda/padding/ndhwc3to8.py @@ -15,6 +15,7 @@ """ CUDA codegen for ndhwc3to8 op """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/padding/nhwc3to4.py b/python/aitemplate/backend/cuda/padding/nhwc3to4.py index c07f8bc33..f74c67ef5 100644 --- a/python/aitemplate/backend/cuda/padding/nhwc3to4.py +++ b/python/aitemplate/backend/cuda/padding/nhwc3to4.py @@ -15,6 +15,7 @@ """ CUDA codegen for nhwc3to4 op """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/padding/nhwc3to8.py b/python/aitemplate/backend/cuda/padding/nhwc3to8.py index 0f4e4eb52..bbfb0bd18 100644 --- a/python/aitemplate/backend/cuda/padding/nhwc3to8.py +++ b/python/aitemplate/backend/cuda/padding/nhwc3to8.py @@ -15,6 +15,7 @@ """ CUDA codegen for nhwc3to8 op """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/padding/pad_last_dim.py b/python/aitemplate/backend/cuda/padding/pad_last_dim.py index 70e1e9f5f..aaa29b099 100644 --- a/python/aitemplate/backend/cuda/padding/pad_last_dim.py +++ b/python/aitemplate/backend/cuda/padding/pad_last_dim.py @@ -15,6 +15,7 @@ """ Codegen functions for pad_last_dim. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/pool2d/__init__.py b/python/aitemplate/backend/cuda/pool2d/__init__.py index 437cf7395..a0d60bcac 100644 --- a/python/aitemplate/backend/cuda/pool2d/__init__.py +++ b/python/aitemplate/backend/cuda/pool2d/__init__.py @@ -15,6 +15,7 @@ """ CUDA pool2d module init """ + from aitemplate.backend.cuda.pool2d import avg_pool2d, max_pool2d __all__ = ["avg_pool2d", "max_pool2d"] diff --git a/python/aitemplate/backend/cuda/pool2d/pool2d.py b/python/aitemplate/backend/cuda/pool2d/pool2d.py index 5c92c55f4..ce91393a6 100644 --- a/python/aitemplate/backend/cuda/pool2d/pool2d.py +++ b/python/aitemplate/backend/cuda/pool2d/pool2d.py @@ -15,6 +15,7 @@ """ CUDA pool2d common functions """ + import jinja2 FUNC_DECL_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/cuda/reduce/__init__.py b/python/aitemplate/backend/cuda/reduce/__init__.py index 9aa2a5bf2..2131e13ff 100644 --- a/python/aitemplate/backend/cuda/reduce/__init__.py +++ b/python/aitemplate/backend/cuda/reduce/__init__.py @@ -15,6 +15,7 @@ """ CUDA reduce module init """ + from aitemplate.backend.cuda.reduce import ( reduce_3d, reduce_common, diff --git a/python/aitemplate/backend/cuda/reduce/reduce_3d.py b/python/aitemplate/backend/cuda/reduce/reduce_3d.py index 04a81416a..82670ada3 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_3d.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_3d.py @@ -20,6 +20,7 @@ reduction ops such as reduce_mean and norm where we need to apply a scalar-op to each final element. """ + import bisect import jinja2 diff --git a/python/aitemplate/backend/cuda/reduce/reduce_common.py b/python/aitemplate/backend/cuda/reduce/reduce_common.py index aa059f038..ed96e2da8 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_common.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_common.py @@ -15,6 +15,7 @@ """ CUDA reduce common functions """ + import jinja2 from aitemplate.backend.backend_spec import CUDASpec diff --git a/python/aitemplate/backend/cuda/reduce/reduce_common_slim_tensor.py b/python/aitemplate/backend/cuda/reduce/reduce_common_slim_tensor.py index 7d5992e01..63ed9c265 100644 --- a/python/aitemplate/backend/cuda/reduce/reduce_common_slim_tensor.py +++ b/python/aitemplate/backend/cuda/reduce/reduce_common_slim_tensor.py @@ -17,7 +17,6 @@ along dim=1 for 3D tensors is supported. """ - from typing import List import jinja2 diff --git a/python/aitemplate/backend/cuda/softmax/__init__.py b/python/aitemplate/backend/cuda/softmax/__init__.py index 615fd1954..89d434aaa 100644 --- a/python/aitemplate/backend/cuda/softmax/__init__.py +++ b/python/aitemplate/backend/cuda/softmax/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.cuda.softmax import softmax __all__ = ["softmax"] diff --git a/python/aitemplate/backend/cuda/softmax/softmax.py b/python/aitemplate/backend/cuda/softmax/softmax.py index c08939752..bf61a1920 100644 --- a/python/aitemplate/backend/cuda/softmax/softmax.py +++ b/python/aitemplate/backend/cuda/softmax/softmax.py @@ -15,6 +15,7 @@ """ Softmax codegen for CUDA. """ + from __future__ import annotations import math diff --git a/python/aitemplate/backend/cuda/target_def.py b/python/aitemplate/backend/cuda/target_def.py index 1523c748f..9a66d54b3 100644 --- a/python/aitemplate/backend/cuda/target_def.py +++ b/python/aitemplate/backend/cuda/target_def.py @@ -15,6 +15,7 @@ """ CUDA target specialization """ + import json import logging import os @@ -161,15 +162,17 @@ def _build_nvcc_compiler_options(self) -> List[str]: "-DCUTLASS_DEBUG_TRACE_LEVEL=" + environ.get_cutlass_debug_trace_level(), ] if environ.enable_ptxas_info(): - options.extend( - [ - "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) - "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels - "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels - "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) - "--source-in-ptx", - ] - ), # Annotate the ptx file with source information + ( + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", + ] + ), + ) # Annotate the ptx file with source information options.extend(self._get_nvcc_debug_options()) if self._ndebug == 1: options.append("-DNDEBUG") @@ -449,15 +452,17 @@ def _build_compile_options(self): ) ) if environ.enable_ptxas_info(): - options.extend( - [ - "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) - "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels - "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels - "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) - "--source-in-ptx", # Annotate the ptx file with source information - ] - ), + ( + options.extend( + [ + "--keep", # Keep the intermediate files for debugging (including ptx, sass, cubin etc.) + "--ptxas-options=--warn-on-local-memory-usage", # warn us if local memory is used in CUDA Kernels + "--ptxas-options=--warn-on-spills", # warn us if register spilling happens in CUDA Kernels + "--resource-usage", # Report on CUDA resource usage (shared mem, registers etc.) + "--source-in-ptx", # Annotate the ptx file with source information + ] + ), + ) options.extend(self._get_nvcc_debug_options()) if self._ndebug == 1: options.append("-DNDEBUG") diff --git a/python/aitemplate/backend/cuda/tensor/__init__.py b/python/aitemplate/backend/cuda/tensor/__init__.py index 4cef720ad..dbc35f47c 100644 --- a/python/aitemplate/backend/cuda/tensor/__init__.py +++ b/python/aitemplate/backend/cuda/tensor/__init__.py @@ -15,6 +15,7 @@ """ CUDA tensor ops module init """ + from aitemplate.backend.cuda.tensor import ( argmax, batch_gather, diff --git a/python/aitemplate/backend/cuda/tensor/concatenate_tanh.py b/python/aitemplate/backend/cuda/tensor/concatenate_tanh.py index 833bc9c22..9928c14a5 100644 --- a/python/aitemplate/backend/cuda/tensor/concatenate_tanh.py +++ b/python/aitemplate/backend/cuda/tensor/concatenate_tanh.py @@ -15,6 +15,7 @@ """ Codegen functions for concatenate_tanh. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/tensor/gather.py b/python/aitemplate/backend/cuda/tensor/gather.py index 0841dcb18..4b9c72260 100644 --- a/python/aitemplate/backend/cuda/tensor/gather.py +++ b/python/aitemplate/backend/cuda/tensor/gather.py @@ -15,6 +15,7 @@ """ CUDA gather function """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/tensor/index_select.py b/python/aitemplate/backend/cuda/tensor/index_select.py index 29c026d97..6b84829a9 100644 --- a/python/aitemplate/backend/cuda/tensor/index_select.py +++ b/python/aitemplate/backend/cuda/tensor/index_select.py @@ -40,6 +40,7 @@ Blocks are(N + threads - 1) / threads; """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/tensor/jagged_to_padded_dense.py b/python/aitemplate/backend/cuda/tensor/jagged_to_padded_dense.py index bdafac544..d16c11bc5 100644 --- a/python/aitemplate/backend/cuda/tensor/jagged_to_padded_dense.py +++ b/python/aitemplate/backend/cuda/tensor/jagged_to_padded_dense.py @@ -15,6 +15,7 @@ """ The back-end bindings of the jagged_to_padded_dense op. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/cuda/tensor/masked_select.py b/python/aitemplate/backend/cuda/tensor/masked_select.py index a09cfa460..66d66f910 100644 --- a/python/aitemplate/backend/cuda/tensor/masked_select.py +++ b/python/aitemplate/backend/cuda/tensor/masked_select.py @@ -15,6 +15,7 @@ """ Define masked_select codegen and CUDA kernel """ + from typing import List import jinja2 diff --git a/python/aitemplate/backend/cuda/tensor/padded_dense_to_jagged.py b/python/aitemplate/backend/cuda/tensor/padded_dense_to_jagged.py index 73d4e9e19..30486b31b 100644 --- a/python/aitemplate/backend/cuda/tensor/padded_dense_to_jagged.py +++ b/python/aitemplate/backend/cuda/tensor/padded_dense_to_jagged.py @@ -15,6 +15,7 @@ """ The back-end bindings of the padded_dense_to_jagged op. """ + from typing import Any, Dict, List import jinja2 diff --git a/python/aitemplate/backend/cuda/tensor/permute.py b/python/aitemplate/backend/cuda/tensor/permute.py index f0ce038d6..6212b9573 100644 --- a/python/aitemplate/backend/cuda/tensor/permute.py +++ b/python/aitemplate/backend/cuda/tensor/permute.py @@ -15,6 +15,7 @@ """ permute for cuda """ + import os from typing import Any, Dict diff --git a/python/aitemplate/backend/cuda/tensor/slice_reshape_scatter.py b/python/aitemplate/backend/cuda/tensor/slice_reshape_scatter.py index c1552ddb9..5d2f31e65 100644 --- a/python/aitemplate/backend/cuda/tensor/slice_reshape_scatter.py +++ b/python/aitemplate/backend/cuda/tensor/slice_reshape_scatter.py @@ -15,6 +15,7 @@ """ Slice reshape scatter CUDA implementation. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/upsample/__init__.py b/python/aitemplate/backend/cuda/upsample/__init__.py index f7fa8ce45..ab2c12d63 100644 --- a/python/aitemplate/backend/cuda/upsample/__init__.py +++ b/python/aitemplate/backend/cuda/upsample/__init__.py @@ -15,6 +15,7 @@ """ CUDA upsampling module init """ + from aitemplate.backend.cuda.upsample import upsampling2d, upsampling2d_add __all__ = ["upsampling2d", "upsampling2d_add"] diff --git a/python/aitemplate/backend/cuda/utils.py b/python/aitemplate/backend/cuda/utils.py index f2a1c4900..20107a085 100644 --- a/python/aitemplate/backend/cuda/utils.py +++ b/python/aitemplate/backend/cuda/utils.py @@ -15,6 +15,7 @@ """ Util functions for CUDA codegen. """ + import logging from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/view_ops/__init__.py b/python/aitemplate/backend/cuda/view_ops/__init__.py index b2be80a1a..7bb0ffb70 100644 --- a/python/aitemplate/backend/cuda/view_ops/__init__.py +++ b/python/aitemplate/backend/cuda/view_ops/__init__.py @@ -15,6 +15,7 @@ """ CUDA view_ops module init """ + from aitemplate.backend.cuda.view_ops import make_jagged, view_ops __all__ = [ diff --git a/python/aitemplate/backend/cuda/view_ops/make_jagged.py b/python/aitemplate/backend/cuda/view_ops/make_jagged.py index 685d2fb3e..7e9783019 100644 --- a/python/aitemplate/backend/cuda/view_ops/make_jagged.py +++ b/python/aitemplate/backend/cuda/view_ops/make_jagged.py @@ -28,6 +28,7 @@ of the constraints can be checked on the device, in which case an std::runtime_error is thrown on violation. """ + from typing import Set import jinja2 diff --git a/python/aitemplate/backend/cuda/view_ops/view_ops.py b/python/aitemplate/backend/cuda/view_ops/view_ops.py index 63f06765a..13fae9cb1 100644 --- a/python/aitemplate/backend/cuda/view_ops/view_ops.py +++ b/python/aitemplate/backend/cuda/view_ops/view_ops.py @@ -15,6 +15,7 @@ """ Codegen functions for view ops. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/vision_ops/nms/__init__.py b/python/aitemplate/backend/cuda/vision_ops/nms/__init__.py index 4f47cf2d8..f758bdffe 100644 --- a/python/aitemplate/backend/cuda/vision_ops/nms/__init__.py +++ b/python/aitemplate/backend/cuda/vision_ops/nms/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.cuda.vision_ops.nms import ( # noqa batched_nms, efficient_nms, diff --git a/python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py b/python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py index 5959e1a3b..100672686 100644 --- a/python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py +++ b/python/aitemplate/backend/cuda/vision_ops/roi_ops/__init__.py @@ -15,6 +15,7 @@ """ CUDA roi_align module init """ + from aitemplate.backend.cuda.vision_ops.roi_ops import multi_level_roi_align, roi_align __all__ = ["roi_align", "multi_level_roi_align"] diff --git a/python/aitemplate/backend/cuda/vision_ops/roi_ops/multi_level_roi_align.py b/python/aitemplate/backend/cuda/vision_ops/roi_ops/multi_level_roi_align.py index 564e01086..49138f5f5 100644 --- a/python/aitemplate/backend/cuda/vision_ops/roi_ops/multi_level_roi_align.py +++ b/python/aitemplate/backend/cuda/vision_ops/roi_ops/multi_level_roi_align.py @@ -15,6 +15,7 @@ """ Codegen functions for multi-level roi align. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_ops.py b/python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_ops.py index 1754c298c..506f4eb1b 100644 --- a/python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_ops.py +++ b/python/aitemplate/backend/cuda/vision_ops/roi_ops/roi_ops.py @@ -15,6 +15,7 @@ """ Codegen functions for roi ops. """ + import jinja2 FUNC_DECL_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/main_templates.py b/python/aitemplate/backend/main_templates.py index 61c6b44e5..78b06ad6a 100644 --- a/python/aitemplate/backend/main_templates.py +++ b/python/aitemplate/backend/main_templates.py @@ -15,6 +15,7 @@ """ This file contains class definitions used in the generated main.cu file. """ + import jinja2 MODEL_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/profiler_cache.py b/python/aitemplate/backend/profiler_cache.py index e2170ba04..de26fbe84 100644 --- a/python/aitemplate/backend/profiler_cache.py +++ b/python/aitemplate/backend/profiler_cache.py @@ -15,6 +15,7 @@ """ SQLite backend for conv/gemm profiling cache """ + import enum import logging import sqlite3 diff --git a/python/aitemplate/backend/profiler_runner.py b/python/aitemplate/backend/profiler_runner.py index 26621a6a9..5b9291c2b 100644 --- a/python/aitemplate/backend/profiler_runner.py +++ b/python/aitemplate/backend/profiler_runner.py @@ -15,6 +15,7 @@ """ A subprocess based multiple GPUs runner for auto-tuning """ + from __future__ import annotations import concurrent.futures diff --git a/python/aitemplate/backend/rocm/__init__.py b/python/aitemplate/backend/rocm/__init__.py index a687a1ee7..34d9dd7ad 100644 --- a/python/aitemplate/backend/rocm/__init__.py +++ b/python/aitemplate/backend/rocm/__init__.py @@ -16,6 +16,7 @@ """ Rocm backend init. """ + from aitemplate.backend.rocm import lib_template, target_def, utils from aitemplate.backend.rocm.attention import * from aitemplate.backend.rocm.common import * diff --git a/python/aitemplate/backend/rocm/attention/mem_eff_attention.py b/python/aitemplate/backend/rocm/attention/mem_eff_attention.py index f902792c1..554a0fc4a 100644 --- a/python/aitemplate/backend/rocm/attention/mem_eff_attention.py +++ b/python/aitemplate/backend/rocm/attention/mem_eff_attention.py @@ -15,6 +15,7 @@ """ attention kernel codegen for ROCM. """ + from typing import Any, Dict import jinja2 diff --git a/python/aitemplate/backend/rocm/common/__init__.py b/python/aitemplate/backend/rocm/common/__init__.py index 3e6e5152f..d96706b71 100644 --- a/python/aitemplate/backend/rocm/common/__init__.py +++ b/python/aitemplate/backend/rocm/common/__init__.py @@ -16,4 +16,5 @@ """ ROCM Common module init """ + from aitemplate.backend.rocm.common.dummy_op import * diff --git a/python/aitemplate/backend/rocm/conv2d/__init__.py b/python/aitemplate/backend/rocm/conv2d/__init__.py index ddcd3131c..c5b6fac6a 100644 --- a/python/aitemplate/backend/rocm/conv2d/__init__.py +++ b/python/aitemplate/backend/rocm/conv2d/__init__.py @@ -15,6 +15,7 @@ """ ROCM conv2d init. """ + from aitemplate.backend.rocm.conv2d import ( conv2d, conv2d_bias, diff --git a/python/aitemplate/backend/rocm/conv2d/common.py b/python/aitemplate/backend/rocm/conv2d/common.py index 0d30e05c8..6178e0f24 100644 --- a/python/aitemplate/backend/rocm/conv2d/common.py +++ b/python/aitemplate/backend/rocm/conv2d/common.py @@ -15,6 +15,7 @@ """ Common ROCM template for conv2d. """ + import os import re from collections import OrderedDict diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d.py b/python/aitemplate/backend/rocm/conv2d/conv2d.py index 8c9df0f5f..db84311ba 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for conv2d. """ + from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py index 91506f2f9..4f0a990a0 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for Conv2dBias: conv2d(w, x) + b """ + from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py index 3c8f0e3ba..0f3adcf8f 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add.py @@ -15,6 +15,7 @@ """ conv2d bias add codegen """ + from ... import registry from . import common diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py index 190f85694..72311e6a1 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_add_relu.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for conv2d_bias_add_relu. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py index b33561394..aa6fceb7d 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for conv2d_bias_relu. """ + from aitemplate.backend import registry from aitemplate.backend.rocm.conv2d import common diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py index f43e42317..8dd196460 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for conv2d_bias_sigmoid. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/conv2d/transposed_conv2d.py b/python/aitemplate/backend/rocm/conv2d/transposed_conv2d.py index 77659bcd3..bed20c4e8 100644 --- a/python/aitemplate/backend/rocm/conv2d/transposed_conv2d.py +++ b/python/aitemplate/backend/rocm/conv2d/transposed_conv2d.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for transposed_conv2d. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/conv2d/transposed_conv2d_bias_relu.py b/python/aitemplate/backend/rocm/conv2d/transposed_conv2d_bias_relu.py index a6c5a3bd9..7384508ee 100644 --- a/python/aitemplate/backend/rocm/conv2d/transposed_conv2d_bias_relu.py +++ b/python/aitemplate/backend/rocm/conv2d/transposed_conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for transposed_conv2d_bias_relu. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/elementwise/__init__.py b/python/aitemplate/backend/rocm/elementwise/__init__.py index 4594bf9ec..82bd7acfb 100644 --- a/python/aitemplate/backend/rocm/elementwise/__init__.py +++ b/python/aitemplate/backend/rocm/elementwise/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.rocm.elementwise import fused_elementwise __all__ = ["fused_elementwise"] diff --git a/python/aitemplate/backend/rocm/gemm/__init__.py b/python/aitemplate/backend/rocm/gemm/__init__.py index ba4594cd5..81b38dc33 100644 --- a/python/aitemplate/backend/rocm/gemm/__init__.py +++ b/python/aitemplate/backend/rocm/gemm/__init__.py @@ -15,6 +15,7 @@ """ Rocm gemm init. """ + from aitemplate.backend.rocm.gemm import ( # noqa: F401 bmm_ccr, bmm_ccr_add, diff --git a/python/aitemplate/backend/rocm/gemm/bmm_ccr.py b/python/aitemplate/backend/rocm/gemm/bmm_ccr.py index a691c2cda..fc5837e9c 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_ccr.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_ccr.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, k, m], b[b, n, k]) This is used for `ops.bmm_ccr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_ccr_add.py b/python/aitemplate/backend/rocm/gemm/bmm_ccr_add.py index 6f81bf58b..60f12a993 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_ccr_add.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_ccr_add.py @@ -17,6 +17,7 @@ c[b, m, n] = a[b, k, m] * b[b, n, k] This is used for `ops.bmm_ccr_add`. """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_common.py b/python/aitemplate/backend/rocm/gemm/bmm_common.py index 5de7014d7..a1063c4a5 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_common.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_common.py @@ -15,6 +15,7 @@ """ Common template for bmm """ + import jinja2 from aitemplate.backend.rocm.gemm import common diff --git a/python/aitemplate/backend/rocm/gemm/bmm_crr.py b/python/aitemplate/backend/rocm/gemm/bmm_crr.py index 6e842652c..3a8d4eb37 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_crr.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_crr.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, k, m], b[b, k, n]) This is used for `ops.bmm_crr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_crr_add.py b/python/aitemplate/backend/rocm/gemm/bmm_crr_add.py index 8cefec56c..8be07a7de 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_crr_add.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_crr_add.py @@ -17,6 +17,7 @@ c[b, m, n] = a[b, k, m] * b[b, k, n] This is used for `ops.bmm_crr_add`. """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_permute_common.py b/python/aitemplate/backend/rocm/gemm/bmm_permute_common.py index b444e34f2..7f1f90b3f 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_permute_common.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_permute_common.py @@ -15,6 +15,7 @@ """ Common template for bmm """ + import jinja2 EXTRA_HEADER_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/backend/rocm/gemm/bmm_rcr.py b/python/aitemplate/backend/rocm/gemm/bmm_rcr.py index ee7784d42..ec28e06ea 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_rcr.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_rcr.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, m, k], b[b, n, k]) This is used for `ops.bmm_rcr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_rcr_permute.py b/python/aitemplate/backend/rocm/gemm/bmm_rcr_permute.py index 54c73b438..4c6e585b9 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_rcr_permute.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_rcr_permute.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, m, k], b[b, n, k]) This is used for `ops.bmm_rcr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_rrr.py b/python/aitemplate/backend/rocm/gemm/bmm_rrr.py index aa3b68752..95a1775e4 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_rrr.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_rrr.py @@ -17,6 +17,7 @@ c[b, m, n] = a[b, m, k] * b[b, k, n] This is used for `ops.bmm_rrr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_rrr_add.py b/python/aitemplate/backend/rocm/gemm/bmm_rrr_add.py index a862a6fb6..6a5391cd3 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_rrr_add.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_rrr_add.py @@ -17,6 +17,7 @@ c[b, m, n] = a[b, m, k] * b[b, k, n] This is used for `ops.bmm_rrr_add`. """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_rrr_permute.py b/python/aitemplate/backend/rocm/gemm/bmm_rrr_permute.py index 6d4fc73fd..d844bc9a9 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_rrr_permute.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_rrr_permute.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, m, k], b[b, n, k]) This is used for `ops.bmm_rrr_permute`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm.py b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm.py index a54f10dd1..74a3b3441 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm.py @@ -17,6 +17,7 @@ c[b, m, n] = bmm(a[b, m, k], b[b, n, k]) This is used for `ops.bmm_rcr`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py index 10337922c..cefb6bbcd 100644 --- a/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py +++ b/python/aitemplate/backend/rocm/gemm/bmm_softmax_bmm_permute.py @@ -22,6 +22,7 @@ This is used for `ops.bmm_softmax_bmm_permute`. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/common.py b/python/aitemplate/backend/rocm/gemm/common.py index 453c356e1..ff8a71576 100644 --- a/python/aitemplate/backend/rocm/gemm/common.py +++ b/python/aitemplate/backend/rocm/gemm/common.py @@ -15,6 +15,7 @@ """ Common template for gemm """ + import os import re from collections import OrderedDict diff --git a/python/aitemplate/backend/rocm/gemm/gemm_epilogue.py b/python/aitemplate/backend/rocm/gemm/gemm_epilogue.py index 52edac942..41955e48c 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_epilogue.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_epilogue.py @@ -15,6 +15,7 @@ """ Templates for different GeMM epilogues. """ + from typing import Dict, List, NamedTuple from aitemplate.compiler.ops.common.epilogue import EpilogueOp diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr.py index 20316c33a..7e8a1b603 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear(bias=false)` When used for `linear`, need to set A->Data, B->Weight """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common from aitemplate.backend.rocm.gemm.layout import RCR diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias.py index 579d0f395..662a56afa 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common from aitemplate.backend.rocm.gemm.layout import RCR diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add.py index 2cfd2aabc..8e2fdbec7 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add.py index 0b4919619..ff6565780 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add_relu.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add_relu.py index afa9723f7..0b88627b9 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add_relu.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_add_relu.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_relu.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_relu.py index 998798618..3f08ed82f 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_relu.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_add_relu.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_fast_gelu.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_fast_gelu.py index 4822664c3..71e941614 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_fast_gelu.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_fast_gelu.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear + swish` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common from aitemplate.backend.rocm.gemm.layout import RCR diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul.py index 9d7108d1e..459a79d5f 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_add.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_add.py index 51fe1c11f..76882d706 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_add.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_add.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N], D1[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_tanh.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_tanh.py index a4e0d1991..f7cfee842 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_tanh.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_mul_tanh.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N] """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute.py index 85fab9657..7043e944a 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m2n3.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m2n3.py index 4d8ba2a14..ff5e428fd 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m2n3.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m2n3.py @@ -19,6 +19,7 @@ c = c.reshape(M0, M1, N0, N1, N2) output = torch.permute(c, [2, 0, 3, 1, 4]) """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m3n2.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m3n2.py index 07df32276..ad43bf928 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m3n2.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_permute_m3n2.py @@ -19,6 +19,7 @@ c = c.reshape(M0, M1, M2, N0, N1) output = torch.permute(c, [2, 0, 3, 1, 4]) """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_relu.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_relu.py index 1e744128e..aa32a1193 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_relu.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_relu.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear + relu` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common from aitemplate.backend.rocm.gemm.layout import RCR diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid.py index f0943c5f7..658380b6b 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear + sigmoid` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul.py index 49d70d02f..fdc9d72bf 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N]. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul_tanh.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul_tanh.py index 802bf22b2..e9ca8fa19 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul_tanh.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_sigmoid_mul_tanh.py @@ -18,6 +18,7 @@ where A[RowMajor][M, K], B[ColMajor][N, K], C[RowMajor][M, N] bias[RowMajor][N], D0[RowMajor][M, N]. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_tanh.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_tanh.py index 804b47505..14980ec25 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_tanh.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_bias_tanh.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear + tanh` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rcr_permute_m2n3.py b/python/aitemplate/backend/rocm/gemm/gemm_rcr_permute_m2n3.py index 6661fd1c3..04d6fec38 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rcr_permute_m2n3.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rcr_permute_m2n3.py @@ -19,6 +19,7 @@ c = c.reshape(M0, M1, N0, N1, N2) output = torch.permute(c, [2, 0, 3, 1, 4]) """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rrr.py b/python/aitemplate/backend/rocm/gemm/gemm_rrr.py index c67848258..1511024f9 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rrr.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rrr.py @@ -18,6 +18,7 @@ This is used for `torch.mm` When used for `mm`, need to set A->Data, B->Weight """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common from aitemplate.backend.rocm.gemm.layout import RRR diff --git a/python/aitemplate/backend/rocm/gemm/gemm_rrr_bias_permute.py b/python/aitemplate/backend/rocm/gemm/gemm_rrr_bias_permute.py index ab34001c8..5e9530d5a 100644 --- a/python/aitemplate/backend/rocm/gemm/gemm_rrr_bias_permute.py +++ b/python/aitemplate/backend/rocm/gemm/gemm_rrr_bias_permute.py @@ -18,6 +18,7 @@ This is used for `torch.nn.functional.linear` When used for `linear`, need to set A->Data, B->Weight, C->Bias """ + from aitemplate.backend import registry from aitemplate.backend.rocm.gemm import common, permute_common from aitemplate.backend.rocm.gemm.layout import RRR diff --git a/python/aitemplate/backend/rocm/lib_template.py b/python/aitemplate/backend/rocm/lib_template.py index 6b97a0d21..8709fc3cf 100644 --- a/python/aitemplate/backend/rocm/lib_template.py +++ b/python/aitemplate/backend/rocm/lib_template.py @@ -15,6 +15,7 @@ """ Common codegen functions for ROCM. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/normalization/__init__.py b/python/aitemplate/backend/rocm/normalization/__init__.py index 4585e7cee..e479a2fd2 100644 --- a/python/aitemplate/backend/rocm/normalization/__init__.py +++ b/python/aitemplate/backend/rocm/normalization/__init__.py @@ -15,4 +15,5 @@ """ Common modules for backends """ + from aitemplate.backend.rocm.normalization import norm_common, softmax # noqa diff --git a/python/aitemplate/backend/rocm/normalization/groupnorm.py b/python/aitemplate/backend/rocm/normalization/groupnorm.py index a059fac29..f85a05c0d 100644 --- a/python/aitemplate/backend/rocm/normalization/groupnorm.py +++ b/python/aitemplate/backend/rocm/normalization/groupnorm.py @@ -15,6 +15,7 @@ """ Groupnorm codegen for ROCM. """ + from collections import OrderedDict from hashlib import sha1 from typing import Any, Dict diff --git a/python/aitemplate/backend/rocm/normalization/layernorm.py b/python/aitemplate/backend/rocm/normalization/layernorm.py index af3efcf24..a6b19d2e1 100644 --- a/python/aitemplate/backend/rocm/normalization/layernorm.py +++ b/python/aitemplate/backend/rocm/normalization/layernorm.py @@ -15,6 +15,7 @@ """ Layernorm codegen for ROCM. """ + from collections import OrderedDict from hashlib import sha1 from typing import Any, Dict diff --git a/python/aitemplate/backend/rocm/padding/__init__.py b/python/aitemplate/backend/rocm/padding/__init__.py index 455e327d6..c160833fb 100644 --- a/python/aitemplate/backend/rocm/padding/__init__.py +++ b/python/aitemplate/backend/rocm/padding/__init__.py @@ -15,6 +15,7 @@ """ CUDA padding init """ + from . import nhwc3to4, nhwc3to8, pad_last_dim __all__ = ["nhwc3to8", "pad_last_dim", "nhwc3to4"] diff --git a/python/aitemplate/backend/rocm/padding/nhwc3to4.py b/python/aitemplate/backend/rocm/padding/nhwc3to4.py index f652d8b75..bef2c8d09 100644 --- a/python/aitemplate/backend/rocm/padding/nhwc3to4.py +++ b/python/aitemplate/backend/rocm/padding/nhwc3to4.py @@ -15,6 +15,7 @@ """ CUDA codegen for nhwc3to4 op """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/padding/nhwc3to8.py b/python/aitemplate/backend/rocm/padding/nhwc3to8.py index 01e508a2c..684500888 100644 --- a/python/aitemplate/backend/rocm/padding/nhwc3to8.py +++ b/python/aitemplate/backend/rocm/padding/nhwc3to8.py @@ -15,6 +15,7 @@ """ CUDA codegen for nhwc3to8 op """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/padding/pad_last_dim.py b/python/aitemplate/backend/rocm/padding/pad_last_dim.py index 5c5936e77..aa9cda80d 100644 --- a/python/aitemplate/backend/rocm/padding/pad_last_dim.py +++ b/python/aitemplate/backend/rocm/padding/pad_last_dim.py @@ -15,6 +15,7 @@ """ Codegen functions for pad_last_dim. """ + import jinja2 from ... import registry diff --git a/python/aitemplate/backend/rocm/pool2d/__init__.py b/python/aitemplate/backend/rocm/pool2d/__init__.py index 072cfd047..d26bdfa71 100644 --- a/python/aitemplate/backend/rocm/pool2d/__init__.py +++ b/python/aitemplate/backend/rocm/pool2d/__init__.py @@ -15,6 +15,7 @@ """ ROCM pool2d init """ + from aitemplate.backend.rocm.pool2d import avg_pool2d, max_pool2d __all__ = ["avg_pool2d", "max_pool2d"] diff --git a/python/aitemplate/backend/rocm/pool2d/avg_pool2d.py b/python/aitemplate/backend/rocm/pool2d/avg_pool2d.py index cf1fffef8..97ba5fe92 100644 --- a/python/aitemplate/backend/rocm/pool2d/avg_pool2d.py +++ b/python/aitemplate/backend/rocm/pool2d/avg_pool2d.py @@ -15,6 +15,7 @@ """ ROCM avg_pool2d funcs """ + from aitemplate.backend import registry from aitemplate.backend.rocm.pool2d import pool2d diff --git a/python/aitemplate/backend/rocm/pool2d/max_pool2d.py b/python/aitemplate/backend/rocm/pool2d/max_pool2d.py index 9f67236f7..251c1103e 100644 --- a/python/aitemplate/backend/rocm/pool2d/max_pool2d.py +++ b/python/aitemplate/backend/rocm/pool2d/max_pool2d.py @@ -15,6 +15,7 @@ """ ROCM max_pool2d funcs """ + from aitemplate.backend import registry from aitemplate.backend.rocm.pool2d import pool2d diff --git a/python/aitemplate/backend/rocm/pool2d/pool2d.py b/python/aitemplate/backend/rocm/pool2d/pool2d.py index 3885ecc84..c61b7a06f 100644 --- a/python/aitemplate/backend/rocm/pool2d/pool2d.py +++ b/python/aitemplate/backend/rocm/pool2d/pool2d.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for pool2d. """ + from hashlib import sha1 import jinja2 diff --git a/python/aitemplate/backend/rocm/tensor/__init__.py b/python/aitemplate/backend/rocm/tensor/__init__.py index 62e7f3b00..7a3bb68c4 100644 --- a/python/aitemplate/backend/rocm/tensor/__init__.py +++ b/python/aitemplate/backend/rocm/tensor/__init__.py @@ -15,6 +15,7 @@ """ ROCM tensor ops module init """ + from aitemplate.backend.rocm.tensor import ( # noqa argmax, batch_gather, diff --git a/python/aitemplate/backend/rocm/tensor/concatenate_tanh.py b/python/aitemplate/backend/rocm/tensor/concatenate_tanh.py index 4806ca919..4d432c48d 100644 --- a/python/aitemplate/backend/rocm/tensor/concatenate_tanh.py +++ b/python/aitemplate/backend/rocm/tensor/concatenate_tanh.py @@ -15,6 +15,7 @@ """ Concatenate tanh op for ROCM backend. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/upsample/__init__.py b/python/aitemplate/backend/rocm/upsample/__init__.py index 3d822c1b0..6843fb0d9 100644 --- a/python/aitemplate/backend/rocm/upsample/__init__.py +++ b/python/aitemplate/backend/rocm/upsample/__init__.py @@ -15,6 +15,7 @@ """ ROCM upsampling module init """ + from aitemplate.backend.rocm.upsample import upsampling2d, upsampling2d_add __all__ = ["upsampling2d", "upsampling2d_add"] diff --git a/python/aitemplate/backend/rocm/utils.py b/python/aitemplate/backend/rocm/utils.py index 39cc9e0b4..5ee114f87 100644 --- a/python/aitemplate/backend/rocm/utils.py +++ b/python/aitemplate/backend/rocm/utils.py @@ -15,6 +15,7 @@ """ Util functions for ROCM. """ + import os import pathlib import re diff --git a/python/aitemplate/backend/rocm/view_ops/__init__.py b/python/aitemplate/backend/rocm/view_ops/__init__.py index 505398dde..564b49431 100644 --- a/python/aitemplate/backend/rocm/view_ops/__init__.py +++ b/python/aitemplate/backend/rocm/view_ops/__init__.py @@ -15,6 +15,7 @@ """ ROCM view_ops module init """ + from aitemplate.backend.rocm.view_ops import view_ops __all__ = ["view_ops"] diff --git a/python/aitemplate/backend/rocm/view_ops/view_ops.py b/python/aitemplate/backend/rocm/view_ops/view_ops.py index f41668fea..fd7d497f6 100644 --- a/python/aitemplate/backend/rocm/view_ops/view_ops.py +++ b/python/aitemplate/backend/rocm/view_ops/view_ops.py @@ -15,6 +15,7 @@ """ ROCM codegen functions for view ops. """ + import jinja2 from aitemplate.backend import registry diff --git a/python/aitemplate/backend/rocm/vision_ops/__init__.py b/python/aitemplate/backend/rocm/vision_ops/__init__.py index f46596197..669bb4c19 100644 --- a/python/aitemplate/backend/rocm/vision_ops/__init__.py +++ b/python/aitemplate/backend/rocm/vision_ops/__init__.py @@ -15,6 +15,7 @@ """ (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. """ + from aitemplate.backend.rocm.vision_ops import efficient_nms, nms # noqa from aitemplate.backend.rocm.vision_ops.roi_ops import ( # noqa # noqa multi_level_roi_align, diff --git a/python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py b/python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py index 8e7fc3709..017499599 100644 --- a/python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py +++ b/python/aitemplate/backend/rocm/vision_ops/roi_ops/__init__.py @@ -15,6 +15,7 @@ """ ROCM roi_align module init """ + from aitemplate.backend.rocm.vision_ops.roi_ops import multi_level_roi_align, roi_align __all__ = ["roi_align", "multi_level_roi_align"] diff --git a/python/aitemplate/backend/target.py b/python/aitemplate/backend/target.py index a464fddbd..fc2bea615 100644 --- a/python/aitemplate/backend/target.py +++ b/python/aitemplate/backend/target.py @@ -15,6 +15,7 @@ """ Target object for AITemplate. """ + import logging import os import pathlib diff --git a/python/aitemplate/compiler/base.py b/python/aitemplate/compiler/base.py index 6c3907ed6..42bf59697 100644 --- a/python/aitemplate/compiler/base.py +++ b/python/aitemplate/compiler/base.py @@ -15,6 +15,7 @@ """ Basic data types of AITemplate. """ + from __future__ import annotations import copy diff --git a/python/aitemplate/compiler/compiler.py b/python/aitemplate/compiler/compiler.py index ab2f5f45e..72741e633 100644 --- a/python/aitemplate/compiler/compiler.py +++ b/python/aitemplate/compiler/compiler.py @@ -15,6 +15,7 @@ """ build a test module from a tensor """ + import logging import os from datetime import datetime diff --git a/python/aitemplate/compiler/dtype.py b/python/aitemplate/compiler/dtype.py index 029ea4197..0f3b63776 100644 --- a/python/aitemplate/compiler/dtype.py +++ b/python/aitemplate/compiler/dtype.py @@ -16,7 +16,6 @@ dtype definitions and utility functions of AITemplate """ - _DTYPE2BYTE = { "bool": 1, "float16": 2, diff --git a/python/aitemplate/compiler/model.py b/python/aitemplate/compiler/model.py index 9617dc1dd..ab0fb4d4d 100644 --- a/python/aitemplate/compiler/model.py +++ b/python/aitemplate/compiler/model.py @@ -15,6 +15,7 @@ """ Python bindings to the AIT runtime. """ + import ctypes import enum import logging @@ -161,8 +162,8 @@ def _reshape_tensor(tensor: TorchTensor, shape: List[int]) -> TorchTensor: Reinterpret a blob of contiguous memory as some shape. Used to convert outputs in RunWithTensors. """ - assert tensor.ndim == len( - shape + assert ( + tensor.ndim == len(shape) ), f"Expected output tensor's ndim to match the length of Run()'s return value: {tensor.ndim=} != {len(shape)=}" numel = math.prod(shape) new_tensor = tensor.flatten()[:numel] diff --git a/python/aitemplate/compiler/op_registry.py b/python/aitemplate/compiler/op_registry.py index 7b870c869..e4181038a 100644 --- a/python/aitemplate/compiler/op_registry.py +++ b/python/aitemplate/compiler/op_registry.py @@ -16,6 +16,7 @@ """ Registry for basic operators and math functions. """ + from typing import Callable, Dict # OP_REGISTRY defines a mapping from a FuncEnum name to a function to create this elementwise operator. diff --git a/python/aitemplate/compiler/ops/__init__.py b/python/aitemplate/compiler/ops/__init__.py index 8752001a4..169fcfe20 100644 --- a/python/aitemplate/compiler/ops/__init__.py +++ b/python/aitemplate/compiler/ops/__init__.py @@ -16,6 +16,7 @@ """ AIT operators. """ + from aitemplate.compiler.ops.common import * from aitemplate.compiler.ops.conv import * from aitemplate.compiler.ops.embedding import * diff --git a/python/aitemplate/compiler/ops/attention/__init__.py b/python/aitemplate/compiler/ops/attention/__init__.py index ff60a7246..24c8aea5d 100644 --- a/python/aitemplate/compiler/ops/attention/__init__.py +++ b/python/aitemplate/compiler/ops/attention/__init__.py @@ -15,6 +15,7 @@ """ flash attention module init """ + from aitemplate.compiler.ops.attention.flash_attention import flash_attention from aitemplate.compiler.ops.attention.mem_eff_attention import mem_eff_attention diff --git a/python/aitemplate/compiler/ops/attention/flash_attention.py b/python/aitemplate/compiler/ops/attention/flash_attention.py index 3b0658867..d6108870c 100644 --- a/python/aitemplate/compiler/ops/attention/flash_attention.py +++ b/python/aitemplate/compiler/ops/attention/flash_attention.py @@ -15,6 +15,7 @@ """ Flash attention. """ + import itertools from collections import OrderedDict from typing import List diff --git a/python/aitemplate/compiler/ops/attention/mem_eff_attention.py b/python/aitemplate/compiler/ops/attention/mem_eff_attention.py index 55d2e9e28..b67807a9d 100644 --- a/python/aitemplate/compiler/ops/attention/mem_eff_attention.py +++ b/python/aitemplate/compiler/ops/attention/mem_eff_attention.py @@ -15,6 +15,7 @@ """ Flash attention. """ + import itertools import logging from collections import OrderedDict diff --git a/python/aitemplate/compiler/ops/common/__init__.py b/python/aitemplate/compiler/ops/common/__init__.py index 4e00e86d3..69316539b 100644 --- a/python/aitemplate/compiler/ops/common/__init__.py +++ b/python/aitemplate/compiler/ops/common/__init__.py @@ -16,6 +16,7 @@ """ Common ops. """ + from aitemplate.compiler.ops.common.elementwise import * from aitemplate.compiler.ops.common.int_elementwise import * from aitemplate.compiler.ops.common.epilogue import * diff --git a/python/aitemplate/compiler/ops/common/elementwise.py b/python/aitemplate/compiler/ops/common/elementwise.py index 1881074bf..b9750651a 100644 --- a/python/aitemplate/compiler/ops/common/elementwise.py +++ b/python/aitemplate/compiler/ops/common/elementwise.py @@ -15,6 +15,7 @@ """ Elementwise operator definition, which covers UNARY / Binary / Ternary operators. """ + import functools from typing import Any, List diff --git a/python/aitemplate/compiler/ops/common/fused_elementwise.py b/python/aitemplate/compiler/ops/common/fused_elementwise.py index 3044c63ca..8770e46f4 100644 --- a/python/aitemplate/compiler/ops/common/fused_elementwise.py +++ b/python/aitemplate/compiler/ops/common/fused_elementwise.py @@ -15,6 +15,7 @@ """ Fused elementwise operator definition. """ + from typing import Iterable, List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/common/int_elementwise.py b/python/aitemplate/compiler/ops/common/int_elementwise.py index c44db8b9c..034bd682a 100644 --- a/python/aitemplate/compiler/ops/common/int_elementwise.py +++ b/python/aitemplate/compiler/ops/common/int_elementwise.py @@ -15,6 +15,7 @@ """ Int elementwise operator definition, for integer calcuation on tensor dimensions. """ + import functools from functools import reduce diff --git a/python/aitemplate/compiler/ops/common/python_ops.py b/python/aitemplate/compiler/ops/common/python_ops.py index 98d8ba13c..35a5f39f5 100644 --- a/python/aitemplate/compiler/ops/common/python_ops.py +++ b/python/aitemplate/compiler/ops/common/python_ops.py @@ -15,6 +15,7 @@ """ Syntax sugar ops to support List/Tuples in the IR. These ops don't generate any code. """ + from typing import Any, List, Tuple, Union from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor diff --git a/python/aitemplate/compiler/ops/common/view_ops.py b/python/aitemplate/compiler/ops/common/view_ops.py index c5c174629..3bfb836bd 100644 --- a/python/aitemplate/compiler/ops/common/view_ops.py +++ b/python/aitemplate/compiler/ops/common/view_ops.py @@ -332,9 +332,7 @@ def _infer_shapes(self, x: Tensor): else: symbol_names = {s.name for s in dynamic_symbol.free_symbols} unknown_symbols = symbol_names - get_global_symbol_set() - assert ( - not unknown_symbols - ), f"Unable to deduce dynamic symbol, because the following symbols are not in global symbol set: {unknown_symbols}" + assert not unknown_symbols, f"Unable to deduce dynamic symbol, because the following symbols are not in global symbol set: {unknown_symbols}" values = simplify_intvar_values(dynamic_symbol) new_var = IntVar(values, symbolic_value=dynamic_symbol) diff --git a/python/aitemplate/compiler/ops/conv/__init__.py b/python/aitemplate/compiler/ops/conv/__init__.py index 2a49f29c9..d17a760c5 100644 --- a/python/aitemplate/compiler/ops/conv/__init__.py +++ b/python/aitemplate/compiler/ops/conv/__init__.py @@ -16,6 +16,7 @@ """ Conv2d family operators. """ + from aitemplate.compiler.ops.conv.conv2d import conv2d from aitemplate.compiler.ops.conv.conv2d_bias import conv2d_bias from aitemplate.compiler.ops.conv.conv2d_bias_add import conv2d_bias_add diff --git a/python/aitemplate/compiler/ops/conv/cache_entry.py b/python/aitemplate/compiler/ops/conv/cache_entry.py index efe4b58e0..6d60a92a6 100644 --- a/python/aitemplate/compiler/ops/conv/cache_entry.py +++ b/python/aitemplate/compiler/ops/conv/cache_entry.py @@ -15,6 +15,7 @@ """ Cache entry for conv2d. """ + from dataclasses import dataclass # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/conv/common_conv2d_bias_activation.py b/python/aitemplate/compiler/ops/conv/common_conv2d_bias_activation.py index ce2024559..61cd8aa73 100644 --- a/python/aitemplate/compiler/ops/conv/common_conv2d_bias_activation.py +++ b/python/aitemplate/compiler/ops/conv/common_conv2d_bias_activation.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_activation op. """ + from typing import Tuple from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/ops/conv/conv2d.py b/python/aitemplate/compiler/ops/conv/conv2d.py index ab9c6a0c9..a7051b620 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d.py +++ b/python/aitemplate/compiler/ops/conv/conv2d.py @@ -15,6 +15,7 @@ """ Base class for conv2d. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias.py b/python/aitemplate/compiler/ops/conv/conv2d_bias.py index 416066f29..162dbd5d7 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias.py @@ -15,6 +15,7 @@ """ Conv2d with bias. """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_activation import ( conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_add.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_add.py index 9a1dffafc..c65d63ba0 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_add.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_add.py @@ -15,6 +15,7 @@ """ fused conv2d_bias_add op """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_add_activation import ( conv2d_bias_add_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_add_hardswish.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_add_hardswish.py index 36a59445c..e318b77cc 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_add_hardswish.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_add_hardswish.py @@ -15,6 +15,7 @@ """ fused conv2d_bias_add_hardswish op, for residual block """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_add_activation import ( conv2d_bias_add_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_add_relu.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_add_relu.py index 150e10554..973ca75c2 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_add_relu.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_add_relu.py @@ -15,6 +15,7 @@ """ fused conv2d_bias_relu_add op, for residual block """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_add_activation import ( conv2d_bias_add_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py index fb34f4625..8c7739225 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_few_channels.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_few_channels op. """ + from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( special_conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish.py index e6039ade5..df4206fb3 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_hardswish op. """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_activation import ( conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py index ac79c62ac..66308fbbf 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_hardswish_few_channels.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_hardswish_few_channels op. """ + from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( special_conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_relu.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_relu.py index ab9fdcb94..b8ab4b787 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_relu.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_relu op. """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_activation import ( conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py index d915b80fe..ec2cb321a 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_relu_few_channels.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_relu_few_channels op. """ + from aitemplate.compiler.ops.conv.special_conv2d_bias_activation import ( special_conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_sigmoid.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_sigmoid.py index 55e009d91..04b8ffb05 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_bias_sigmoid.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_sigmoid.py @@ -15,6 +15,7 @@ """ Fused conv2d_bias_sigmoid op. """ + from aitemplate.compiler.ops.conv.common_conv2d_bias_activation import ( conv2d_bias_activation, ) diff --git a/python/aitemplate/compiler/ops/conv/conv2d_depthwise.py b/python/aitemplate/compiler/ops/conv/conv2d_depthwise.py index ca2117e05..04ca0d8ea 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_depthwise.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_depthwise.py @@ -15,6 +15,7 @@ """ Fused conv2d_depthwise op. """ + from typing import List, Tuple from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/ops/conv/conv2d_depthwise_bias.py b/python/aitemplate/compiler/ops/conv/conv2d_depthwise_bias.py index 505f0b976..1d8c87d62 100644 --- a/python/aitemplate/compiler/ops/conv/conv2d_depthwise_bias.py +++ b/python/aitemplate/compiler/ops/conv/conv2d_depthwise_bias.py @@ -15,6 +15,7 @@ """ Fused conv2d_depthwise op. """ + from typing import List, Tuple from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/ops/conv/conv3d.py b/python/aitemplate/compiler/ops/conv/conv3d.py index a998d7f77..67e19a13b 100644 --- a/python/aitemplate/compiler/ops/conv/conv3d.py +++ b/python/aitemplate/compiler/ops/conv/conv3d.py @@ -16,6 +16,7 @@ """ Base class for conv3d. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/conv/conv3d_bias.py b/python/aitemplate/compiler/ops/conv/conv3d_bias.py index 7c5a41362..d4d98d91b 100644 --- a/python/aitemplate/compiler/ops/conv/conv3d_bias.py +++ b/python/aitemplate/compiler/ops/conv/conv3d_bias.py @@ -16,6 +16,7 @@ """ Conv3d with bias. """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/ops/conv/depthwise_conv3d.py b/python/aitemplate/compiler/ops/conv/depthwise_conv3d.py index fe9b5a3b3..e31008194 100644 --- a/python/aitemplate/compiler/ops/conv/depthwise_conv3d.py +++ b/python/aitemplate/compiler/ops/conv/depthwise_conv3d.py @@ -15,6 +15,7 @@ """ Base class for depthwise_conv3d. """ + import itertools import re from collections import OrderedDict diff --git a/python/aitemplate/compiler/ops/conv/special_conv2d_bias_activation.py b/python/aitemplate/compiler/ops/conv/special_conv2d_bias_activation.py index f0b402820..a57028aa0 100644 --- a/python/aitemplate/compiler/ops/conv/special_conv2d_bias_activation.py +++ b/python/aitemplate/compiler/ops/conv/special_conv2d_bias_activation.py @@ -15,6 +15,7 @@ """ Fused special_conv2d_bias_activation op. """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.conv.conv2d import conv2d from aitemplate.compiler.ops.padding import nhwc3to4, nhwc3to8 diff --git a/python/aitemplate/compiler/ops/conv/transposed_conv2d_bias_relu.py b/python/aitemplate/compiler/ops/conv/transposed_conv2d_bias_relu.py index 81ea0f61e..504f7e611 100644 --- a/python/aitemplate/compiler/ops/conv/transposed_conv2d_bias_relu.py +++ b/python/aitemplate/compiler/ops/conv/transposed_conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ Fused transposed_conv2d_bias_relu op. """ + from aitemplate.compiler.ops.conv.transposed_conv2d_bias import transposed_conv2d_bias diff --git a/python/aitemplate/compiler/ops/embedding/bert_embeddings.py b/python/aitemplate/compiler/ops/embedding/bert_embeddings.py index 8d8f7f42c..38ee6c23d 100644 --- a/python/aitemplate/compiler/ops/embedding/bert_embeddings.py +++ b/python/aitemplate/compiler/ops/embedding/bert_embeddings.py @@ -15,6 +15,7 @@ """ Operator definition for bert_embeddings. """ + from aitemplate import backend from aitemplate.backend import registry from aitemplate.compiler.base import IntImm, Operator, Tensor @@ -79,10 +80,13 @@ def __call__( "int64", ], f"Expected dtype int/int32/int64 for index, got dtype {dtype_input_ids}" - assert dtype_word_embeddings in [ - "float16", - "float32", - ], f"Expected dtype float16/float32 for embeddings, got dtype {dtype_word_embeddings}" + assert ( + dtype_word_embeddings + in [ + "float16", + "float32", + ] + ), f"Expected dtype float16/float32 for embeddings, got dtype {dtype_word_embeddings}" # expecting all three ids to have the same shapes assert shape_utils.is_same_shape(input_ids.shape(), token_type_ids.shape()), ( diff --git a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_bmm_rrr_div.py b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_bmm_rrr_div.py index 4be9a4585..0520fb9f5 100644 --- a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_bmm_rrr_div.py +++ b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_bmm_rrr_div.py @@ -15,6 +15,7 @@ """ Batch GEMM specialization: BMM_RRR(A, B0) / BMM_RRR(A, B1) """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.gemm_universal import bmm_rrr from aitemplate.compiler.tensor_accessor import TensorAccessor diff --git a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py index 62f8db0eb..43279fffc 100644 --- a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py +++ b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_fast_gelu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: FAST_GELU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.gemm_universal.gemm_rcr import gemm_rcr from aitemplate.compiler.tensor_accessor import TensorAccessor diff --git a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_silu.py b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_silu.py index e847b1acc..3da071e2c 100644 --- a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_silu.py +++ b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/dual_gemm_rcr_silu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: SILU(GEMM_RCR(A, B)) * GEMM_RCR(A, B1) """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.gemm_universal.gemm_rcr import gemm_rcr from aitemplate.compiler.tensor_accessor import TensorAccessor diff --git a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py index 76aa740dd..f30e49006 100644 --- a/python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py +++ b/python/aitemplate/compiler/ops/gemm_epilogue_vistor/gemm_rcr_bias_softmax.py @@ -15,6 +15,7 @@ """ Operator definition for gemm_rcr_bias_softmax. """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.gemm_epilogue_vistor.gemm_rcr_softmax import ( gemm_rcr_softmax, diff --git a/python/aitemplate/compiler/ops/gemm_special/__init__.py b/python/aitemplate/compiler/ops/gemm_special/__init__.py index b577f5ae0..ec4ce0584 100644 --- a/python/aitemplate/compiler/ops/gemm_special/__init__.py +++ b/python/aitemplate/compiler/ops/gemm_special/__init__.py @@ -15,6 +15,7 @@ """ special gemm ops """ + from aitemplate.compiler.ops.gemm_special.batched_dense_vec_jagged_2d_mul import ( batched_dense_vec_jagged_2d_mul, ) diff --git a/python/aitemplate/compiler/ops/gemm_special/batched_dense_vec_jagged_2d_mul.py b/python/aitemplate/compiler/ops/gemm_special/batched_dense_vec_jagged_2d_mul.py index 9ea46d170..f666c97a4 100644 --- a/python/aitemplate/compiler/ops/gemm_special/batched_dense_vec_jagged_2d_mul.py +++ b/python/aitemplate/compiler/ops/gemm_special/batched_dense_vec_jagged_2d_mul.py @@ -16,6 +16,7 @@ """ Define batched_dense_vec_jagged_2d_mul op """ + from typing import List from aitemplate.backend import registry diff --git a/python/aitemplate/compiler/ops/gemm_special/bmm_rrr_k1_tanh.py b/python/aitemplate/compiler/ops/gemm_special/bmm_rrr_k1_tanh.py index 951727081..e4f24617d 100644 --- a/python/aitemplate/compiler/ops/gemm_special/bmm_rrr_k1_tanh.py +++ b/python/aitemplate/compiler/ops/gemm_special/bmm_rrr_k1_tanh.py @@ -15,6 +15,7 @@ """ Operator definition for bmm_rrr_k1_tanh. """ + from typing import List from aitemplate.compiler.base import IntVar, Tensor diff --git a/python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm_permute.py b/python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm_permute.py index e2e04bcc0..dcf694556 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm_permute.py +++ b/python/aitemplate/compiler/ops/gemm_universal/bmm_softmax_bmm_permute.py @@ -15,6 +15,7 @@ """ BMM_RCR + Softmax + BMM_RRR + Permute Specialization """ + from typing import Tuple from aitemplate.compiler.base import IntImm, Tensor diff --git a/python/aitemplate/compiler/ops/gemm_universal/cache_entry.py b/python/aitemplate/compiler/ops/gemm_universal/cache_entry.py index 6ca3537c7..553cc4ff9 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/cache_entry.py +++ b/python/aitemplate/compiler/ops/gemm_universal/cache_entry.py @@ -15,6 +15,7 @@ """ GEMM profiling cache entries """ + from dataclasses import dataclass diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_common.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_common.py index c51aa22ba..1c1cf33ec 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_common.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_common.py @@ -15,6 +15,7 @@ """ Common functions/classes for GEMM ops """ + import itertools import logging import math diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py index 9df2c9222..b355d3999 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias.py @@ -15,6 +15,7 @@ """ GEMM Specialization: GEMM_RCR(A, B) + Bias """ + from aitemplate.compiler.base import IntImm, Tensor from aitemplate.compiler.ops.gemm_universal import gemm_rcr from aitemplate.compiler.tensor_accessor import TensorAccessor diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_fast_gelu.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_fast_gelu.py index 743ade763..891d38301 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_fast_gelu.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_fast_gelu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: FastGELU(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py index 34157307f..28ed25b4e 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_gelu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: GELU(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py index b658c243d..26b4bb8cb 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_hardswish.py @@ -15,6 +15,7 @@ """ GEMM Specialization: HardSwish(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py index 6c6307d76..7f2288498 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_relu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: ReLU(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py index f4f868328..6372e9e92 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_sigmoid.py @@ -15,6 +15,7 @@ """ Sigmoid(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py index c4138269c..05bd1b348 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_swish.py @@ -15,6 +15,7 @@ """ GEMM Specialization: SiLU(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py index bf3d7ef4a..6ef30e016 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_bias_tanh.py @@ -15,6 +15,7 @@ """ GEMM Specialization: Tanh(GEMM_RCR(A, B) + Bias) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr_bias # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py index 264d9df5b..d4b116b55 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rcr_fast_gelu.py @@ -15,6 +15,7 @@ """ GEMM Specialization: FastGELU(GEMM_RCR(A, B)) """ + from aitemplate.compiler.ops.gemm_universal import gemm_rcr # pylint: disable=C0103,W0223,W0221 diff --git a/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py b/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py index a8c052d87..4f7195d89 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py +++ b/python/aitemplate/compiler/ops/gemm_universal/gemm_rrr_bias.py @@ -15,6 +15,7 @@ """ gemm rrr with bias """ + from aitemplate.compiler.base import IntImm, Tensor from aitemplate.compiler.ops.gemm_universal import gemm_rrr from aitemplate.compiler.tensor_accessor import TensorAccessor diff --git a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr.py b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr.py index 39e6bdfe1..4b09452c0 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr.py +++ b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr.py @@ -15,6 +15,7 @@ """ Grouped GEMM Specialization for A[RowMajor], B[ColMajor], C[RowMajor] """ + import logging import re from collections import OrderedDict diff --git a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias.py b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias.py index 2cc5ced97..22e12d10c 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias.py +++ b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias.py @@ -15,6 +15,7 @@ """ Grouped GEMM Specialization: GEMM_RCR(A, B) + Bias """ + from collections import OrderedDict from typing import List diff --git a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_relu.py b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_relu.py index e1b62eb81..8e1ea3123 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_relu.py +++ b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_relu.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Grouped GEMM Specialization: ReLU(GEMM_RCR(A, B) + Bias) -""" +"""Grouped GEMM Specialization: ReLU(GEMM_RCR(A, B) + Bias)""" from aitemplate.compiler.ops.gemm_universal import group_gemm_rcr_bias diff --git a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_sigmoid.py b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_sigmoid.py index 5098e3285..2a480101e 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_sigmoid.py +++ b/python/aitemplate/compiler/ops/gemm_universal/group_gemm_rcr_bias_sigmoid.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Grouped GEMM Specialization: Sigmoid(GEMM_RCR(A, B) + Bias) -""" +"""Grouped GEMM Specialization: Sigmoid(GEMM_RCR(A, B) + Bias)""" from aitemplate.compiler.ops.gemm_universal import group_gemm_rcr_bias diff --git a/python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias_permute.py b/python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias_permute.py index 14ea8c5ce..abe763f0c 100644 --- a/python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias_permute.py +++ b/python/aitemplate/compiler/ops/gemm_universal/perm021fc_ccr_bias_permute.py @@ -15,6 +15,7 @@ """ GEMM Specialization: (A.permute(0, 2, 1)[col] @ B[col] + Bias).permute(0, 2, 1) """ + from aitemplate.compiler.base import Tensor from aitemplate.compiler.ops.common.view_ops import reshape from aitemplate.compiler.ops.gemm_universal.perm021fc_ccr_bias import perm021fc_ccr_bias diff --git a/python/aitemplate/compiler/ops/groupnorm/groupnorm.py b/python/aitemplate/compiler/ops/groupnorm/groupnorm.py index b78c67c89..5f87c7315 100644 --- a/python/aitemplate/compiler/ops/groupnorm/groupnorm.py +++ b/python/aitemplate/compiler/ops/groupnorm/groupnorm.py @@ -15,6 +15,7 @@ """ Operator definition for groupnorm. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_offsets.py b/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_offsets.py index bde349977..12c5b0331 100644 --- a/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_offsets.py +++ b/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_offsets.py @@ -15,6 +15,7 @@ """ Define jagged_lengths_to_offsets op """ + from typing import List from aitemplate.backend import registry diff --git a/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_presences.py b/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_presences.py index 7a372d8d2..f61f0bfb3 100644 --- a/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_presences.py +++ b/python/aitemplate/compiler/ops/jagged/jagged_lengths_to_presences.py @@ -15,6 +15,7 @@ """ Define jagged_lengths_to_presences op """ + from typing import List from aitemplate.backend import registry diff --git a/python/aitemplate/compiler/ops/layernorm/batch_layernorm_sigmoid_mul.py b/python/aitemplate/compiler/ops/layernorm/batch_layernorm_sigmoid_mul.py index 647f25485..ff4abf4d2 100644 --- a/python/aitemplate/compiler/ops/layernorm/batch_layernorm_sigmoid_mul.py +++ b/python/aitemplate/compiler/ops/layernorm/batch_layernorm_sigmoid_mul.py @@ -17,6 +17,7 @@ gamma: [b, n] beta: [b, n] """ + from typing import List from aitemplate.compiler.base import IntImm diff --git a/python/aitemplate/compiler/ops/layernorm/group_layernorm.py b/python/aitemplate/compiler/ops/layernorm/group_layernorm.py index 3dd310a64..3948f8cf1 100644 --- a/python/aitemplate/compiler/ops/layernorm/group_layernorm.py +++ b/python/aitemplate/compiler/ops/layernorm/group_layernorm.py @@ -15,6 +15,7 @@ """ Operator definition for group_layernorm. """ + from typing import Any, List from aitemplate.compiler.base import IntImm, IntVarTensor, Tensor diff --git a/python/aitemplate/compiler/ops/layernorm/group_layernorm_sigmoid_mul.py b/python/aitemplate/compiler/ops/layernorm/group_layernorm_sigmoid_mul.py index 1fb0b85e4..7d4dea4a9 100644 --- a/python/aitemplate/compiler/ops/layernorm/group_layernorm_sigmoid_mul.py +++ b/python/aitemplate/compiler/ops/layernorm/group_layernorm_sigmoid_mul.py @@ -15,6 +15,7 @@ """ Operator definition for group_layernorm_sigmoid_mul. """ + from typing import List from aitemplate.compiler.base import IntImm diff --git a/python/aitemplate/compiler/ops/layernorm/layernorm.py b/python/aitemplate/compiler/ops/layernorm/layernorm.py index 0188feec0..1a0226afc 100644 --- a/python/aitemplate/compiler/ops/layernorm/layernorm.py +++ b/python/aitemplate/compiler/ops/layernorm/layernorm.py @@ -15,6 +15,7 @@ """ Operator definition for layernorm. """ + import logging import os import re diff --git a/python/aitemplate/compiler/ops/layernorm/layernorm_sigmoid_mul.py b/python/aitemplate/compiler/ops/layernorm/layernorm_sigmoid_mul.py index 4d3b198ab..f1875fa91 100644 --- a/python/aitemplate/compiler/ops/layernorm/layernorm_sigmoid_mul.py +++ b/python/aitemplate/compiler/ops/layernorm/layernorm_sigmoid_mul.py @@ -15,6 +15,7 @@ """ Operator definition for layernorm_sigmoid_mul. """ + from aitemplate import backend from aitemplate.backend import registry from aitemplate.compiler.base import Operator diff --git a/python/aitemplate/compiler/ops/padding/__init__.py b/python/aitemplate/compiler/ops/padding/__init__.py index 6448b85a2..e6b6df5f6 100644 --- a/python/aitemplate/compiler/ops/padding/__init__.py +++ b/python/aitemplate/compiler/ops/padding/__init__.py @@ -15,6 +15,7 @@ """ Padding ops module init. """ + from aitemplate.compiler.ops.padding.ndhwc3to8 import ndhwc3to8 from aitemplate.compiler.ops.padding.nhwc3to4 import nhwc3to4 from aitemplate.compiler.ops.padding.nhwc3to8 import nhwc3to8 diff --git a/python/aitemplate/compiler/ops/padding/ndhwc3to8.py b/python/aitemplate/compiler/ops/padding/ndhwc3to8.py index 738d249f8..74f84f22d 100644 --- a/python/aitemplate/compiler/ops/padding/ndhwc3to8.py +++ b/python/aitemplate/compiler/ops/padding/ndhwc3to8.py @@ -15,6 +15,7 @@ """ Common NDHWC3to8 padding op """ + import itertools from typing import List diff --git a/python/aitemplate/compiler/ops/padding/nhwc_pad_common.py b/python/aitemplate/compiler/ops/padding/nhwc_pad_common.py index 96c5eb0be..f175ec7c1 100644 --- a/python/aitemplate/compiler/ops/padding/nhwc_pad_common.py +++ b/python/aitemplate/compiler/ops/padding/nhwc_pad_common.py @@ -15,6 +15,7 @@ """ Common NHWC padding ops """ + import itertools from typing import List diff --git a/python/aitemplate/compiler/ops/padding/pad_last_dim.py b/python/aitemplate/compiler/ops/padding/pad_last_dim.py index 6def61e73..5d10871d6 100644 --- a/python/aitemplate/compiler/ops/padding/pad_last_dim.py +++ b/python/aitemplate/compiler/ops/padding/pad_last_dim.py @@ -15,6 +15,7 @@ """ Pad last dimension. """ + from typing import List import jinja2 diff --git a/python/aitemplate/compiler/ops/pool/__init__.py b/python/aitemplate/compiler/ops/pool/__init__.py index 7cd9df61a..1248f4034 100644 --- a/python/aitemplate/compiler/ops/pool/__init__.py +++ b/python/aitemplate/compiler/ops/pool/__init__.py @@ -15,6 +15,7 @@ """ Pool module init. """ + from aitemplate.compiler.ops.pool.avg_pool2d import avg_pool2d from aitemplate.compiler.ops.pool.max_pool2d import max_pool2d diff --git a/python/aitemplate/compiler/ops/pool/avg_pool2d.py b/python/aitemplate/compiler/ops/pool/avg_pool2d.py index 094968e72..4343bbabe 100644 --- a/python/aitemplate/compiler/ops/pool/avg_pool2d.py +++ b/python/aitemplate/compiler/ops/pool/avg_pool2d.py @@ -15,6 +15,7 @@ """ Avg_pool2d op. """ + from aitemplate.compiler.ops.pool.pool2d import pool2d_base diff --git a/python/aitemplate/compiler/ops/pool/max_pool2d.py b/python/aitemplate/compiler/ops/pool/max_pool2d.py index f95144463..5efd035e5 100644 --- a/python/aitemplate/compiler/ops/pool/max_pool2d.py +++ b/python/aitemplate/compiler/ops/pool/max_pool2d.py @@ -15,6 +15,7 @@ """ Max_pool2d op. """ + from aitemplate.compiler.ops.pool.pool2d import pool2d_base diff --git a/python/aitemplate/compiler/ops/pool/pool2d.py b/python/aitemplate/compiler/ops/pool/pool2d.py index 37bbb9151..3e6b9cd34 100644 --- a/python/aitemplate/compiler/ops/pool/pool2d.py +++ b/python/aitemplate/compiler/ops/pool/pool2d.py @@ -15,6 +15,7 @@ """ Pool2d. """ + import itertools import logging import re diff --git a/python/aitemplate/compiler/ops/reduce/__init__.py b/python/aitemplate/compiler/ops/reduce/__init__.py index 335c329c2..bccbc4e4d 100644 --- a/python/aitemplate/compiler/ops/reduce/__init__.py +++ b/python/aitemplate/compiler/ops/reduce/__init__.py @@ -15,6 +15,7 @@ """ Reduce module init. """ + from aitemplate.compiler.ops.reduce.reduce_max import reduce_max from aitemplate.compiler.ops.reduce.reduce_mean import reduce_mean from aitemplate.compiler.ops.reduce.reduce_min import reduce_min diff --git a/python/aitemplate/compiler/ops/reduce/reduce_common.py b/python/aitemplate/compiler/ops/reduce/reduce_common.py index 6f0e12831..56cce4e0c 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_common.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_common.py @@ -15,6 +15,7 @@ """ Base operator definition for reduce-family ops. """ + import itertools import logging diff --git a/python/aitemplate/compiler/ops/reduce/reduce_max.py b/python/aitemplate/compiler/ops/reduce/reduce_max.py index 3a58ae3e0..f1bfa706e 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_max.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_max.py @@ -15,6 +15,7 @@ """ reduce_max op """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/reduce/reduce_mean.py b/python/aitemplate/compiler/ops/reduce/reduce_mean.py index 37952f488..d69300490 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_mean.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_mean.py @@ -15,6 +15,7 @@ """ Reduce_mean op implementation. """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/reduce/reduce_min.py b/python/aitemplate/compiler/ops/reduce/reduce_min.py index f1ef6c2e6..1f91084de 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_min.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_min.py @@ -15,6 +15,7 @@ """ reduce_min op """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/reduce/reduce_sum.py b/python/aitemplate/compiler/ops/reduce/reduce_sum.py index 08d538e96..57dc16d3c 100644 --- a/python/aitemplate/compiler/ops/reduce/reduce_sum.py +++ b/python/aitemplate/compiler/ops/reduce/reduce_sum.py @@ -15,6 +15,7 @@ """ reduce_sum op """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/reduce/var.py b/python/aitemplate/compiler/ops/reduce/var.py index 91117404d..23a49d083 100644 --- a/python/aitemplate/compiler/ops/reduce/var.py +++ b/python/aitemplate/compiler/ops/reduce/var.py @@ -15,6 +15,7 @@ """ var op implementation """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/reduce/vector_norm.py b/python/aitemplate/compiler/ops/reduce/vector_norm.py index 38ea5c367..d6c3c28b7 100644 --- a/python/aitemplate/compiler/ops/reduce/vector_norm.py +++ b/python/aitemplate/compiler/ops/reduce/vector_norm.py @@ -16,6 +16,7 @@ vector_norm op implementation that simulates pytorch's linalg.vector_norm. Currently, we only support L2 norm. """ + from aitemplate.compiler.ops.reduce.reduce_common import reduce_base # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/softmax/__init__.py b/python/aitemplate/compiler/ops/softmax/__init__.py index 61a8ad8ba..c7238b32c 100644 --- a/python/aitemplate/compiler/ops/softmax/__init__.py +++ b/python/aitemplate/compiler/ops/softmax/__init__.py @@ -15,6 +15,7 @@ """ softmax module init """ + from aitemplate.compiler.ops.softmax.softmax import softmax diff --git a/python/aitemplate/compiler/ops/softmax/cache_entry.py b/python/aitemplate/compiler/ops/softmax/cache_entry.py index f2d7dee93..0f8c28d44 100644 --- a/python/aitemplate/compiler/ops/softmax/cache_entry.py +++ b/python/aitemplate/compiler/ops/softmax/cache_entry.py @@ -15,6 +15,7 @@ """ Softmax cache entry. """ + from dataclasses import dataclass # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/softmax/softmax.py b/python/aitemplate/compiler/ops/softmax/softmax.py index 52e1524ee..5cb7b841b 100644 --- a/python/aitemplate/compiler/ops/softmax/softmax.py +++ b/python/aitemplate/compiler/ops/softmax/softmax.py @@ -15,6 +15,7 @@ """ Softmax op implementation """ + import logging import os import re diff --git a/python/aitemplate/compiler/ops/tensor/__init__.py b/python/aitemplate/compiler/ops/tensor/__init__.py index ee031c330..50cee8a86 100644 --- a/python/aitemplate/compiler/ops/tensor/__init__.py +++ b/python/aitemplate/compiler/ops/tensor/__init__.py @@ -16,6 +16,7 @@ """ reduce module init """ + from aitemplate.compiler.ops.tensor.argmax import argmax from aitemplate.compiler.ops.tensor.batch_gather import batch_gather from aitemplate.compiler.ops.tensor.cast import cast diff --git a/python/aitemplate/compiler/ops/tensor/argmax.py b/python/aitemplate/compiler/ops/tensor/argmax.py index e72e79a0c..890bd9c3f 100644 --- a/python/aitemplate/compiler/ops/tensor/argmax.py +++ b/python/aitemplate/compiler/ops/tensor/argmax.py @@ -15,6 +15,7 @@ """ Argmax. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/tensor/batch_gather.py b/python/aitemplate/compiler/ops/tensor/batch_gather.py index f6586affc..a793aaec7 100644 --- a/python/aitemplate/compiler/ops/tensor/batch_gather.py +++ b/python/aitemplate/compiler/ops/tensor/batch_gather.py @@ -15,6 +15,7 @@ """ Batch_gather. """ + import itertools from collections import OrderedDict from typing import List @@ -73,11 +74,14 @@ def _infer_shapes(self, x: Tensor, indices: Tensor) -> List[IntVar]: def __call__(self, x: Tensor, indices: Tensor) -> Tensor: dtype = indices._attrs["dtype"] - assert dtype in [ - "int", - "int32", - "int64", - ], f"batch_gather(): Expected dtype int/int32/int64 for index, got dtype {dtype}" + assert ( + dtype + in [ + "int", + "int32", + "int64", + ] + ), f"batch_gather(): Expected dtype int/int32/int64 for index, got dtype {dtype}" self._attrs["inputs"] = [x, indices] self._set_depth() self._extract_exec_path(x) diff --git a/python/aitemplate/compiler/ops/tensor/chunk.py b/python/aitemplate/compiler/ops/tensor/chunk.py index 786a62177..9cf038ee9 100644 --- a/python/aitemplate/compiler/ops/tensor/chunk.py +++ b/python/aitemplate/compiler/ops/tensor/chunk.py @@ -15,6 +15,7 @@ """ chunk """ + import math from typing import List diff --git a/python/aitemplate/compiler/ops/tensor/concatenate.py b/python/aitemplate/compiler/ops/tensor/concatenate.py index 27bda6492..9452772bf 100644 --- a/python/aitemplate/compiler/ops/tensor/concatenate.py +++ b/python/aitemplate/compiler/ops/tensor/concatenate.py @@ -15,6 +15,7 @@ """ Concatenate. """ + from copy import deepcopy from functools import reduce from typing import List, Optional, Sequence, Tuple, Union diff --git a/python/aitemplate/compiler/ops/tensor/concatenate_tanh.py b/python/aitemplate/compiler/ops/tensor/concatenate_tanh.py index 7ecaf757f..4a0166905 100644 --- a/python/aitemplate/compiler/ops/tensor/concatenate_tanh.py +++ b/python/aitemplate/compiler/ops/tensor/concatenate_tanh.py @@ -15,6 +15,7 @@ """ Concatenate_tanh """ + from aitemplate.compiler.ops.tensor import concatenate # pylint: disable=C0103 diff --git a/python/aitemplate/compiler/ops/tensor/dynamic_slice.py b/python/aitemplate/compiler/ops/tensor/dynamic_slice.py index 4055a2c5d..a12fb04af 100644 --- a/python/aitemplate/compiler/ops/tensor/dynamic_slice.py +++ b/python/aitemplate/compiler/ops/tensor/dynamic_slice.py @@ -15,6 +15,7 @@ """ Dynamic_slice. """ + from typing import List, Optional, Union import sympy diff --git a/python/aitemplate/compiler/ops/tensor/gather.py b/python/aitemplate/compiler/ops/tensor/gather.py index 4a2b4d131..ee268ca17 100644 --- a/python/aitemplate/compiler/ops/tensor/gather.py +++ b/python/aitemplate/compiler/ops/tensor/gather.py @@ -15,6 +15,7 @@ """ Operator definition for gather. """ + from aitemplate import backend from aitemplate.backend import registry from aitemplate.compiler.base import Operator, Tensor diff --git a/python/aitemplate/compiler/ops/tensor/identity.py b/python/aitemplate/compiler/ops/tensor/identity.py index 262a32523..da362874c 100644 --- a/python/aitemplate/compiler/ops/tensor/identity.py +++ b/python/aitemplate/compiler/ops/tensor/identity.py @@ -15,6 +15,7 @@ """ identity op """ + from typing import List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/jagged_to_padded_dense.py b/python/aitemplate/compiler/ops/tensor/jagged_to_padded_dense.py index a9dbf4aa9..5c9961391 100644 --- a/python/aitemplate/compiler/ops/tensor/jagged_to_padded_dense.py +++ b/python/aitemplate/compiler/ops/tensor/jagged_to_padded_dense.py @@ -16,6 +16,7 @@ """ Define jagged_to_padded_dense op """ + import logging from typing import List diff --git a/python/aitemplate/compiler/ops/tensor/masked_select.py b/python/aitemplate/compiler/ops/tensor/masked_select.py index 4eac1764b..015699dbc 100644 --- a/python/aitemplate/compiler/ops/tensor/masked_select.py +++ b/python/aitemplate/compiler/ops/tensor/masked_select.py @@ -15,6 +15,7 @@ """ Define masked_select op """ + import logging from typing import List diff --git a/python/aitemplate/compiler/ops/tensor/padded_dense_to_jagged.py b/python/aitemplate/compiler/ops/tensor/padded_dense_to_jagged.py index 14143d698..3af350069 100644 --- a/python/aitemplate/compiler/ops/tensor/padded_dense_to_jagged.py +++ b/python/aitemplate/compiler/ops/tensor/padded_dense_to_jagged.py @@ -16,6 +16,7 @@ """ The front-end definition of the padded_dense_to_jagged op. """ + from typing import List from aitemplate.backend import registry diff --git a/python/aitemplate/compiler/ops/tensor/permute.py b/python/aitemplate/compiler/ops/tensor/permute.py index c4a9049d6..51faf00d3 100644 --- a/python/aitemplate/compiler/ops/tensor/permute.py +++ b/python/aitemplate/compiler/ops/tensor/permute.py @@ -15,6 +15,7 @@ """ permute op """ + from typing import List, Sequence from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/permute021.py b/python/aitemplate/compiler/ops/tensor/permute021.py index d775db8bc..47c89a15d 100644 --- a/python/aitemplate/compiler/ops/tensor/permute021.py +++ b/python/aitemplate/compiler/ops/tensor/permute021.py @@ -15,6 +15,7 @@ """ permute(0, 2, 1) op """ + from typing import List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/permute0213.py b/python/aitemplate/compiler/ops/tensor/permute0213.py index 3616bdf73..908f7894f 100644 --- a/python/aitemplate/compiler/ops/tensor/permute0213.py +++ b/python/aitemplate/compiler/ops/tensor/permute0213.py @@ -16,6 +16,7 @@ Permute(0, 2, 1, 3) op. Change the dimensions dim1 and dim2 of input 4d tensor. """ + from typing import List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/permute102.py b/python/aitemplate/compiler/ops/tensor/permute102.py index f6d0af738..c163a952b 100644 --- a/python/aitemplate/compiler/ops/tensor/permute102.py +++ b/python/aitemplate/compiler/ops/tensor/permute102.py @@ -16,6 +16,7 @@ Permute(1, 0, 2) op. Change the dimension of dim0 and dim1 of input 3d tensor. """ + from typing import List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/permute210.py b/python/aitemplate/compiler/ops/tensor/permute210.py index d177073ef..6a7d7f486 100644 --- a/python/aitemplate/compiler/ops/tensor/permute210.py +++ b/python/aitemplate/compiler/ops/tensor/permute210.py @@ -16,6 +16,7 @@ Permute(2, 1, 0) op. Swap the dimension of dim0 and dim2 of input 3d tensor. """ + from typing import List from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/size.py b/python/aitemplate/compiler/ops/tensor/size.py index 9607db1c5..e8e0eaf5f 100644 --- a/python/aitemplate/compiler/ops/tensor/size.py +++ b/python/aitemplate/compiler/ops/tensor/size.py @@ -15,6 +15,7 @@ """ Op to return the size of a tensor. """ + from typing import List, Union from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/slice_reshape_scatter.py b/python/aitemplate/compiler/ops/tensor/slice_reshape_scatter.py index a109084d9..afd083ef4 100644 --- a/python/aitemplate/compiler/ops/tensor/slice_reshape_scatter.py +++ b/python/aitemplate/compiler/ops/tensor/slice_reshape_scatter.py @@ -15,6 +15,7 @@ """ Slice_reshape_scatter. """ + from typing import Optional from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/split.py b/python/aitemplate/compiler/ops/tensor/split.py index 4e75a20c2..797860477 100644 --- a/python/aitemplate/compiler/ops/tensor/split.py +++ b/python/aitemplate/compiler/ops/tensor/split.py @@ -15,6 +15,7 @@ """ Split. """ + from typing import List, Sequence, Union from aitemplate import backend diff --git a/python/aitemplate/compiler/ops/tensor/topk.py b/python/aitemplate/compiler/ops/tensor/topk.py index fb058751a..958caad28 100644 --- a/python/aitemplate/compiler/ops/tensor/topk.py +++ b/python/aitemplate/compiler/ops/tensor/topk.py @@ -15,6 +15,7 @@ """ Topk. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/tensor/where.py b/python/aitemplate/compiler/ops/tensor/where.py index 4be14790c..844a01ac2 100644 --- a/python/aitemplate/compiler/ops/tensor/where.py +++ b/python/aitemplate/compiler/ops/tensor/where.py @@ -71,8 +71,8 @@ def __call__( if common_dtype is None: common_dtype = normalize_dtype(tensor.dtype()) else: - assert common_dtype == normalize_dtype( - tensor.dtype() + assert ( + common_dtype == normalize_dtype(tensor.dtype()) ), f"Expect tensor of the same dtype, got {common_dtype} and {normalize_dtype(tensor.dtype())}" inputs.append(tensor) diff --git a/python/aitemplate/compiler/ops/upsample/__init__.py b/python/aitemplate/compiler/ops/upsample/__init__.py index 6af4b174d..525c21710 100644 --- a/python/aitemplate/compiler/ops/upsample/__init__.py +++ b/python/aitemplate/compiler/ops/upsample/__init__.py @@ -15,6 +15,7 @@ """ Upsampling module init. """ + from aitemplate.compiler.ops.upsample.upsampling2d import upsampling2d from aitemplate.compiler.ops.upsample.upsampling2d_add import upsampling2d_add diff --git a/python/aitemplate/compiler/ops/upsample/upsampling2d.py b/python/aitemplate/compiler/ops/upsample/upsampling2d.py index e53d4aed0..c61d2f240 100644 --- a/python/aitemplate/compiler/ops/upsample/upsampling2d.py +++ b/python/aitemplate/compiler/ops/upsample/upsampling2d.py @@ -15,6 +15,7 @@ """ Upsampling2d op. """ + from aitemplate.compiler.ops.upsample.upsampling_common import upsampling2d_base diff --git a/python/aitemplate/compiler/ops/upsample/upsampling2d_add.py b/python/aitemplate/compiler/ops/upsample/upsampling2d_add.py index e63c2c560..b4138562b 100644 --- a/python/aitemplate/compiler/ops/upsample/upsampling2d_add.py +++ b/python/aitemplate/compiler/ops/upsample/upsampling2d_add.py @@ -15,6 +15,7 @@ """ Upsampling2d_add op. """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/ops/upsample/upsampling_common.py b/python/aitemplate/compiler/ops/upsample/upsampling_common.py index edf87bba0..add7a81f2 100644 --- a/python/aitemplate/compiler/ops/upsample/upsampling_common.py +++ b/python/aitemplate/compiler/ops/upsample/upsampling_common.py @@ -15,6 +15,7 @@ """ upsampling2d. """ + import itertools import logging import re diff --git a/python/aitemplate/compiler/ops/vision_ops/__init__.py b/python/aitemplate/compiler/ops/vision_ops/__init__.py index 9f32e81ed..059b9a583 100644 --- a/python/aitemplate/compiler/ops/vision_ops/__init__.py +++ b/python/aitemplate/compiler/ops/vision_ops/__init__.py @@ -15,5 +15,6 @@ """ Vision ops module init. """ + from aitemplate.compiler.ops.vision_ops.nms import * # noqa from aitemplate.compiler.ops.vision_ops.roi_ops import * # noqa diff --git a/python/aitemplate/compiler/ops/vision_ops/nms/__init__.py b/python/aitemplate/compiler/ops/vision_ops/nms/__init__.py index eea6045e6..49c186ca6 100644 --- a/python/aitemplate/compiler/ops/vision_ops/nms/__init__.py +++ b/python/aitemplate/compiler/ops/vision_ops/nms/__init__.py @@ -15,6 +15,7 @@ """ Nms family ops. """ + from aitemplate.compiler.ops.vision_ops.nms.batched_nms import batched_nms from aitemplate.compiler.ops.vision_ops.nms.efficient_nms import efficient_nms from aitemplate.compiler.ops.vision_ops.nms.nms import nms diff --git a/python/aitemplate/compiler/ops/vision_ops/nms/batched_nms.py b/python/aitemplate/compiler/ops/vision_ops/nms/batched_nms.py index 686beea1d..6f5411091 100644 --- a/python/aitemplate/compiler/ops/vision_ops/nms/batched_nms.py +++ b/python/aitemplate/compiler/ops/vision_ops/nms/batched_nms.py @@ -15,6 +15,7 @@ """ Batched nms. """ + import itertools from typing import List diff --git a/python/aitemplate/compiler/ops/vision_ops/nms/efficient_nms.py b/python/aitemplate/compiler/ops/vision_ops/nms/efficient_nms.py index d39872b7c..584ef7269 100644 --- a/python/aitemplate/compiler/ops/vision_ops/nms/efficient_nms.py +++ b/python/aitemplate/compiler/ops/vision_ops/nms/efficient_nms.py @@ -15,6 +15,7 @@ """ Efficient nms. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/vision_ops/nms/nms.py b/python/aitemplate/compiler/ops/vision_ops/nms/nms.py index beee90245..db7c9d43c 100644 --- a/python/aitemplate/compiler/ops/vision_ops/nms/nms.py +++ b/python/aitemplate/compiler/ops/vision_ops/nms/nms.py @@ -15,6 +15,7 @@ """ Nms. """ + import itertools import logging import os diff --git a/python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py b/python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py index 19edd785a..c8611e7a4 100644 --- a/python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py +++ b/python/aitemplate/compiler/ops/vision_ops/roi_ops/__init__.py @@ -15,6 +15,7 @@ """ Roi-align module init. """ + from aitemplate.compiler.ops.vision_ops.roi_ops.multi_level_roi_align import ( multi_level_roi_align, ) diff --git a/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_align.py b/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_align.py index cdcb0fa80..27f43f516 100644 --- a/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_align.py +++ b/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_align.py @@ -15,6 +15,7 @@ """ Roi_align. """ + from aitemplate.compiler.ops.vision_ops.roi_ops.roi_ops import roi_ops_base diff --git a/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_ops.py b/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_ops.py index 39b2ab046..ad67af3dc 100644 --- a/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_ops.py +++ b/python/aitemplate/compiler/ops/vision_ops/roi_ops/roi_ops.py @@ -15,6 +15,7 @@ """ Roi. """ + import itertools import logging import re diff --git a/python/aitemplate/compiler/stable_set.py b/python/aitemplate/compiler/stable_set.py index 82f945078..82ab5684f 100644 --- a/python/aitemplate/compiler/stable_set.py +++ b/python/aitemplate/compiler/stable_set.py @@ -19,6 +19,7 @@ potentially make debugging (e.g. comparison with the original graph, comparison between AIT GPU trace and other GPU traces) easier. """ + from collections import abc from typing import Any, Iterable diff --git a/python/aitemplate/compiler/symbolic.py b/python/aitemplate/compiler/symbolic.py index 2edcfb195..7b3a81d81 100644 --- a/python/aitemplate/compiler/symbolic.py +++ b/python/aitemplate/compiler/symbolic.py @@ -35,6 +35,7 @@ For more advanced usage on Sympy, check: https://docs.sympy.org/latest/tutorials/intro-tutorial/intro.html """ + from __future__ import annotations import itertools diff --git a/python/aitemplate/compiler/transform/apply_padding.py b/python/aitemplate/compiler/transform/apply_padding.py index 67a7343bc..24ec1f6e5 100644 --- a/python/aitemplate/compiler/transform/apply_padding.py +++ b/python/aitemplate/compiler/transform/apply_padding.py @@ -15,6 +15,7 @@ """ Applies paddings to gemms based on alignment requirements. """ + import logging from typing import Callable, Dict, List diff --git a/python/aitemplate/compiler/transform/dedup_make_jagged_ops.py b/python/aitemplate/compiler/transform/dedup_make_jagged_ops.py index 4ecea4e8a..4fcaa2ef5 100644 --- a/python/aitemplate/compiler/transform/dedup_make_jagged_ops.py +++ b/python/aitemplate/compiler/transform/dedup_make_jagged_ops.py @@ -15,6 +15,7 @@ """ Deduplicate make_jagged ops in the graph. """ + import logging from dataclasses import dataclass from typing import Dict, List, Set diff --git a/python/aitemplate/compiler/transform/fuse_bmm_permute.py b/python/aitemplate/compiler/transform/fuse_bmm_permute.py index a563a0cab..18018b819 100644 --- a/python/aitemplate/compiler/transform/fuse_bmm_permute.py +++ b/python/aitemplate/compiler/transform/fuse_bmm_permute.py @@ -17,6 +17,7 @@ bmm_xxc + permute021 -> bmm_xxr bmm_xxr + permute021 -> bmm_xxc """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/transform/fuse_conv_elementwise.py b/python/aitemplate/compiler/transform/fuse_conv_elementwise.py index 9c192bc92..58c7c3645 100644 --- a/python/aitemplate/compiler/transform/fuse_conv_elementwise.py +++ b/python/aitemplate/compiler/transform/fuse_conv_elementwise.py @@ -15,6 +15,7 @@ """ Fuse conv + elementwise ops. """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/transform/fuse_expand_bmm.py b/python/aitemplate/compiler/transform/fuse_expand_bmm.py index cb7ea8037..29d0c20de 100644 --- a/python/aitemplate/compiler/transform/fuse_expand_bmm.py +++ b/python/aitemplate/compiler/transform/fuse_expand_bmm.py @@ -23,6 +23,7 @@ The basic idea behind the transformation is that we leverage bmm's broadcasting capability to achieve the same functionality as expand. """ + from typing import List from aitemplate.compiler.base import Operator, Tensor diff --git a/python/aitemplate/compiler/transform/fuse_group_ops.py b/python/aitemplate/compiler/transform/fuse_group_ops.py index 63c45fb99..ca7ce4e54 100644 --- a/python/aitemplate/compiler/transform/fuse_group_ops.py +++ b/python/aitemplate/compiler/transform/fuse_group_ops.py @@ -15,6 +15,7 @@ """ Horizontal fusion pass to group ops together. """ + import collections import logging import os diff --git a/python/aitemplate/compiler/transform/fuse_mm_elementwise.py b/python/aitemplate/compiler/transform/fuse_mm_elementwise.py index cbe30d300..aac18214c 100644 --- a/python/aitemplate/compiler/transform/fuse_mm_elementwise.py +++ b/python/aitemplate/compiler/transform/fuse_mm_elementwise.py @@ -15,6 +15,7 @@ """ Fuse GEMM with elementwise operations """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/transform/fuse_mm_reshape_permute.py b/python/aitemplate/compiler/transform/fuse_mm_reshape_permute.py index d8e9370c5..27f6e54e8 100644 --- a/python/aitemplate/compiler/transform/fuse_mm_reshape_permute.py +++ b/python/aitemplate/compiler/transform/fuse_mm_reshape_permute.py @@ -15,6 +15,7 @@ """ Fuse GEMM + reshape + permute0213 """ + from typing import List, Sequence from aitemplate.compiler.base import IntImm, Operator, Tensor diff --git a/python/aitemplate/compiler/transform/fuse_ops.py b/python/aitemplate/compiler/transform/fuse_ops.py index 733c13a22..e56a15d80 100644 --- a/python/aitemplate/compiler/transform/fuse_ops.py +++ b/python/aitemplate/compiler/transform/fuse_ops.py @@ -15,6 +15,7 @@ """ Perform operator fusions. """ + import collections import itertools import logging diff --git a/python/aitemplate/compiler/transform/fuse_permute_bmm_and_gemm.py b/python/aitemplate/compiler/transform/fuse_permute_bmm_and_gemm.py index b883e3215..860d7bff0 100644 --- a/python/aitemplate/compiler/transform/fuse_permute_bmm_and_gemm.py +++ b/python/aitemplate/compiler/transform/fuse_permute_bmm_and_gemm.py @@ -15,6 +15,7 @@ """ Perform fusions for permute+bmm operators. """ + from typing import Callable, List, Optional, Set, Tuple, Type, Union from aitemplate.compiler import ops diff --git a/python/aitemplate/compiler/transform/fuse_split.py b/python/aitemplate/compiler/transform/fuse_split.py index 91aeac2d8..03d059ef2 100644 --- a/python/aitemplate/compiler/transform/fuse_split.py +++ b/python/aitemplate/compiler/transform/fuse_split.py @@ -15,6 +15,7 @@ """ Perform transformations on ops which support strided inputs / outputs. """ + import logging from typing import List diff --git a/python/aitemplate/compiler/transform/mark_param_tensor.py b/python/aitemplate/compiler/transform/mark_param_tensor.py index 1a6b4dc5c..a0a1acfe7 100644 --- a/python/aitemplate/compiler/transform/mark_param_tensor.py +++ b/python/aitemplate/compiler/transform/mark_param_tensor.py @@ -15,6 +15,7 @@ """ mark tensors which are parameters """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/transform/memory_planning.py b/python/aitemplate/compiler/transform/memory_planning.py index bbb3d5bed..0b73784b4 100644 --- a/python/aitemplate/compiler/transform/memory_planning.py +++ b/python/aitemplate/compiler/transform/memory_planning.py @@ -15,6 +15,7 @@ """ Graph pass for memory planning. """ + import bisect import logging from collections import defaultdict diff --git a/python/aitemplate/compiler/transform/move_view_ops.py b/python/aitemplate/compiler/transform/move_view_ops.py index e49e1bc07..dca6d2e83 100644 --- a/python/aitemplate/compiler/transform/move_view_ops.py +++ b/python/aitemplate/compiler/transform/move_view_ops.py @@ -16,6 +16,7 @@ This pass move any view op between two concatenate ops to the front of the first concatenate op if possible. """ + import copy from typing import Callable, List, Optional, Tuple diff --git a/python/aitemplate/compiler/transform/name_graph.py b/python/aitemplate/compiler/transform/name_graph.py index b2a13b359..3bf4a904f 100644 --- a/python/aitemplate/compiler/transform/name_graph.py +++ b/python/aitemplate/compiler/transform/name_graph.py @@ -15,6 +15,7 @@ """ Graph pass to assign names to a sorted graph. """ + import logging import re from typing import List diff --git a/python/aitemplate/compiler/transform/optimize_graph.py b/python/aitemplate/compiler/transform/optimize_graph.py index 4e73cc3d6..2ffbeeaee 100644 --- a/python/aitemplate/compiler/transform/optimize_graph.py +++ b/python/aitemplate/compiler/transform/optimize_graph.py @@ -15,6 +15,7 @@ """ Applies graph transformations. """ + from typing import List from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/compiler/transform/profile.py b/python/aitemplate/compiler/transform/profile.py index 698250f55..285ab231b 100644 --- a/python/aitemplate/compiler/transform/profile.py +++ b/python/aitemplate/compiler/transform/profile.py @@ -15,6 +15,7 @@ """ Graph pass to invoke profiling. """ + import logging import os from collections import OrderedDict diff --git a/python/aitemplate/compiler/transform/profile_dynamic_dim.py b/python/aitemplate/compiler/transform/profile_dynamic_dim.py index ecef60721..b7c4575b6 100644 --- a/python/aitemplate/compiler/transform/profile_dynamic_dim.py +++ b/python/aitemplate/compiler/transform/profile_dynamic_dim.py @@ -15,6 +15,7 @@ """ Graph pass to invoke profiling with dynamic shapes. """ + import logging from collections import OrderedDict from copy import deepcopy diff --git a/python/aitemplate/compiler/transform/refine_graph.py b/python/aitemplate/compiler/transform/refine_graph.py index c270ee94d..ed91d3344 100644 --- a/python/aitemplate/compiler/transform/refine_graph.py +++ b/python/aitemplate/compiler/transform/refine_graph.py @@ -15,6 +15,7 @@ """ Graph pass to dedup operators with same signatures. """ + import logging from typing import List diff --git a/python/aitemplate/compiler/transform/remove_elementwise_no_ops.py b/python/aitemplate/compiler/transform/remove_elementwise_no_ops.py index 334611d2a..d9454a7a5 100644 --- a/python/aitemplate/compiler/transform/remove_elementwise_no_ops.py +++ b/python/aitemplate/compiler/transform/remove_elementwise_no_ops.py @@ -15,6 +15,7 @@ """ Eliminate elementwise no-ops (*/1, +-0) """ + from typing import Callable, Dict, List from aitemplate.compiler.base import Tensor @@ -67,7 +68,6 @@ def remove_elementwise_no_ops( ) -> List[Tensor]: """elementwise no-ops (*/1, +-0)""" for tensor in sorted_graph: - src_ops = tensor._attrs["src_ops"] if len(src_ops) != 1: continue diff --git a/python/aitemplate/compiler/transform/remove_no_ops.py b/python/aitemplate/compiler/transform/remove_no_ops.py index b68c19ad3..e7991f381 100644 --- a/python/aitemplate/compiler/transform/remove_no_ops.py +++ b/python/aitemplate/compiler/transform/remove_no_ops.py @@ -29,6 +29,7 @@ many other unrelated passes use sanitize_sorted_graph. We don't need to call the passes in this file more than once. """ + from typing import List from aitemplate.compiler.base import IntImm, IntVar, JaggedIntVar, Operator, Tensor diff --git a/python/aitemplate/compiler/transform/remove_unused_ops.py b/python/aitemplate/compiler/transform/remove_unused_ops.py index 26675a0b7..1cb77d1b0 100644 --- a/python/aitemplate/compiler/transform/remove_unused_ops.py +++ b/python/aitemplate/compiler/transform/remove_unused_ops.py @@ -15,6 +15,7 @@ """ Remove useless operators from a sorted_graph. """ + from collections import deque from typing import List diff --git a/python/aitemplate/compiler/transform/split_large_concat_ops.py b/python/aitemplate/compiler/transform/split_large_concat_ops.py index c4c4b55c8..2caa3b519 100644 --- a/python/aitemplate/compiler/transform/split_large_concat_ops.py +++ b/python/aitemplate/compiler/transform/split_large_concat_ops.py @@ -17,6 +17,7 @@ concat ops, which share the same inputs with correct input_masks and the same output. """ + import copy import logging diff --git a/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py b/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py index f3e1761f5..14bc2dfba 100644 --- a/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py +++ b/python/aitemplate/compiler/transform/split_large_slice_scatter_ops.py @@ -16,6 +16,7 @@ This transformation splits a slice_scatter or slice_reshape_scatter with a large number of inputs into multiple slice_scatter or slice_reshape_scatter ops. """ + import copy import logging diff --git a/python/aitemplate/compiler/transform/split_large_split_ops.py b/python/aitemplate/compiler/transform/split_large_split_ops.py index beab11ca6..5ac3b53ee 100644 --- a/python/aitemplate/compiler/transform/split_large_split_ops.py +++ b/python/aitemplate/compiler/transform/split_large_split_ops.py @@ -16,6 +16,7 @@ This transformation splits a split with a large number of outputs into multiple splitt ops, which share the same input with correct output_masks. """ + import logging from typing import List diff --git a/python/aitemplate/compiler/transform/toposort.py b/python/aitemplate/compiler/transform/toposort.py index 8fb6411e8..729534792 100644 --- a/python/aitemplate/compiler/transform/toposort.py +++ b/python/aitemplate/compiler/transform/toposort.py @@ -15,6 +15,7 @@ """ Graph pass for topological sort. """ + import heapq from typing import List, Tuple, Union diff --git a/python/aitemplate/compiler/transform/transform_memory_ops.py b/python/aitemplate/compiler/transform/transform_memory_ops.py index 0a42f8d36..9e74c0e58 100644 --- a/python/aitemplate/compiler/transform/transform_memory_ops.py +++ b/python/aitemplate/compiler/transform/transform_memory_ops.py @@ -15,6 +15,7 @@ """ Perform memory operator related transformations. """ + import copy from typing import List diff --git a/python/aitemplate/compiler/transform/transform_merge_slice_ops.py b/python/aitemplate/compiler/transform/transform_merge_slice_ops.py index ec6dbf004..fae7a3c28 100644 --- a/python/aitemplate/compiler/transform/transform_merge_slice_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_slice_ops.py @@ -15,6 +15,7 @@ """ This file implements a pass that merges consecutive slice ops if possible. """ + from typing import List, Optional from aitemplate.compiler.base import IntImm, IntVar, Operator, Tensor diff --git a/python/aitemplate/compiler/transform/transform_merge_view_ops.py b/python/aitemplate/compiler/transform/transform_merge_view_ops.py index a863df642..704f0ec13 100644 --- a/python/aitemplate/compiler/transform/transform_merge_view_ops.py +++ b/python/aitemplate/compiler/transform/transform_merge_view_ops.py @@ -15,6 +15,7 @@ """ This file implements a pass that merges consecutive view ops if possible. """ + from typing import List, Set from aitemplate.compiler import ops diff --git a/python/aitemplate/compiler/transform/transform_odd_alignment.py b/python/aitemplate/compiler/transform/transform_odd_alignment.py index 77f4de59a..b3213cf34 100644 --- a/python/aitemplate/compiler/transform/transform_odd_alignment.py +++ b/python/aitemplate/compiler/transform/transform_odd_alignment.py @@ -15,6 +15,7 @@ """ Add permute for gemm/bmm if alignment is odd. """ + from math import inf from typing import Dict, List, Tuple diff --git a/python/aitemplate/compiler/transform/transform_permute_to_reshape.py b/python/aitemplate/compiler/transform/transform_permute_to_reshape.py index 1aa21242e..1825f174b 100644 --- a/python/aitemplate/compiler/transform/transform_permute_to_reshape.py +++ b/python/aitemplate/compiler/transform/transform_permute_to_reshape.py @@ -15,6 +15,7 @@ """ Transform permute to reshape wherever applicable. """ + from typing import List from aitemplate.compiler.base import IntImm, Operator, Tensor diff --git a/python/aitemplate/compiler/transform/transform_special_ops.py b/python/aitemplate/compiler/transform/transform_special_ops.py index 84e55bf73..1ec5f2937 100644 --- a/python/aitemplate/compiler/transform/transform_special_ops.py +++ b/python/aitemplate/compiler/transform/transform_special_ops.py @@ -16,6 +16,7 @@ Perform graph transformation specifically for gemm -> gemm_special. Check each transform function summary for specific pattern to be transformed. """ + from typing import Callable, List, Tuple, Type, Union from aitemplate.compiler import ops diff --git a/python/aitemplate/compiler/transform/transform_strided_op_and_view_op.py b/python/aitemplate/compiler/transform/transform_strided_op_and_view_op.py index ec6f533b0..d8db73e88 100644 --- a/python/aitemplate/compiler/transform/transform_strided_op_and_view_op.py +++ b/python/aitemplate/compiler/transform/transform_strided_op_and_view_op.py @@ -88,9 +88,7 @@ def _fuse_strided_op_and_view_op_single_pass( tensor._attrs["src_ops"] = StableSet({src_op}) transform_utils.remove_tensor_from_sorted_graph(view_input_tensor) break - assert ( - found_tensor - ), f"Cannot find view_input_tensor {view_input_tensor} from src_op outputs {src_op._attrs['outputs']}!" + assert found_tensor, f"Cannot find view_input_tensor {view_input_tensor} from src_op outputs {src_op._attrs['outputs']}!" else: if tensor._attrs["is_output"]: continue @@ -115,9 +113,7 @@ def _fuse_strided_op_and_view_op_single_pass( accessor.update_base_tensor_shape(view_input_tensor) dst_op._attrs["inputs"][idx] = view_input_tensor view_input_tensor._attrs["dst_ops"].add(dst_op) - assert ( - found_tensor - ), f"Cannot find tensor {tensor} from dst_op inputs {dst_op._attrs['inputs']}!" + assert found_tensor, f"Cannot find tensor {tensor} from dst_op inputs {dst_op._attrs['inputs']}!" to_be_removed_dst_ops.add(dst_op) tensor._attrs["dst_ops"] = tensor._attrs["dst_ops"] - to_be_removed_dst_ops if len(tensor._attrs["dst_ops"]) == 0: diff --git a/python/aitemplate/compiler/transform/transform_strided_ops.py b/python/aitemplate/compiler/transform/transform_strided_ops.py index 2de95d9be..f16c2a597 100644 --- a/python/aitemplate/compiler/transform/transform_strided_ops.py +++ b/python/aitemplate/compiler/transform/transform_strided_ops.py @@ -15,6 +15,7 @@ """ Perform transformations on ops which support strided inputs / outputs. """ + import functools from typing import List diff --git a/python/aitemplate/compiler/transform/transform_strided_slice.py b/python/aitemplate/compiler/transform/transform_strided_slice.py index 72397c2b8..f096671a4 100644 --- a/python/aitemplate/compiler/transform/transform_strided_slice.py +++ b/python/aitemplate/compiler/transform/transform_strided_slice.py @@ -15,6 +15,7 @@ """ Perform transformations on slice and strided ops. """ + import math from typing import List diff --git a/python/aitemplate/frontend/nn/attention.py b/python/aitemplate/frontend/nn/attention.py index 1f1240762..5a8c1e8ef 100644 --- a/python/aitemplate/frontend/nn/attention.py +++ b/python/aitemplate/frontend/nn/attention.py @@ -15,6 +15,7 @@ """ Frontend for attention module """ + from aitemplate.compiler import ops from aitemplate.compiler.ops import flash_attention from aitemplate.compiler.ops.common.epilogue import FuncEnum diff --git a/python/aitemplate/frontend/nn/batch_norm.py b/python/aitemplate/frontend/nn/batch_norm.py index 823954b4a..b4f97783e 100644 --- a/python/aitemplate/frontend/nn/batch_norm.py +++ b/python/aitemplate/frontend/nn/batch_norm.py @@ -15,6 +15,7 @@ """ Frontend for attention module """ + from aitemplate.compiler.public import elementwise, FuncEnum, permute from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv1d.py b/python/aitemplate/frontend/nn/conv1d.py index 1ce285bd2..655ede7b1 100644 --- a/python/aitemplate/frontend/nn/conv1d.py +++ b/python/aitemplate/frontend/nn/conv1d.py @@ -15,6 +15,7 @@ """ Conv1d Module. """ + from aitemplate.compiler.ops import conv2d, conv2d_bias, squeeze, unsqueeze from aitemplate.frontend import Tensor from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/conv2d/__init__.py b/python/aitemplate/frontend/nn/conv2d/__init__.py index 50d8b0dd2..f5dd3a7e0 100644 --- a/python/aitemplate/frontend/nn/conv2d/__init__.py +++ b/python/aitemplate/frontend/nn/conv2d/__init__.py @@ -16,6 +16,7 @@ """ modules for conv2d """ + from aitemplate.frontend.nn.conv2d.conv2d import Conv2d from aitemplate.frontend.nn.conv2d.conv2d_bias import Conv2dBias from aitemplate.frontend.nn.conv2d.conv2d_bias_add_hardswish import ( diff --git a/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_act.py b/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_act.py index 1a57137ad..09f5dc901 100644 --- a/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_act.py +++ b/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_act.py @@ -15,6 +15,7 @@ """ common module for conv_bias_act subgraph """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_add_act.py b/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_add_act.py index 687a3e676..c2d01aec9 100644 --- a/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_add_act.py +++ b/python/aitemplate/frontend/nn/conv2d/common_conv2d_bias_add_act.py @@ -15,6 +15,7 @@ """ common module for conv2d bias act residual add """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d.py b/python/aitemplate/frontend/nn/conv2d/conv2d.py index 1b78611cf..c823b19bb 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d.py @@ -15,6 +15,7 @@ """ conv2d Module. """ + from aitemplate.compiler.ops import conv2d from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias.py index 2a1e0779e..dd5dcf567 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias.py @@ -15,6 +15,7 @@ """ conv2d bias module """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_hardswish.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_hardswish.py index 046f9b589..e132c639e 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_hardswish.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_hardswish.py @@ -15,6 +15,7 @@ """ conv2d + bias + residual + hardswish """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_add_act import Conv2dBiasAddAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_relu.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_relu.py index 99a779ab1..e16f15450 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_relu.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_add_relu.py @@ -15,6 +15,7 @@ """ General template module for conv2e + bias + residual + relu """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_add_act import Conv2dBiasAddAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py index 36cb07963..17219b84f 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_few_channels.py @@ -15,6 +15,7 @@ """ conv2d bias for few channels """ + from aitemplate.frontend.nn.conv2d.special_conv2d_bias_act import SpecialConv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py index 55662e4f6..e5d79e495 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish.py @@ -15,6 +15,7 @@ """ conv bias hardswish module """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish_few_channels.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish_few_channels.py index 8cf6c3033..f14b13c79 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish_few_channels.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_hardswish_few_channels.py @@ -15,6 +15,7 @@ """ conv2d bias hardswish module for few channels """ + from aitemplate.frontend.nn.conv2d.special_conv2d_bias_act import SpecialConv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py index 25e02abb9..67416700a 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ conv2d bias relu module """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu_few_channels.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu_few_channels.py index 56a2eb8fb..3ea2fcbdf 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu_few_channels.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_relu_few_channels.py @@ -15,6 +15,7 @@ """ conv2d bias relu for few channels """ + from aitemplate.frontend.nn.conv2d.special_conv2d_bias_act import SpecialConv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py index 65077f4c4..a0c97f1f1 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_sigmoid.py @@ -15,6 +15,7 @@ """ conv2d bias sigmoid module """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise.py b/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise.py index 6968c22e6..d521d04c8 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise.py @@ -15,6 +15,7 @@ """ conv2d depthwise module """ + from aitemplate.compiler.ops import conv2d_depthwise from aitemplate.frontend.nn.conv2d.conv2d import Conv2d diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise_bias.py b/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise_bias.py index 129b491d4..ccd9843ac 100644 --- a/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise_bias.py +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_depthwise_bias.py @@ -15,6 +15,7 @@ """ conv2d depthwise bias module """ + from aitemplate.frontend.nn.conv2d.common_conv2d_bias_act import Conv2dBiasAct diff --git a/python/aitemplate/frontend/nn/conv2d/special_conv2d_bias_act.py b/python/aitemplate/frontend/nn/conv2d/special_conv2d_bias_act.py index d713908f9..b0383c98b 100644 --- a/python/aitemplate/frontend/nn/conv2d/special_conv2d_bias_act.py +++ b/python/aitemplate/frontend/nn/conv2d/special_conv2d_bias_act.py @@ -15,6 +15,7 @@ """ common module for conv_bias_act subgraph """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_act.py b/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_act.py index 628932729..96041a0d5 100644 --- a/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_act.py +++ b/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_act.py @@ -15,6 +15,7 @@ """ common module for ConvTranspose2d_bias_act subgraph """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py b/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py index 079ed7b57..a88fe7b26 100644 --- a/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py +++ b/python/aitemplate/frontend/nn/conv2d/transposed_conv2d_bias_relu.py @@ -15,6 +15,7 @@ """ conv2d bias relu module """ + from aitemplate.frontend.nn.conv2d.transposed_conv2d_bias_act import ( ConvTranspose2dBiasAct, ) diff --git a/python/aitemplate/frontend/nn/conv3d.py b/python/aitemplate/frontend/nn/conv3d.py index ea4256e46..743efe4af 100644 --- a/python/aitemplate/frontend/nn/conv3d.py +++ b/python/aitemplate/frontend/nn/conv3d.py @@ -15,6 +15,7 @@ """ conv3d Module. """ + from aitemplate.compiler.ops import conv3d, conv3d_bias, depthwise_conv3d from aitemplate.compiler.ops.padding.ndhwc3to8 import ndhwc3to8 from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/dropout.py b/python/aitemplate/frontend/nn/dropout.py index 91874de30..5dc2d31ec 100644 --- a/python/aitemplate/frontend/nn/dropout.py +++ b/python/aitemplate/frontend/nn/dropout.py @@ -13,6 +13,7 @@ # limitations under the License. # """Dropout/DropPath placeholder""" + from aitemplate.frontend.nn.module import Module # pylint: disable=C0103 diff --git a/python/aitemplate/frontend/nn/dual_gemm.py b/python/aitemplate/frontend/nn/dual_gemm.py index 17c84e5f7..e8147c185 100644 --- a/python/aitemplate/frontend/nn/dual_gemm.py +++ b/python/aitemplate/frontend/nn/dual_gemm.py @@ -15,6 +15,7 @@ """ Frontend for attention module """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.linear import Linear from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/fpn_proposal.py b/python/aitemplate/frontend/nn/fpn_proposal.py index 3f4f12e8f..c0ff37278 100644 --- a/python/aitemplate/frontend/nn/fpn_proposal.py +++ b/python/aitemplate/frontend/nn/fpn_proposal.py @@ -15,6 +15,7 @@ """ FPNProposal module. """ + import numpy as np from aitemplate.compiler import ops diff --git a/python/aitemplate/frontend/nn/group_norm.py b/python/aitemplate/frontend/nn/group_norm.py index 4d93a3d06..91227d821 100644 --- a/python/aitemplate/frontend/nn/group_norm.py +++ b/python/aitemplate/frontend/nn/group_norm.py @@ -15,6 +15,7 @@ """ GroupNorm module """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/identity.py b/python/aitemplate/frontend/nn/identity.py index 31d1efb68..95f9207c3 100644 --- a/python/aitemplate/frontend/nn/identity.py +++ b/python/aitemplate/frontend/nn/identity.py @@ -15,6 +15,7 @@ """ Identity module. """ + from aitemplate.frontend.nn.module import Module # pylint: disable=C0103 diff --git a/python/aitemplate/frontend/nn/layer_norm.py b/python/aitemplate/frontend/nn/layer_norm.py index 90331baae..1ec9dc2b9 100644 --- a/python/aitemplate/frontend/nn/layer_norm.py +++ b/python/aitemplate/frontend/nn/layer_norm.py @@ -15,6 +15,7 @@ """ LayerNorm module. """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/linear.py b/python/aitemplate/frontend/nn/linear.py index a6a6e1793..5c702044b 100644 --- a/python/aitemplate/frontend/nn/linear.py +++ b/python/aitemplate/frontend/nn/linear.py @@ -15,6 +15,7 @@ """ Linear module. """ + from aitemplate.compiler import ops from aitemplate.frontend.nn.module import Module from aitemplate.frontend.nn.parameter import Parameter diff --git a/python/aitemplate/frontend/nn/padding.py b/python/aitemplate/frontend/nn/padding.py index dfdca6fa9..97056c763 100644 --- a/python/aitemplate/frontend/nn/padding.py +++ b/python/aitemplate/frontend/nn/padding.py @@ -15,6 +15,7 @@ """ Padding related modules. """ + from aitemplate.compiler.ops import ndhwc3to8, nhwc3to8 from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/parameter.py b/python/aitemplate/frontend/nn/parameter.py index 660dd65b6..6489b3e07 100644 --- a/python/aitemplate/frontend/nn/parameter.py +++ b/python/aitemplate/frontend/nn/parameter.py @@ -15,6 +15,7 @@ """ Parameter definition. """ + from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/frontend/nn/patch_embed.py b/python/aitemplate/frontend/nn/patch_embed.py index 0d2658128..1a8021718 100644 --- a/python/aitemplate/frontend/nn/patch_embed.py +++ b/python/aitemplate/frontend/nn/patch_embed.py @@ -15,6 +15,7 @@ """ patch_embed Module. """ + from typing import Callable, Tuple from aitemplate.compiler import ops diff --git a/python/aitemplate/frontend/nn/pool2d.py b/python/aitemplate/frontend/nn/pool2d.py index a1eb439c2..63951197d 100644 --- a/python/aitemplate/frontend/nn/pool2d.py +++ b/python/aitemplate/frontend/nn/pool2d.py @@ -15,6 +15,7 @@ """ pool2d-family modules. """ + from aitemplate.compiler.ops import avg_pool2d, max_pool2d from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/pool3d.py b/python/aitemplate/frontend/nn/pool3d.py index db663a3f0..d0393c03a 100644 --- a/python/aitemplate/frontend/nn/pool3d.py +++ b/python/aitemplate/frontend/nn/pool3d.py @@ -15,6 +15,7 @@ """ pool3d-family modules. """ + from aitemplate.compiler.ops import max_pool2d from aitemplate.compiler.ops.common import reshape from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/positional_encoding.py b/python/aitemplate/frontend/nn/positional_encoding.py index 830778a0d..58f1e8904 100644 --- a/python/aitemplate/frontend/nn/positional_encoding.py +++ b/python/aitemplate/frontend/nn/positional_encoding.py @@ -15,6 +15,7 @@ """ positional_encoding Modules. """ + import logging from typing import Tuple diff --git a/python/aitemplate/frontend/nn/proposal.py b/python/aitemplate/frontend/nn/proposal.py index 18b53f313..67f796a81 100644 --- a/python/aitemplate/frontend/nn/proposal.py +++ b/python/aitemplate/frontend/nn/proposal.py @@ -15,6 +15,7 @@ """ Proposal module. """ + import math import numpy as np diff --git a/python/aitemplate/frontend/nn/roi_ops.py b/python/aitemplate/frontend/nn/roi_ops.py index a3d17bbf2..4decf9803 100644 --- a/python/aitemplate/frontend/nn/roi_ops.py +++ b/python/aitemplate/frontend/nn/roi_ops.py @@ -15,6 +15,7 @@ """ RoiAlign-family modules. """ + from aitemplate.compiler.ops import multi_level_roi_align, roi_align from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/softmax.py b/python/aitemplate/frontend/nn/softmax.py index 2a8ff0e5a..327d32417 100644 --- a/python/aitemplate/frontend/nn/softmax.py +++ b/python/aitemplate/frontend/nn/softmax.py @@ -15,6 +15,7 @@ """ softmax Module. """ + from typing import Optional from aitemplate.compiler import ops diff --git a/python/aitemplate/frontend/nn/upsample.py b/python/aitemplate/frontend/nn/upsample.py index 619d97236..a593657ad 100644 --- a/python/aitemplate/frontend/nn/upsample.py +++ b/python/aitemplate/frontend/nn/upsample.py @@ -15,6 +15,7 @@ """ Unsampling2d module. """ + from aitemplate.compiler.ops import upsampling2d, upsampling2d_add from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/nn/vanilla_attention.py b/python/aitemplate/frontend/nn/vanilla_attention.py index 7fe7f0377..61507dc7b 100644 --- a/python/aitemplate/frontend/nn/vanilla_attention.py +++ b/python/aitemplate/frontend/nn/vanilla_attention.py @@ -15,6 +15,7 @@ """ Frontend for vanilla attention module """ + from functools import partial from aitemplate.compiler import ops diff --git a/python/aitemplate/frontend/nn/view_ops.py b/python/aitemplate/frontend/nn/view_ops.py index 1406ed9bf..2858fe0e7 100644 --- a/python/aitemplate/frontend/nn/view_ops.py +++ b/python/aitemplate/frontend/nn/view_ops.py @@ -15,6 +15,7 @@ """ View-related modules. """ + from aitemplate.compiler.ops import flatten, reshape from aitemplate.frontend.nn.module import Module diff --git a/python/aitemplate/frontend/parameter.py b/python/aitemplate/frontend/parameter.py index 660dd65b6..6489b3e07 100644 --- a/python/aitemplate/frontend/parameter.py +++ b/python/aitemplate/frontend/parameter.py @@ -15,6 +15,7 @@ """ Parameter definition. """ + from aitemplate.compiler.base import Tensor diff --git a/python/aitemplate/testing/__init__.py b/python/aitemplate/testing/__init__.py index 746641f05..73fc6ec22 100644 --- a/python/aitemplate/testing/__init__.py +++ b/python/aitemplate/testing/__init__.py @@ -15,6 +15,7 @@ """ testing module """ + from aitemplate.testing import benchmark_ait, benchmark_pt from aitemplate.testing.detect_target import detect_target from aitemplate.testing.profile import profile_callable diff --git a/python/aitemplate/testing/benchmark_trt.py b/python/aitemplate/testing/benchmark_trt.py index ebd22e841..ecd3fd000 100644 --- a/python/aitemplate/testing/benchmark_trt.py +++ b/python/aitemplate/testing/benchmark_trt.py @@ -15,6 +15,7 @@ """ helper functions to benchmark fx-trt """ + from aitemplate.testing.benchmark_pt import benchmark_torch_function # usort:skip from torch_tensorrt.fx import lower from torch_tensorrt.fx.utils import LowerPrecision diff --git a/python/aitemplate/testing/detect_target.py b/python/aitemplate/testing/detect_target.py index 20fd561ae..37ec7ec27 100644 --- a/python/aitemplate/testing/detect_target.py +++ b/python/aitemplate/testing/detect_target.py @@ -15,6 +15,7 @@ """ Automatic detect target for testing """ + import logging import os from subprocess import PIPE, Popen diff --git a/python/aitemplate/testing/jagged_utils.py b/python/aitemplate/testing/jagged_utils.py index f8f4cb3d9..c4ea0cad7 100644 --- a/python/aitemplate/testing/jagged_utils.py +++ b/python/aitemplate/testing/jagged_utils.py @@ -370,9 +370,7 @@ def batched_dense_vec_jagged_2d_mul_ref( return torch.matmul( vectors.unsqueeze(dim=2), # [B, H, 1, N] padded_matrices.permute([0, 2, 1, 3]), # [B, H, N, D] - ).squeeze( - dim=2 - ) # [B, H, D] + ).squeeze(dim=2) # [B, H, D] def add_jagged_dense_ref( diff --git a/python/aitemplate/testing/profile.py b/python/aitemplate/testing/profile.py index 5b506def3..4d96374a8 100644 --- a/python/aitemplate/testing/profile.py +++ b/python/aitemplate/testing/profile.py @@ -15,6 +15,7 @@ """ Torch module profiling utility. """ + import logging from operator import itemgetter from typing import Callable, List, Tuple diff --git a/python/aitemplate/testing/test_utils.py b/python/aitemplate/testing/test_utils.py index fcd25bb34..aa3601a95 100644 --- a/python/aitemplate/testing/test_utils.py +++ b/python/aitemplate/testing/test_utils.py @@ -15,6 +15,7 @@ """ Utils for unit tests. """ + import contextlib import itertools import os diff --git a/python/aitemplate/utils/environ.py b/python/aitemplate/utils/environ.py index be21f4004..bf98ba97c 100644 --- a/python/aitemplate/utils/environ.py +++ b/python/aitemplate/utils/environ.py @@ -15,6 +15,7 @@ """ A common place for holding AIT-related env control variables """ + import logging import os from typing import Optional diff --git a/python/aitemplate/utils/io.py b/python/aitemplate/utils/io.py index f6dc0e1f7..b533fe191 100644 --- a/python/aitemplate/utils/io.py +++ b/python/aitemplate/utils/io.py @@ -15,6 +15,7 @@ """ Util functions to handle file or network io """ + import hashlib import logging import os diff --git a/python/aitemplate/utils/markdown_table.py b/python/aitemplate/utils/markdown_table.py index 9bdbb9bea..89a6f8505 100644 --- a/python/aitemplate/utils/markdown_table.py +++ b/python/aitemplate/utils/markdown_table.py @@ -18,6 +18,7 @@ Original Project: https://github.com/hvalev/markdownTable Accessed: Jul 16, 2022 """ + import math diff --git a/python/aitemplate/utils/misc.py b/python/aitemplate/utils/misc.py index 5ad5d26fc..2291eef7b 100644 --- a/python/aitemplate/utils/misc.py +++ b/python/aitemplate/utils/misc.py @@ -15,6 +15,7 @@ """ miscellaneous utilities """ + import hashlib import logging import os diff --git a/python/aitemplate/utils/mk_cutlass_lib/extra_conv_emit.py b/python/aitemplate/utils/mk_cutlass_lib/extra_conv_emit.py index 8732e4ac4..77fbdc75e 100644 --- a/python/aitemplate/utils/mk_cutlass_lib/extra_conv_emit.py +++ b/python/aitemplate/utils/mk_cutlass_lib/extra_conv_emit.py @@ -15,6 +15,7 @@ """ Extra cutlass enum, mainly for epilogue """ + import jinja2 CONV_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/utils/mk_cutlass_lib/extra_cutlass_generator.py b/python/aitemplate/utils/mk_cutlass_lib/extra_cutlass_generator.py index a684dcc5b..21ee9c435 100644 --- a/python/aitemplate/utils/mk_cutlass_lib/extra_cutlass_generator.py +++ b/python/aitemplate/utils/mk_cutlass_lib/extra_cutlass_generator.py @@ -15,6 +15,7 @@ """ Extra cutlass tiling configs for special problems """ + import jinja2 SRC_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/utils/mk_cutlass_lib/extra_enum.py b/python/aitemplate/utils/mk_cutlass_lib/extra_enum.py index aa674561b..2d9c36be5 100644 --- a/python/aitemplate/utils/mk_cutlass_lib/extra_enum.py +++ b/python/aitemplate/utils/mk_cutlass_lib/extra_enum.py @@ -15,6 +15,7 @@ """ Extra cutlass enum, mainly for epilogue """ + import jinja2 SRC_TEMPLATE = jinja2.Template( diff --git a/python/aitemplate/utils/mk_cutlass_lib/extra_gemm_emit.py b/python/aitemplate/utils/mk_cutlass_lib/extra_gemm_emit.py index 93ce36765..0da4073d3 100644 --- a/python/aitemplate/utils/mk_cutlass_lib/extra_gemm_emit.py +++ b/python/aitemplate/utils/mk_cutlass_lib/extra_gemm_emit.py @@ -15,6 +15,7 @@ """ Extra cutlass enum, mainly for epilogue """ + import jinja2 diff --git a/python/aitemplate/utils/serialization/serdes_code.py b/python/aitemplate/utils/serialization/serdes_code.py index e2c1329bc..92541f00e 100644 --- a/python/aitemplate/utils/serialization/serdes_code.py +++ b/python/aitemplate/utils/serialization/serdes_code.py @@ -15,6 +15,7 @@ """ Dump/Read sorted_graph to/from python code. """ + import copy import logging import os diff --git a/python/aitemplate/utils/visualization/plot.py b/python/aitemplate/utils/visualization/plot.py index d0c21999a..ee35c3c93 100644 --- a/python/aitemplate/utils/visualization/plot.py +++ b/python/aitemplate/utils/visualization/plot.py @@ -15,6 +15,7 @@ """ Graph visualization tool for AITemplate """ + import json import os diff --git a/python/aitemplate/utils/visualization/pydot.py b/python/aitemplate/utils/visualization/pydot.py index e580fc611..d671d31f6 100644 --- a/python/aitemplate/utils/visualization/pydot.py +++ b/python/aitemplate/utils/visualization/pydot.py @@ -17,6 +17,7 @@ Original Project: https://github.com/pydot/pydot Accessed: Jul 25, 2022 """ + import copy import errno import io diff --git a/tests/lint/check_meta_header.py b/tests/lint/check_meta_header.py index da385fa58..2f7dea7a1 100644 --- a/tests/lint/check_meta_header.py +++ b/tests/lint/check_meta_header.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Check Python source code contains Meta copyright header -""" +"""Check Python source code contains Meta copyright header""" from __future__ import annotations diff --git a/tests/unittest/compiler/test_fuse_bmm_permute.py b/tests/unittest/compiler/test_fuse_bmm_permute.py index 5360f4082..cdbca31e4 100644 --- a/tests/unittest/compiler/test_fuse_bmm_permute.py +++ b/tests/unittest/compiler/test_fuse_bmm_permute.py @@ -67,7 +67,6 @@ def _test_bmm_permute( orig_layout: str, dtype: str = "float16", ): - is_row_major_a = orig_layout[0] == "r" is_row_major_b = orig_layout[1] == "r" is_row_major_c = orig_layout[2] == "r" diff --git a/tests/unittest/compiler/test_fused_elementwise_complex_dependency.py b/tests/unittest/compiler/test_fused_elementwise_complex_dependency.py index 4ef635e7e..9f3fd1c34 100644 --- a/tests/unittest/compiler/test_fused_elementwise_complex_dependency.py +++ b/tests/unittest/compiler/test_fused_elementwise_complex_dependency.py @@ -15,6 +15,7 @@ """ Unittests for elementwise fusion with complex dependencies. """ + import unittest import torch diff --git a/tests/unittest/compiler/test_fused_elementwise_out_of_order.py b/tests/unittest/compiler/test_fused_elementwise_out_of_order.py index 011226c67..e74229b9c 100644 --- a/tests/unittest/compiler/test_fused_elementwise_out_of_order.py +++ b/tests/unittest/compiler/test_fused_elementwise_out_of_order.py @@ -15,6 +15,7 @@ """ Unittests for elementwise fusion out-of-order cases. """ + import unittest import torch diff --git a/tests/unittest/compiler/test_strided_layernorm.py b/tests/unittest/compiler/test_strided_layernorm.py index 160145f8b..ff1cc394d 100644 --- a/tests/unittest/compiler/test_strided_layernorm.py +++ b/tests/unittest/compiler/test_strided_layernorm.py @@ -158,7 +158,6 @@ def _test_slice_layer_norm( test_name="test_slice_layer_norm", use_welford_algorithm=False, ): - input_rank = 1 + len(input_nonbatch_shape) if 1 == len(start_indices) and len(start_indices) != input_rank: start_indices = [start_indices[0]] * input_rank diff --git a/tests/unittest/ops/test_activation.py b/tests/unittest/ops/test_activation.py index 14e1c0c3b..daaa8e806 100644 --- a/tests/unittest/ops/test_activation.py +++ b/tests/unittest/ops/test_activation.py @@ -15,6 +15,7 @@ """ Unittests for special activation Operator. """ + import unittest import torch diff --git a/tests/unittest/ops/test_argmax.py b/tests/unittest/ops/test_argmax.py index 3aa93c604..446e89078 100644 --- a/tests/unittest/ops/test_argmax.py +++ b/tests/unittest/ops/test_argmax.py @@ -15,6 +15,7 @@ """ Unittests for argmax Operator. """ + import unittest import torch diff --git a/tests/unittest/ops/test_argmax_sm80.py b/tests/unittest/ops/test_argmax_sm80.py index aeb4e498c..d30ddfbdf 100644 --- a/tests/unittest/ops/test_argmax_sm80.py +++ b/tests/unittest/ops/test_argmax_sm80.py @@ -15,6 +15,7 @@ """ Unittests for argmax Operator. """ + import unittest import torch diff --git a/tests/unittest/ops/test_attention.py b/tests/unittest/ops/test_attention.py index 18010ed99..ad62091e4 100644 --- a/tests/unittest/ops/test_attention.py +++ b/tests/unittest/ops/test_attention.py @@ -15,6 +15,7 @@ """ Unittests for flash_attention Operator. """ + import itertools import logging import math diff --git a/tests/unittest/ops/test_b2b_bmm.py b/tests/unittest/ops/test_b2b_bmm.py index c645ab65b..7cd621cff 100644 --- a/tests/unittest/ops/test_b2b_bmm.py +++ b/tests/unittest/ops/test_b2b_bmm.py @@ -15,6 +15,7 @@ """ Unittests for b2b bmm Operators. """ + import itertools import logging import unittest diff --git a/tests/unittest/ops/test_batch_gather.py b/tests/unittest/ops/test_batch_gather.py index a62484130..6fb5632b7 100644 --- a/tests/unittest/ops/test_batch_gather.py +++ b/tests/unittest/ops/test_batch_gather.py @@ -15,6 +15,7 @@ """ Unittests for batch_gather Operator. """ + import unittest import torch @@ -46,7 +47,6 @@ def _test_batch_gather( test_name="gather", dtype="float16", ): - in_shape = shape o_shape = list(in_shape) @@ -133,7 +133,6 @@ def _test_batch_gather_topk( test_name="topk", dtype="float16", ): - m_shape = (N,) + shape n_shape = (topK,) + shape diff --git a/tests/unittest/ops/test_batched_dense_vec_jagged_2d_mul.py b/tests/unittest/ops/test_batched_dense_vec_jagged_2d_mul.py index b07d5c7cd..c5bb0d77f 100644 --- a/tests/unittest/ops/test_batched_dense_vec_jagged_2d_mul.py +++ b/tests/unittest/ops/test_batched_dense_vec_jagged_2d_mul.py @@ -15,6 +15,7 @@ """ Unittests for batched_dense_vec_jagged_2d_mul Operator. """ + import unittest from typing import List diff --git a/tests/unittest/ops/test_efficient_nms.py b/tests/unittest/ops/test_efficient_nms.py index e1e63b3c7..19d8bb595 100644 --- a/tests/unittest/ops/test_efficient_nms.py +++ b/tests/unittest/ops/test_efficient_nms.py @@ -15,6 +15,7 @@ """ Unittests for nms Operator. """ + import os import shutil import unittest diff --git a/tests/unittest/ops/test_fused_elementwise_broadcast.py b/tests/unittest/ops/test_fused_elementwise_broadcast.py index d0432bc46..62e304de6 100644 --- a/tests/unittest/ops/test_fused_elementwise_broadcast.py +++ b/tests/unittest/ops/test_fused_elementwise_broadcast.py @@ -15,6 +15,7 @@ """ Unittests for fused_elementwise broadcast. """ + import itertools import unittest diff --git a/tests/unittest/ops/test_fused_elementwise_with_strided_outputs.py b/tests/unittest/ops/test_fused_elementwise_with_strided_outputs.py index b5398eb38..e430077c8 100644 --- a/tests/unittest/ops/test_fused_elementwise_with_strided_outputs.py +++ b/tests/unittest/ops/test_fused_elementwise_with_strided_outputs.py @@ -15,6 +15,7 @@ """ Unittests for fused_elementwise Operator with strided outputs. """ + import unittest from typing import List diff --git a/tests/unittest/ops/test_grouped_b2b_bmm.py b/tests/unittest/ops/test_grouped_b2b_bmm.py index f552e5937..5f88ca665 100644 --- a/tests/unittest/ops/test_grouped_b2b_bmm.py +++ b/tests/unittest/ops/test_grouped_b2b_bmm.py @@ -15,6 +15,7 @@ """ Unittests for grouped b2b bmm Operators. """ + import itertools import logging import os diff --git a/tests/unittest/ops/test_grouped_classic_b2b_bmm.py b/tests/unittest/ops/test_grouped_classic_b2b_bmm.py index b928ea01d..30e517d0e 100644 --- a/tests/unittest/ops/test_grouped_classic_b2b_bmm.py +++ b/tests/unittest/ops/test_grouped_classic_b2b_bmm.py @@ -15,6 +15,7 @@ """ Unittests for grouped b2b bmm Operators. """ + import logging import os diff --git a/tests/unittest/ops/test_groupnorm.py b/tests/unittest/ops/test_groupnorm.py index fb65ddfd1..754993779 100644 --- a/tests/unittest/ops/test_groupnorm.py +++ b/tests/unittest/ops/test_groupnorm.py @@ -15,6 +15,7 @@ """ Unittests for group norm Operator. """ + import logging import unittest diff --git a/tests/unittest/ops/test_index_select.py b/tests/unittest/ops/test_index_select.py index c404ab9af..54a022a5d 100644 --- a/tests/unittest/ops/test_index_select.py +++ b/tests/unittest/ops/test_index_select.py @@ -15,6 +15,7 @@ """ Unittests for masked_select Operator. """ + import logging import random import unittest @@ -57,7 +58,6 @@ def _test_index_select( benchmark=False, dim_idxs=None, ): - X1 = Tensor( shape=shape if x_shape is None else x_shape, dtype=dtype, diff --git a/tests/unittest/ops/test_layernorm.py b/tests/unittest/ops/test_layernorm.py index 583f55fa6..a79a9a13c 100644 --- a/tests/unittest/ops/test_layernorm.py +++ b/tests/unittest/ops/test_layernorm.py @@ -15,6 +15,7 @@ """ Unittests for LayerNorm Operator. """ + import logging import unittest diff --git a/tests/unittest/ops/test_layernorm_sigmoid_mul.py b/tests/unittest/ops/test_layernorm_sigmoid_mul.py index 0d41ff55d..b0ddd1b08 100644 --- a/tests/unittest/ops/test_layernorm_sigmoid_mul.py +++ b/tests/unittest/ops/test_layernorm_sigmoid_mul.py @@ -15,6 +15,7 @@ """ Unittests for FusedLayernormSigmoidMul Operator. """ + import logging import unittest @@ -130,7 +131,7 @@ def _test_fused_layernorm_sigmoid_mul( inputs["beta"] = x3_pt x6 = torch.empty_like(x6_pt) module.run_with_tensors(inputs, [x6]) - torch.testing.assert_close(x6, x6_pt, atol=atol, rtol=rtol), + (torch.testing.assert_close(x6, x6_pt, atol=atol, rtol=rtol),) def test_fused_layernorm_sigmoid_mul_fp16(self): for eps in (1e-5, 1e-1): diff --git a/tests/unittest/ops/test_masked_select.py b/tests/unittest/ops/test_masked_select.py index 45023df0a..33263726c 100644 --- a/tests/unittest/ops/test_masked_select.py +++ b/tests/unittest/ops/test_masked_select.py @@ -15,6 +15,7 @@ """ Unittests for masked_select Operator. """ + import unittest import torch diff --git a/tests/unittest/ops/test_nms.py b/tests/unittest/ops/test_nms.py index 9300ff1ca..822dfd194 100644 --- a/tests/unittest/ops/test_nms.py +++ b/tests/unittest/ops/test_nms.py @@ -15,6 +15,7 @@ """ Unittests for nms Operator. """ + import unittest from unittest import skipIf diff --git a/tests/unittest/ops/test_perm102_bmm_rcr.py b/tests/unittest/ops/test_perm102_bmm_rcr.py index 39c306704..403e42410 100644 --- a/tests/unittest/ops/test_perm102_bmm_rcr.py +++ b/tests/unittest/ops/test_perm102_bmm_rcr.py @@ -20,7 +20,6 @@ # self._1085_1133, _2905_2929, self._1084_1132) # baddbmm(bias, X, W) """ - import unittest import torch diff --git a/tests/unittest/ops/test_perm102_bmm_rrr.py b/tests/unittest/ops/test_perm102_bmm_rrr.py index e8851b56c..65b358209 100644 --- a/tests/unittest/ops/test_perm102_bmm_rrr.py +++ b/tests/unittest/ops/test_perm102_bmm_rrr.py @@ -20,7 +20,6 @@ # self._1085_1133, _2905_2929, self._1084_1132) # baddbmm(bias, X, W) """ - import unittest import torch diff --git a/tests/unittest/ops/test_softmax.py b/tests/unittest/ops/test_softmax.py index c4fd47eda..7d7216d92 100644 --- a/tests/unittest/ops/test_softmax.py +++ b/tests/unittest/ops/test_softmax.py @@ -15,6 +15,7 @@ """ Unittests for LayerNorm Operator. """ + import json import math import tempfile diff --git a/tests/unittest/ops/test_topk.py b/tests/unittest/ops/test_topk.py index 787aa719c..2d5eedb6b 100644 --- a/tests/unittest/ops/test_topk.py +++ b/tests/unittest/ops/test_topk.py @@ -15,6 +15,7 @@ """ Unittests for topk Operator. """ + import unittest import numpy as np @@ -46,7 +47,6 @@ def _test_topk( copy_op=False, dtype="float16", ): - o_shape = list(shape) o_shape[-1] = topK diff --git a/tests/unittest/ops/test_vanilla_attention.py b/tests/unittest/ops/test_vanilla_attention.py index 63149aa45..a289a81db 100644 --- a/tests/unittest/ops/test_vanilla_attention.py +++ b/tests/unittest/ops/test_vanilla_attention.py @@ -15,6 +15,7 @@ """ Unittests for vanilla_attention. """ + import logging import math import os diff --git a/tests/unittest/test_stable_set.py b/tests/unittest/test_stable_set.py index 3b5b92342..9fd972aeb 100644 --- a/tests/unittest/test_stable_set.py +++ b/tests/unittest/test_stable_set.py @@ -15,6 +15,7 @@ """ Unittests for StableSet. """ + import unittest from aitemplate.compiler.stable_set import StableSet diff --git a/tests/unittest/util/test_debug_utils.py b/tests/unittest/util/test_debug_utils.py index a86059d91..49c6d1eb1 100644 --- a/tests/unittest/util/test_debug_utils.py +++ b/tests/unittest/util/test_debug_utils.py @@ -15,6 +15,7 @@ """ Unittests for debug utils. """ + import numpy as np import pytest diff --git a/tests/unittest/util/test_serdes.py b/tests/unittest/util/test_serdes.py index 284391a72..ab32f972c 100644 --- a/tests/unittest/util/test_serdes.py +++ b/tests/unittest/util/test_serdes.py @@ -15,6 +15,7 @@ """ Unittests for special activation Operator. """ + import logging import unittest