diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..81b19d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,197 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# pycharm line profiler result +**/*.pclprof + +outs +build-release + +# nsight compute report +*.ncu-rep + +# cache for kernel binary +.hidet_cache/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..3159b95 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = git@github.com:NVIDIA/cutlass.git +[submodule "3rdparty/tvm"] + path = 3rdparty/tvm + url = git@github.com:apache/tvm.git diff --git a/3rdparty/tvm b/3rdparty/tvm new file mode 160000 index 0000000..c07a463 --- /dev/null +++ b/3rdparty/tvm @@ -0,0 +1 @@ +Subproject commit c07a46327c86fc541297ebb985cc9c1dcef5a0db diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..24ac0cd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,65 @@ +cmake_minimum_required(VERSION 3.19) + +project(hidet C CXX CUDA) + +# common configs +set(CMAKE_C_COMPILER_LAUNCHER ccache) +set(CMAKE_CXX_COMPILER_LAUNCHER ccache) +set(CMAKE_CUDA_COMPILER_LAUNCHER ccache) + +# submodules +include(cmake/TVM.cmake) + +# config hidet +if(EXISTS "${CMAKE_BINARY_DIR}/config.cmake") + include(${CMAKE_BINARY_DIR}/config.cmake) +else() + include(${CMAKE_SOURCE_DIR}/config.cmake) +endif() + +set(CMAKE_BUILD_TYPE ${HIDET_BUILD_TYPE}) +message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") + +# add runtime target +add_library(hidet_runtime SHARED + src/hidet/runtime/cuda_context.cpp + ) +target_include_directories(hidet_runtime PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${CMAKE_SOURCE_DIR}/include + /usr/include + ) +set_target_properties(hidet_runtime PROPERTIES + CUDA_RUNTIME_LIBRARY SHARED + CUDA_ARCHITECTURES ${HIDET_CUDA_ARCH} + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ) + +# add main target +add_library(hidet SHARED + src/hidet/packedfunc.cpp + src/hidet/logging.cpp + src/hidet/cuda_api.cpp + src/hidet/cuda_kernels.cu + ) + +target_include_directories(hidet PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ${CMAKE_SOURCE_DIR}/include + ) + +target_link_directories(hidet PRIVATE ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES}) + +target_link_libraries(hidet cudart cublas curand) +target_link_libraries(hidet "-Wl,--no-as-needed" hidet_runtime) + +set_target_properties(hidet PROPERTIES + CUDA_RUNTIME_LIBRARY SHARED + CUDA_ARCHITECTURES ${HIDET_CUDA_ARCH} + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ) + +# add -lineinfo option to nvcc, allowing us to get the source code from binary +# do not influence optimization, can be used in nsight compute profiling +target_compile_options(hidet PRIVATE $<$:-lineinfo>) + diff --git a/README.md b/README.md new file mode 100644 index 0000000..36cd419 --- /dev/null +++ b/README.md @@ -0,0 +1,146 @@ +# ASPLOS 2023 Artifact Evaluation + +This repository contains the artifacts for the paper + +"Hidet: Task Mapping Programming Paradigm for Deep Learning Tensor Programs". + +## Installation + +### Requirements + +We did experiment on the following hardware platform + +- CPU: Intel Core i9-12900K +- GPU: NVIDIA GeForce RTX 3090 (one with 420 Watt TDP) +- Memory: 64 GiB + +Other workstation equipped with a modern NVIDIA GPU should also be able to run the experiments. + +On the software side, we require the following software to be installed + +- cmake 3.19+ +- llvm (required by TVM, we used llvm-10) +- ccache (used to accelerate duplicated compilation) + +### NVIDIA CUDA Toolkit + +Please follow https://developer.nvidia.com/cuda-downloads guide to install the CUDA toolkit. + +We used NVIDIA Driver 510.73.08 and CUDA 11.6 for our experiments. The newer versions of CUDA should also work. + +Please run the following commands to check whether the NVIDIA Driver and CUDA toolkit are installed correctly. + +```bash +nvidia-smi +nvcc --version +``` + +### Install Hidet and baselines + +```bash +# clone hidet repository +git clone git@github.com:yaoyaoding/hidet +cd hidet +git checkout artifact +git submodule init +git submodule update --recursive --init + +# install the dependencies of hidet and the baselines (e.g., TensorRT, PyTorch, Onnx Runtime) +# the versions of baselines are specified in requirements.txt file. +pip3 install -r requirements.txt + +# build hidet and tvm +mkdir build +cd build +cmake .. +make -j8 +cd .. + +# there should be four dynamic libraries in the build/lib directory: +# libtvm.so, libtvm_runtime.so, libhidet.so, libhidet_runtime.so +ls build/lib + +# set the environment variables, it is recommended to set them in your .bashrc file +export HIDET_HOME=`pwd` +export PYTHONPATH=$PYTHONPATH:$HIDET_HOME/python:$HIDET_HOME/artifacts +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HIDET_HOME/build/lib + +# test whether hidet has been installed successfully +python3 -c "import hidet" +python3 -c "import artifact" +``` + +# Run the experiments + +This artifact contains all the experiments in the evaluation section of the paper: + +- Experiment 0 (section 6.1): End to end performance comparison +- Experiment 1 (section 6.2.1): Schedule space comparison +- Experiment 2 (section 6.2.2): Performance sensitivity over input sizes +- Experiment 3 (section 6.2.3): Evaluation on different batch sizes +- Experiment 4 (section 6.2.4): Post-scheduling fusion evaluation +- Experiment 5 (section 6.2.5): Comparison with TensorRT + +The 6 experiments are organized in the `artifacts` directory. Each experiment corresponds to a directory with a `main.py` script. +Directly run the `main.py` script to launch corresponding experiments. We will automatically cache the optimized operator and models in `.hidet_cache` directory, thus you can stop and restart the experiments at any time. The second run of the experiments will be much faster. + +It will take tens of hours to finish all experiments. Most of the time is spent by autotvm and ansor schedulers. If you want to first skip the experiments related to autotvm and ansor, you can comment the line `--exec ansor` and `--exec autotvm` in each `main.py` script. + +```bash +cd artifacts + +# the following environment variables allow TVM to use all the cores of the machine +# if your machine has a different number of threads other than 24, change the value of TVM_NUM_THREADS accordingly +export TVM_BIND_THREADS=0 +export TVM_NUM_THREADS=24 + +python3 0_end_to_end/main.py +python3 1_latency_distribution/main.py +python3 2_input_sensitivity/main.py +python3 3_batch_size/main.py +python3 4_prolouge_epilogue_fusion/main.py +python3 5_tensorrt/main.py +``` +(we store above instructions in `run.sh`, you can run `bash run.sh` to run all experiments). + +Each script would have outputs like +```text + BatchSize Model Executor Config Space Latency Std Error + 1 resnet50 hidet sp2_simt_f32_f32_pk_default 2 1.184 0.000 0.000 + + BatchSize Model Executor Config Space Latency Std Error + 1 inception_v3 hidet sp2_simt_f32_f32_pk_default 2 1.722 0.005 0.000 + + BatchSize Model Executor Config Space Latency Std Error + 1 mobilenet_v2 hidet sp2_simt_f32_f32_pk_default 2 0.337 0.001 0.000 + + BatchSize Model Executor Config Space Latency Std Error + 1 bert hidet sp2_simt_f32_f32_pk_default 2 2.378 0.029 0.000 + + BatchSize Model Executor Config Space Latency Std Error + 1 gpt2 hidet sp2_simt_f32_f32_pk_default 2 2.608 0.054 0.000 +``` + +You may also see output log like +```text +Compiling task avg_pool2d_rearrange_rearrange_reshape_rearrange... +Compiling task matmul_reshape... +100%|█████████████████████████████████████████████████████████| 177/177 [00:45<00:00, 3.86it/s] +``` +This indicates that hidet is compiling the kernel for each operator. The progress bar indicates hidet is tuning a kernel. The compilation and tuning results will be cached in `hidet/.hidet_cache` directory. The subsequent runs will reuse the cached results. Feel free to ignore these logs. + +The 8 columns in the output correspond to +- batch size, +- model name, +- executor, + - PyTorch: torch + - Onnx Runtime: ort + - AutoTVM: autotvm + - Ansor: ansor + - TensorRT: tensorrt + - Hidet: hidet +- config for executor, +- search space of hidet (please ignore this column for other executor), +- end to end latency in milliseconds, +- standard deviation of latency, +- and the output error (see `hidet.utils.py.error_tolerance` for the definition) compared with onnx runtime with cpu backend. We compared the output to make sure the inference results are correct for each executor. diff --git a/artifacts/0_end2end/main.py b/artifacts/0_end2end/main.py new file mode 100644 index 0000000..a58ecd8 --- /dev/null +++ b/artifacts/0_end2end/main.py @@ -0,0 +1,24 @@ +from artifact import bench +import hidet + + +def main(): + for executor in [ + '--exec torch', + '--exec ort', + '--exec autotvm', + '--exec ansor', + '--exec hidet', + ]: + for model in [ + '--model resnet50', + '--model inception_v3', + '--model mobilenet_v2', + '--model bert', + '--model gpt2' + ]: + bench('{} {}'.format(executor, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/0_end2end/output.txt b/artifacts/0_end2end/output.txt new file mode 100644 index 0000000..6d03076 --- /dev/null +++ b/artifacts/0_end2end/output.txt @@ -0,0 +1 @@ +args Namespace(ansor_trial=80, autotvm_trial=100, bert_hidden_size=768, bert_seq_length=128, bert_vocab_size=30522, bs=1, disable_graph_cache=False, exec='hidet', hidet_space=2, mma='simt', model='resnet50', number=10, ort_provider='cuda', out_dir='./results/', parallel_k='default', precision='f32', reduce_precision='f32', repeat=10, trt_fp16=False, trt_tf32=False, warmup=10) time stamp 1665522974.8277862 diff --git a/artifacts/1_latency_distribution/main.py b/artifacts/1_latency_distribution/main.py new file mode 100644 index 0000000..7938d83 --- /dev/null +++ b/artifacts/1_latency_distribution/main.py @@ -0,0 +1,19 @@ +from artifact import bench + + +def main(): + for executor in [ + '--exec hidet', + '--exec ansor', + '--exec autotvm', + ]: + for model in [ + # a conv-bn-relu subgraph in resnet50 with conv2d: + # input: [1, 256, 28, 28], weight: [256, 256, 3, 3], padding: 1, stride: 2 + '--model op_resnet50_conv_2', + ]: + bench('{} {}'.format(executor, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/2_input_sensitivity/main.py b/artifacts/2_input_sensitivity/main.py new file mode 100644 index 0000000..3b5cd79 --- /dev/null +++ b/artifacts/2_input_sensitivity/main.py @@ -0,0 +1,29 @@ +from artifact import bench + + +def main(): + for executor in [ + '--exec ansor', + '--exec autotvm', + '--exec hidet', + ]: + for model in [ + '--model op_matmul_nn_4', # 2048x2048x2048 + '--model op_matmul_nn_5', # 2047x2047x2047 + '--model op_matmul_nn_6', # 2046x2046x2046 + '--model op_matmul_nn_7', # 2045x2045x2045 + '--model op_matmul_nn_8', # 2044x2044x2044 + '--model op_matmul_nn_9', # 2043x2043x2043 + '--model op_matmul_nn_10', # 2042x2042x2042 + '--model op_matmul_nn_11', # 2041x2041x2041 + ]: + if '11' in model and ('ansor' in executor or 'autotvm' in executor): + # both schedulers failed to find a schedule for this input 2041x2041x2041, skip + # for autotvm, it will fall back to a default schedule. + # for ansor, it will fall into a dead loop. + continue + bench('{} {}'.format(executor, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/3_batch_size/main.py b/artifacts/3_batch_size/main.py new file mode 100644 index 0000000..716d1ad --- /dev/null +++ b/artifacts/3_batch_size/main.py @@ -0,0 +1,22 @@ +from artifact import bench + + +def main(): + for executor in [ + '--exec torch', + '--exec ort', + '--exec autotvm', + '--exec ansor', + '--exec hidet', + ]: + for bs in [ + '--bs 1', + '--bs 4', + '--bs 8', + ]: + model = '--model resnet50' + bench('{} {} {}'.format(executor, bs, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/4_prologue_epilogue_fusion/main.py b/artifacts/4_prologue_epilogue_fusion/main.py new file mode 100644 index 0000000..e88084d --- /dev/null +++ b/artifacts/4_prologue_epilogue_fusion/main.py @@ -0,0 +1,39 @@ +from artifact import bench + + +def main(): + for executor in [ + '--exec ort', + '--exec ansor', + '--exec hidet', + ]: + for model in [ + '--model resnet50_conv_0', + '--model resnet50_conv_1', + '--model resnet50_conv_2', + '--model resnet50_conv_3', + '--model resnet50_conv_4', + '--model resnet50_conv_5', + '--model resnet50_conv_6', + '--model resnet50_conv_7', + '--model resnet50_conv_8', + '--model resnet50_conv_9', + '--model resnet50_conv_10', + '--model resnet50_conv_11', + '--model resnet50_conv_12', + '--model resnet50_conv_13', + '--model resnet50_conv_14', + '--model resnet50_conv_15', + '--model resnet50_conv_16', + '--model resnet50_conv_17', + '--model resnet50_conv_18', + '--model resnet50_conv_19', + '--model resnet50_conv_20', + '--model resnet50_conv_21', + '--model resnet50_conv_22', + ]: + bench('{} {}'.format(executor, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/5_tensorrt/main.py b/artifacts/5_tensorrt/main.py new file mode 100644 index 0000000..621aad2 --- /dev/null +++ b/artifacts/5_tensorrt/main.py @@ -0,0 +1,20 @@ +from artifact import bench + + +def main(): + for executor in [ + '--exec trt', + '--exec hidet', + ]: + for model in [ + '--model resnet50', + '--model inception_v3', + '--model mobilenet_v2', + '--model bert', + '--model gpt2' + ]: + bench('{} {}'.format(executor, model)) + + +if __name__ == '__main__': + main() diff --git a/artifacts/artifact/__init__.py b/artifacts/artifact/__init__.py new file mode 100644 index 0000000..92d3760 --- /dev/null +++ b/artifacts/artifact/__init__.py @@ -0,0 +1 @@ +from .bench import bench diff --git a/artifacts/artifact/bench.py b/artifacts/artifact/bench.py new file mode 100644 index 0000000..e14e7b1 --- /dev/null +++ b/artifacts/artifact/bench.py @@ -0,0 +1,383 @@ +from typing import List, Optional, Tuple, Union +from functools import lru_cache +import json +from tabulate import tabulate +import os +import time +import numpy as np +import argparse +import hidet as hi +import hidet +from hidet import Tensor +from hidet.utils import cuda, nvtx_annotate, hidet_cache_file, error_tolerance +from hidet.utils.git_utils import get_repo_sha, get_repo_commit_date + + +import os +os.environ["PATH"] = os.environ["PATH"]+":/usr/local/cuda/bin/" + + +class BenchResult: + def __init__(self, latencies: List[float] = None, outputs: List[Tensor] = None, configs: str = None): + self.latencies = latencies + self.outputs: Optional[List[Tensor]] = outputs + self.configs = configs + + +short2long = { + 'f16': 'float16', + 'f32': 'float32', + 'bf16': 'bfloat16' +} + + +def environment_info(args) -> str: + return str(tabulate( + headers=[ + 'Name', 'Value' + ], + tabular_data=[ + ['Commit', get_repo_sha()], + ['GPU', cuda.query_device_name()], + ['Arch', cuda.query_arch()], + ['Compute Capacity', cuda.query_compute_capability()], + ['Current SM Clock (MHz)', cuda.query_gpu_current_clock()], + ['Current Memory Clock (MHz)', cuda.query_memory_current_clock()], + ['Warmup/Number/Repeat', '{} / {} / {}'.format(args.warmup, args.number, args.repeat)] + ] + )) + + +@lru_cache() +def get_onnx_model(name: str, batch_size: int) -> Tuple[str, List[str], List[hidet.Tensor]]: + from hidet.testing import onnx_models + return onnx_models.get_onnx_model(name, batch_size) + + +def run_with_onnx(model_path: str, input_names: List[str], input_tensors: List[hidet.Tensor]) -> List[np.ndarray]: + import onnxruntime + onnx_session = onnxruntime.InferenceSession(model_path, providers=['CPUExecutionProvider']) # use cpu executor for high accuracy + onnx_outputs = onnx_session.run(None, input_feed={name: tensor.numpy() for name, tensor in zip(input_names, input_tensors)}) + return onnx_outputs + + +def benchmark_run(run_func, warmup, number, repeat) -> List[float]: + results = [] + with nvtx_annotate('warmup'): + for i in range(warmup): + run_func() + cuda.device_synchronize() + for i in range(repeat): + with nvtx_annotate(f'repeat {i}'): + cuda.device_synchronize() + start_time = time.time() + for j in range(number): + run_func() + cuda.device_synchronize() + end_time = time.time() + results.append((end_time - start_time) * 1000 / number) + return results + + +def bench_torch(args, out_dir) -> BenchResult: + from hidet.testing.torch_models.all import get_torch_model + result = BenchResult() + model, input_dict = get_torch_model(args.model, batch_size=args.bs) + + def run_func(): + model(**input_dict) + + result.latencies = benchmark_run(run_func, warmup=args.warmup, number=args.number, repeat=args.repeat) + result.configs = 'fp32' + result.outputs = None + return result + + +def bench_hidet(args, out_dir) -> BenchResult: + result = BenchResult() + # print('args', args, 'time stamp', time.time()) + + # configs + result.configs = 'sp{}_{}_{}_{}_pk_{}'.format(args.hidet_space, args.mma, args.precision, args.reduce_precision, args.parallel_k) + + # latencies and outputs + graph_path = hidet_cache_file( + 'hidet_graph', + args.model, + 'bs_{}_{}'.format(args.bs, result.configs), + 'graph.pickle' + ) + onnx_path, input_names, input_tensors = get_onnx_model(name=args.model, batch_size=args.bs) + + hidet.space_level(args.hidet_space) + + if os.path.exists(graph_path) and not args.disable_graph_cache: + graph = hidet.load_graph(graph_path) + else: + if args.disable_graph_cache: + print('disabled graph cache, rebuilding...') + t1 = time.time() + model = hidet.tos.frontend.onnx_utils.from_onnx(onnx_path) + symbol_inputs = [hi.symbol_like(data) for data in input_tensors] + outputs = model(*symbol_inputs) + graph: hi.FlowGraph = hi.trace_from(outputs, inputs=symbol_inputs) + with hidet.tos.PassContext() as ctx: + ctx.save_graph_instrument(out_dir=os.path.join(out_dir, 'ir')) + ctx.set_precision(short2long[args.precision]) + ctx.set_reduce_precision(short2long[args.reduce_precision]) + ctx.set_mma(args.mma) + if args.parallel_k == 'disabled': + ctx.set_parallel_k(disabled=True) + elif args.parallel_k == 'default': + ctx.set_parallel_k(default=True) + elif args.parallel_k == 'search': + ctx.set_parallel_k(search=True) + else: + ctx.set_parallel_k(nparts=int(args.parallel_k)) + + graph = hi.tos.transforms.optimize(graph) + + hidet.save_graph(graph, graph_path + '.tmp') + os.rename(graph_path + '.tmp', graph_path) + + graph(*input_tensors) + t2 = time.time() + with open(os.path.join(os.path.dirname(graph_path), 'tuning_time.txt'), 'w') as f: + f.write(str((t2 - t1) / 60.0) + ' minutes') + + cuda_graph = graph.cuda_graph() + result.outputs = cuda_graph.run_with_inputs(input_tensors) + result.latencies = benchmark_run(lambda: cuda_graph.run(), args.warmup, args.number, args.repeat) + + return result + + +def bench_trt(args, out_dir) -> BenchResult: + from hidet.utils.tensorrt_utils import create_engine_from_onnx, engine_benchmark, engine_inspect, engine_profiler, engine_inference + result = BenchResult() + + # configs + configs = [] + if args.trt_fp16: + configs.append('fp16') + if args.trt_tf32: + configs.append('tf32') + if len(configs) == 0: + configs.append('fp32') + result.configs = '_'.join(configs) + + # latencies + onnx_path, input_names, input_tensors = get_onnx_model(name=args.model, batch_size=args.bs) + engine = create_engine_from_onnx( + onnx_model_path=onnx_path, + workspace_bytes=512 << 20, # 512 MiB + input_shapes={name: tensor.shape for name, tensor in zip(input_names, input_tensors)}, + use_tf32=args.trt_tf32, + use_fp16=args.trt_fp16 + ) + dummy_inputs_dict = {name: tensor for name, tensor in zip(input_names, input_tensors)} + result.latencies = engine_benchmark( + engine=engine, + dummy_inputs=dummy_inputs_dict, + warmup=args.warmup, number=args.number, repeat=args.repeat + ) + + # outputs + result.outputs = list(engine_inference(engine, inputs=dummy_inputs_dict).values()) + + # extra information + with open(os.path.join(out_dir, 'engine_inspect.json'), 'w') as f: + json.dump(engine_inspect(engine), f, indent=2) + with open(os.path.join(out_dir, 'engine_trace.json'), 'w') as f: + json.dump(engine_profiler(engine, dummy_inputs_dict), f, indent=2) + + return result + + +def bench_ort(args, out_dir) -> BenchResult: + from hidet.utils.ort_utils import create_ort_session, ort_benchmark, ort_inference + result = BenchResult() + + # configs + result.configs = 'provider_{}'.format(args.ort_provider) + provider = { + 'cuda': 'CUDAExecutionProvider', + 'trt': 'TensorrtExecutionProvider' + }[args.ort_provider] + + # latencies + onnx_path, input_names, input_tensors = get_onnx_model(name=args.model, batch_size=args.bs) + session = create_ort_session(onnx_path, provider=provider) + inputs = {name: tensor for name, tensor in zip(input_names, input_tensors)} + result.latencies = ort_benchmark( + session, + dummy_inputs=inputs, + warmup=args.warmup, number=args.number, repeat=args.repeat + ) + + # outputs + result.outputs = list(ort_inference(session, inputs=inputs).values()) + + return result + + +def bench_tvm(args, out_dir) -> BenchResult: + from hidet.utils.tvm_utils import tvm_graph_module_from_onnx, tvm_benchmark, tvm_inference + result = BenchResult() + + # configs + if args.exec == 'autotvm': + trial = args.autotvm_trial + elif args.exec == 'ansor': + trial = args.ansor_trial + else: + trial = 1 + result.configs = 'trial_{}'.format(trial) + + # latencies + onnx_path, input_names, input_tensors = get_onnx_model(name=args.model, batch_size=args.bs) + gmod = tvm_graph_module_from_onnx( + onnx_model_path=onnx_path, + input_shapes={ + name: tensor.shape for name, tensor in zip(input_names, input_tensors) + }, + tune_autotvm=(args.exec == 'autotvm'), + tune_ansor=(args.exec == 'ansor'), + tune_trial_per_task=trial + ) + inputs = {name: tensor for name, tensor in zip(input_names, input_tensors)} + result.latencies = tvm_benchmark( + gmod, + dummy_inputs=inputs, + warmup=args.warmup, number=args.number, repeat=args.repeat + ) + + # outputs + result.outputs = tvm_inference(gmod, inputs) + + return result + + +def bench(command_line_args: Optional[str] = None): + if command_line_args: + args = parser.parse_args(command_line_args.strip().split()) + else: + args = parser.parse_args() + # output dir + out_dir = os.path.join(args.out_dir, + '{}_{}'.format(get_repo_commit_date(), get_repo_sha(short=True)), + cuda.query_device_name(short=True), + 'models') + exec_name = 'bs{}_{}_{}_{}_{}'.format(args.bs, args.model, args.exec, args.precision, args.reduce_precision) + if args.exec == 'hidet': + exec_name += '_space{}_pk_{}'.format(args.hidet_space, args.parallel_k) + elif args.exec in ['autotvm', 'ansor']: + trial = args.autotvm_trial if args.exec == 'autotvm' else args.ansor_trial + exec_name += '_trial{}'.format(trial) + out_dir = os.path.join(out_dir, exec_name) + os.makedirs(out_dir, exist_ok=True) + + # bench + bench_dict = { + 'hidet': bench_hidet, + 'trt': bench_trt, + 'ort': bench_ort, + 'autotvm': bench_tvm, + 'ansor': bench_tvm, + 'tvm': bench_tvm, + 'torch': bench_torch, + } + bench_func = bench_dict[args.exec] + with nvtx_annotate(message=args.exec): + with hidet.utils.py.Timer() as timer: + result: BenchResult = bench_func(args, out_dir) + + # error tolerance + onnx_path, input_names, input_tensors = get_onnx_model(name=args.model, batch_size=args.bs) + onnx_outputs = run_with_onnx(model_path=onnx_path, input_names=input_names, input_tensors=input_tensors) + et = -1.0 + if result.outputs is not None: + for baseline_output, onnx_output in zip(result.outputs, onnx_outputs): + et = max(et, error_tolerance(baseline_output.numpy(), onnx_output)) + + # write results + with open(os.path.join(out_dir, 'env.txt'), 'w') as f: + f.write(environment_info(args)) + with open(os.path.join(out_dir, 'raw.json'), 'w') as f: + raw = { + 'latency': result.latencies, + 'bench_time': timer.elapsed_seconds() + } + json.dump(raw, f, indent=2) + with open(os.path.join(out_dir, 'args.json'), 'w') as f: + json.dump(args.__dict__, f, indent=2) + with open(os.path.join(out_dir, 'summary.txt'), 'w') as f: + # model mode space median std + head = '{:>10} {:>20} {:>12} {:>40} {:>10} {:>10} {:>10} {:>10}\n'.format( + 'BatchSize', 'Model', 'Executor', 'Config', 'Space', 'Latency', 'Std', 'Error' + ) + summary = '{:>10} {:>20} {:>12} {:>40} {:10} {:10.3f} {:10.3f} {:10.3f}\n'.format( + args.bs, + args.model, + args.exec, + result.configs, + args.hidet_space, + float(np.median(result.latencies)), + float(np.std(result.latencies)), + et + ) + print(head + summary) + f.write(head + summary) + + +parser = argparse.ArgumentParser(description='Hidet model benchmark script.') + +# ====== + +# general parameters +parser.add_argument('--model', type=str, + # choices=['resnet50', 'inception_v3', 'mobilenet_v2', 'bert', 'bart'], + required=True, + help='The model to benchmark.') +parser.add_argument('--exec', type=str, choices=['hidet', 'trt', 'ort', 'tvm', 'autotvm', 'ansor', 'tf', 'tf_xla', 'torch'], required=True, + help='Executor.') +parser.add_argument('--out_dir', type=str, default='./results/', + help='Output directory.') +parser.add_argument('--warmup', type=int, default=10, help='Number of warmups.') +parser.add_argument('--number', type=int, default=10, help='Number of runs per repeat.') +parser.add_argument('--repeat', type=int, default=10, help='Number of repeats.') + +# ====== + +# executor parameters +# hidet executor parameters +parser.add_argument('--precision', choices=['f16', 'bf16', 'f32'], default='f32') +parser.add_argument('--reduce_precision', choices=['f16', 'f32'], default='f32') +parser.add_argument('--mma', choices=['simt', 'wmma', 'mma'], default='simt') +parser.add_argument('--hidet_space', type=int, choices=[0, 1, 2], default=2, help='The space level of each operator in the model. Large space level means longer compilation time and better performance.') +parser.add_argument('--parallel_k', choices=['disabled', 'default', 'search', '2', '4', '6', '8'], default='default') +parser.add_argument('--disable-graph-cache', action='store_true') + +# tvm number of trial per task +parser.add_argument('--ansor_trial', type=int, default=800, help='Number of trial per task in autotvm and ansor, default 800.') +parser.add_argument('--autotvm_trial', type=int, default=1000, help='Number of trial per task in autotvm and ansor, default 800.') + +# tensorrt configs +parser.add_argument('--trt_tf32', action='store_true') +parser.add_argument('--trt_fp16', action='store_true') + +# onnx runtime configs +parser.add_argument('--ort_provider', choices=['cuda', 'trt'], default='cuda') + +# ====== + +# model agnostic parameters +parser.add_argument('--bs', type=int, default=1, help='Batch size.') + +# model specific parameters +# bert +parser.add_argument('--bert_seq_length', type=int, default=128, help='Sequence length of bert input.') +parser.add_argument('--bert_hidden_size', type=int, default=768, help='Hidden size of bert.') +parser.add_argument('--bert_vocab_size', type=int, default=30522, help='Vocabulary size of bert.') + + diff --git a/artifacts/run.sh b/artifacts/run.sh new file mode 100644 index 0000000..b623254 --- /dev/null +++ b/artifacts/run.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# the following environment variables allow TVM to use all the cores of the machine +# if your machine has a different number of threads other than 24, change the value of TVM_NUM_THREADS accordingly +export TVM_BIND_THREADS=0 +export TVM_NUM_THREADS=24 + +python ./0_end2end/main.py +python ./1_latency_distribution/main.py +python ./2_input_sensitivity/main.py +python ./3_batch_size/main.py +python ./4_prologue_epilogue_fusion/main.py +python ./5_tensorrt/main.py diff --git a/cmake/TVM.cmake b/cmake/TVM.cmake new file mode 100644 index 0000000..f6890c5 --- /dev/null +++ b/cmake/TVM.cmake @@ -0,0 +1,20 @@ +set(USE_CUDA ON) +set(USE_LLVM ON) +if(NOT HIDET_CUDNN_PATH STREQUAL "Auto") + set(USE_CUDNN ${HIDET_CUDNN_PATH}) +else() + set(USE_CUDNN ON) +endif() +set(USE_CUBLAS ON) +set(USE_CCACHE ON) +set(USE_CUTLASS ON) +message(STATUS "Config TVM...") +list(APPEND CMAKE_MESSAGE_INDENT " ") +add_subdirectory(3rdparty/tvm) +list(POP_BACK CMAKE_MESSAGE_INDENT) +set_target_properties(tvm tvm_runtime PROPERTIES + COMPILE_FLAGS -Wno-unused-command-line-argument + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin + ) diff --git a/config.cmake b/config.cmake new file mode 100644 index 0000000..027d0e8 --- /dev/null +++ b/config.cmake @@ -0,0 +1,20 @@ +# Set the compute capacity of target GPU. +# A lower number is compatible with newer GPUs. Compute capacity of typical NVIDIA GPUs: +# - 60: P100 +# - 70: V100 +# - 80: A100 +# - 61: RTX 10 Series +# - 75: RTX 20 Series +# - 86: RTX 30 Series +set(HIDET_CUDA_ARCH 60) + +## Set the cudnn directory +## - Auto: Detect automatically +## - Path-to-cuDNN: Use the given path to search cudnn library. Can be used to specify +## different version of cuDNN. +#set(HIDET_CUDNN_PATH Auto) + +# Set build type +# - Debug +# - Release +set(HIDET_BUILD_TYPE Release) diff --git a/include/hidet/common.h b/include/hidet/common.h new file mode 100644 index 0000000..e05aab3 --- /dev/null +++ b/include/hidet/common.h @@ -0,0 +1,15 @@ +#pragma once +#include +#ifdef assert +#undef assert +#endif +#define assert(x) if(!(x)){ \ + std::cerr << __FILE__ << ":" << __LINE__ << ": " \ + << #x << "failed" << std::endl; \ + exit(-1); \ +} + +#ifndef DLL +#define DLL extern "C" __attribute__((visibility("default"))) +#endif + diff --git a/include/hidet/cuda_utils.h b/include/hidet/cuda_utils.h new file mode 100644 index 0000000..e4a4b94 --- /dev/null +++ b/include/hidet/cuda_utils.h @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +//#include +#include +#include +#include + + +#define CUDA_CALL(func) { \ + cudaError_t e = (func); \ + if(e != cudaSuccess) { \ + cudaGetLastError(); \ + throw HidetException(__FILE__, __LINE__, cudaGetErrorString(e)); \ + }} + +#define CUBLAS_CALL(func) { \ + cublasStatus_t e = (func); \ + if(e != CUBLAS_STATUS_SUCCESS) { \ + throw HidetException(__FILE__, __LINE__, \ + std::string("cutlass error with code") + std::to_string(e)); \ + }} + +//#define CUDNN_CALL(func) { \ +// cudnnStatus_t _status = (func); \ +// if(_status != CUDNN_STATUS_SUCCESS) { \ +// throw HidetException(__FILE__, __LINE__, \ +// cudnnGetErrorString(_status)); \ +// }} + +#define CURAND_CALL(func) { \ + curandStatus_t e = (func); \ + if(e != CURAND_STATUS_SUCCESS) { \ + throw HidetException(__FILE__, __LINE__, \ + std::string("curand error with code") + std::to_string(e)); \ + }} + + diff --git a/include/hidet/logging.h b/include/hidet/logging.h new file mode 100644 index 0000000..d3c3a6a --- /dev/null +++ b/include/hidet/logging.h @@ -0,0 +1,36 @@ +#pragma once +#include +#include +#include + +struct ErrorState { + bool has_error; + std::string error_msg; + + static ErrorState *global(); +}; + +struct HidetException: std::exception { + std::string file; + int line; + std::string msg; + + HidetException(std::string file, int line, std::string msg):file(file), line(line), msg(msg){} + + const char * what() const noexcept override { + static std::string what_msg; + what_msg = this->file + ":" + std::to_string(this->line) + " " + this->msg; + return what_msg.c_str(); + } +}; + +DLL void hidet_set_last_error(const char *msg); + +DLL const char * hidet_get_last_error(); + +#define API_BEGIN() try { +/*body*/ +#define API_END(ret) } catch (const HidetException& e) { \ + hidet_set_last_error(e.what()); \ + return ret; \ + } diff --git a/include/hidet/packedfunc.h b/include/hidet/packedfunc.h new file mode 100644 index 0000000..e3b9018 --- /dev/null +++ b/include/hidet/packedfunc.h @@ -0,0 +1,24 @@ +#pragma once + +#ifndef DLL +#define DLL extern "C" __attribute__((visibility("default"))) +#endif + +enum ArgType { + INT32 = 1, + FLOAT32 = 2, + POINTER = 3, +}; + +typedef void (*PackedFunc_t)(int num_args, int *arg_types, void** args); + +struct PackedFunc { + int num_args; + int* arg_types; + void** func_pointer; +}; + +#define INT_ARG(p) (*(int*)(p)) +#define FLOAT_ARG(p) (*(float*)(p)) + + diff --git a/include/hidet/runtime.h b/include/hidet/runtime.h new file mode 100644 index 0000000..640b429 --- /dev/null +++ b/include/hidet/runtime.h @@ -0,0 +1,13 @@ +#pragma once +#include +#include + + +struct CudaContext { + cudaStream_t stream; + static CudaContext* global(); +}; + +DLL void set_cuda_stream(cudaStream_t stream); + +DLL cudaStream_t get_cuda_stream(); diff --git a/python/hidet/__init__.py b/python/hidet/__init__.py new file mode 100644 index 0000000..b339492 --- /dev/null +++ b/python/hidet/__init__.py @@ -0,0 +1,22 @@ +import sys +from . import ir +from . import backend +from . import utils +from . import tos +from . import runtime +from . import driver +from . import testing + +from .ir import Task, save_task, load_task + +from .tos import Tensor, Operator, Module, FlowGraph + +from .tos import ops +from .tos import empty, randn, zeros, ones, full, symbol, array, empty_like, randn_like, zeros_like, ones_like, symbol_like, full_like +from .tos import space_level, get_space_level +from .tos import trace_from, load_graph, save_graph +from .tos import jit + +from .utils import hidet_set_cache_root as set_cache_root + +sys.setrecursionlimit(10000) diff --git a/python/hidet/backend/__init__.py b/python/hidet/backend/__init__.py new file mode 100644 index 0000000..b22b353 --- /dev/null +++ b/python/hidet/backend/__init__.py @@ -0,0 +1,2 @@ +from .codegen import codegen +from .build import compile_source, load_task_func, BuildInstance, batch_build_ir_modules, load_lib_func diff --git a/python/hidet/backend/build.py b/python/hidet/backend/build.py new file mode 100644 index 0000000..65aea37 --- /dev/null +++ b/python/hidet/backend/build.py @@ -0,0 +1,271 @@ +from __future__ import annotations +from typing import List, Optional, Dict +import contextlib +import psutil +import multiprocessing +from tqdm import tqdm +import ctypes +import os +import subprocess +import tempfile +from subprocess import PIPE + +from hidet.libinfo import get_include_dir +from hidet.ir.func import IRModule +from hidet.ir.type import FuncType +from hidet.ir.task import Task +from hidet.transforms import PassContext, lower +from hidet.runtime import CompiledFunction +from hidet.ffi import PackedFunc +from hidet.ffi.ffi import library_paths +from hidet.utils import cuda, Timer +from hidet.backend import codegen + + +dlclose = ctypes.CDLL(None).dlclose +dlclose.argtypes = [ctypes.c_void_p] +dlclose.rettype = ctypes.c_int + + +class LoadedSharedLibrary: + loaded_cdll_libraries: Dict[str, ctypes.CDLL] = {} + reference_count: Dict[str, int] = {} + + def __init__(self, lib_path: str): + self.lib_path: str = lib_path + if lib_path in self.loaded_cdll_libraries: + self.cdll: ctypes.CDLL = self.loaded_cdll_libraries[lib_path] + self.reference_count[lib_path] += 1 + else: + cdll = ctypes.CDLL(lib_path) + self.cdll: ctypes.CDLL = cdll + self.loaded_cdll_libraries[lib_path] = cdll + self.reference_count[lib_path] = 1 + + def __getitem__(self, item): + ret = self.cdll[item] + ret._lib = self + return ret + + def __getattr__(self, item): + return self[item] + + def __del__(self): + self.reference_count[self.lib_path] -= 1 + if self.reference_count[self.lib_path] == 0: + del self.loaded_cdll_libraries[self.lib_path] + del self.reference_count[self.lib_path] + dlclose(self.cdll._handle) + + +def compile_source(src_path: str, out_lib_path: str, keep_ptx=False) -> None: + """ + Compile the source code in 'src_path' file and output the library to 'out_lib_path'. + + Parameters + ---------- + src_path: str + The path to source code. + out_lib_path: str + The path to output library. + keep_ptx: bool, default False + Whether to keep the ptx code in the same directory of output library. + """ + src_path = os.path.abspath(src_path) + out_lib_path = os.path.abspath(out_lib_path) + cc = cuda.query_compute_capability() + + # dir contains the runtime header file 'hidet/runtime.h' + include_dirs = [get_include_dir()] + # dir contains the runtime library 'libhidet_runtime.so' + library_dirs = [os.path.dirname(library_paths['hidet_runtime'])] + + cc_code = '{}{}'.format(cc[0], cc[1]) + command = [ + 'nvcc', + *['-I{}'.format(include_dir) for include_dir in include_dirs], + *['-L{}'.format(library_dir) for library_dir in library_dirs], + '-keep' if keep_ptx else '', + '-gencode', f'arch=compute_{cc_code},code=sm_{cc_code}', + '--ptxas-options=-v', + '--compiler-options', "'-fPIC'", + '-lineinfo', + '-lhidet_runtime', + '--shared', src_path, + '-o', out_lib_path, + ] + + try: + with tempfile.TemporaryDirectory() as working_dir: + result = subprocess.run(" ".join(command).split(), stderr=PIPE, stdout=PIPE, cwd=working_dir) + if result.returncode: + message = '' + if result.stdout: + message += result.stdout.decode() + '\n' + if result.stderr: + message += result.stderr.decode() + if keep_ptx and os.path.exists(os.path.join(working_dir, os.path.basename(src_path).replace('.cu', '.ptx'))): + out_lib_dir = os.path.dirname(out_lib_path) + ptx_name = os.path.basename(src_path).replace('.cu', '.ptx') + ptx_path = os.path.join(working_dir, ptx_name) + target_ptx_path = os.path.join(out_lib_dir, ptx_name) + os.rename(ptx_path, target_ptx_path) + raise Exception('Failed to compile file "{}":\n\n{}'.format(src_path, message)) + out_lib_dir = os.path.dirname(out_lib_path) + if keep_ptx: + ptx_name = os.path.basename(src_path).replace('.cu', '.ptx') + ptx_path = os.path.join(working_dir, ptx_name) + target_ptx_path = os.path.join(out_lib_dir, ptx_name) + os.rename(ptx_path, target_ptx_path) + with open(os.path.join(out_lib_dir, 'nvcc_log.txt'), 'w') as f: + f.write('Command: {}\n'.format(" ".join(result.args))) + f.write(result.stdout.decode('utf-8')) + f.write(result.stderr.decode('utf-8')) + except subprocess.CalledProcessError as e: + print(' '.join(command)) + print(e.stderr.decode('utf-8')) + raise e + + +def load_task_func(lib_path: str, task) -> CompiledFunction: + """ + Load task's entry function from dynamic linked library. + + Parameters + ---------- + lib_path: str + The dynamic library path. + task: hidet.tos.task.Task + The task that corresponds to the dynamic library. + + Returns + ------- + ret: CompiledFunction + The loaded function that can be directly called in python. + """ + try: + # lib = ctypes.CDLL(lib_path) + lib = LoadedSharedLibrary(lib_path) + except OSError as e: + print("Removed the file '{}'".format(lib_path)) + os.remove(lib_path) + raise e + func_name = 'hidet_{}'.format(task.name) + param_types = [param.data_type for param in task.parameters] + packed_func = PackedFunc(param_types=param_types, c_func_pointer=lib[func_name]) + return CompiledFunction(name=task.name, packed_func=packed_func) + + +def load_lib_func(lib_path: str, func_name: str, func_type: FuncType) -> CompiledFunction: + try: + # lib = ctypes.CDLL(lib_path) + lib = LoadedSharedLibrary(lib_path) + except OSError as e: + print("Removed the file '{}'".format(lib_path)) + os.remove(lib_path) + raise e + func_name = 'hidet_{}'.format(func_name) + param_types = [param_type for param_type in func_type.param_types] + packed_func = PackedFunc(param_types=param_types, c_func_pointer=lib[func_name]) + return CompiledFunction(name=func_name, packed_func=packed_func) + + +class BuildInstance: + def __init__(self, ir_module, output_dir, keep_ir=False, nvcc_keep=True, verbose=True): + """ + The build instance. + + Parameters + ---------- + ir_module: IRModule + The ir module to build. + output_dir: str + The output directory for this build. + keep_ir: bool + Whether to keep the ir when lowering. If True, the ir will be stored in '{output_dir}/ir'. Default False. + nvcc_keep: bool + Whether to keep the ptx code in the same directory of output library., Default: True + verbose: bool + Whether to + verbose: bool + Reserved. + """ + self.ir_module = ir_module + self.output_dir = output_dir + self.keep_ir = keep_ir + self.nvcc_keep = nvcc_keep + self.verbose = verbose + + +def build_ir_module_job(build_instance: BuildInstance) -> Optional[str]: + """ + Build an ir module in build instance. + + Parameters + ---------- + build_instance: BuildInstance + The build instance to build. + + Returns + ------- + lib_path: str + The path to the built dynamic linked library. + """ + from hidet.transforms.instruments import SaveIRInstrument + instruments = [] + os.makedirs(build_instance.output_dir, exist_ok=True) + if build_instance.keep_ir: + instruments.append(SaveIRInstrument(out_dir=os.path.join(build_instance.output_dir, 'ir'))) + with PassContext(instruments=instruments): + ir_module = lower(build_instance.ir_module) + src_path = os.path.join(build_instance.output_dir, 'source.cu') + lib_path = os.path.join(build_instance.output_dir, 'lib.so') + codegen(ir_module, src_out_path=src_path) + try: + compile_source(src_path, lib_path) + except subprocess.CalledProcessError: + print('Compilation failed for an instance') + return None + return lib_path + + +def batch_build_ir_modules(build_instances, parallel=True, verbose=False) -> List[Optional[CompiledFunction]]: + """ + Build a batch of ir modules. + + Parameters + ---------- + build_instances: List[BuildInstance] + The batch of build instances to build. + + parallel: bool + Whether build in parallel. Default True. + + verbose: bool + Whether show the progress and summary. Default False. + + Returns + ------- + funcs: List[Optional[CompiledFunction]] + The compiled functions, in the same order as build_instances. + When the build for a build instance failed, None for that instance is returned. + """ + with Timer() as timer: + lib_paths = [] + if parallel: + # Set the affinity of current process. Some package such as numpy will change affinity of current process, + # which might limit the parallelism of compilation. + os.sched_setaffinity(0, range(os.cpu_count())) + mem_for_worker = 1.5 * 1024 * 1024 * 1024 # 1.5 GiB + num_workers = min(max(int(psutil.virtual_memory().available // mem_for_worker), 1), psutil.cpu_count()) + with multiprocessing.Pool(processes=num_workers) as pool: + for lib_path in tqdm(pool.imap(build_ir_module_job, build_instances), total=len(build_instances), disable=not verbose): + lib_paths.append(lib_path) + else: + lib_paths = map(build_ir_module_job, build_instances) + assert len(lib_paths) == len(build_instances) + funcs = [load_task_func(lib_path, instance.ir_module.task) if lib_path else None for lib_path, instance in zip(lib_paths, build_instances)] + if verbose: + print('Batch build {} modules within {:.3f} seconds, on average {:.1f} seconds per module.'.format( + len(build_instances), timer.elapsed_seconds(), timer.elapsed_seconds() / len(build_instances))) + return funcs diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py new file mode 100644 index 0000000..bec1f50 --- /dev/null +++ b/python/hidet/backend/codegen.py @@ -0,0 +1,515 @@ +from hidet.ir.dialects.pattern import AnyExpr +from hidet.ir.func import * +from hidet.ir.stmt import * +from hidet.ir.expr import * +from hidet.ir.dialects.compute import TensorNode, ScalarNode +from hidet.ir.functors import StmtExprFunctor, TypeFunctor, TypeInfer +from hidet.ir.dialects.lowlevel import VoidType, PointerType, Dereference, Address, ReferenceType, Reference, TensorPointerType +from hidet.utils.doc import Doc, NewLine, Text, doc_join +from hidet.ir.utils.call_graph import CallGraph +from hidet.utils.namer import Namer +from hidet.ir.primitives import is_primitive_function, lookup_primitive_function + + +class Codegen(StmtExprFunctor, TypeFunctor): + def __init__(self): + super().__init__() + self.func_name_map = {} + self.ir_module: Optional[IRModule] = None + self.namer = Namer() + self.type_infer = TypeInfer() + + @staticmethod + def canonize_funcname(name: str): + return 'hidet_' + name.replace('.', '_') + + def param_declare(self, v: Var): + v_type = v.type + name_doc = self(v) + if isinstance(v_type, ScalarType): + dtype_doc = self(v_type) + return dtype_doc + ' ' + name_doc + elif isinstance(v_type, PointerType): + if len(v_type.specifiers) > 0: + attr_doc = doc_join([self(attr) for attr in v_type.specifiers], sep=' ') + ' ' + else: + attr_doc = Doc() + dtype = v_type.base_type + base_type_doc = self(dtype) + if v_type.use_bracket: + return attr_doc + base_type_doc + ' ' + name_doc + '[]' + else: + return attr_doc + base_type_doc + ' *' + ' __restrict__ ' + name_doc + elif isinstance(v_type, TensorPointerType): + dtype = v_type.tensor_type.scalar_type + base_type_doc = self(dtype) + return base_type_doc + ' *' + ' __restrict__ ' + name_doc + elif isinstance(v_type, ReferenceType): + if isinstance(v_type.base_type, ScalarType): + base_type_doc = self(v_type.base_type) + return base_type_doc + ' &' + name_doc + else: + raise NotImplementedError() + elif isinstance(v_type, TensorType): + dtype = v_type.scalar_type + base_type_doc = self(dtype) + return base_type_doc + ' *' + ' __restrict__ ' + name_doc + # dtype_doc = self(v_type.scalar_type) + # name_doc = self(v) + # shape_doc = Doc() + # for s in v_type.shape: + # shape_doc += '[' + self(s) + ']' + # return dtype_doc + ' ' + '__restrict__' + name_doc + shape_doc + else: + raise ValueError() + + def local_var_declare(self, v: Var): + v_type = v.type + if isinstance(v_type, ScalarType): + dtype_doc = self(v_type) + name_doc = self(v) + return dtype_doc + ' ' + name_doc + elif isinstance(v_type, TensorType): + if v_type.scope.name == 'shared': + scope_doc = '__shared__ ' + else: + scope_doc = '' + dtype_doc = self(v_type.scalar_type) + name_doc = self(v) + shape_doc = Doc() + for s in v_type.shape: + shape_doc += '[' + self(s) + ']' + return scope_doc + dtype_doc + ' ' + name_doc + shape_doc + elif isinstance(v_type, PointerType): + if len(v_type.specifiers) > 0: + attr_doc = doc_join([self(attr) for attr in v_type.specifiers], sep=' ') + ' ' + else: + attr_doc = Doc() + base_type_doc = self(v_type.base_type) + name_doc = self(v) + if v_type.use_bracket: + return attr_doc + base_type_doc + ' ' + name_doc + '[]' + else: + return attr_doc + base_type_doc + ' *' + name_doc + elif isinstance(v_type, TensorPointerType): + dtype_doc = self(v_type.tensor_type.scalar_type) + name_doc = self(v) + return dtype_doc + ' *' + name_doc + else: + assert False + + def __call__(self, node) -> Doc: + return self.visit(node) + + def visit(self, node): + if isinstance(node, IRModule): + return self.visit_IRModule(node) + elif isinstance(node, Function): + return self.visit_Function(node) + elif isinstance(node, (Stmt, Expr)): + return StmtExprFunctor.visit(self, node) + elif isinstance(node, TypeNode): + return TypeFunctor.visit(self, node) + elif isinstance(node, (tuple, list)): + return doc_join([self(v) for v in node], ', ') + elif isinstance(node, (int, float, bool)): + return self(convert(node)) + elif isinstance(node, str): + return Text(node) + elif isinstance(node, Doc): + return node + else: + raise ValueError(type(node)) + + def visit_IRModule(self, module: IRModule) -> Doc: + self.ir_module = module + doc = Doc() + # todo: only add necessary headers + # doc += Text('#include ') + NewLine() + # doc += Text('#include ') + NewLine() + # doc += Text('#include ') + NewLine() + doc += Text('#include ') + NewLine() + doc += Text('#include ') + NewLine() + doc += Text('#include ') + NewLine() + + # nvcc use float to 'store' tfloat32 data + doc += Text('typedef float tfloat32_t;') + NewLine() + + # According to here: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma-altfp + # there should be a function called '__float_to_tf32' in cuda c to convert float to tfloat32, + # but I did not find such a function. By looking at cutlass's implementation of converting + # float to tfloat32, it seems that we do not need to do anything to convert. Put a definition + # here in case nvidia add the definition in the future. + doc += Text('#define __float_to_tf32(x) (x)') + NewLine() + + doc += '/*' + NewLine() + doc += str(module.task) + NewLine() + doc += '*/' + NewLine() + doc += Text('extern "C" {') + NewLine() + + call_graph = CallGraph(module) + for node in call_graph.reversed_order: + doc += self(node.func) + NewLine() + + doc += NewLine() + '}' + return doc + + def visit_Function(self, func: Function) -> Doc: + self.namer.clear() + + doc = NewLine() + + # ret + if func.kind == 'cuda_kernel': + doc += '__global__' + elif func.kind == 'cuda_device': + doc += '__device__ __forceinline__' + elif func.kind == 'packed_func' or func.kind == 'host_kernel': + doc += '__host__' + + doc += ' ' + self(func.ret_type) + # doc += ' void' + + # launch bound for grid worker + if func.kind == 'cuda_kernel': + block_dim = func.attrs['cuda_block_dim'] + if 'cuda_min_blocks' in func.attrs: + min_blocks = func.attrs['cuda_min_blocks'] + doc += f' __launch_bounds__({block_dim}, {min_blocks})' + else: + doc += f' __launch_bounds__({block_dim})' + + # func name + canonized_func_name = self.canonize_funcname(func.name) + doc += ' ' + canonized_func_name + self.func_name_map[func.name] = canonized_func_name + + # parameters + doc += '(' + param_docs = [] + for i in range(len(func.params)): + param = func.params[i] + param_docs.append(self.param_declare(param)) + doc += doc_join(param_docs, Text(', ')) + doc += ') {' + + # comments + label = func.get_attr('label') + if label: + doc += (NewLine() + '// label: {}'.format(label)).indent() + + # const locals + for const_local_var, const_local_value in func.local_const_vars: + doc += (NewLine() + self.local_var_declare(const_local_var) + ' = ' + self(const_local_value) + ';').indent() + + # locals + for local_var in func.local_vars: + doc += (NewLine() + self.local_var_declare(local_var) + ';').indent() + # doc += (NewLine() + self(local_var.type) + ' ' + self(local_var) + ';').indent() + + # body + doc += self(func.body).indent() + + doc += NewLine() + '}' + + return doc + + def visit_Add(self, e: Add): + return Text('(') + self(e.a) + ' + ' + self(e.b) + ')' + + def visit_Sub(self, e: Sub): + return Text('(') + self(e.a) + ' - ' + self(e.b) + ')' + + def visit_Multiply(self, e: Multiply): + return Text('(') + self(e.a) + ' * ' + self(e.b) + ')' + + def visit_Div(self, e: Div): + return Text('(') + self(e.a) + ' / ' + self(e.b) + ')' + + def visit_Mod(self, e: Mod): + return Text('(') + self(e.a) + ' % ' + self(e.b) + ')' + + def visit_FloorDiv(self, e: FloorDiv): + return Text('(') + self(e.a) + ' / ' + self(e.b) + ')' + + def visit_LessThan(self, e: LessThan): + return Text('(') + self(e.a) + ' < ' + self(e.b) + ')' + + def visit_Neg(self, e: Neg): + return '(-' + self(e.a) + ')' + + def visit_LessEqual(self, e: LessThan): + return Text('(') + self(e.a) + ' <= ' + self(e.b) + ')' + + def visit_Equal(self, e: Equal): + return Text('(') + self(e.a) + ' == ' + self(e.b) + ')' + + def visit_And(self, e: And): + return Text('(') + self(e.a) + ' && ' + self(e.b) + ')' + + def visit_Or(self, e: Or): + return Text('(') + self(e.a) + ' || ' + self(e.b) + ')' + + def visit_Not(self, e: Not): + return Text('!') + self(e.a) + + def visit_BitwiseAnd(self, e: BitwiseAnd): + return '(' + self(e.a) + ' & ' + self(e.b) + ')' + + def visit_BitwiseOr(self, e: BitwiseOr): + return '(' + self(e.a) + ' | ' + self(e.b) + ')' + + def visit_BitwiseNot(self, e: BitwiseNot): + return '(~' + self(e.base) + ')' + + def visit_LeftShift(self, e: LeftShift): + return '(' + self(e.base) + ' << ' + self(e.cnt) + ')' + + def visit_RightShift(self, e: RightShift): + return '(' + self(e.base) + ' >> ' + self(e.cnt) + ')' + + def visit_TensorElement(self, e: TensorElement): + return self(e.base) + doc_join(['[' + self(idx) + ']' for idx in e.indices], '') + + def visit_IfThenElse(self, e: IfThenElse): + return '(' + self(e.cond) + ' ? ' + self(e.then_expr) + ' : ' + self(e.else_expr) + ')' + + def visit_Cast(self, e: Cast): + return Text('(') + self.visit(e.target_type) + ')' + self(e.expr) + + def visit_Address(self, e: Address): + return Text('&') + self.visit(e.expr) + + def visit_Reference(self, e: Reference): + raise NotImplementedError() + + def visit_Dereference(self, e: Dereference): + return Text('*') + self(e.expr) + + def visit_Call(self, e: Call): + func_name = e.func_var.hint + # func_name = func_name.replace('.', '_') + if '.' in func_name: + target, func_name = func_name.split('.') + if func_name in self.ir_module.functions: + # first check whether callee is in current ir module + # because ir module functions will cover primitive functions + func = self.ir_module.lookup(func_name) + else: + key = e.func_var.hint + if not is_primitive_function(key): + raise ValueError("Callee {} not found in current ir module, and it is not primitive function.".format(key)) + entry = lookup_primitive_function(key) + if entry.function is not None: + raise ValueError("Please use import_primitive_functions pass to import primitive function first: {}, functions in current module:\n{}.".format(entry.name, list(self.ir_module.functions.keys()))) + if entry.generic: + raise ValueError("Please use resolve_generic_primitive_function pass to lower the generic primitive function {}.".format(entry.name)) + # system-provided function, do not canonize the func name + return entry.name + (Text('(') + doc_join([self(arg) for arg in e.args], Text(', ')) + ')') + func_name = Text(self.canonize_funcname(func_name)) + if func.kind == 'cuda_kernel': + def dim3_str(dims): + if isinstance(dims, (int, Expr)): + return self(dims) + else: + return Text('dim3(') + self(dims) + ')' + + configs = [ + dim3_str(func.attrs['cuda_grid_dim']), # grid dimension + dim3_str(func.attrs['cuda_block_dim']), # block dimension + func.attrs.get('cuda_smem_bytes', 0), # dynamic shared memory size + Text('get_cuda_stream()') # cuda stream (get_cuda_stream() function is defined in hidet/runtime.h) + ] + launch_config = Text('<<<') + doc_join([self(v) for v in configs], sep=', ') + Text('>>>') + else: + launch_config = [] + param_doc = Text('(') + doc_join([self(arg) for arg in e.args], Text(', ')) + ')' + return func_name + launch_config + param_doc + + def visit_Let(self, e: Let): + raise ValueError("please run 'expand_let_expr' pass before codegen") + + def visit_Var(self, e: Var): + cast2int = { + 'threadIdx.x', + 'blockIdx.x' + } + name = self.namer.get_name(e) + if name in cast2int: + return Text(f'(int){name}') + else: + return Text(name) + + @staticmethod + def scalar_literal(value, dtype: str): + if dtype == 'bool': + return Text('true') if value else Text('false') + elif dtype == 'float32': + return Text(f'{value}f') + elif dtype == 'int32': + return Text(f'{value}') + elif dtype == 'float16': + return Text('half({})'.format(value)) + elif dtype == 'int64': + return Text('{}'.format(value)) + elif dtype == 'bfloat16': + return Text('__float2bfloat16({})'.format(value)) + elif dtype == 'tfloat32': + return Text('__float_to_tf32({})'.format(value)) + elif dtype == 'uint32': + assert value >= 0 + return Text('{}u'.format(value)) + else: + raise NotImplementedError('Cannot recognize scalar literal {} with dtype {}'.format(value, dtype)) + + def visit_Constant(self, e: Constant): + if e.is_scalar(): + return self.scalar_literal(e.value, e.data_type.name) + else: + assert isinstance(e.data_type, TensorType) + dtype = e.data_type.scalar_type.name + items = [self.scalar_literal(v, dtype) for v in np.array(e.value).flatten()] + return '{' + doc_join(items, ', ') + '}' + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + return NewLine() + self(stmt.expr) + ';' + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + doc = NewLine() + doc += self(stmt.buf) + for idx in stmt.indices: + doc += '[' + self(idx) + ']' + doc += Text(' = ') + self(stmt.value) + ';' + return doc + + def visit_AssignStmt(self, stmt: AssignStmt): + return NewLine() + self(stmt.var) + ' = ' + self(stmt.value) + ';' + + def visit_LetStmt(self, stmt: LetStmt): + doc = Doc() + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + doc += NewLine() + self(bind_var.type) + ' ' + self(bind_var) + ' = ' + self(bind_value) + ';' + doc += self(stmt.body) + return doc + + def visit_ForStmt(self, stmt: ForStmt): + v = stmt.loop_var + init_doc = self(v.type) + ' ' + self(v) + ' = ' + self(convert(0)) + cond_doc = self(v < stmt.extent) + update_doc = self(v) + ' = ' + self(v + 1) + doc = Text('') + if stmt.unroll is not None: + if isinstance(stmt.unroll, bool): + if stmt.unroll: + doc += NewLine() + '#pragma unroll' # complete unroll + else: + doc += NewLine() + '#pragma unroll 1' # prevent from unrolling + else: + assert isinstance(stmt.unroll, int) + doc += NewLine() + '#pragma unroll {}'.format(stmt.unroll) + doc += NewLine() + Text('for (') + init_doc + '; ' + cond_doc + '; ' + update_doc + ') ' + body_doc = self(stmt.body) + doc += Text('{') + body_doc.indent() + NewLine() + Text('} ') + return doc + + def visit_IfStmt(self, stmt: IfStmt): + cond_doc = self(stmt.cond) + if not (len(cond_doc.docs) > 0 and isinstance(cond_doc.docs[0], str) and cond_doc.docs[0].startswith('(')): + cond_doc = Text('(') + cond_doc + ')' + doc = NewLine() + Text('if ') + cond_doc + ' ' + doc += Text('{') + self(stmt.then_body).indent() + NewLine() + Text('} ') + if stmt.else_body: + doc += Text('else ') + doc += Text('{') + self(stmt.else_body).indent() + NewLine() + Text('} ') + return doc + + def visit_ReturnStmt(self, stmt: ReturnStmt): + doc = Doc() + doc += NewLine() + 'return' + if stmt.ret_value is not None: + doc += ' ' + self(stmt.ret_value) + doc += ';' + return doc + + def visit_AssertStmt(self, stmt: AssertStmt): + return NewLine() + Text('assert(((void)"') + stmt.msg + '", ' + self(stmt.cond) + '));' + + def visit_AsmStmt(self, stmt: AsmStmt): + volatile_doc = 'volatile ' if stmt.is_volatile else '' + template_doc = f'"{Text(stmt.template_string)}"' + output_docs = [] + for label, expr in zip(stmt.output_labels, stmt.output_exprs): + output_docs.append(Text(f'"{label}"') + '(' + self(expr) + ')') + input_docs = [] + for label, expr in zip(stmt.input_labels, stmt.input_exprs): + input_docs.append(Text(f'"{label}"') + '(' + self(expr) + ')') + return NewLine() + 'asm ' + volatile_doc + '(' + template_doc + ' : ' + doc_join(output_docs, ', ') + ' : ' + doc_join(input_docs, ', ') + ');' + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + expr_docs = [str(self(e)) for e in stmt.exprs] + stmt_string: str = stmt.template_string.format(*expr_docs) + lines = stmt_string.split('\n') + doc = Text('') + for line in lines: + doc += NewLine() + line + return doc + + def visit_SeqStmt(self, stmt: SeqStmt): + doc = Doc() + for idx, s in enumerate(stmt.seq): + doc += self(s) + return doc + + def visit_ScalarType(self, t: ScalarType): + scalar_type_map = { + 'bool': 'bool', + 'uint8': 'uint8_t', + 'uint32': 'uint32_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + + 'float16': 'half', + 'float32': 'float', + 'bfloat16': 'nv_bfloat16', + 'tfloat32': 'tfloat32_t', + } + return Text(scalar_type_map[t.name]) + + def visit_TensorType(self, t: TensorType): + return Text('TensorType(') + self(t.scalar_type) + ', [' + doc_join([self(s) for s in t.shape], ", ") + '], ' + t.scope.name + ')' + + def visit_PointerType(self, t: PointerType): + return self(t.base_type) + Text('*') + + def visit_TensorPointerType(self, t: TensorPointerType): + return self(t.tensor_type.scalar_type) + Text('*') + + def visit_ReferenceType(self, t: ReferenceType): + raise ValueError() + + def visit_VoidType(self, t: VoidType): + return Text('void') + + # the following expressions should not remain to codegen + def visit_TensorSlice(self, e: TensorSlice): + raise ValueError() + + def visit_ScalarNode(self, e: ScalarNode): + raise ValueError() + + def visit_TensorNode(self, e: TensorNode): + raise ValueError() + + def visit_AnyExpr(self, e: AnyExpr): + raise ValueError() + + +def codegen(ir_module: IRModule, src_out_path: Optional[str] = None) -> Optional[str]: + gen = Codegen() + doc = gen(ir_module) + code = str(doc) + if src_out_path is not None: + with open(src_out_path, 'w') as f: + f.write(code) + else: + return code diff --git a/python/hidet/baselines/__init__.py b/python/hidet/baselines/__init__.py new file mode 100644 index 0000000..a83e35d --- /dev/null +++ b/python/hidet/baselines/__init__.py @@ -0,0 +1 @@ +from . import matmul diff --git a/python/hidet/baselines/conv2d.py b/python/hidet/baselines/conv2d.py new file mode 100644 index 0000000..ca7ae6d --- /dev/null +++ b/python/hidet/baselines/conv2d.py @@ -0,0 +1,129 @@ +import numpy as np +from hidet.ir.type import scalar_type +from hidet.ir.dialects.lowlevel import pointer_type +from hidet.ffi import PackedFunc, _LIB + +cudnn_math_mode_dict = { + 'default': 0, + 'tensor_core': 1, + 'tensor_core_allow_conversion': 2, + 'fma': 3 +} +cudnn_algo_dict = { + 'auto': -1, + 'implicit_gemm': 0, + 'implicit_precomp_gemm': 1, + 'gemm': 2, + 'direct': 3, + 'fft': 4, + 'fft_tiling': 5, + 'winograd': 6, + 'winograd_nofused': 7 +} + + +def conv2d_cudnn_available(math_mode: str = 'default', algo: str = 'auto') -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # 0: batch_size + scalar_type('int32'), # 1: in_channels + scalar_type('int32'), # 2: height + scalar_type('int32'), # 3: width + scalar_type('int32'), # 4: out_channels + scalar_type('int32'), # 5: kernel_h + scalar_type('int32'), # 6: kernel_w + scalar_type('int32'), # 7: padding_h + scalar_type('int32'), # 8: padding_w + scalar_type('int32'), # 9: stride_h + scalar_type('int32'), # 10: stride_w + scalar_type('int32'), # 11: math_mode + scalar_type('int32'), # 12: algo + ], + ret_type=bool, + c_func_pointer=_LIB.Conv2DCudnnAvailable, + default_args={ + 11: cudnn_math_mode_dict[math_mode.lower()], + 12: cudnn_algo_dict[algo.lower()], + } + ) + + +def conv2d_cudnn(math_mode: str = 'default', algo: str = 'auto') -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # 0: batch_size + scalar_type('int32'), # 1: in_channels + scalar_type('int32'), # 2: height + scalar_type('int32'), # 3: width + scalar_type('int32'), # 4: out_channels + scalar_type('int32'), # 5: kernel_h + scalar_type('int32'), # 6: kernel_w + scalar_type('int32'), # 7: padding_h + scalar_type('int32'), # 8: padding_w + scalar_type('int32'), # 9: stride_h + scalar_type('int32'), # 10: stride_w + scalar_type('int32'), # 11: math_mode + scalar_type('int32'), # 12: algo + pointer_type(scalar_type('float32')), # 13: x + pointer_type(scalar_type('float32')), # 14: w + pointer_type(scalar_type('float32')), # 15: y + ], + c_func_pointer=_LIB.Conv2dCudnn, + default_args={ + 11: cudnn_math_mode_dict[math_mode.lower()], + 12: cudnn_algo_dict[algo.lower()], + } + ) + + +def conv2d_reference() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # 0: batch_size + scalar_type('int32'), # 1: in_channels + scalar_type('int32'), # 2: height + scalar_type('int32'), # 3: width + scalar_type('int32'), # 4: out_channels + scalar_type('int32'), # 5: kernel_h + scalar_type('int32'), # 6: kernel_w + scalar_type('int32'), # 7: padding_h + scalar_type('int32'), # 8: padding_w + scalar_type('int32'), # 9: stride_h + scalar_type('int32'), # 10: stride_w + pointer_type(scalar_type('float32')), # 11: x + pointer_type(scalar_type('float32')), # 12: w + pointer_type(scalar_type('float32')), # 13: y + ], + c_func_pointer=_LIB.Conv2dReference + ) + + +def conv2d_implicit_gemm_reference() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # 0: batch_size + scalar_type('int32'), # 1: in_channels + scalar_type('int32'), # 2: height + scalar_type('int32'), # 3: width + scalar_type('int32'), # 4: out_channels + scalar_type('int32'), # 5: kernel_h + scalar_type('int32'), # 6: kernel_w + scalar_type('int32'), # 7: padding_h + scalar_type('int32'), # 8: padding_w + scalar_type('int32'), # 9: stride_h + scalar_type('int32'), # 10: stride_w + pointer_type(scalar_type('float32')), # 11: x + pointer_type(scalar_type('float32')), # 12: w + pointer_type(scalar_type('float32')), # 13: y + ], + c_func_pointer=_LIB.Conv2dImplicitGemmReference + ) + + +def conv2d_torch(batch_size, in_channels, height, width, out_channels, kernel_h, kernel_w, padding_h, padding_w, stride_h, stride_w, x: np.ndarray, w: np.ndarray, y: np.ndarray = None): + import torch.nn.functional + y_torch = torch.nn.functional.conv2d(input=torch.from_numpy(x).cuda(), weight=torch.from_numpy(w).cuda(), bias=None, stride=(stride_h, stride_w), padding=(padding_h, padding_w)) + if y is None: + return y_torch.cpu().numpy() + else: + np.copyto(y, y_torch.cpu().numpy()) diff --git a/python/hidet/baselines/matmul.py b/python/hidet/baselines/matmul.py new file mode 100644 index 0000000..95f44e5 --- /dev/null +++ b/python/hidet/baselines/matmul.py @@ -0,0 +1,86 @@ +from hidet.ir.type import scalar_type +from hidet.ir.dialects.lowlevel import pointer_type +from hidet.ffi import PackedFunc, _LIB + + +def matmul_opt() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulOpt + ) + + +def matmul_ref() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulReference + ) + + +def matmul_ref_1d() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulReference1D + ) + + +def matmul_cublas() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulCublas + ) + +def matmul_cublas_tensorcore() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulCublasTc + ) + + +def matmul_cutlass() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # N + scalar_type('int32'), # M + scalar_type('int32'), # K + pointer_type(scalar_type('float32')), # A + pointer_type(scalar_type('float32')), # B + pointer_type(scalar_type('float32')), # C + ], + c_func_pointer=_LIB.MatmulCutlass + ) diff --git a/python/hidet/baselines/pool2d.py b/python/hidet/baselines/pool2d.py new file mode 100644 index 0000000..d9e01c6 --- /dev/null +++ b/python/hidet/baselines/pool2d.py @@ -0,0 +1,47 @@ +import numpy as np +from hidet.ir.type import scalar_type +from hidet.ir.dialects.lowlevel import pointer_type +from hidet.ffi import PackedFunc, _LIB + +cudnn_pooling_mode_dict = { + 'max': 0, + 'avg_include_pad': 1, + 'avg': 2, + 'max_deterministic': 3 +} + + +def pool2d_cudnn(pooling_mode='max') -> PackedFunc: + """ + pooling_mode: 'max', 'avg', 'avg_include_pad', 'max_deterministic' + """ + return PackedFunc( + param_types=[ + scalar_type('int32'), # 0: batch_size + scalar_type('int32'), # 1: in_channels + scalar_type('int32'), # 2: height + scalar_type('int32'), # 3: width + scalar_type('int32'), # 4: kernel_h + scalar_type('int32'), # 5: kernel_w + scalar_type('int32'), # 6: padding_h + scalar_type('int32'), # 7: padding_w + scalar_type('int32'), # 8: stride_h + scalar_type('int32'), # 9: stride_w + scalar_type('int32'), # 10: mode + pointer_type(scalar_type('float32')), # 11: x + pointer_type(scalar_type('float32')), # 12: y + ], + c_func_pointer=_LIB.Pool2dCudnn, + default_args={ + 10: cudnn_pooling_mode_dict[pooling_mode.lower()], + } + ) + + +def max_pool2d_cudnn() -> PackedFunc: + return pool2d_cudnn(pooling_mode='max') + + +def avg_pool2d_cudnn() -> PackedFunc: + return pool2d_cudnn(pooling_mode='avg') + diff --git a/python/hidet/baselines/softmax.py b/python/hidet/baselines/softmax.py new file mode 100644 index 0000000..cff2137 --- /dev/null +++ b/python/hidet/baselines/softmax.py @@ -0,0 +1,18 @@ +import numpy as np +from hidet.ir.type import scalar_type +from hidet.ir.dialects.lowlevel import pointer_type +from hidet.ffi import PackedFunc, _LIB + + +def softmax_cudnn() -> PackedFunc: + return PackedFunc( + param_types=[ + scalar_type('int32'), # n + scalar_type('int32'), # c + scalar_type('int32'), # h + scalar_type('int32'), # w + pointer_type(scalar_type('float32')), # x + pointer_type(scalar_type('float32')), # y + ], + c_func_pointer=_LIB.SoftmaxCudnn + ) diff --git a/python/hidet/driver.py b/python/hidet/driver.py new file mode 100644 index 0000000..f2e39cf --- /dev/null +++ b/python/hidet/driver.py @@ -0,0 +1,109 @@ +from typing import List +import os +import multiprocessing +import logging +from hashlib import sha256 +from hidet.transforms import lower, PassContext, SaveIRInstrument, ProfileInstrument +from hidet.backend import codegen, compile_source, load_task_func, load_lib_func +from hidet.utils import COLORS, hidet_cache_dir +from hidet.utils.py import cyan, green +from hidet.ir.task import Task, TaskContext +from hidet.ir.func import IRModule +from hidet.ir.type import FuncType + +logger = logging.Logger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + +cache_disabled = False + + +def disable_cache(disable: bool = False): + global cache_disabled + cache_disabled = not disable + + +def build_task(task: Task, space_level, use_cache=True, cache_dir=None, load=True): + # resolve task dir + if cache_dir is None: + cache_dir = os.path.join(hidet_cache_dir(), 'ops') + config_str = 'space_{}'.format(space_level) + task_string = str(task) + task_hash = sha256(task_string.encode()).hexdigest()[:16] + task_dir = os.path.join(cache_dir, config_str, task.name, task_hash) + src_path = os.path.join(task_dir, 'source.cu') + lib_path = os.path.join(task_dir, 'lib.so') + + # use previously generated library when available + if not cache_disabled and use_cache and os.path.exists(lib_path): + logger.debug("Load cached task binary {} from path: \n{}".format(green(task.name), cyan(lib_path))) + if not load: + return None + return load_task_func(lib_path, task) + + logger.info("Compiling task {}{}{}...".format(COLORS.OKGREEN, task.name, COLORS.ENDC)) + # print(task) + # exit(0) + + # build from scratch + os.makedirs(task_dir, exist_ok=True) + # write task + with open(os.path.join(task_dir, 'task.txt'), 'w') as f: + f.write(task_string) + # implement task + with TaskContext(space_level=space_level, resolve_out_dir=task_dir): + ir_module = task.implement(target='cuda') + # lower ir module + with PassContext(instruments=[ + # SaveIRInstrument(out_dir=os.path.join('./outs/ir', task.name, task_hash)), + # ProfileInstrument(log_file=os.path.join('./outs/ir', task.name, task_hash, 'lower_time.txt')) + ]): + ir_module = lower(ir_module) + # code generation + codegen(ir_module, src_out_path=src_path) + # compile source code + compile_source(src_path, out_lib_path=lib_path, keep_ptx=False) + # load function + if not load: + return None + return load_task_func(lib_path, task) + + +def _build_task_job(args): + task, space_level, use_cache, cache_dir, load = args + build_task(task, space_level, use_cache, cache_dir, load) + + +def build_batch_task(tasks: List[Task], space_level: int, parallel=True, use_cache=True, cache_dir=None): + if parallel and len(tasks) > 1: + with multiprocessing.Pool() as pool: + pool.map(_build_task_job, [(task, space_level, use_cache, cache_dir, False) for task in tasks]) + else: + map(_build_task_job, [(task, space_level, use_cache, cache_dir, False) for task in tasks]) + + +def build_ir_module(ir_module: IRModule, func_name: str, keep_ptx=False, working_dir='./outs'): + module_string = str(ir_module) + module_hash = sha256(module_string.encode()).hexdigest()[:16] + working_dir = os.path.join(working_dir, 'ir_module', module_hash) + src_path = os.path.join(working_dir, 'source.cu') + lib_path = os.path.join(working_dir, 'lib.so') + + # lower ir module + with PassContext(instruments=[ + SaveIRInstrument(out_dir=working_dir), + ProfileInstrument(log_file=os.path.join(working_dir, 'lower_time.txt')) + ]): + ir_module = lower(ir_module) + # code generation + codegen(ir_module, src_out_path=src_path) + # compile source code + compile_source(src_path, out_lib_path=lib_path, keep_ptx=keep_ptx) + func = ir_module.lookup(func_name + '_grid') + return load_lib_func(lib_path, func_name, func_type=FuncType.from_func(func)) + + +if __name__ == '__main__': + print(sha256('abc'.encode()).hexdigest()) + print(hex(hash('abc'))) + print(type(hex(1))) diff --git a/python/hidet/ffi/__init__.py b/python/hidet/ffi/__init__.py new file mode 100644 index 0000000..90aa6f1 --- /dev/null +++ b/python/hidet/ffi/__init__.py @@ -0,0 +1,7 @@ +from .ffi import _LIB +from .packedfunc import PackedFunc +from .packedfunc import ArgType + +from .cuda_api import cuda +from .runtime_api import runtime_api +from .cuda_kernels import cuda_kernels diff --git a/python/hidet/ffi/cuda_api.py b/python/hidet/ffi/cuda_api.py new file mode 100644 index 0000000..15b7062 --- /dev/null +++ b/python/hidet/ffi/cuda_api.py @@ -0,0 +1,182 @@ +from typing import Tuple +from ctypes import c_uint64, c_uint32, c_float, c_uint8, byref, POINTER, c_char_p +from hidet.ffi.ffi import get_func + + +class CudaAPI: + # memory related apis + _mem_info = get_func('hidet_cuda_mem_info', [POINTER(c_uint64), POINTER(c_uint64)], None) + _malloc_async = get_func('hidet_cuda_malloc_async', [c_uint64], c_uint64) + _malloc_host = get_func('hidet_cuda_malloc_host', [c_uint64], c_uint64) + _free_async = get_func('hidet_cuda_free_async', [c_uint64], None) + _free_host = get_func('hidet_cuda_free_host', [c_uint64], None) + _memset_async = get_func('hidet_cuda_memset_async', [c_uint64, c_uint64, c_uint8], None) + _memcpy_async = get_func('hidet_cuda_memcpy_async', [c_uint64, c_uint64, c_uint64, c_uint32], None) + _mem_pool_trim_to = get_func('hidet_cuda_mem_pool_trim_to', [c_uint64], None) + # device control + _device_synchronize = get_func('hidet_cuda_device_synchronize', [], None) + # stream and event + _stream_create = get_func('hidet_cuda_stream_create', [], c_uint64) + _stream_destroy = get_func('hidet_cuda_stream_destroy', [c_uint64], None) + _stream_synchronize = get_func('hidet_cuda_stream_synchronize', [c_uint64], None) + _event_create = get_func('hidet_cuda_event_create', [], c_uint64) + _event_destroy = get_func('hidet_cuda_event_destroy', [c_uint64], None) + _event_elapsed_time = get_func('hidet_cuda_event_elapsed_time', [c_uint64, c_uint64], c_float) + _event_record = get_func('hidet_cuda_event_record', [c_uint64, c_uint64], None) + # cuda graph + _graph_create = get_func('hidet_cuda_graph_create', [], c_uint64) + _graph_destroy = get_func('hidet_cuda_graph_destroy', [c_uint64], None) + _stream_begin_capture = get_func('hidet_cuda_stream_begin_capture', [c_uint64], None) + _stream_end_capture = get_func('hidet_cuda_stream_end_capture', [c_uint64], c_uint64) + _graph_instantiate = get_func('hidet_cuda_graph_instantiate', [c_uint64], c_uint64) + _graph_exec_launch = get_func('hidet_cuda_graph_exec_launch', [c_uint64, c_uint64], None) + _graph_exec_destroy = get_func('hidet_cuda_graph_exec_destroy', [c_uint64], None) + # profiler control + _profiler_start = get_func('hidet_cuda_profiler_start', [], None) + _profiler_stop = get_func('hidet_cuda_profiler_stop', [], None) + # random number generation + _generate_uniform = get_func('hidet_curand_generate_uniform', [c_uint64, c_uint64], None) + _generate_normal = get_func('hidet_curand_generate_normal', [c_uint64, c_uint64, c_float, c_float], None) + # get device property + _device_property = get_func('hidet_cuda_get_device_property', [c_uint64, c_char_p], c_uint64) + + @classmethod + def mem_info(cls) -> Tuple[int, int]: + free_bytes = c_uint64(0) + total_bytes = c_uint64(0) + cls._mem_info(byref(free_bytes), byref(total_bytes)) + return free_bytes.value, total_bytes.value + + @classmethod + def malloc_async(cls, num_bytes: int) -> int: + return cls._malloc_async(num_bytes) + + @classmethod + def malloc_host(cls, num_bytes: int) -> int: + return cls._malloc_host(num_bytes) + + @classmethod + def free_async(cls, addr: int) -> None: + return cls._free_async(addr) + + @classmethod + def free_host(cls, addr: int) -> None: + return cls._free_host(addr) + + @classmethod + def memset_async(cls, addr: int, num_bytes: int, value: int) -> None: + return cls._memset_async(addr, num_bytes, value) + + HostToHost = 0 + HostToDevice = 1 + DeviceToHost = 2 + DeviceToDevice = 3 + + @classmethod + def memcpy_async(cls, src_addr: int, dst_addr: int, num_bytes: int, kind: int) -> None: + assert 0 <= kind <= 3 + cls._memcpy_async(src_addr, dst_addr, num_bytes, kind) + if kind != cls.DeviceToDevice: + cls.device_synchronize() + + @classmethod + def mem_pool_trim_to(cls, min_bytes_to_keep: int) -> None: + cls._mem_pool_trim_to(min_bytes_to_keep) + + @classmethod + def device_synchronize(cls) -> None: + return cls._device_synchronize() + + @classmethod + def generate_uniform(cls, addr: int, num_elements: int) -> None: + return cls._generate_uniform(addr, num_elements) + + @classmethod + def generate_normal(cls, addr: int, num_elements: int, mean: float, stddev: float) -> None: + return cls._generate_normal(addr, num_elements, mean, stddev) + + @classmethod + def create_stream(cls) -> int: + return cls._stream_create() + + @classmethod + def destroy_stream(cls, stream_handle: int) -> int: + return cls._stream_destroy(stream_handle) + + @classmethod + def stream_synchronize(cls, stream_handle: int) -> None: + return cls._stream_synchronize(stream_handle) + + @classmethod + def create_event(cls) -> int: + return cls._event_create() + + @classmethod + def destroy_event(cls, event_handle: int) -> None: + return cls._event_destroy(event_handle) + + @classmethod + def event_elapsed_time(cls, start_event_handle: int, end_event_handle: int) -> float: + return cls._event_elapsed_time(start_event_handle, end_event_handle) + + @classmethod + def event_record(cls, event_handle: int, stream_handle: int) -> None: + return cls._event_record(event_handle, stream_handle) + + @staticmethod + def create_graph() -> int: + return CudaAPI._graph_create() + + @staticmethod + def destroy_graph(graph_handle: int) -> None: + return CudaAPI._graph_destroy(graph_handle) + + @staticmethod + def stream_begin_capture(stream_handle: int) -> None: + return CudaAPI._stream_begin_capture(stream_handle) + + @staticmethod + def stream_end_capture(stream_handle: int) -> int: + # return the cuda graph handle captured in this stream + return CudaAPI._stream_end_capture(stream_handle) + + @staticmethod + def instantiate_graph(graph_handle: int) -> int: + return CudaAPI._graph_instantiate(graph_handle) + + @staticmethod + def launch_graph_exec(graph_exec_handle: int, stream_handle: int) -> None: + return CudaAPI._graph_exec_launch(graph_exec_handle, stream_handle) + + @staticmethod + def destroy_graph_exec(graph_exec_handle: int) -> None: + return CudaAPI._graph_exec_destroy(graph_exec_handle) + + @staticmethod + def start_profiler() -> None: + CudaAPI._profiler_start() + CudaAPI.device_synchronize() + + @staticmethod + def stop_profiler() -> None: + CudaAPI._profiler_stop() + CudaAPI.device_synchronize() + + PropertyMultiProcessorCount = 'multiProcessorCount' + PropertyMajor = 'major' + PropertyMinor = 'minor' + + @staticmethod + def device_property(name: str, device_id: int = 0) -> int: + return CudaAPI._device_property(device_id, name.encode('utf-8')) + + @staticmethod + def compute_capability() -> Tuple[int, int]: + return (CudaAPI.device_property(CudaAPI.PropertyMajor), + CudaAPI.device_property(CudaAPI.PropertyMinor)) + + +cuda = CudaAPI() + +if __name__ == '__main__': + print(cuda.device_property(cuda.PropertyMultiProcessorCount)) diff --git a/python/hidet/ffi/cuda_kernels.py b/python/hidet/ffi/cuda_kernels.py new file mode 100644 index 0000000..96afae4 --- /dev/null +++ b/python/hidet/ffi/cuda_kernels.py @@ -0,0 +1,23 @@ +from ctypes import c_uint64, c_uint32, c_float, c_uint8, c_int32, c_int64 +from hidet.ffi.ffi import get_func + + +class CudaKernels: + # kernels + _fill_value_int32 = get_func('hidet_cuda_fill_value_int32', [c_uint64, c_uint64, c_int32], None) + _fill_value_int64 = get_func('hidet_cuda_fill_value_int64', [c_uint64, c_uint64, c_int64], None) + _fill_value_float32 = get_func('hidet_cuda_fill_value_float32', [c_uint64, c_uint64, c_float], None) + + @classmethod + def fill_value(cls, addr: int, num_elements: int, value, dtype: str) -> None: + if dtype == 'float32': + return cls._fill_value_float32(addr, num_elements, value) + elif dtype == 'int64': + return cls._fill_value_int64(addr, num_elements, value) + elif dtype == 'int32': + return cls._fill_value_int32(addr, num_elements, value) + else: + raise NotImplementedError('Currently do not support fill value with dtype "{}"'.format(dtype)) + + +cuda_kernels = CudaKernels() diff --git a/python/hidet/ffi/ffi.py b/python/hidet/ffi/ffi.py new file mode 100644 index 0000000..d79ab90 --- /dev/null +++ b/python/hidet/ffi/ffi.py @@ -0,0 +1,65 @@ +from typing import List, Dict +import os +import os.path +from typing import Optional +import ctypes +from hidet.libinfo import get_library_search_dirs + +_LIB: Optional[ctypes.CDLL] = None + + +library_paths: Dict[str, Optional[str]] = { + 'hidet': None, + 'hidet_runtime': None +} + + +def load_library(): + global _LIB + if _LIB: + return + library_dirs = get_library_search_dirs() + for library_dir in library_dirs: + libhidet_path = os.path.join(library_dir, 'libhidet.so') + libhidet_runtime_path = os.path.join(library_dir, 'libhidet_runtime.so') + if not os.path.exists(libhidet_path) or not os.path.exists(libhidet_runtime_path): + continue + _LIB = ctypes.cdll.LoadLibrary(libhidet_path) + _LIB_RUNTIME = ctypes.cdll.LoadLibrary(libhidet_runtime_path) + library_paths['hidet'] = libhidet_path + library_paths['hidet_runtime'] = libhidet_runtime_path + if _LIB is None: + raise OSError('Can not find library in the following directory: \n' + '\n'.join(library_dirs)) + + +def get_last_error() -> Optional[str]: + func = _LIB['hidet_get_last_error'] + func.restype = ctypes.c_char_p + ret = func() + if isinstance(ret, bytes): + return ret.decode('utf-8') + else: + return None + + +class BackendException(Exception): + pass + + +def get_func(func_name, arg_types: List, restype): + func = _LIB[func_name] + func.argtypes = arg_types + func.restype = restype + + def func_with_check(*args): + ret = func(*args) + status = get_last_error() + if status is not None: + msg = 'Calling {} with arguments {} failed. error:\n{}'.format(func_name, args, status) + raise BackendException(msg) + return ret + + return func_with_check + + +load_library() diff --git a/python/hidet/ffi/packedfunc.py b/python/hidet/ffi/packedfunc.py new file mode 100644 index 0000000..8a99687 --- /dev/null +++ b/python/hidet/ffi/packedfunc.py @@ -0,0 +1,116 @@ +from typing import Dict, Sequence, Union, Type +import ctypes + +from .ffi import _LIB +from ctypes import c_int32, c_void_p, pointer, c_float, cast, c_bool +from ctypes import POINTER, Structure +from hidet.ir.type import TypeNode, ScalarType, TensorType +from hidet.ir.dialects.lowlevel import PointerType, TensorPointerType + +c_int32_p = POINTER(c_int32) +c_float_p = POINTER(c_float) + + +class ArgType: + INT32 = 1 + FLOAT32 = 2 + POINTER = 3 + + +class CPackedFunc(Structure): + _fields_ = [("num_args", c_int32), + ("arg_types", c_int32_p), + ("func_pointer", c_void_p)] + + +class PackedFunc: + def __init__(self, param_types, c_func_pointer, ret_type=None, default_args: Dict[int, object] = None): + self.param_types = param_types + self.ret_type = ret_type + self.c_func_pointer = c_func_pointer + self.default_args = default_args if default_args is not None else {} + + type_codes = [self._type_code(param_type) for param_type in self.param_types] + if self.ret_type: + type_codes.append(self._type_code(self.ret_type)) + n = len(type_codes) + num_args = c_int32(n) + arg_types = cast(pointer((c_int32 * n)(*type_codes)), POINTER(c_int32)) + func_pointer = cast(self.c_func_pointer, c_void_p) + self.c_packed_func = CPackedFunc(num_args, arg_types, func_pointer) + + def _convert_arg(self, param_type, arg: Union[int, float, 'Tensor']): + """ + convert arg to a c_void_p + """ + from hidet.tos.tensor import Tensor + if isinstance(arg, int): + assert isinstance(param_type, ScalarType) + if param_type.name == 'int32': + return cast(pointer(c_int32(arg)), c_void_p) + elif isinstance(arg, float): + if param_type.name == 'float32': + return cast(pointer(c_float(arg)), c_void_p) + elif isinstance(arg, Tensor): + return cast(arg.storage.addr, c_void_p) + raise NotImplementedError("Call PackedFunc with argument type: '{}' has not been implemented yet.".format(type(arg))) + + def _type_code(self, param_type: Union[Type[Union[bool, int, TypeNode]]]): + type_map = { + 'bool': c_int32(1), + 'int32': c_int32(1), + 'float32': c_int32(2), + 'pointer': c_int32(3) + } + if param_type is bool or param_type is int: + type_name = 'int32' + elif isinstance(param_type, ScalarType): + type_name = param_type.name + elif isinstance(param_type, (PointerType, TensorType, TensorPointerType)): + type_name = 'pointer' + else: + raise NotImplementedError(param_type) + return type_map[type_name] + + def _apply_default_args(self, orig_args): + n = len(orig_args) + len(self.default_args) + args = [] + orig_args = list(reversed(orig_args)) + for i in range(n): + if i in self.default_args: + args.append(self.default_args[i]) + else: + args.append(orig_args.pop()) + return args + + def _convert_args(self, args: Sequence): + args = self._apply_default_args(args) + assert len(args) == len(self.param_types) + converted_args = [self._convert_arg(param_type, arg) for param_type, arg in zip(self.param_types, args)] + if self.ret_type is not None: + if self.ret_type is bool: + ret_arg = c_int32() + else: + raise NotImplementedError("Currently do not support return type '{}' in packed function.".format(self.ret_type)) + converted_args.append(cast(pointer(ret_arg), c_void_p)) + else: + ret_arg = None + p_args = cast(pointer((ctypes.c_void_p * len(converted_args))(*converted_args)), c_void_p) + return p_args, ret_arg + + def __call__(self, *args): + p_args, ret_arg = self._convert_args(args) + _LIB.CallPackedFunc(self.c_packed_func, p_args) + if ret_arg is not None: + if issubclass(self.ret_type, bool): + return bool(ret_arg.value) + else: + raise NotImplementedError() + else: + return None + + def profile(self, *args, warmup: int = 1, number: int = 1, repeat: int = 10): + results = (c_float * repeat)() + p_args, ret_arg = self._convert_args(args) + _LIB.ProfilePackedFunc(self.c_packed_func, p_args, warmup, number, repeat, cast(pointer(results), c_float_p)) + return [float(v) / number for v in results] diff --git a/python/hidet/ffi/runtime_api.py b/python/hidet/ffi/runtime_api.py new file mode 100644 index 0000000..2a2e439 --- /dev/null +++ b/python/hidet/ffi/runtime_api.py @@ -0,0 +1,20 @@ +from ctypes import c_void_p +from .ffi import get_func + + +class RuntimeAPI: + _set_current_stream = get_func('set_cuda_stream', [c_void_p], None) + _get_current_stream = get_func('get_cuda_stream', [], c_void_p) + + @staticmethod + def set_current_stream(stream_handle: int) -> None: + RuntimeAPI._set_current_stream(c_void_p(stream_handle)) + + @staticmethod + def get_current_stream() -> int: + p = RuntimeAPI._get_current_stream() + return p.value + + +runtime_api = RuntimeAPI() + diff --git a/python/hidet/ir/__init__.py b/python/hidet/ir/__init__.py new file mode 100644 index 0000000..150efc7 --- /dev/null +++ b/python/hidet/ir/__init__.py @@ -0,0 +1,33 @@ +from . import type +from . import expr +from . import stmt +from . import func +from . import functors +from . import builders +from . import primitives +from . import layout +from . import task + +from .func import IRModule, Function +from .type import TypeNode, TensorType, ScalarType, FuncType +from .type import scalar_type, tensor_type + +from .expr import Expr, Var, Constant +from .expr import BinaryOp, Condition, LessThan, Equal, Add, Sub, Multiply, Div, Mod, FloorDiv, Let, Cast +from .expr import var, scalar_var, tensor_var, is_one, is_zero, convert + +from .stmt import Stmt, EvaluateStmt, BufferStoreStmt, AssignStmt, ForStmt, IfStmt, AssertStmt, SeqStmt, LetStmt + +from .dialects.compute import TensorNode, ScalarNode +from .dialects.lowlevel import VoidType, PointerType, Dereference + +from .builders import FunctionBuilder, StmtBuilder + +from .task import Task, save_task, load_task + + +# primitives +from .primitives import max, min, exp, pow + +# utils +from .utils import index_serialize, index_deserialize diff --git a/python/hidet/ir/analyzers/__init__.py b/python/hidet/ir/analyzers/__init__.py new file mode 100644 index 0000000..96213c2 --- /dev/null +++ b/python/hidet/ir/analyzers/__init__.py @@ -0,0 +1,3 @@ +from . import bound_analyzer + +from .bound_analyzer import BoundInfo, BoundAnalyzer, infer_bound diff --git a/python/hidet/ir/analyzers/bound_analyzer.py b/python/hidet/ir/analyzers/bound_analyzer.py new file mode 100644 index 0000000..a74c57d --- /dev/null +++ b/python/hidet/ir/analyzers/bound_analyzer.py @@ -0,0 +1,256 @@ +from typing import Optional, List, Set, Dict, Union, Mapping, Sequence, Tuple +import itertools +import operator +from collections import defaultdict + +from hidet.ir.expr import Expr, Var, Add, Sub, Multiply, FloorDiv, Mod, Constant, Div +from hidet.ir.func import Function +from hidet.ir.functors import FuncStmtExprVisitor +from hidet.ir.stmt import Stmt, ForStmt, LetStmt + + +# from hidet.ir.task import Grid, ThreadBlock, Warp, Thread, Host + + +class BoundInfo: + _max_num_candidates = 1024 + _max_compute_iters = 128 * 128 + + def __init__(self, value=None, candidates=None, min_value=None, max_value=None): + # three level of bound information: + # 1. know its value + # 2. know its candidates + # 3. know it min_value and/or max_value + # the specific one will hide the loose one, e.g., candidates are ignored when value is present. + self.value: Optional[int] = None + self.candidates: Optional[Set[int]] = None + self.min_value: Optional[int] = None + self.max_value: Optional[int] = None + if value is not None: + self.value = value + elif candidates: + if len(candidates) == 1: + self.value = candidates[0] + else: + self.candidates = candidates + elif min_value is not None and max_value is not None: + if min_value == max_value: + self.value = min_value + elif max_value - min_value <= BoundInfo._max_num_candidates: + self.candidates = set(range(min_value, max_value + 1)) + else: + self.min_value = min_value + self.max_value = max_value + else: + self.min_value = min_value + self.max_value = max_value + + @staticmethod + def combine(lhs: 'BoundInfo', rhs: 'BoundInfo', op) -> 'BoundInfo': + if not lhs.has_determent_range() or not rhs.has_determent_range(): + return BoundInfo() + if lhs.is_empty_set() or rhs.is_empty_set(): + return BoundInfo(candidates={}) + + lhs_candidates = lhs.candidate_set() + rhs_candidates = rhs.candidate_set() + if lhs_candidates and rhs_candidates and len(lhs_candidates) * len(rhs_candidates) <= BoundInfo._max_compute_iters: + candidates = set() + for lv in lhs_candidates: + for rv in rhs_candidates: + candidates.add(op(lv, rv)) + if len(candidates) == 1: + return BoundInfo(value=candidates.pop()) + else: + return BoundInfo(candidates=candidates) + else: + # fall back to use min/max value of lhs_candidates/rhs_candidates to infer + lhs_candidates = [lhs.possible_min_value(), lhs.possible_max_value()] + rhs_candidates = [rhs.possible_min_value(), rhs.possible_max_value()] + if op in [operator.add, operator.sub, operator.mul]: + candidates = [op(a, b) for a, b in itertools.product([min(lhs_candidates), max(lhs_candidates)], + [min(rhs_candidates), max(rhs_candidates)])] + return BoundInfo(min_value=min(candidates), max_value=max(candidates)) + elif op is operator.floordiv: + if all(v > 0 for v in rhs_candidates): + return BoundInfo(min_value=min(lhs_candidates) // max(rhs_candidates), + max_value=max(lhs_candidates) // min(rhs_candidates)) + else: + return BoundInfo() + elif op is operator.mod: + if rhs.possible_max_value() is not None: + return BoundInfo(min_value=0, max_value=rhs.possible_max_value() - 1) + else: + return BoundInfo() + else: + raise NotImplementedError() + + def candidate_set(self): + if self.value is not None: + return [self.value] + elif self.candidates is not None: + return self.candidates + else: + return None + + def is_empty_set(self): + return self.candidates is not None and len(self.candidates) == 0 + + def has_determent_range(self) -> bool: + if self.value is not None: + return True + if self.candidates is not None: + return True + if self.min_value is not None and self.max_value is not None: + return True + return False + + def possible_max_value(self): + if self.value is not None: + return self.value + elif self.candidates: + return max(self.candidates) + elif self.max_value is not None: + return self.max_value + else: + return None + + def possible_min_value(self): + if self.value is not None: + return self.value + elif self.candidates: + return min(self.candidates) + elif self.min_value is not None: + return self.min_value + else: + return None + + def is_one(self): + return self.value == 1 + + def is_zero(self): + return self.value == 0 + + def __add__(self, other): + return self.combine(self, other, operator.add) + + def __sub__(self, other): + return self.combine(self, other, operator.sub) + + def __mul__(self, other): + return self.combine(self, other, operator.mul) + + def __floordiv__(self, other): + return self.combine(self, other, operator.floordiv) + + def __mod__(self, other): + return self.combine(self, other, operator.mod) + + def __lt__(self, other): + lhs_max = self.possible_max_value() + rhs_min = other.possible_min_value() + return lhs_max is not None and rhs_min is not None and lhs_max < rhs_min + + def __le__(self, other): + lhs_max = self.possible_max_value() + rhs_min = other.possible_min_value() + return lhs_max is not None and rhs_min is not None and lhs_max <= rhs_min + + def __str__(self): + if self.value is not None: + return str(self.value) + elif self.candidates is not None: + return str(len(self.candidates)) + ': ' + str(self.candidates) + elif self.min_value or self.max_value: + return f'[{self.min_value}:{self.max_value}]' + else: + return 'Any' + + +def normalize_launch_dims(dims: Union[int, Sequence[int]]) -> Sequence[int]: + if isinstance(dims, int): + return [dims, dims, dims] + else: + dims = list(dims) + while len(dims) < 3: + dims = dims + [1] + return dims + + +class BoundAnalyzer(FuncStmtExprVisitor): + # we only infer bound based on variables from LetStmt and ForStmt, and the constants. + # so the local variable with AssignStmt is not used infer bound. + op_dict = { + Add: operator.add, + Sub: operator.sub, + Multiply: operator.mul, + FloorDiv: operator.floordiv, + Mod: operator.mod, + Div: operator.floordiv, # for the node with BoundInfo, we are sure they are integers + } + + def __init__(self, var2bound: Dict[Expr, BoundInfo] = None): + # please give the bound of external variable such as threadIdx.x using var2bound parameter + super().__init__() + self.bound: Dict[Expr, BoundInfo] = defaultdict(BoundInfo) + if var2bound: + self.bound.update(var2bound) + + def visit_Function(self, func: Function): + # note: we use the vars in func.extern_vars instead of hidet.ir.primitives.thread_idx for multiprocessing + extern_var_map = {var.name: var for var in func.extern_vars} + if func.kind == 'cuda_kernel': + block_dims = normalize_launch_dims(func.attrs['cuda_block_dim']) + grid_dims = normalize_launch_dims(func.attrs['cuda_grid_dim']) + for block_dim, suffix in zip(block_dims, ['x', 'y', 'z']): + self.bound[extern_var_map['threadIdx.{}'.format(suffix)]] = BoundInfo(min_value=0, max_value=int(block_dim) - 1) + for grid_dim, suffix in zip(grid_dims, ['x', 'y', 'z']): + self.bound[extern_var_map['blockIdx.{}'.format(suffix)]] = BoundInfo(min_value=0, max_value=int(grid_dim) - 1) + self.visit(func.body) + + def combine(self, e: Union[Add, Sub, Multiply, FloorDiv, Mod, Div]): + self.visit(e.a) + self.visit(e.b) + self.bound[e] = BoundAnalyzer.op_dict[e.__class__](self.bound[e.a], self.bound[e.b]) + + def visit_Add(self, e: Add): + self.combine(e) + + def visit_Sub(self, e: Sub): + self.combine(e) + + def visit_Multiply(self, e: Multiply): + self.combine(e) + + def visit_Div(self, e: Div): + self.combine(e) + + def visit_FloorDiv(self, e: FloorDiv): + self.combine(e) + + def visit_Mod(self, e: Mod): + self.combine(e) + + def visit_LetStmt(self, stmt: LetStmt): + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + self.visit(bind_value) + self.bound[bind_var] = self.bound[bind_value] + self.visit(stmt.body) + + def visit_ForStmt(self, stmt: ForStmt): + self.visit(stmt.extent) + max_val = self.bound[stmt.extent].possible_max_value() + if max_val is not None: + max_val -= 1 + self.bound[stmt.loop_var] = BoundInfo(min_value=0, max_value=max_val) + self.visit(stmt.body) + + def visit_Constant(self, e: Constant): + if e.is_scalar() and e.data_type.name == 'int32': + self.bound[e] = BoundInfo(value=e.value) + + +def infer_bound(node: Union[Function, Stmt, Expr], var2bound: Optional[Mapping[Var, BoundInfo]] = None) -> Dict[Expr, BoundInfo]: + visitor = BoundAnalyzer(var2bound) + visitor.visit(node) + return visitor.bound diff --git a/python/hidet/ir/builders/__init__.py b/python/hidet/ir/builders/__init__.py new file mode 100644 index 0000000..c907351 --- /dev/null +++ b/python/hidet/ir/builders/__init__.py @@ -0,0 +1,5 @@ +from . import func_builder +from . import stmt_builder + +from .func_builder import FunctionBuilder +from .stmt_builder import StmtBuilder diff --git a/python/hidet/ir/builders/func_builder.py b/python/hidet/ir/builders/func_builder.py new file mode 100644 index 0000000..0b62593 --- /dev/null +++ b/python/hidet/ir/builders/func_builder.py @@ -0,0 +1,72 @@ +from typing import List, Dict, Optional + +from hidet.ir.dialects.lowlevel import VoidType +from hidet.ir.expr import Var +from hidet.ir.func import Function +from hidet.ir.stmt import Stmt + +from .stmt_builder import StmtBuilder + + +class FunctionBuilder(StmtBuilder): + def __init__(self, name: str, kind: str, label: str = "", ret_type=VoidType(), grid_dim=None, block_dim=None, dynamic_smem_bytes=None, min_blocks=None, attrs=None): + super().__init__() + self.name = name + self.kind = kind + self.params: List[Var] = [] + self.ret_type = ret_type + self.local_vars = [] + self.func: Optional[Function] = None + self.body: Optional[Stmt] = None + self.extern_vars = [] + self.attrs: Dict[str] = attrs if attrs else {} + self.label = label + + if grid_dim: + self.attrs['cuda_grid_dim'] = grid_dim + if block_dim: + self.attrs['cuda_block_dim'] = block_dim + if dynamic_smem_bytes: + self.attrs['cuda_dynamic_smem_bytes'] = dynamic_smem_bytes + if min_blocks: + self.attrs['cuda_min_blocks'] = min_blocks + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self.finish_func() + + def extend_params(self, params: List[Var]): + self.params.extend(params) + + def extend_extern_vars(self, extern_vars: List[Var]): + self.extern_vars.extend(extern_vars) + + def extend_local_vars(self, local_vars: List[Var]): + assert isinstance(local_vars, (tuple, list)) + self.local_vars.extend(local_vars) + + def extend_attrs(self, new_attrs: Dict[str, object]): + self.attrs.update(new_attrs) + + def set_body(self, body: Stmt): + self.body = body + + def finish_func(self): + from hidet.ir.primitives.cuda.vars import block_idx, thread_idx + assert self.func is None + if 'label' not in self.attrs: + self.attrs['label'] = self.label + if self.kind in ['cuda_kernel', 'cuda_device']: + self.extend_extern_vars([block_idx(dim) for dim in ['x', 'y', 'z']]) + self.extend_extern_vars([thread_idx(dim) for dim in ['x', 'y', 'z']]) + if self.body is None: + self.body = self.finish() + self.func = Function(self.name, kind=self.kind, params=self.params, body=self.body, ret_type=self.ret_type, local_vars=self.local_vars, + local_const_vars=[], extern_vars=self.extern_vars, attrs=self.attrs) + + def get(self) -> Function: + assert self.func.body is not None + return self.func diff --git a/python/hidet/ir/builders/stmt_builder.py b/python/hidet/ir/builders/stmt_builder.py new file mode 100644 index 0000000..7e42ec9 --- /dev/null +++ b/python/hidet/ir/builders/stmt_builder.py @@ -0,0 +1,101 @@ +from typing import Union, Optional, Sequence, List + +from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt +from hidet.ir.type import TypeNode, scalar_type, ScalarType +from hidet.ir.expr import Expr, Var, var, convert +from hidet.ir.layout import TaskLayout, TaskLayoutExpander + +ScopedStmt = Union[IfStmt, ForStmt, LetStmt] + + +class StmtScope: + def __init__(self, sb: 'StmtBuilder', stmts: Union[Sequence[ScopedStmt], ScopedStmt], ret=None): + if isinstance(stmts, (IfStmt, ForStmt, LetStmt)): + stmts = [stmts] + self.sb = sb + self.stmts = stmts + self.ret = ret + + def __enter__(self): + for stmt in self.stmts: + self.sb.enter_body(stmt) + return self.ret + + def __exit__(self, exc_type, exc_val, exc_tb): + for _ in self.stmts: + self.sb.exit_body() + + +class StmtBuilder: + def __init__(self): + self.scope_stack = [[]] + + def __iadd__(self, other: Union[Stmt, Expr]): + assert isinstance(other, (Stmt, Expr)) + self.append(other) + return self + + def let(self, v: Union[str, Var], value: Union[int, Expr]) -> StmtScope: + if isinstance(v, str): + v = var(v) + return StmtScope(self, stmts=LetStmt(v, value), ret=v) + + def lets(self, bind_vars: Sequence[Union[str, Var]], values: Sequence[Union[int, Expr]]) -> StmtScope: + assert len(bind_vars) == len(values) + bind_vars = [var(v) if isinstance(v, str) else v for v in bind_vars] + bind_values = [convert(value) for value in values] + seq_let_stmt = LetStmt(bind_vars, bind_values, body=1) + return StmtScope(self, stmts=seq_let_stmt, ret=bind_vars) + + def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], unroll: Optional[bool] = None) -> StmtScope: + if isinstance(v, str): + v = var(v) + return StmtScope(self, stmts=ForStmt(v, extent, unroll), ret=v) + + def if_then(self, cond: Union[bool, Expr]) -> StmtScope: + return StmtScope(self, stmts=[IfStmt(cond)], ret=None) + + def otherwise(self) -> StmtScope: + assert len(self.scope_stack[-1]) > 0 + if_stmt = self.scope_stack[-1].pop() + assert isinstance(if_stmt, IfStmt) + assert if_stmt.then_body is not None + assert if_stmt.else_body is None + return StmtScope(self, stmts=if_stmt, ret=None) + + def for_task(self, worker_index: Expr, task_layout: TaskLayout): + expander = TaskLayoutExpander() + fields = expander.expand(worker_index, task_layout) + return StmtScope(self, stmts=expander.stmts, ret=fields) + + def append(self, stmt: Union[Stmt, Expr]): + if stmt is None: + return + if not isinstance(stmt, Stmt): + assert isinstance(stmt, Expr) + stmt = EvaluateStmt(stmt) + self.scope_stack[-1].append(stmt) + + def enter_body(self, stmt: Union[IfStmt, ForStmt, LetStmt]): + self.scope_stack[-1].append(stmt) + self.scope_stack.append([]) + + def exit_body(self): + body = SeqStmt(self.scope_stack.pop()) + assert len(self.scope_stack) > 0 + last_stmt = self.scope_stack[-1][-1] + if isinstance(last_stmt, (ForStmt, LetStmt)): + assert last_stmt.body is None or last_stmt.body == 1 + last_stmt.body = body + elif isinstance(last_stmt, IfStmt): + if last_stmt.then_body is None: + last_stmt.then_body = body + else: + assert last_stmt.else_body is None + last_stmt.else_body = body + else: + assert False + + def finish(self): + assert len(self.scope_stack) == 1 + return SeqStmt(self.scope_stack[0]) diff --git a/python/hidet/ir/dialects/__init__.py b/python/hidet/ir/dialects/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/hidet/ir/dialects/compute.py b/python/hidet/ir/dialects/compute.py new file mode 100644 index 0000000..31e7689 --- /dev/null +++ b/python/hidet/ir/dialects/compute.py @@ -0,0 +1,135 @@ +from typing import Union, Sequence, Tuple, Optional, List, Dict, Any +from hidet.ir.type import ScalarType, TensorType, Scope, tensor_type, scalar_type +from hidet.ir.expr import Expr, Constant, convert, Var, var, And, if_then_else +from hidet.utils.info import float_type_min_value + + +class ComputeNode(Expr): + def __init__(self, name): + self.name: Optional[str] = name + + +class ScalarNode(ComputeNode): + def __init__(self, name, data_type, reduce_compute=None): + super().__init__(name) + self.data_type: ScalarType = data_type + self.reduce_compute: Optional[ReduceCompute] = reduce_compute + + def is_input(self) -> bool: + return self.reduce_compute is None + + +class TensorNode(ComputeNode): + def __init__(self, name, data_type, grid_compute=None): + super().__init__(name) + self.data_type: TensorType = data_type + self.grid_compute: Optional[GridCompute] = grid_compute + + def is_input(self) -> bool: + return self.grid_compute is None + + def const_shape(self) -> List[int]: + return self.data_type.const_shape() + + def protect_read(self, indices, default_value=0.0) -> Expr: + conds = [] + assert len(indices) == len(self.data_type.shape) + for index, extent in zip(indices, self.data_type.shape): + conds.append(0 <= index) + conds.append(index < extent) + return if_then_else(And.join(*conds), self.__getitem__(indices), default_value) + + +class GridCompute: + def __init__(self, shape, axes, value): + from hidet.ir.functors import collect, simplify + self.input_tensors: List[TensorNode] = collect(value, TensorNode, stop_when_found=True) + self.input_scalars: List[ScalarNode] = collect(value, ScalarNode, stop_when_found=True) + self.shape: Tuple[Expr] = convert(shape) + self.axes: Tuple[Var] = convert(axes) + self.value: Expr = simplify(value) + + +class ReduceCompute: + def __init__(self, shape, axes, value, reduce_type, accumulate_dtype: str = 'float32'): + from hidet.ir.functors import collect, simplify + self.input_tensors: List[TensorNode] = collect(value, TensorNode, stop_when_found=True) + self.input_scalars: List[ScalarNode] = collect(value, ScalarNode, stop_when_found=True) + self.shape: Tuple[Expr] = convert(shape) + self.axes: Tuple[Var] = convert(axes) + self.value: Expr = simplify(value) + self.reduce_type: str = reduce_type + self.accumulate_dtype = accumulate_dtype + assert reduce_type in ['max', 'avg', 'sum'] + + @staticmethod + def init_const(reduce_type: str, data_type: Union[ScalarType, str]): + init_dict = { + 'sum': Constant(0.0, data_type), + 'avg': Constant(0.0, data_type), + 'max': Constant(float_type_min_value(), data_type) + } + return init_dict[reduce_type] + + @staticmethod + def combine(reduce_type: str, lhs: Expr, rhs: Expr): + from hidet.ir import primitives + func_dict = { + 'sum': lambda a, b: a + b, + 'avg': lambda a, b: a + b, + 'max': lambda a, b: primitives.max(a, b) + } + return func_dict[reduce_type](lhs, rhs) + + @staticmethod + def finalize(reduce_type: str, acc: Expr, size: Expr): + func_dict = { + 'sum': lambda acc, size: acc, + 'avg': lambda acc, size: acc / size, + 'max': lambda acc, size: acc + } + return func_dict[reduce_type](acc, size) + + def const_shape(self) -> List[int]: + return [int(v) for v in self.shape] + + +def scalar_input(name, dtype): + if isinstance(dtype, str): + dtype = ScalarType(dtype) + else: + assert isinstance(dtype, ScalarType) + return ScalarNode(name, dtype, reduce_compute=None) + + +def tensor_input(name, base_type, shape, scope=None, layout=None): + data_type = tensor_type(scope, base_type, shape, layout) + return TensorNode(name, data_type, grid_compute=None) + + +def reduce(shape: Sequence[Union[int, Expr]], fcompute, reduce_type: str, accumulate_dtype: str = 'float32') -> ScalarNode: + from hidet.ir.functors import infer_type + shape = convert(shape) + axes = [var() for _ in shape] + value = convert(fcompute(*axes)) + return ScalarNode( + name=f'acc_{reduce_type}', + data_type=infer_type(value), + reduce_compute=ReduceCompute(shape, axes, value, reduce_type, accumulate_dtype) + ) + + +def compute(name, shape, fcompute, scope=None, layout=None) -> TensorNode: + from hidet.ir.functors import infer_type + shape = convert(shape) + axes = [var() for _ in shape] + value = convert(fcompute(*axes)) + if scope is None: + # todo: automatic determine scope by checking inputs' scope + scope = 'global' + return TensorNode( + name=name, + data_type=tensor_type(scope, dtype=infer_type(value), shape=shape, layout=layout), + grid_compute=GridCompute(shape, axes, value) + ) + diff --git a/python/hidet/ir/dialects/lowlevel.py b/python/hidet/ir/dialects/lowlevel.py new file mode 100644 index 0000000..39b8d0b --- /dev/null +++ b/python/hidet/ir/dialects/lowlevel.py @@ -0,0 +1,62 @@ +from typing import Optional, Union, Sequence +from hidet.ir.type import TypeNode, ScalarType, TensorType, Scope, Int, tensor_type +from hidet.ir.expr import Expr, TensorElement, Var, Constant +from hidet.ir.layout import DataLayout + + +class VoidType(TypeNode): + pass + + +class PointerType(TypeNode): + def __init__(self, base_type, specifiers: Optional[Sequence[str]] = None, use_bracket: bool = False): + super().__init__() + self.base_type = base_type + self.specifiers = list(specifiers) if specifiers else [] + self.use_bracket = use_bracket + + +class ReferenceType(TypeNode): + def __init__(self, base_type): + super().__init__() + self.base_type = base_type + + +class TensorPointerType(TypeNode): + def __init__(self, + scope: Optional[Union[Scope, str]] = None, + dtype: Optional[Union[ScalarType, str]] = None, + shape: Optional[Sequence[Int]] = None, + layout: Optional[Union[Sequence[Int], DataLayout]] = None): + self.tensor_type: TensorType = tensor_type(scope, dtype, shape, layout) + + +# +# Moved to hidet.ir.expr +# +# class Cast(Expr): +# def __init__(self, expr, target_type): +# self.expr = expr +# if isinstance(target_type, str): +# target_type = ScalarType(target_type) +# self.target_type = target_type + + +class Dereference(Expr): + def __init__(self, expr): + self.expr = expr + + +class Address(Expr): + def __init__(self, expr): + self.expr = expr + + +class Reference(Expr): + def __init__(self, expr): + assert isinstance(expr, (TensorElement, Var)), "only l-value can be referenced." + self.expr = expr + + +def pointer_type(base_type): + return PointerType(base_type) diff --git a/python/hidet/ir/dialects/pattern.py b/python/hidet/ir/dialects/pattern.py new file mode 100644 index 0000000..afa1261 --- /dev/null +++ b/python/hidet/ir/dialects/pattern.py @@ -0,0 +1,560 @@ +import contextlib +import traceback +from typing import Type, Tuple, Any, ContextManager, Callable +from contextlib import ExitStack +from hidet.ir.type import * +from hidet.ir.expr import * +from hidet.ir.dialects.compute import * +from hidet.ir.dialects.lowlevel import * +from hidet.ir.task import * +from hidet.ir.layout import StridesLayout + + +class PatternNode(Node): + # A pattern can match a series of exprs/types/other node objects + pass + + +class StringPattern(PatternNode): + pass + + +class TypePattern(TypeNode, PatternNode): + pass + + +class ScalarTypePattern(TypePattern): + def __init__(self, allowed_types=None): + self.allowed_types: Optional[List[str]] = allowed_types + + +class TensorTypePattern(TypePattern): + def __init__(self, scope=None, scalar_type=None, rank=None, shape=None, layout=None, allow_dynamic_size=False): + self.rank: Optional[int] = rank + self.scope: Optional[Union[Scope, List[Scope]]] = scope + self.scalar_type: Optional[Union[ScalarType, ScalarTypePattern]] = scalar_type + self.shape: Optional[List[Expr]] = shape + self.layout: Optional[DataLayout] = layout + self.allow_dynamic_size = allow_dynamic_size + + +class ExprPattern(Expr, PatternNode): + pass + + +class AnyExpr(ExprPattern): + def __init__(self, cls=None, exclude_cls=None): + self.cls: Optional[Type[Expr]] = cls + self.exclude_cls: Optional[Type[Expr]] = exclude_cls + + +class UnionPattern(PatternNode): + def __init__(self, patterns): + self.patterns: List[Node] = patterns + + +class OptionalPattern(PatternNode): + def __init__(self, pattern): + self.pattern = pattern + + +# class ReduceComputePattern(ExprPattern): +# def __init__(self, allow_dynamic_axis=True): +# self.allow_dynamic_axis = allow_dynamic_axis +# +# +# class TensorComputePattern(ExprPattern): +# def __init__(self, rank=None, allow_dynamic_axis=True, value=None): +# self.rank = rank +# self.allow_dynamic_axis = allow_dynamic_axis +# self.value = value +# +# +# class ScalarExprPattern(ExprPattern): +# def __init__(self, base_pattern=None, exclude_vars=()): +# self.base_pattern: Optional[Expr] = base_pattern +# self.exclude_vars: Sequence[Union[Var, TensorInput, ScalarInput]] = exclude_vars +# +# +# class TaskPattern(PatternNode): +# def __init__(self, compute_pattern=None, required_params=None, allow_extra_params=True, allow_tensor_extra_params=True, worker=None): +# self.compute_pattern: Optional[Expr] = compute_pattern +# self.required_params: Optional[List[ComputeNode]] = required_params +# self.extra_params: Expr = Expr() # as a handle to reference the unmatched params +# self.allow_extra_params: bool = allow_extra_params +# self.allow_tensor_extra_params: bool = allow_tensor_extra_params +# self.worker: Optional[Worker] = worker +# +# +# class TensorAccessPattern(ExprPattern): +# def __init__(self, base_pattern=None, indices_pattern=None): +# self.base_pattern: Optional[Expr] = base_pattern +# self.indices_pattern: Optional[List[Expr]] = indices_pattern + + +class NotMatchedError(Exception): + def __init__(self, pattern, target, message=""): + super().__init__(message) + self.pattern = pattern + self.target = target + + +Matchable = Optional[Union[Node, tuple]] + + +class MatchContext: + def __init__(self, matcher: 'PatternMatcher', pattern: Node, target: Node): + self.matcher = matcher + self.matched = matcher.matched + self.dispatch = matcher.dispatch_table() + self.pattern: Matchable = pattern + self.target: Matchable = target + + def __enter__(self): + if self.pattern is None: + # None in pattern matches anything + return + if self.target is None: + if isinstance(self.pattern, OptionalPattern): + self.matched[self.pattern] = None + return + else: + raise NotMatchedError(self.pattern, self.target, 'Expect non-None target') + assert not isinstance(self.target, list), self.target + if self.pattern in self.matched: + if self.matched[self.pattern] is not self.target: + # we think the constant with the same value as the same object + lhs, rhs = self.matched[self.pattern], self.target + if isinstance(lhs, Constant) and isinstance(rhs, Constant): + if lhs.value == rhs.value: + return + if isinstance(lhs, tuple) and isinstance(rhs, tuple): + # something like (None, None) with the same hash + if lhs is not rhs: + return + raise NotMatchedError(self.pattern, self.target, 'Can not match a pattern to two different targets') + else: + return + + if not isinstance(self.pattern, PatternNode): + # put all pattern class that allow to accept other classes + PatternMatcher.check_type(self.pattern, self.target) + try: + self.matched[self.pattern] = self.target + # noinspection PyArgumentList + self.dispatch[self.pattern.__class__](self.matcher, self.pattern, self.target) + except NotMatchedError as e: + # error from current + del self.matched[self.pattern] + raise e + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type == NotMatchedError: + # error from inner + # delete the matched target for pattern + # do not return True, propagate the exception: + # 1. it can be caught by pattern like UnionPattern to try other target, + # 2. or, it will be caught by the of PatternMatcher.__call__, indicating failure of matching. + if self.pattern is not None: + del self.matched[self.pattern] + + +class PatternMatcher: + """ + invariant: every time when we enter match(...) + 0 the self.matched[...] stored the matched patterns and targets, and ongoing matching must follow them + 1 if successful, all sub-expressions of pattern have been set in self.matched[...] + 2 if failed, it acted like we have not call this function (we treat self.matched[v] = None and v not in self.matched as the same state) + """ + _dispatch_table: Dict[Type[Node], Callable[[Node, Node], None]] = None + + @staticmethod + def dispatch_table(): + if PatternMatcher._dispatch_table is None: + PatternMatcher._dispatch_table = { + # string + StringPattern: PatternMatcher.match_StringPattern, + # expr + Add: PatternMatcher.match_CommutativeBinary, + Sub: PatternMatcher.match_Binary, + Multiply: PatternMatcher.match_CommutativeBinary, + Div: PatternMatcher.match_Binary, + Mod: PatternMatcher.match_Binary, + FloorDiv: PatternMatcher.match_Binary, + LessThan: PatternMatcher.match_Binary, + Equal: PatternMatcher.match_Binary, + LessEqual: PatternMatcher.match_Binary, + TensorElement: PatternMatcher.match_TensorElement, + IfThenElse: PatternMatcher.match_IfThenElse, + Call: PatternMatcher.match_Call, + Var: PatternMatcher.match_Var, + Constant: PatternMatcher.match_Constant, + And: PatternMatcher.match_CommutativeBinary, + Or: PatternMatcher.match_CommutativeBinary, + # compute dialect expr + # ScalarInput: PatternMatcher.match_ScalarInput, + # TensorInput: PatternMatcher.match_TensorInput, + # TensorCompute: PatternMatcher.match_TensorCompute, + # ReduceCompute: PatternMatcher.match_ReduceCompute, + # CustomCompute: PatternMatcher.match_CustomCompute, + # type + ScalarType: PatternMatcher.match_ScalarType, + TensorType: PatternMatcher.match_TensorType, + # scope + Scope: PatternMatcher.match_Scope, + # layout + # TaskLayout: PatternMatcher.always_match, + DataLayout: PatternMatcher.match_DataLayout, + StridesLayout: PatternMatcher.match_StridesLayout, + # worker + # Host: PatternMatcher.always_match, + # Grid: PatternMatcher.match_Grid, + # ThreadBlock: PatternMatcher.match_ThreadBlock, + # Warp: PatternMatcher.match_Warp, + # Thread: PatternMatcher.always_match, + # patterns + # TaskPattern: PatternMatcher.match_TaskPattern, + AnyExpr: PatternMatcher.match_AnyPattern, + # ReduceComputePattern: PatternMatcher.match_ReduceComputePattern, + # TensorComputePattern: PatternMatcher.match_TensorComputePattern, + # ScalarExprPattern: PatternMatcher.match_ScalarExprPattern, + UnionPattern: PatternMatcher.match_UnionPattern, + OptionalPattern: PatternMatcher.match_OptionalPattern, + ScalarTypePattern: PatternMatcher.match_ScalarTypePattern, + TensorTypePattern: PatternMatcher.match_TensorTypePattern, + # TensorAccessPattern: PatternMatcher.match_TensorAccessPattern, + # python containers and types + str: PatternMatcher.match_String, + list: PatternMatcher.match_Sequence, + tuple: PatternMatcher.match_Sequence + } + return PatternMatcher._dispatch_table + + def __init__(self): + self.matched: Dict[Matchable, Optional[Any]] = {} + + def __call__(self, pattern, target): + self.matched.clear() + try: + with self.match(pattern, target): + pass + return self.matched, "Matched" + except NotMatchedError as e: + return None, str(e) + # return None, str(traceback.format_exc()) + + def match(self, pattern: Optional[Union[Node, Sequence]], target: Optional[Union[Node, Sequence]]) -> ContextManager: + return MatchContext(self, pattern, target) + + @staticmethod + def check_type(pattern, target, expect_target_type=None): + if expect_target_type is None: + expect_target_type = pattern.__class__ + if not isinstance(target, expect_target_type): + raise NotMatchedError(pattern, target, "Pattern expect target with type {}, but got type {}".format(expect_target_type, type(target))) + + def check_cond(self, pattern, target, cond, message=""): + if not cond: + raise NotMatchedError(pattern, target, message) + + def always_match(self, pattern, target): + pass + + def match_StringPattern(self, pattern: StringPattern, target: Any): + if not isinstance(target, str): + raise NotMatchedError(pattern, target) + + def match_CommutativeBinary(self, pattern: BinaryOp, target: BinaryOp): + # return self.match_Binary(pattern, target) + try: + with ExitStack() as stack: + stack.enter_context(self.match(pattern.a, target.a)) + stack.enter_context(self.match(pattern.b, target.b)) + except NotMatchedError: + pass + else: + return + try: + with ExitStack() as stack: + stack.enter_context(self.match(pattern.a, target.b)) + stack.enter_context(self.match(pattern.b, target.a)) + except NotMatchedError: + pass + else: + return + + raise NotMatchedError(pattern, target, "Commutative binary op has not matched") + + def match_Binary(self, pattern: BinaryOp, target: BinaryOp): + with ExitStack() as stack: + stack.enter_context(self.match(pattern.a, target.a)) + stack.enter_context(self.match(pattern.b, target.b)) + + def match_TensorElement(self, pattern: TensorElement, target: TensorElement): + with ExitStack() as stack: + stack.enter_context(self.match(pattern.base, target.base)) + stack.enter_context(self.match(pattern.indices, target.indices)) + + def match_IfThenElse(self, pattern: IfThenElse, target: IfThenElse): + with ExitStack() as stack: + stack.enter_context(self.match(pattern.cond, target.cond)) + stack.enter_context(self.match(pattern.then_expr, target.then_expr)) + stack.enter_context(self.match(pattern.else_expr, target.else_expr)) + + def match_Call(self, pattern: Call, target: Call): + with ExitStack() as stack: + stack.enter_context(self.match(pattern.func_var, target.func_var)) + stack.enter_context(self.match(pattern.args, target.args)) + + def match_Var(self, pattern: Var, target: Var): + if isinstance(pattern.type, FuncType): + return + with self.match(pattern.type, target.type): + pass + + def match_Constant(self, pattern: Constant, target: Constant): + with self.match(pattern.data_type, target.data_type): + pass + if pattern.value is None: + # None matches any const value + return + if pattern.value != target.value: + raise NotMatchedError(pattern, target) + + # def match_ScalarInput(self, pattern: ScalarInput, target: ScalarInput): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.data_type, target.data_type)) + # + # def match_TensorInput(self, pattern: TensorInput, target: TensorInput): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.data_type, target.data_type)) + # + # def match_TensorCompute(self, pattern: TensorCompute, target: TensorCompute): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.shape, target.shape)) + # stack.enter_context(self.match(pattern.axes, target.axes)) + # stack.enter_context(self.match(pattern.value, target.value)) + # stack.enter_context(self.match(pattern.data_type, target.data_type)) + # + # def match_ReduceCompute(self, pattern: ReduceCompute, target: ReduceCompute): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.axes, target.axes)) + # stack.enter_context(self.match(pattern.shape, target.shape)) + # stack.enter_context(self.match(pattern.value, target.value)) + # stack.enter_context(self.match(pattern.reduce_type, target.reduce_type)) + # stack.enter_context(self.match(pattern.data_type, target.data_type)) + # + # def match_CustomCompute(self, pattern: CustomCompute, target: CustomCompute): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.identifier, target.identifier)) + # stack.enter_context(self.match(pattern.data_type, target.data_type)) + # stack.enter_context(self.match(pattern.params, target.params)) + # for key, value in pattern.attributes.items(): + # if key not in target.attributes: + # raise NotMatchedError(pattern, target, 'key {} not found in target CustomCompute'.format(key)) + # stack.enter_context(self.match(value, target.attributes[key])) + + def match_DataLayout(self, pattern, target): + if isinstance(target, (StridesLayout, DataLayout)): + pass + else: + raise NotMatchedError(pattern, target) + + def match_StridesLayout(self, pattern: StridesLayout, target: StridesLayout): + pass + + def match_AnyPattern(self, pattern: AnyExpr, target: Expr): + # if pattern.type is None, match any expr, otherwise match any expr with specific type + if pattern.cls and not isinstance(target, pattern.cls): + raise NotMatchedError(pattern, target) + if pattern.exclude_cls and isinstance(target, pattern.exclude_cls): + raise NotMatchedError(pattern, target) + + def match_UnionPattern(self, pattern: UnionPattern, target: Node): + for p in pattern.patterns: + success = True + try: + with self.match(p, target): + pass + except NotMatchedError: + success = False + if success: + return + raise NotMatchedError(pattern, target) + + def match_OptionalPattern(self, pattern: OptionalPattern, target: Node): + if target is None: + return + else: + with self.match(pattern.pattern, target): + pass + + def match_Scope(self, pattern: Scope, target: Scope): + if pattern.name is not None and (pattern.name is None or pattern.name != target.name): + raise NotMatchedError(pattern, target) + + # def match_ReduceComputePattern(self, pattern: ReduceComputePattern, target: ReduceCompute): + # self.check_type(pattern, target, ReduceCompute) + # if not pattern.allow_dynamic_axis and any(not isinstance(v, Constant) for v in target.shape): + # raise NotMatchedError(pattern, target, "does not allow dynamic axis in reduce") + # + # def match_TensorComputePattern(self, pattern: TensorComputePattern, target: TensorCompute): + # self.check_type(pattern, target, TensorCompute) + # if pattern.rank is not None and len(target.shape) != pattern.rank: + # raise NotMatchedError(pattern, target, "rank does not match") + # if not pattern.allow_dynamic_axis and any(not isinstance(v, Constant) for v in target.shape): + # raise NotMatchedError(pattern, target, "does not allow dynamic axis") + # with self.match(pattern.value, target.value): + # pass + # + # def match_ScalarExprPattern(self, pattern: ScalarExprPattern, target: Expr): + # from hidet.ir.functors import collect + # if len(pattern.exclude_vars) > 0: + # if len(pattern.exclude_vars) > 0: + # matched_exclude_vars = [self.matched[v] for v in pattern.exclude_vars if v in self.matched] + # included_vars = collect(target, Var) + # for included_var in included_vars: + # if included_var in matched_exclude_vars: + # raise NotMatchedError(pattern, target, "excluded var occurred in target") + # if pattern.base_pattern is not None: + # for sub_expr in collect(target, Expr): + # try: + # with self.match(pattern.base_pattern, sub_expr): + # return + # except NotMatchedError: + # continue + # raise NotMatchedError(pattern, target, "can not find a sub expression to match base pattern") + + def match_ScalarType(self, pattern: ScalarType, target: ScalarType): + if pattern.name: + if pattern.name != target.name: + raise NotMatchedError(pattern, target) + + def match_TensorType(self, pattern: TensorType, target: TensorType): + with ExitStack() as stack: + stack.enter_context(self.match(pattern.scalar_type, target.scalar_type)) + stack.enter_context(self.match(pattern.shape, target.shape)) + stack.enter_context(self.match(pattern.layout, target.layout)) + stack.enter_context(self.match(pattern.scope, target.scope)) + + def match_ScalarTypePattern(self, pattern: ScalarTypePattern, target: ScalarType): + self.check_type(pattern, target, ScalarType) + if pattern.allowed_types is not None and target.name not in pattern.allowed_types: + raise NotMatchedError(pattern, target) + + def match_TensorTypePattern(self, pattern: TensorTypePattern, target: TensorType): + self.check_type(pattern, target, TensorType) + with ExitStack() as stack: + if pattern.rank is not None and len(target.shape) != pattern.rank: + raise NotMatchedError(pattern, target) + if pattern.scope is not None: + if isinstance(pattern.scope, Scope) and pattern.scope.name != target.scope.name: + raise NotMatchedError(pattern, target) + if isinstance(pattern.scope, list) and target.scope.name not in [s.name for s in pattern.scope]: + raise NotMatchedError(pattern, target) + stack.enter_context(self.match(pattern.scalar_type, target.scalar_type)) + stack.enter_context(self.match(pattern.shape, target.shape)) + stack.enter_context(self.match(pattern.layout, target.layout)) + if not pattern.allow_dynamic_size and any(not isinstance(s, Constant) for s in target.shape): + raise NotMatchedError(pattern, target) + + # def match_TensorAccessPattern(self, pattern: TensorAccessPattern, target: Expr): + # self.check_type(pattern, target, TensorElement) + # assert isinstance(target, TensorElement) + # self.check_cond(pattern, target, + # cond=isinstance(target.base, (TensorInput, TensorCompute)), + # message='TensorAccessPattern expect the target is a ' + # 'tensor access based on TensorInput or TensorCompute') + # with ExitStack() as stack: + # if pattern.base_pattern is not None: + # stack.enter_context(self.match(pattern.base_pattern, target.base)) + # if pattern.indices_pattern is not None: + # self.check_cond(pattern, target, cond=len(pattern.indices_pattern) == len(target.indices), + # message='TensorAccessPattern expect pattern and target have the same number of indices') + # for index_pattern, index_expr in zip(pattern.indices_pattern, target.indices): + # stack.enter_context(self.match(index_pattern, index_expr)) + # + # def match_Grid(self, pattern: Grid, target: Grid): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.grid_dim, target.grid_dim)) + # stack.enter_context(self.match(pattern.block_dim, target.block_dim)) + # + # def match_ThreadBlock(self, pattern: ThreadBlock, target: ThreadBlock): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.block_dim, target.block_dim)) + # stack.enter_context(self.match(pattern.task_layout, target.task_layout)) + # + # def match_Warp(self, pattern: Warp, target: Warp): + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.task_layout, target.task_layout)) + # + # def match_TaskPattern(self, pattern: TaskPattern, target: Task): + # self.check_type(pattern, target, Task) + # with ExitStack() as stack: + # stack.enter_context(self.match(pattern.compute_pattern, target.compute)) + # matched_params = [self.matched[param] for param in pattern.required_params] if pattern.required_params else [] + # extra_params = [param for param in target.params if param not in matched_params] + # if not pattern.allow_extra_params and len(extra_params) > 0: + # raise NotMatchedError(pattern, target, "do not allow extra param(s)") + # if not pattern.allow_tensor_extra_params and any(isinstance(p, TensorInput) for p in extra_params): + # raise NotMatchedError(pattern, target, "do not allow extra tensor param(s)") + # stack.enter_context(self.match(pattern.worker, target.worker)) + + def match_Sequence(self, pattern: Sequence, target: Sequence): + with ExitStack() as stack: + if len(pattern) != len(target): + raise NotMatchedError(pattern, target, "length does not match") + for a, b in zip(pattern, target): + stack.enter_context(self.match(a, b)) + + def match_String(self, pattern: str, target: str): + if pattern != target: + raise NotMatchedError(pattern, target) + + +# def compute_pattern(name, shape, fcompute, accumulate=None, scope=None, layout=None): +# shape = convert(shape) +# axes = [var() for _ in shape] +# value = convert(fcompute(*axes)) +# data_type = TensorType(scope, dtype=ScalarType('float32'), shape=shape, layout=layout) +# return TensorCompute(name, shape, axes, value, data_type, accumulate) +# + +def reduce_pattern(shape: Sequence[Union[int, Expr]], fcompute, reduce_type: str): + shape = convert(shape) + axes = [var() for _ in shape] + value = convert(fcompute(*axes)) + return ReduceCompute(value, shape, axes, reduce_type, scalar_type('float32')) + + +def any_const_int(): + return Constant(None, ScalarType('int32')) + + +def any_const_ints(num=1): + return [any_const_int() for _ in range(num)] + + +def any_const(): + return AnyExpr(Constant) + +# +# def any_scalar_expr(base_pattern: Optional[Expr] = None, +# exclude_vars: Sequence[Union[Var, TensorInput, ScalarInput]] = ()): +# return ScalarExprPattern(base_pattern=base_pattern, exclude_vars=exclude_vars) + + +# def any_tensor_input() -> TensorInput: +# return TensorInput(None, None) + + +def int_vars(names): + return [var(name, dtype='int32') for name in names] + + +def match(pattern: Node, target: Node) -> Tuple[Optional[Dict[Node, Any]], str]: + """ + :return: match, report + """ + matcher = PatternMatcher() + return matcher(pattern, target) diff --git a/python/hidet/ir/expr.py b/python/hidet/ir/expr.py new file mode 100644 index 0000000..ddef34a --- /dev/null +++ b/python/hidet/ir/expr.py @@ -0,0 +1,494 @@ +import string +import numpy as np +from typing import Optional, Union, Sequence, List, Tuple +from .node import Node +from .type import TypeNode, TensorType, TensorType, ScalarType, Scope, tensor_type, scalar_type +from .layout import DataLayout + +PyScalar = Union[int, float] + + +class Expr(Node): + def __neg__(self): + return Neg(self) + + def __add__(self, other): + return Add(self, other) + + def __radd__(self, other): + return Add(other, self) + + def __sub__(self, other): + return Sub(self, other) + + def __rsub__(self, other): + return Sub(other, self) + + def __mul__(self, other): + return Multiply(self, other) + + def __rmul__(self, other): + return Multiply(other, self) + + def __truediv__(self, other): + return Div(self, other) + + def __rtruediv__(self, other): + return Div(other, self) + + def __floordiv__(self, other): + return Div(self, other) + + def __rfloordiv__(self, other): + return Div(other, self) + + def __mod__(self, other): + return Mod(self, other) + + def __rmod__(self, other): + return Mod(other, self) + + def __lt__(self, other): + return LessThan(self, other) + + def __le__(self, other): + return LessEqual(self, other) + + # + # for performance, we should use Equal(e1, e2) to represent equivalence expression + # + # def __eq__(self, other): + # return Equal(self, other) + # + # def __hash__(self): + # return id(self) + # + + def __ge__(self, other): + return LessEqual(other, self) + + def __invert__(self): + from hidet.ir.dialects.lowlevel import Address + return Address(self) + + def __getitem__(self, items): + if not isinstance(items, (tuple, list)): + items = [items] + indices = [] + starts = [] + ends = [] + for item in items: + if isinstance(item, slice): + indices.append(None) + starts.append(item.start) + ends.append(item.stop) + assert item.step is None, "do not support step slice" + else: + indices.append(item) + starts.append(None) + ends.append(None) + rank = tensor_rank(self) + if len(items) < rank or any(i is None for i in indices): + while len(indices) < rank: + indices.append(None) + starts.append(None) + ends.append(None) + return TensorSlice(base=self, indices=indices, starts=starts, ends=ends) + else: + return TensorElement(base=self, indices=indices) + + def __int__(self): + assert isinstance(self, Constant), 'Expect a Constant, got {} with type {}'.format(self, type(self)) + return int(self) + + def __float__(self): + assert isinstance(self, Constant), 'Expect a Constant, got {} with type {}'.format(self, type(self)) + return float(self) + + def __str__(self): + from hidet.ir.functors import astext + # return str(astext(self)) + ' at {}'.format(hex(id(self))) + return str(astext(self)) + + def equals(self, other): + return Equal(self, other) + + def is_const(self): + return isinstance(self, Constant) + + def const(self) -> 'Constant': + assert isinstance(self, Constant) + return self + + +class BinaryOp(Expr): + def __init__(self, a, b): + self.a = convert(a) + self.b = convert(b) + + +class UnaryOp(Expr): + def __init__(self, a): + self.a = convert(a) + + +def convert(obj: Optional[Union[Expr, PyScalar, tuple, Sequence]], dtype: Optional[Union[str, ScalarType]] = None) -> Optional[Union[Expr, tuple]]: + if isinstance(obj, Expr): + return obj + + if dtype is not None: + if isinstance(obj, (bool, int, float)): + return Constant(obj, dtype) + else: + raise ValueError('Can not convert {} to {}.'.format(obj, dtype)) + + if isinstance(obj, bool): + return Constant(obj, ScalarType('bool')) + elif isinstance(obj, int): + return Constant(obj, ScalarType('int32')) + elif isinstance(obj, float): + return Constant(obj, ScalarType('float32')) + elif isinstance(obj, (tuple, list)): + return tuple([convert(v) for v in obj]) + elif obj is None: + return None + else: + raise NotImplementedError(type(obj)) + + +class Condition(Expr): + pass + + +class LessThan(Condition, BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class LessEqual(Condition, BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class Equal(Condition, BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + def __bool__(self): + r = object.__eq__(self.a, self.b) + if r is NotImplemented: + return False + else: + return True + + +class And(Condition, BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + @staticmethod + def join(*conds): + cond = convert(True) + for c in conds: + cond = And(cond, convert(c)) + return cond + + @staticmethod + def join_list(conds: Sequence[Condition]): + return And.join(*conds) + + +class Or(Condition, BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + @staticmethod + def join(*conds): + cond = convert(False) + for c in conds: + cond = Or(cond, convert(c)) + return cond + + +class Not(Condition, UnaryOp): + def __init__(self, a): + super().__init__(a) + + +class Neg(UnaryOp): + def __init__(self, a): + super().__init__(a) + + +class Add(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class Sub(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class Multiply(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class Div(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class FloorDiv(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class Mod(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class BitwiseNot(Expr): + def __init__(self, base): + super().__init__() + self.base = base + + +class BitwiseAnd(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + +class BitwiseOr(BinaryOp): + def __init__(self, a, b): + super().__init__(a, b) + + @staticmethod + def join_list(lst): + if len(lst) == 0: + return convert(0) + else: + current = lst[0] + for v in lst[1:]: + current = BitwiseOr(current, v) + return current + + +class LeftShift(Expr): + def __init__(self, base, cnt): + super().__init__() + self.base = convert(base) + self.cnt = convert(cnt) + + +class RightShift(Expr): + def __init__(self, base, cnt): + super().__init__() + self.base = base + self.cnt = cnt + + +class TensorElement(Expr): + def __init__(self, base, indices): + self.base = base + self.indices = convert(indices) + + +class TensorSlice(Expr): + def __init__(self, base, indices, starts, ends): + # a[3, 4:, :5, :] will be represented by + # base: a + # indices: [3, None, None, None] + # starts: [None, 4, None, None] + # ends: [None, None, 5, None] + self.base = base + self.indices: Tuple = convert(indices) + self.starts: Tuple = convert(starts) + self.ends: Tuple = convert(ends) + if self.base is not None: + assert len(self.indices) == tensor_rank(base) + + +class Call(Expr): + def __init__(self, func_var, args): + self.func_var: Var = func_var + self.args = convert(args) + + +class Let(Expr): + def __init__(self, var, value, body): + self.var = var + self.value = convert(value) + self.body = convert(body) + + +class Cast(Expr): + def __init__(self, expr, target_type): + self.expr = expr + if isinstance(target_type, str): + target_type = ScalarType(target_type) + self.target_type: TypeNode = target_type + + +class Constant(Expr): + def __init__(self, value=None, data_type=None): + if data_type and isinstance(data_type, str): + data_type = ScalarType(data_type) + self.value: Optional[np.ndarray, float, int] = value + self.data_type: Optional[Union[ScalarType, TensorType]] = data_type + + def is_scalar(self) -> bool: + return self.data_type and isinstance(self.data_type, ScalarType) + + def is_tensor(self) -> bool: + return self.data_type and isinstance(self.data_type, TensorType) + + def __int__(self): + return int(self.value) + + def __float__(self): + return float(self.value) + + def array(self) -> np.ndarray: + return self.value + + +class IfThenElse(Expr): + def __init__(self, cond: Union[Expr, PyScalar], then_expr: Union[Expr, PyScalar], else_expr: Union[Expr, PyScalar]): + self.cond = convert(cond) + self.then_expr = convert(then_expr) + self.else_expr = convert(else_expr) + + +class Var(Expr): + id_clock = 0 + + def __init__(self, hint: Optional[str], type: TypeNode, name: Optional[str] = None): + """ + A variable may have a hint, name, and id. + + Hint is used to determine the name in codegen. Different vars may have the + same hint. If two vars have the same hint such as 'x', the final name would be like 'x1', 'x2'. + + OUTDATED: + Name is the determined name in the final code. Used by primitive varaibles such as 'threadIdx.x'. No variable should have + a same name as primitive objects (including primitive variables and primitive functions). + + Id is used to track the allocation of Var object in python, which is only used to help us to distinguish different Var + in python debugger. + """ + from hidet.ir.dialects.lowlevel import TensorPointerType + self.hint = hint + self.name = name + self.type: Union[TypeNode, TensorType, TensorPointerType] = type + self.id = self.new_id() + + @staticmethod + def new_id(): + Var.id_clock += 1 + return Var.id_clock + + @staticmethod + def reset_id_counter(): + Var.id_clock = 0 + + +def var(hint: str = None, dtype='int32'): + if isinstance(hint, str): + assert set(hint) <= set(string.ascii_letters + '_.' + string.digits) + return Var(hint, ScalarType(dtype)) + + +def scalar_var(hint: str, dtype: Union[str, ScalarType] = 'float32') -> Var: + dtype = dtype if isinstance(dtype, ScalarType) else scalar_type(dtype) + return Var(hint, dtype) + + +def tensor_var(hint: str, shape, scope: str = 'global', dtype: Union[str, ScalarType] = 'float32', layout=None) -> Var: + return Var(hint, tensor_type(scope, dtype, shape, layout)) + + +def is_one(v: Expr) -> bool: + return isinstance(v, Constant) and v.value == 1 + + +def is_zero(v: Expr) -> bool: + return isinstance(v, Constant) and v.value == 0 + + +def is_true(v: Expr) -> bool: + return isinstance(v, Constant) and v.data_type.name == 'bool' and v.value is True + + +def is_false(v: Expr) -> bool: + return isinstance(v, Constant) and v.data_type.name == 'bool' and v.value is False + + +def is_const_int(v: Expr) -> bool: + return isinstance(v, Constant) and v.data_type.name == 'int32' + + +def if_then_else(cond: Union[Expr, PyScalar], then_expr: Union[Expr, PyScalar], else_expr: Union[Expr, PyScalar]) -> IfThenElse: + return IfThenElse(convert(cond), convert(then_expr), convert(else_expr)) + + +def is_tensor(v: Expr) -> bool: + from hidet.ir.dialects.lowlevel import TensorPointerType + if not isinstance(v, Var): + return False + return isinstance(v.type, (TensorType, TensorPointerType)) + + +def get_tensor_layout(v: Expr): + from hidet.ir.dialects.lowlevel import TensorPointerType + assert isinstance(v, Var) and isinstance(v.type, (TensorType, TensorPointerType)) + return v.type.layout if isinstance(v.type, TensorType) else v.type.tensor_type.layout + + +def tensor_rank(v: Expr) -> int: + from hidet.ir.dialects.compute import TensorNode + from hidet.ir.dialects.lowlevel import TensorPointerType, PointerType + if isinstance(v, Var): + if isinstance(v.type, TensorType): + return len(v.type.shape) + elif isinstance(v.type, TensorPointerType): + return len(v.type.tensor_type.shape) + elif isinstance(v.type, PointerType): + return 1 + else: + raise ValueError(v) + elif isinstance(v, TensorSlice): + return sum([1 if i is None else 0 for i in v.indices]) + elif isinstance(v, TensorNode): + return len(v.data_type.shape) + elif isinstance(v, Constant) and isinstance(v.data_type, TensorType): + return len(v.data_type.shape) + elif isinstance(v, Cast) and isinstance(v.target_type, PointerType): + return 1 + else: + raise ValueError('Can not infer the tensor rank of "{}"'.format(v)) + + +def cast(v: Expr, dtype): + return Cast(v, dtype) + + +def const_tensor(value: np.ndarray, data_type=None) -> Constant: + if data_type is None: + data_type = tensor_type( + scope='host', + dtype=ScalarType.from_numpy_dtype(value.dtype), + shape=list(value.shape) + ) + return Constant(value=value, data_type=data_type) + + +def const_like(value: Union[float, int], e: Expr) -> Constant: + from hidet.ir.functors import infer_type + dtype = infer_type(e) + if isinstance(dtype, ScalarType): + return Constant(value=value, data_type=dtype) + else: + raise ValueError('Expect a scalar type, but got {}'.format(dtype)) diff --git a/python/hidet/ir/func.py b/python/hidet/ir/func.py new file mode 100644 index 0000000..6945dd8 --- /dev/null +++ b/python/hidet/ir/func.py @@ -0,0 +1,116 @@ +from typing import Dict, List, Union, Optional, Tuple +from hidet.ir.node import Node +from hidet.ir.type import TypeNode, FuncType +from hidet.ir.expr import Var, Constant +from hidet.ir.stmt import Stmt + + +class Function(Node): + valid_attrs = [ + 'kind', + 'packed_func', + 'label', + 'kind', + 'cuda_grid_dim', + 'cuda_block_dim', + 'cuda_dynamic_smem_bytes', + 'cuda_min_blocks' + ] + """ + Valid Attrs: + 'kind': str, candidates: 'cuda_device', 'cuda_kernel', 'host_kernel', 'packed_func' + the kind of this function. + - 'cuda_device': this is a cuda device function, can only be called by cuda function + - 'cuda_kernel': this is a cuda kernel function + - 'host_kernel': this is a cpu kernel function + - 'packed_func': this is a packed function that wraps kernel function(s) + 'cuda_grid_dim': Union[int, List[int]] + the grid dimension in cuda launch configuration + 'cuda_block_dim': Union[int, List[int]] + the block dimension in cuda launch configuration + 'cuda_dynamic_smem_bytes': int + the dynamic shared memory in cuda launch configuration + 'cuda_min_blocks': int + the minimal number of thread blocks in launch bound of cuda kernel function + 'packed_func': Function + the target function that this packed_func has packed. valid when attrs['kind'] == 'packed_func' + 'label': str + the label of this function when it is in a function group + """ + + def __init__(self, name: str, params, body, ret_type, kind: str, local_vars, local_const_vars=None, extern_vars=None, attrs=None): + self.name = name.replace('.', '_') + self.kind = kind + assert isinstance(kind, str) and kind in ['cuda_device', 'cuda_kernel', 'host_kernel', 'packed_func'] + self.params: List[Var] = params + self.body: Stmt = body + self.ret_type: TypeNode = ret_type + self.local_vars: List[Var] = local_vars + self.local_const_vars: List[Tuple[Var, Constant]] = local_const_vars if local_const_vars else [] + self.extern_vars: List[Var] = extern_vars if extern_vars else [] + self.attrs = attrs if attrs else {} + + assert all(attr in self.valid_attrs for attr in self.attrs) + + def annotate(self, attr_name, attr_value, update=False): + assert attr_name in self.valid_attrs + if attr_name in self.attrs and not update: + raise AttributeError(f'{attr_name} has existed') + self.attrs[attr_name] = attr_value + + def get_attr(self, attr_name, default=None): + if attr_name in self.attrs: + return self.attrs[attr_name] + return default + + +class IRModule(Node): + def __init__(self, funcs=None, task=None, global_vars=None): + from hidet.ir.task import Task + if funcs: + assert isinstance(funcs, dict) + # assert task is not None, 'Please specify the task' + self.task: Optional[Task] = task + self.functions: Dict[str, Function] = funcs if funcs else {} + self.global_vars: Dict[str, Var] = global_vars if global_vars else {} + + def include(self, module, skip_duplicated=True): + for name, func in module.functions.items(): + if name in self.functions: + if skip_duplicated: + continue + else: + raise ValueError('Function {} has already existed in module while include another module.'.format(name)) + else: + self.functions[name] = func + + for name, var in module.global_vars.items(): + self.global_vars[name] = var + + def lookup(self, name_or_var: Union[str, Var]): + if isinstance(name_or_var, Var): + name = name_or_var.hint + else: + name = name_or_var + if name not in self.functions: + raise KeyError('Function {} does not exist in module, existed functions: \n{}.'.format(name, list(self.functions.keys()))) + return self.functions[name] + + def lookup_var(self, name): + assert name in self.functions, (name, self.functions.keys()) + if name not in self.global_vars: + func = self.functions[name] + if isinstance(func, Function): + self.global_vars[name] = Var(name, FuncType.from_func(func)) + else: + raise ValueError() + + return self.global_vars[name] + + def add(self, name, func: Function): + if name in self.functions: + raise ValueError('Function {} has already existed in module.'.format(name)) + else: + self.functions[name] = func + + diff --git a/python/hidet/ir/functors/__init__.py b/python/hidet/ir/functors/__init__.py new file mode 100644 index 0000000..59b296f --- /dev/null +++ b/python/hidet/ir/functors/__init__.py @@ -0,0 +1,11 @@ +from .base import NodeFunctor +from .base import ExprFunctor, ExprVisitor, ExprRewriter +from .base import StmtFunctor, StmtVisitor, StmtRewriter +from .base import StmtExprFunctor, StmtExprVisitor, StmtExprRewriter, TypeFunctor, FuncStmtExprRewriter, FuncStmtExprVisitor +from .base import same_list +from .type_infer import infer_type, TypeInfer +from .util_functors import rewrite, collect, collect_free_vars, clone +from .printer import astext +from .simplifier import simplify, simplify_to_int +from .hasher import ExprHash +from .compute_inliner import inline_compute diff --git a/python/hidet/ir/functors/base.py b/python/hidet/ir/functors/base.py new file mode 100644 index 0000000..bae91ef --- /dev/null +++ b/python/hidet/ir/functors/base.py @@ -0,0 +1,791 @@ +from abc import ABC +from typing import Mapping + +from hidet.ir.dialects.pattern import * +from hidet.ir.func import * +from hidet.ir.stmt import * + + +class NodeFunctor: + def __init__(self, use_memo=True): + self.memo = {} if use_memo else None + if not hasattr(self.__class__, 'dispatch_table'): + self.setup_dispatch_table() + + def __call__(self, node: Any): + return self.visit(node) + + def visit(self, node: Union[Node, tuple, list]): + key = id(node) if isinstance(node, list) else node + if self.memo is not None and key in self.memo: + return self.memo[key] + if isinstance(node, Node): + idx = node.class_index() if node is not None else 0 + # noinspection PyUnresolvedReferences + dispatch_table = self.__class__.dispatch_table + if idx >= len(dispatch_table): + raise NotImplementedError('Does not implement dispatch function in "{}" for node "{}"'.format(type(self).__qualname__, type(node).__qualname__)) + ret = dispatch_table[idx](self, node) + elif isinstance(node, tuple): + ret = tuple(self.visit(v) for v in node) + elif isinstance(node, list): + ret = [self.visit(v) for v in node] + else: + raise NotImplementedError("Can not dispatch object with type {}".format(type(node))) + if self.memo is not None: + self.memo[key] = ret + return ret + + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return {} + + @classmethod + def setup_dispatch_table(cls: Type[Node]): + cls_stack: List[type] = [cls] + mapping = {} + while len(cls_stack) > 0: + cur_cls = cls_stack.pop() + if hasattr(cur_cls, 'get_dispatch_mapping'): + cur_mapping = cur_cls.get_dispatch_mapping(cls) + for k, v in cur_mapping.items(): + if k not in mapping: + mapping[k] = v + cls_stack.extend(cur_cls.__bases__) + setattr(cls, 'dispatch_table', Node.dispatch_table(mapping)) + + +class ExprFunctor(NodeFunctor): + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return { + Add: cls.visit_Add, + Sub: cls.visit_Sub, + Multiply: cls.visit_Multiply, + Div: cls.visit_Div, + Mod: cls.visit_Mod, + FloorDiv: cls.visit_FloorDiv, + Neg: cls.visit_Neg, + LessThan: cls.visit_LessThan, + LessEqual: cls.visit_LessEqual, + Equal: cls.visit_Equal, + And: cls.visit_And, + Or: cls.visit_Or, + Not: cls.visit_Not, + BitwiseAnd: cls.visit_BitwiseAnd, + BitwiseOr: cls.visit_BitwiseOr, + BitwiseNot: cls.visit_BitwiseNot, + LeftShift: cls.visit_LeftShift, + RightShift: cls.visit_RightShift, + TensorElement: cls.visit_TensorElement, + TensorSlice: cls.visit_TensorSlice, + IfThenElse: cls.visit_IfThenElse, + Call: cls.visit_Call, + Let: cls.visit_Let, + Var: cls.visit_Var, + Constant: cls.visit_Constant, + Cast: cls.visit_Cast, + Dereference: cls.visit_Dereference, + Address: cls.visit_Address, + Reference: cls.visit_Reference, + TensorNode: cls.visit_TensorNode, + ScalarNode: cls.visit_ScalarNode, + AnyExpr: cls.visit_AnyExpr, + } + + def visit_Add(self, e: Add): + raise NotImplementedError() + + def visit_Sub(self, e: Sub): + raise NotImplementedError() + + def visit_Multiply(self, e: Multiply): + raise NotImplementedError() + + def visit_Div(self, e: Div): + raise NotImplementedError() + + def visit_Mod(self, e: Mod): + raise NotImplementedError() + + def visit_FloorDiv(self, e: FloorDiv): + raise NotImplementedError() + + def visit_LessThan(self, e: LessThan): + raise NotImplementedError() + + def visit_LessEqual(self, e: LessEqual): + raise NotImplementedError() + + def visit_Equal(self, e: Equal): + raise NotImplementedError() + + def visit_And(self, e: And): + raise NotImplementedError() + + def visit_Or(self, e: Or): + raise NotImplementedError() + + def visit_Neg(self, e: Neg): + raise NotImplementedError() + + def visit_Not(self, e: Not): + raise NotImplementedError() + + def visit_BitwiseAnd(self, e: BitwiseAnd): + raise NotImplementedError() + + def visit_BitwiseOr(self, e: BitwiseOr): + raise NotImplementedError() + + def visit_BitwiseNot(self, e: BitwiseNot): + raise NotImplementedError() + + def visit_LeftShift(self, e: LeftShift): + raise NotImplementedError() + + def visit_RightShift(self, e: RightShift): + raise NotImplementedError() + + def visit_TensorElement(self, e: TensorElement): + raise NotImplementedError() + + def visit_TensorSlice(self, e: TensorSlice): + raise NotImplementedError() + + def visit_IfThenElse(self, e: IfThenElse): + raise NotImplementedError() + + def visit_Cast(self, e: Cast): + raise NotImplementedError() + + def visit_Dereference(self, e: Dereference): + raise NotImplementedError() + + def visit_Address(self, e: Address): + raise NotImplementedError() + + def visit_Reference(self, e: Reference): + raise NotImplementedError() + + def visit_Call(self, e: Call): + raise NotImplementedError() + + def visit_Let(self, e: Let): + raise NotImplementedError() + + def visit_Var(self, e: Var): + raise NotImplementedError() + + def visit_Constant(self, e: Constant): + raise NotImplementedError() + + def visit_ScalarNode(self, e: ScalarNode): + raise NotImplementedError() + + def visit_TensorNode(self, e: TensorNode): + raise NotImplementedError() + + def visit_AnyExpr(self, e: AnyExpr): + raise NotImplementedError() + + +class ExprVisitor(ExprFunctor): + def visit_Add(self, e: Add): + self.visit(e.a) + self.visit(e.b) + + def visit_Sub(self, e: Sub): + self.visit(e.a) + self.visit(e.b) + + def visit_Multiply(self, e: Multiply): + self.visit(e.a) + self.visit(e.b) + + def visit_Div(self, e: Div): + self.visit(e.a) + self.visit(e.b) + + def visit_Mod(self, e: Mod): + self.visit(e.a) + self.visit(e.b) + + def visit_FloorDiv(self, e: FloorDiv): + self.visit(e.a) + self.visit(e.b) + + def visit_LessThan(self, e: LessThan): + self.visit(e.a) + self.visit(e.b) + + def visit_LessEqual(self, e: LessThan): + self.visit(e.a) + self.visit(e.b) + + def visit_Equal(self, e: Equal): + self.visit(e.a) + self.visit(e.b) + + def visit_And(self, e: And): + self.visit(e.a) + self.visit(e.b) + + def visit_Or(self, e: Or): + self.visit(e.a) + self.visit(e.b) + + def visit_Neg(self, e: Neg): + self.visit(e.a) + + def visit_Not(self, e: Not): + self.visit(e.a) + + def visit_BitwiseAnd(self, e: BitwiseAnd): + self.visit(e.a) + self.visit(e.b) + + def visit_BitwiseOr(self, e: BitwiseOr): + self.visit(e.a) + self.visit(e.b) + + def visit_BitwiseNot(self, e: BitwiseNot): + self.visit(e.base) + + def visit_LeftShift(self, e: LeftShift): + self.visit(e.base) + self.visit(e.cnt) + + def visit_RightShift(self, e: RightShift): + self.visit(e.base) + self.visit(e.cnt) + + def visit_TensorElement(self, e: TensorElement): + self.visit(e.base) + for idx in e.indices: + self.visit(idx) + + def visit_TensorSlice(self, e: TensorSlice): + self.visit(e.base) + for idx, start, end in zip(e.starts, e.indices, e.ends): + for obj in [idx, start, end]: + if obj is not None: + self.visit(obj) + + def visit_IfThenElse(self, e: IfThenElse): + self.visit(e.cond) + self.visit(e.then_expr) + self.visit(e.else_expr) + + def visit_Call(self, e: Call): + self.visit(e.func_var) + for arg in e.args: + self.visit(arg) + + def visit_Let(self, e: Let): + self.visit(e.value) + self.visit(e.var) + self.visit(e.body) + + def visit_Var(self, e: Var): + pass + + def visit_Constant(self, e: Constant): + pass + + # compute dialect + def visit_ScalarNode(self, e: ScalarNode): + if e.reduce_compute: + self.visit(e.reduce_compute.value) + + def visit_TensorNode(self, e: TensorNode): + if e.grid_compute: + self.visit(e.grid_compute.value) + + # lowlevel dialect + def visit_Cast(self, e: Cast): + self.visit(e.expr) + + def visit_Dereference(self, e: Dereference): + self.visit(e.expr) + + def visit_Address(self, e: Address): + self.visit(e.expr) + + def visit_Reference(self, e: Reference): + self.visit(e.expr) + + def visit_AnyExpr(self, e: AnyExpr): + pass + + +class ExprRewriter(ExprFunctor): + def rewrite(self, e): + return self.visit(e) + + def visit_Binary(self, e: BinaryOp): + a = self(e.a) + b = self(e.b) + if a is e.a and b is e.b: + return e + else: + return e.__class__(a, b) + + def visit_Add(self, e: Add): + return self.visit_Binary(e) + + def visit_Sub(self, e: Sub): + return self.visit_Binary(e) + + def visit_Multiply(self, e: Multiply): + return self.visit_Binary(e) + + def visit_Div(self, e: Div): + return self.visit_Binary(e) + + def visit_Mod(self, e: Mod): + return self.visit_Binary(e) + + def visit_FloorDiv(self, e: FloorDiv): + return self.visit_Binary(e) + + def visit_LessThan(self, e: LessThan): + return self.visit_Binary(e) + + def visit_LessEqual(self, e: LessEqual): + return self.visit_Binary(e) + + def visit_Equal(self, e: Equal): + return self.visit_Binary(e) + + def visit_And(self, e: And): + return self.visit_Binary(e) + + def visit_Or(self, e: Or): + return self.visit_Binary(e) + + def visit_Neg(self, e: Neg): + a = self(e.a) + if a is e.a: + return e + else: + return Neg(a) + + def visit_Not(self, e: Not): + a = self(e.a) + if a is e.a: + return e + else: + return Not(a) + + def visit_BitwiseAnd(self, e: BitwiseAnd): + return self.visit_Binary(e) + + def visit_BitwiseOr(self, e: BitwiseOr): + return self.visit_Binary(e) + + def visit_BitwiseNot(self, e: BitwiseNot): + base = self.visit(e.base) + if base is e.base: + return e + else: + return BitwiseNot(base) + + def visit_LeftShift(self, e: LeftShift): + base = self.visit(e.base) + cnt = self.visit(e.cnt) + if base is e.base and cnt is e.cnt: + return e + else: + return LeftShift(base, cnt) + + def visit_RightShift(self, e: RightShift): + base = self.visit(e.base) + cnt = self.visit(e.cnt) + if base is e.base and cnt is e.cnt: + return e + else: + return RightShift(base, cnt) + + def visit_TensorElement(self, e: TensorElement): + base = self(e.base) + indices = [self(idx) if idx is not None else None for idx in e.indices] + if base is e.base and same_list(indices, e.indices): + return e + else: + return TensorElement(base, indices) + + def visit_TensorSlice(self, e: TensorSlice): + base = self(e.base) + indices = [self(idx) if idx is not None else None for idx in e.indices] + starts = [self(start) if start is not None else None for start in e.starts] + ends = [self(end) if end is not None else None for end in e.ends] + if base is e.base and same_list(indices, e.indices) and same_list(starts, e.starts) and same_list(ends, e.ends): + return e + else: + return TensorSlice(base, indices, starts, ends) + + def visit_IfThenElse(self, e: IfThenElse): + cond = self(e.cond) + then_expr = self(e.then_expr) + else_expr = self(e.else_expr) + if cond is e.cond and then_expr is e.then_expr and else_expr is e.else_expr: + return e + else: + return IfThenElse(cond, then_expr, else_expr) + + def visit_Cast(self, e: Cast): + expr = self(e.expr) + if expr is e.expr: + return e + else: + return Cast(expr, e.target_type) + + def visit_Dereference(self, e: Dereference): + expr = self(e.expr) + if expr is e.expr: + return e + else: + return Dereference(expr) + + def visit_Address(self, e: Address): + expr = self(e.expr) + if expr is e.expr: + return e + else: + return Address(expr) + + def visit_Reference(self, e: Reference): + expr = self(e.expr) + if expr is e.expr: + return e + else: + return Reference(expr) + + def visit_Call(self, e: Call): + func_var = self(e.func_var) + args = [self(arg) for arg in e.args] + if func_var is e.func_var and same_list(args, e.args): + return e + else: + return Call(func_var, args) + + def visit_Let(self, e: Let): + var = e.var + value = self(e.value) + body = self(e.body) + if same_list([var, value, body], [e.var, e.value, e.body]): + return e + else: + return Let(var, value, body) + + def visit_Var(self, e: Var): + return e + + def visit_Constant(self, e: Constant): + return e + + def visit_ScalarNode(self, e: ScalarNode): + if e.reduce_compute is None: + return e + else: + rc = e.reduce_compute + axes = self(rc.axes) + value = self(rc.value) + shape = self(rc.shape) + if value is rc.value and same_list(axes, rc.axes) and same_list(shape, rc.shape): + return e + else: + return ScalarNode(e.name, e.data_type, ReduceCompute(shape, axes, value, rc.reduce_type)) + + def visit_TensorNode(self, e: TensorNode): + if e.grid_compute is None: + return e + else: + gc = e.grid_compute + axes = self(gc.axes) + value = self(gc.value) + shape = self(gc.shape) + if value is gc.value and same_list(axes, gc.axes) and same_list(shape, gc.shape): + return e + else: + return TensorNode(e.name, e.data_type, GridCompute(shape, axes, value)) + + def visit_AnyExpr(self, e: AnyExpr): + return e + + +class StmtFunctor(NodeFunctor): + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return { + EvaluateStmt: cls.visit_EvaluateStmt, + BufferStoreStmt: cls.visit_BufferStoreStmt, + AssignStmt: cls.visit_AssignStmt, + LetStmt: cls.visit_LetStmt, + ForStmt: cls.visit_ForStmt, + IfStmt: cls.visit_IfStmt, + ReturnStmt: cls.visit_ReturnStmt, + AsmStmt: cls.visit_AsmStmt, + AssertStmt: cls.visit_AssertStmt, + BlackBoxStmt: cls.visit_BlackBoxStmt, + SeqStmt: cls.visit_SeqStmt, + } + + def visit_expr(self, e: Expr): + raise NotImplementedError() + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + raise NotImplementedError() + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + raise NotImplementedError() + + def visit_AssignStmt(self, stmt: AssignStmt): + raise NotImplementedError() + + def visit_LetStmt(self, stmt: LetStmt): + raise NotImplementedError() + + def visit_ForStmt(self, stmt: ForStmt): + raise NotImplementedError() + + def visit_IfStmt(self, stmt: IfStmt): + raise NotImplementedError() + + def visit_ReturnStmt(self, stmt: ReturnStmt): + raise NotImplementedError() + + def visit_AssertStmt(self, stmt: AssertStmt): + raise NotImplementedError() + + def visit_AsmStmt(self, stmt: AsmStmt): + raise NotImplementedError() + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + raise NotImplementedError() + + def visit_SeqStmt(self, stmt: SeqStmt): + raise NotImplementedError() + + +class StmtVisitor(StmtFunctor): + def visit_expr(self, e: Expr): + pass + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + self.visit_expr(stmt.expr) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + self.visit_expr(stmt.buf) + self.visit_expr(stmt.value) + for idx in stmt.indices: + self.visit_expr(idx) + + def visit_AssignStmt(self, stmt: AssignStmt): + self.visit_expr(stmt.var) + self.visit_expr(stmt.value) + + def visit_LetStmt(self, stmt: LetStmt): + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + self.visit_expr(bind_value) + self.visit(stmt.body) + + def visit_ForStmt(self, stmt: ForStmt): + self.visit_expr(stmt.extent) + self.visit(stmt.body) + + def visit_IfStmt(self, stmt: IfStmt): + self.visit_expr(stmt.cond) + self.visit(stmt.then_body) + if stmt.else_body: + self.visit(stmt.else_body) + + def visit_ReturnStmt(self, stmt: ReturnStmt): + self.visit(stmt.ret_value) + + def visit_AssertStmt(self, stmt: AssertStmt): + self.visit(stmt.cond) + + def visit_AsmStmt(self, stmt: AsmStmt): + for expr in stmt.input_exprs: + self.visit_expr(expr) + for expr in stmt.output_exprs: + self.visit_expr(expr) + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + for expr in stmt.exprs: + self.visit_expr(expr) + + def visit_SeqStmt(self, stmt: SeqStmt): + for s in stmt.seq: + self.visit(s) + + +class StmtRewriter(StmtFunctor): + def visit_expr(self, e: Expr): + return e + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + e = self.visit_expr(stmt.expr) + if e is stmt.expr: + return stmt + else: + return EvaluateStmt(e) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + buf = self.visit_expr(stmt.buf) + indices = [self.visit_expr(e) for e in stmt.indices] + value = self.visit_expr(stmt.value) + if buf is stmt.buf and all(a is b for a, b in zip(indices, stmt.indices)) and value is stmt.value: + return stmt + else: + return BufferStoreStmt(buf, indices, value) + + def visit_AssignStmt(self, stmt: AssignStmt): + var = self.visit_expr(stmt.var) + value = self.visit_expr(stmt.value) + if var is stmt.var and value is stmt.value: + return stmt + else: + return AssignStmt(var, value) + + def visit_LetStmt(self, stmt: LetStmt): + bind_values = [self.visit_expr(bind_value) for bind_value in stmt.bind_values] + body = self.visit(stmt.body) + if same_list(bind_values, stmt.bind_values) and body is stmt.body: + return stmt + else: + return LetStmt(stmt.bind_vars, bind_values, body) + + def visit_ForStmt(self, stmt: ForStmt): + loop_var = stmt.loop_var + extent = self.visit_expr(stmt.extent) + body = self.visit(stmt.body) + if loop_var is stmt.loop_var and body is stmt.body: + return stmt + else: + return ForStmt(loop_var, extent, stmt.unroll, body) + + def visit_IfStmt(self, stmt: IfStmt): + cond = self.visit_expr(stmt.cond) + then_body = self.visit(stmt.then_body) + else_body = self.visit(stmt.else_body) if stmt.else_body else None + if cond is stmt.cond and then_body is stmt.then_body and else_body is stmt.else_body: + return stmt + else: + return IfStmt(cond, then_body, else_body) + + def visit_ReturnStmt(self, stmt: ReturnStmt): + ret_value = self.visit_expr(stmt.ret_value) if stmt.ret_value is not None else None + if ret_value is stmt.ret_value: + return stmt + else: + return ReturnStmt(ret_value) + + def visit_AssertStmt(self, stmt: AssertStmt): + cond = self.visit_expr(stmt.cond) + if cond is stmt.cond: + return stmt + else: + return AssertStmt(cond, stmt.msg) + + def visit_AsmStmt(self, stmt: AsmStmt): + input_exprs = [self.visit_expr(e) for e in stmt.input_exprs] + output_exprs = [self.visit_expr(e) for e in stmt.output_exprs] + if same_list(input_exprs, stmt.input_exprs) and same_list(output_exprs, stmt.output_exprs): + return stmt + else: + return AsmStmt(stmt.template_string, list(zip(stmt.output_labels, output_exprs)), + list(zip(stmt.input_labels, input_exprs)), stmt.is_volatile) + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + exprs = [self.visit_expr(e) for e in stmt.exprs] + if same_list(exprs, stmt.exprs): + return stmt + else: + return BlackBoxStmt(stmt.template_string, *exprs) + + def visit_SeqStmt(self, stmt: SeqStmt): + seq = [] + for s in stmt.seq: + seq.append(self.visit(s)) + if all(a is b for a, b in zip(seq, stmt.seq)): + return stmt + else: + return SeqStmt(seq) + + +class StmtExprFunctor(ExprFunctor, StmtFunctor): + def visit_expr(self, e: Expr): + return self.visit(e) + + +class StmtExprVisitor(ExprVisitor, StmtVisitor): + def visit_expr(self, e: Expr): + return self.visit(e) + + +class FuncStmtExprVisitor(StmtExprVisitor): + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return {Function: cls.visit_Function} + + def visit_Function(self, func: Function): + self(func.body) + + +class StmtExprRewriter(ExprRewriter, StmtRewriter): + def visit_expr(self, e: Expr): + return self.visit(e) + + +class FuncStmtExprRewriter(StmtExprRewriter): + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return {Function: cls.visit_Function} + + def visit_Function(self, func: Function): + body = self(func.body) + if body is func.body: + return func + else: + return Function(func.name, params=func.params, body=body, ret_type=func.ret_type, kind=func.kind, local_vars=func.local_vars, + local_const_vars=func.local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + + +class TypeFunctor(NodeFunctor): + @staticmethod + def get_dispatch_mapping(cls) -> Mapping[Type[Node], Any]: + return { + ScalarType: cls.visit_ScalarType, + TensorType: cls.visit_TensorType, + PointerType: cls.visit_PointerType, + TensorPointerType: cls.visit_TensorPointerType, + ReferenceType: cls.visit_ReferenceType, + VoidType: cls.visit_VoidType, + } + + def visit_ScalarType(self, t: ScalarType): + raise NotImplementedError() + + def visit_TensorType(self, t: TensorType): + raise NotImplementedError() + + def visit_PointerType(self, t: PointerType): + raise NotImplementedError() + + def visit_TensorPointerType(self, t: TensorPointerType): + raise NotImplementedError() + + def visit_ReferenceType(self, t: ReferenceType): + raise NotImplementedError() + + def visit_VoidType(self, t: VoidType): + raise NotImplementedError() + + +def same_list(lhs: Sequence, rhs: Sequence): + if len(lhs) != len(rhs): + return False + return all(a is b for a, b in zip(lhs, rhs)) diff --git a/python/hidet/ir/functors/compute_inliner.py b/python/hidet/ir/functors/compute_inliner.py new file mode 100644 index 0000000..d98e72d --- /dev/null +++ b/python/hidet/ir/functors/compute_inliner.py @@ -0,0 +1,45 @@ +from hidet.ir.dialects.compute import TensorNode +from hidet.ir.expr import TensorElement +from hidet.utils import prod +from .base import ExprRewriter +from .util_functors import rewrite + + +class ComputeInlineRewriter(ExprRewriter): + def __init__(self, reduce_limit=0): + super().__init__() + self.reduce_limit = reduce_limit + + def visit_TensorElement(self, e: TensorElement): + base = self(e.base) + if isinstance(base, TensorNode) and base.grid_compute: + grid_compute = base.grid_compute + input_scalars = grid_compute.input_scalars + cnt = sum(prod(input_scalar.reduce_compute.const_shape()) for input_scalar in input_scalars if input_scalar.reduce_compute) + if cnt == 0 or cnt <= self.reduce_limit: + return rewrite(grid_compute.value, {axis: index for axis, index in zip(grid_compute.axes, e.indices)}) + return e + + +def inline_compute(expr: TensorNode, reduce_limit=0) -> TensorNode: + """ + Inline the computation. + + GridCompute(axes => value)[indices] => value (axes replaced by indices). + + Parameters + ---------- + expr: TensorNode + the computation node to inline. + reduce_limit: int, default 0 + reduce_limit < 0: Do not allow reducing compute in the computation, raise an Exception when encountered. + reduce_limit == 0: Allow reducing compute in the computation, but do not expand it. + reduce_limit > 0: Allow reducing compute, and expand it when its extent is equal or greater than reduce_limit. + + Returns + ------- + ret: TensorNode + Inlined compute. + """ + inliner = ComputeInlineRewriter(reduce_limit) + return inliner(expr) diff --git a/python/hidet/ir/functors/hasher.py b/python/hidet/ir/functors/hasher.py new file mode 100644 index 0000000..247e7a6 --- /dev/null +++ b/python/hidet/ir/functors/hasher.py @@ -0,0 +1,159 @@ +from hidet.ir.dialects.compute import TensorNode, ScalarNode +from hidet.ir.dialects.lowlevel import Reference, Address, ReferenceType, TensorPointerType, Dereference, VoidType, PointerType +# from hidet.ir.dialects.pattern import ScalarExprPattern, TensorComputePattern, ReduceComputePattern, AnyExpr +from hidet.ir.dialects.pattern import AnyExpr +from hidet.ir.node import Node +from hidet.ir.expr import Call, TensorElement, Not, Or, And, Constant, Var, Let, Equal, LessThan, FloorDiv, Mod, Div, Multiply, Sub, Add, TensorType, ScalarType, Expr, IfThenElse, RightShift, LeftShift, BitwiseNot, BitwiseOr, BitwiseAnd, TensorSlice, Neg, Cast +from hidet.ir.functors import ExprFunctor, TypeFunctor, NodeFunctor +from hidet.ir.type import TypeNode +from hidet.ir.utils.hash_sum import HashSum + + +class ExprHash(ExprFunctor, TypeFunctor): + def __init__(self): + super().__init__() + + def visit(self, e): + if e in self.memo: + return self.memo[e] + if isinstance(e, (str, float, int)): + ret = HashSum(e) + elif isinstance(e, tuple): + ret = HashSum(tuple(self(v) for v in e)) + elif isinstance(e, (Expr, TypeNode)): + ret = NodeFunctor.visit(self, e) + elif e is None: + ret = HashSum(None) + else: + # for stmt/func/... + ret = HashSum(e) + self.memo[e] = ret + return ret + + def hash(self, expr): + self.memo.clear() + return self(expr) + + def visit_Var(self, e: Var): + return HashSum(e) + e.class_index() + + def visit_Constant(self, e: Constant): + return HashSum(e.value) + self(e.data_type) + e.class_index() + + def visit_Add(self, e: Add): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_Sub(self, e: Sub): + return self(e.a) + self(e.b) + e.class_index() + + def visit_Multiply(self, e: Multiply): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_Div(self, e: Div): + return self(e.a) + self(e.b) + e.class_index() + + def visit_Mod(self, e: Mod): + return self(e.a) + self(e.b) + e.class_index() + + def visit_FloorDiv(self, e: FloorDiv): + return self(e.a) + self(e.b) + e.class_index() + + def visit_Neg(self, e: Neg): + return self(e.a) + e.class_index() + + def visit_LessThan(self, e: LessThan): + return self(e.a) + self(e.b) + e.class_index() + + def visit_LessEqual(self, e: LessThan): + return self(e.a) + self(e.b) + e.class_index() + + def visit_Equal(self, e: Equal): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_IfThenElse(self, e: IfThenElse): + return self(e.cond) + self(e.then_expr) + self(e.else_expr) + e.class_index() + + def visit_And(self, e: And): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_Or(self, e: Or): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_Not(self, e: Not): + return self(e.a) + e.class_index() + + def visit_BitwiseAnd(self, e: BitwiseAnd): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_BitwiseOr(self, e: BitwiseOr): + return (self(e.a) & self(e.b)) + e.class_index() + + def visit_BitwiseNot(self, e: BitwiseNot): + return self(e.base) + e.class_index() + + def visit_LeftShift(self, e: LeftShift): + return (self(e.base) + self(e.cnt)) + e.class_index() + + def visit_RightShift(self, e: RightShift): + return (self(e.base) + self(e.cnt)) + e.class_index() + + def visit_TensorElement(self, e: TensorElement): + return self(e.base) + self(e.indices) + e.class_index() + + def visit_Cast(self, e: Cast): + return self(e.expr) + self(e.target_type) + e.class_index() + + def visit_Dereference(self, e: Dereference): + return self(e.expr) + e.class_index() + + def visit_Address(self, e: Address): + return self(e.expr) + e.class_index() + + def visit_Reference(self, e: Reference): + return self(e.expr) + e.class_index() + + def visit_Call(self, e: Call): + return self(e.func_var) + self(e.args) + e.class_index() + + def visit_Let(self, e: Let): + return self(e.var) + self(e.value) + self(e.body) + e.class_index() + + def visit_ScalarType(self, t: ScalarType): + return self(t.name) + t.class_index() + + def visit_TensorType(self, t: TensorType): + return self(t.scalar_type) + self(t.scope.name) + self(t.shape) + t.class_index() + + def visit_PointerType(self, t: PointerType): + return self(t.base_type) + t.class_index() + + def visit_TensorPointerType(self, t: TensorPointerType): + return self(t.tensor_type) + t.class_index() + + def visit_ReferenceType(self, t: ReferenceType): + return self(t.base_type) + t.class_index() + + def visit_VoidType(self, t: VoidType): + return t.class_index() + + def visit_TensorSlice(self, e: TensorSlice): + return self(e.base) + self(e.indices) + self(e.starts) + self(e.ends) + e.class_index() + + def visit_ScalarNode(self, e: ScalarNode): + if e.reduce_compute: + rc = e.reduce_compute + return self(rc.axes) + self(rc.value) + self(rc.shape) + e.class_index() + else: + return HashSum(e) + e.class_index() + + def visit_TensorNode(self, e: TensorNode): + if e.grid_compute: + rc = e.grid_compute + return self(rc.axes) + self(rc.value) + self(rc.shape) + e.class_index() + else: + return HashSum(e) + e.class_index() + + def visit_AnyExpr(self, e: AnyExpr): + return HashSum(e) + e.class_index() + + diff --git a/python/hidet/ir/functors/printer.py b/python/hidet/ir/functors/printer.py new file mode 100644 index 0000000..5722167 --- /dev/null +++ b/python/hidet/ir/functors/printer.py @@ -0,0 +1,422 @@ +from typing import Dict, Optional, List +from hidet.ir.node import Node +from hidet.ir.func import IRModule, Function +from hidet.ir.type import ScalarType, TensorType, TypeNode +from hidet.ir.expr import Constant, Var, Call, TensorElement, Add, Multiply, Expr, LessThan, FloorDiv, Mod, Equal, Div, Sub, Not, Or, And, Let, IfThenElse, TensorSlice, RightShift, LeftShift, BitwiseNot, BitwiseOr, BitwiseAnd, Neg, Cast +from hidet.ir.stmt import SeqStmt, IfStmt, ForStmt, AssignStmt, BufferStoreStmt, EvaluateStmt, Stmt, AssertStmt, BlackBoxStmt, AsmStmt, ReturnStmt, LetStmt +from hidet.ir.dialects.compute import TensorNode, ScalarNode +from hidet.ir.dialects.lowlevel import VoidType, PointerType, Dereference, Address, ReferenceType, TensorPointerType, Reference +from hidet.ir.dialects.pattern import AnyExpr +from hidet.ir.layout import RowMajorLayout, ColumnMajorLayout +from hidet.ir.task import Task, Prologue, Epilogue, InverseMap +from hidet.utils.doc import Doc, NewLine, Text, doc_join +from hidet.utils.namer import Namer + +from .base import StmtExprFunctor, TypeFunctor, NodeFunctor + + +class IRPrinter(StmtExprFunctor, TypeFunctor): + def __init__(self): + super().__init__() + self.namer = Namer() + self.ir_module: Optional[IRModule] = None + + def __call__(self, node): + return self.visit(node) + + def visit(self, obj): + if isinstance(obj, (list, tuple)): + return doc_join([self(v) for v in obj], ', ') + elif isinstance(obj, dict): + return doc_join([self(k) + ': ' + self(v) for k, v in obj.items()], ', ') + elif isinstance(obj, str): + return Text(obj.replace('\n', '\\n').replace('\t', '\\t')) + elif isinstance(obj, (int, float)): + return Text(str(obj)) + elif isinstance(obj, TypeNode): + return TypeFunctor.visit(self, obj) + elif isinstance(obj, Function): + return self.visit_Function(obj) + elif isinstance(obj, IRModule): + return self.visit_IRModule(obj) + elif isinstance(obj, (Expr, Stmt)): + return NodeFunctor.visit(self, obj) + elif isinstance(obj, Task): + return self.visit_Task(obj) + elif isinstance(obj, Prologue): + return self.visit_Prologue(obj) + elif isinstance(obj, Epilogue): + return self.visit_Epilogue(obj) + elif isinstance(obj, InverseMap): + return self.visit_InverseMap(obj) + elif obj is None: + return Text('None') + else: + return object.__repr__(obj) + + def visit_Function(self, func: Function): + self.namer.clear() + doc = Doc() + + # parameters + doc += 'fn(' + param_docs = [] + for i in range(len(func.params)): + param = func.params[i] + param_docs.append([NewLine(), self(param), ': ', self(param.type)]) + doc += doc_join(param_docs, Text(', ')) + doc += ')' + doc = doc.indent(6) + + # const locals + for local_var, local_value in func.local_const_vars: + doc += (NewLine() + Text('declare ') + self(local_var) + Text(': ') + self(local_var.type) + ' = ' + self(local_value)).indent(4) + + # locals + for local_var in func.local_vars: + doc += (NewLine() + Text('declare ') + self(local_var) + Text(': ') + self(local_var.type)).indent(4) + + # body + doc += self(func.body).indent(4) + + return doc + + def visit_IRModule(self, ir_module: IRModule): + doc = Doc() + self.ir_module = ir_module + if ir_module.task is not None: + doc += str(ir_module.task) + doc += NewLine() + for name, func in ir_module.functions.items(): + doc += ['def ', name, ' ', self(func), NewLine(), NewLine()] + return doc + + def visit_Add(self, e: Add): + return Text('(') + self(e.a) + ' + ' + self(e.b) + ')' + + def visit_Sub(self, e: Sub): + return Text('(') + self(e.a) + ' - ' + self(e.b) + ')' + + def visit_Multiply(self, e: Multiply): + return Text('(') + self(e.a) + ' * ' + self(e.b) + ')' + + def visit_Div(self, e: Div): + return Text('(') + self(e.a) + ' / ' + self(e.b) + ')' + + def visit_Mod(self, e: Mod): + return Text('(') + self(e.a) + ' % ' + self(e.b) + ')' + + def visit_FloorDiv(self, e: FloorDiv): + return Text('(') + self(e.a) + ' / ' + self(e.b) + ')' + + def visit_Neg(self, e: Neg): + return Text('(-') + self(e.a) + ')' + + def visit_LessThan(self, e: LessThan): + return Text('(') + self(e.a) + ' < ' + self(e.b) + ')' + + def visit_LessEqual(self, e: LessThan): + return Text('(') + self(e.a) + ' <= ' + self(e.b) + ')' + + def visit_Equal(self, e: Equal): + return Text('(') + self(e.a) + ' == ' + self(e.b) + ')' + + def visit_And(self, e: And): + return Text('(') + self(e.a) + ' && ' + self(e.b) + ')' + + def visit_Or(self, e: Or): + return Text('(') + self(e.a) + ' || ' + self(e.b) + ')' + + def visit_Not(self, e: Not): + return Text('!') + self(e.a) + + def visit_BitwiseAnd(self, e: BitwiseAnd): + return '(' + self(e.a) + ' & ' + self(e.b) + ')' + + def visit_BitwiseOr(self, e: BitwiseOr): + return '(' + self(e.a) + ' | ' + self(e.b) + ')' + + def visit_BitwiseNot(self, e: BitwiseNot): + return '(~' + self(e.base) + ')' + + def visit_LeftShift(self, e: LeftShift): + return '(' + self(e.base) + ' << ' + self(e.cnt) + ')' + + def visit_RightShift(self, e: RightShift): + return '(' + self(e.base) + ' >> ' + self(e.cnt) + ')' + + def visit_TensorElement(self, e: TensorElement): + return self(e.base) + '[' + self(e.indices) + ']' + + def visit_TensorSlice(self, e: TensorSlice): + subscriptions = [] + for index, start, end in zip(e.indices, e.starts, e.ends): + if index is not None: + subscriptions.append(self(index)) + else: + doc = Doc() + if start is not None: + doc += self(start) + doc += ':' + if end is not None: + doc += self(end) + subscriptions.append(doc) + return self(e.base) + '[' + doc_join(subscriptions, ', ') + ']' + + def visit_IfThenElse(self, e: IfThenElse): + return '(' + self(e.cond) + ' ? ' + self(e.then_expr) + ' : ' + self(e.else_expr) + ')' + + def visit_Call(self, e: Call): + doc = Doc() + # name + doc += e.func_var.hint + # launch + func_name = e.func_var.hint + if self.ir_module and func_name in self.ir_module.functions: + func = self.ir_module.functions[func_name] + if func.kind == 'cuda_kernel': + doc += '<<<' + self(func.attrs['cuda_grid_dim']) + ', ' + self(func.attrs['cuda_block_dim']) + '>>>' + # params + doc += '(' + self(e.args) + ')' + return doc + + def visit_Let(self, e: Let): + return Text('let(') + self(e.var) + '=' + self(e.value) + ': ' + self(e.body) + ')' + + def visit_Cast(self, e: Cast): + return Text('cast(') + self(e.target_type) + ', ' + self(e.expr) + ')' + + def visit_Reference(self, e: Reference): + return Text('Ref(') + self(e.expr) + ')' + + def visit_Dereference(self, e: Dereference): + return Text('*') + self(e.expr) + + def visit_Address(self, e: Address): + return Text('&') + self(e.expr) + + def visit_Var(self, e: Var): + return Text(self.namer.get_name(e)) + + def visit_Constant(self, e: Constant): + if e.value is None: + return self('Constant(None, type=') + self(e.data_type) + ')' + if e.is_tensor(): + return 'ConstTensor({}, {})'.format(e.value.shape, e.data_type) + else: + dtype = e.data_type.name + if dtype == 'float32': + ret = '{}f'.format(float(e.value)) + elif dtype == 'float16': + ret = 'half({})'.format(float(e.value)) + elif dtype == 'int32': + ret = '{}'.format(int(e.value)) + else: + ret = '{}({})'.format(dtype, e.value) + return Text(ret) + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + return NewLine() + self(stmt.expr) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + doc = NewLine() + doc += self(stmt.buf) + doc += '[' + self(stmt.indices) + ']' + doc += ' = ' + self(stmt.value) + return doc + + def visit_AssignStmt(self, stmt: AssignStmt): + return NewLine() + self(stmt.var) + ' = ' + self(stmt.value) + + def visit_LetStmt(self, stmt: LetStmt): + doc = Doc() + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + doc += NewLine() + 'let ' + self(bind_var) + ' = ' + self(bind_value) + doc += self(stmt.body) + # doc += self(stmt.body).indent() + return doc + + def visit_ForStmt(self, stmt: ForStmt): + rng = Text('range(') + self(stmt.extent) + ')' + doc = NewLine() + Text('for ') + self(stmt.loop_var) + ' in ' + rng + if stmt.unroll is not None: + if stmt.unroll: + doc += '[unroll]' + else: + doc += '[no-unroll]' + doc += self(stmt.body).indent(4) + return doc + + def visit_IfStmt(self, stmt: IfStmt): + doc = NewLine() + Text('if ') + self(stmt.cond) + doc += self(stmt.then_body).indent(4) + if stmt.else_body: + doc += NewLine() + Text('else') + doc += self(stmt.else_body).indent(4) + return doc + + def visit_ReturnStmt(self, stmt: ReturnStmt): + doc = NewLine() + Text('return') + if stmt.ret_value: + doc += ' ' + self(stmt.ret_value) + return doc + + def visit_AssertStmt(self, stmt: AssertStmt): + return NewLine() + 'assert(' + self(stmt.cond) + ', ' + stmt.msg + ')' + + def visit_AsmStmt(self, stmt: AsmStmt): + volatile_doc = 'volatile ' if stmt.is_volatile else '' + template_doc = '"' + Text(stmt.template_string) + '"' + output_docs = [] + for label, expr in zip(stmt.output_labels, stmt.output_exprs): + output_docs.append('"' + Text(label) + '"' + '(' + self(expr) + ')') + input_docs = [] + for label, expr in zip(stmt.input_labels, stmt.input_exprs): + input_docs.append('"' + Text(label) + '"' + '(' + self(expr) + ')') + return NewLine() + 'asm ' + volatile_doc + '(' + template_doc + ' : ' + doc_join(output_docs, ', ') + ' : ' + doc_join(input_docs, ', ') + ');' + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + expr_docs = [str(self(e)) for e in stmt.exprs] + stmt_string: str = stmt.template_string.format(*expr_docs) + lines = stmt_string.split('\n') + doc = Text('') + for line in lines: + doc += NewLine() + line + return doc + + def visit_SeqStmt(self, stmt: SeqStmt): + doc = Doc() + for idx, s in enumerate(stmt.seq): + doc += self(s) + return doc + + def visit_ScalarType(self, t: ScalarType): + return Text('{}'.format(t.name)) + + def visit_TensorType(self, t: TensorType): + assert t.scope is not None + if isinstance(t.layout, RowMajorLayout): + layout = 'row_major' + elif isinstance(t.layout, ColumnMajorLayout): + layout = 'column_major' + elif t.layout is None: + layout = 'None' + else: + layout = type(t.layout).__name__ + items = [self(t.scalar_type), '[' + self(t.shape) + ']', self(t.scope.name), self(layout)] + return Text('tensor(') + doc_join(items, ', ') + ')' + + def visit_PointerType(self, t: PointerType): + return Text('PointerType(') + self(t.base_type) + ')' + + def visit_TensorPointerType(self, t: TensorPointerType): + return Text('TensorPointerType(') + self(t.tensor_type) + ')' + + def visit_ReferenceType(self, t: ReferenceType): + return Text('ReferenceType(') + self(t.base_type) + ')' + + def visit_VoidType(self, t: VoidType): + return Text('VoidType') + + def visit_AnyExpr(self, e: AnyExpr): + return Text('AnyExpr') + + def print_tensor_nodes(self, nodes: List[TensorNode], exclude_nodes: List[TensorNode] = None) -> Doc: + from hidet.ir.functors import collect + if exclude_nodes is None: + exclude_nodes = [] + nodes: List[TensorNode] = collect(nodes, TensorNode) + doc = Doc() + for node in reversed(nodes): + if node in exclude_nodes: + continue + if node.grid_compute is None: + doc += NewLine() + self.namer.get_name(node) + ': ' + self(node.data_type) + else: + gc = node.grid_compute + items = [ + '[' + self(gc.shape) + ']', + '(' + self(gc.axes) + ') => ' + self(gc.value), + ] + doc += NewLine() + self.namer.get_name(node) + ': ' + 'grid(' + doc_join(items, ', ') + ')' + return doc + + def visit_Task(self, e: Task): + lines = [ + Text('name: ') + e.name, + Text('parameters: ') + (NewLine() + doc_join(['{}: {}'.format(self.namer.get_name(v), self(v.data_type)) for v in e.parameters], NewLine())).indent(), + Text('inputs: ') + '[' + doc_join([self.namer.get_name(v) for v in e.inputs], ', ') + ']', + Text('outputs: ') + '[' + doc_join([self.namer.get_name(v) for v in e.outputs], ', ') + ']', + Text('computations: ') + self.print_tensor_nodes(e.outputs).indent(), + Text('attributes: {') + self(e.attributes) + '}' + ] + front_part = doc_join(lines, NewLine()) + inverse_map_doc = Doc() + prologue_doc = Doc() + epilogue_doc = Doc() + if e.inverse_map: + inverse_map_doc += NewLine() + Text('inverse_map:') + for tensor, inverse_map in e.inverse_map.items(): + inverse_map_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + self(inverse_map)).indent() + if e.prologues: + prologue_doc += NewLine() + Text('prologue:') + for tensor, prologue in e.prologues.items(): + prologue_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + self(prologue)).indent() + if e.epilogues: + epilogue_doc += NewLine() + Text('epilogue:') + for tensor, epilogue in e.epilogues.items(): + epilogue_doc += (NewLine() + self.namer.get_name(tensor) + ': ' + self(epilogue)).indent() + return Text('Task(') + (NewLine() + front_part + inverse_map_doc + prologue_doc + epilogue_doc).indent() + NewLine() + ')' + + def visit_Prologue(self, e: Prologue): + from hidet.ir.functors import collect + items = [ + '(' + self(e.indices) + ') => ' + self(e.value), + 'extra_inputs: [' + self(e.extra_inputs) + ']' + ] + doc = 'Prologue(' + doc_join(items, ', ') + ')' + nodes = [node for node in collect(e.value, TensorNode) if node.grid_compute is not None] + if len(nodes) > 0: + doc += self.print_tensor_nodes(nodes, exclude_nodes=[]).indent() + return doc + + def visit_Epilogue(self, e: Epilogue): + from hidet.ir.functors import collect + items = [ + '(' + self(e.indices) + ')', + self(e.orig_value) + ' => ' + self(e.value), + 'out_indices=(' + self(e.out_indices) + ')', + 'out_tensor=' + self(e.out_tensor) + ')' + ] + doc = doc_join(items, ', ') + # ret = 'Epilogue((' + self(e.indices) + '), ' + self(e.orig_value) + ' => ' + self(e.value) + ', out_indices=(' + self(e.out_indices) + '), out_tensor=' + self(e.out_tensor) + ')' + nodes = [node for node in collect(e.value, TensorNode) if node.grid_compute is not None] + if len(nodes) > 0: + doc += self.print_tensor_nodes(nodes, exclude_nodes=[]).indent() + return doc + + def visit_InverseMap(self, e: InverseMap): + return 'InverseMap([' + self(e.axes) + '] => [' + self(e.indices) + '])' + + def visit_ScalarNode(self, e: ScalarNode): + if e.reduce_compute is None: + return self.namer.get_name(e, e.name) + else: + rc = e.reduce_compute + items = [ + '[' + self(rc.shape) + ']', + '(' + self(rc.axes) + ') => ' + self(rc.value), + self(rc.reduce_type) + ] + return 'reduce(' + doc_join(items, ', ') + ')' + + def visit_TensorNode(self, e: TensorNode): + return self.namer.get_name(e) + + +def astext(obj: Node) -> str: + if isinstance(obj, Node): + printer = IRPrinter() + return str(printer(obj)) + else: + raise ValueError() diff --git a/python/hidet/ir/functors/simplifier.py b/python/hidet/ir/functors/simplifier.py new file mode 100644 index 0000000..51121b4 --- /dev/null +++ b/python/hidet/ir/functors/simplifier.py @@ -0,0 +1,143 @@ +from typing import Union +import operator +from hidet.ir.expr import Expr, BinaryOp, Add, Sub, Multiply, Div, Mod, FloorDiv, LessThan, LessEqual, Equal, Constant, And, Or, Not +from hidet.ir.expr import is_one, is_zero, is_true, is_false, convert +from hidet.ir.stmt import Stmt, IfStmt, SeqStmt, ForStmt +from hidet.ir.functors import StmtExprRewriter, same_list, rewrite + + +class Simplifier(StmtExprRewriter): + def visit_Binary(self, e: BinaryOp): + a = self(e.a) + b = self(e.b) + if isinstance(e, Add): + if is_zero(a): + return b + if is_zero(b): + return a + elif isinstance(e, Sub): + if is_zero(b): + return a + elif isinstance(e, Multiply): + if is_one(a): + return b + if is_one(b): + return a + if is_zero(a) or is_zero(b): + return convert(0) + elif isinstance(e, Div): + if is_one(b): + return a + elif isinstance(e, Mod): + if is_one(e.b): + return convert(0) + elif isinstance(e, FloorDiv): + if is_one(b): + return a + elif isinstance(e, LessThan): + pass + elif isinstance(e, LessEqual): + pass + elif isinstance(e, Equal): + pass + elif isinstance(e, And): + if is_false(a) or is_false(b): + return convert(False) + if is_true(a): + return b + if is_true(b): + return a + elif isinstance(e, Or): + if is_true(a) or is_true(b): + return convert(True) + if is_false(a): + return b + if is_false(b): + return a + else: + raise ValueError() + + if isinstance(a, Constant) and isinstance(b, Constant): + op_dict = { + Add: operator.add, + Sub: operator.sub, + Multiply: operator.mul, + Div: operator.truediv, + Mod: operator.mod, + FloorDiv: operator.floordiv, + LessThan: operator.lt, + Equal: operator.eq + } + if e.__class__ in op_dict: + if a.data_type.name == 'int32' and b.data_type.name == 'int32' and isinstance(e, Div): + # the Div for int32 will use floordiv. Override the native behavior of python + return a.value // b.value + else: + return convert(op_dict[e.__class__](a.value, b.value)) + elif isinstance(e, And): + return convert(a.value and b.value) + elif isinstance(e, Or): + return convert(a.value or b.value) + else: + raise ValueError() + if a is e.a and b is e.b: + return e + return e.__class__(a, b) + + def visit_Not(self, e: Not): + a = self(e.a) + if isinstance(a, Constant): + return convert(not a.value) + if a is e.a: + return e + else: + return Not(a) + + def visit_IfStmt(self, stmt: IfStmt): + cond = self.visit_expr(stmt.cond) + then_body = self.visit(stmt.then_body) + else_body = self.visit(stmt.else_body) if stmt.else_body else None + if is_true(cond): + return then_body + elif is_false(cond): + if else_body: + return else_body + else: + return SeqStmt([]) + else: + if cond is stmt.cond and then_body is stmt.then_body and else_body is stmt.else_body: + return stmt + else: + return IfStmt(cond, then_body, else_body) + + def visit_ForStmt(self, stmt: ForStmt): + loop_var = self(stmt.loop_var) + extent = self(stmt.extent) + body = self(stmt.body) + if is_one(extent): + return rewrite(stmt.body, {loop_var: convert(0)}) + else: + if loop_var is stmt.loop_var and body is stmt.body: + return stmt + else: + return ForStmt(loop_var, extent, stmt.unroll, body) + + +def simplify(node: Union[Stmt, Expr], repeat_limit=10): + if isinstance(node, (int, float)): + return node + simplifier = Simplifier() + for i in range(repeat_limit): + old_node = node + node = simplifier(node) + if old_node is node: + break + return node + + +def simplify_to_int(node: Union[Expr, int], repeat_limit=10) -> int: + if isinstance(node, int): + return node + node = simplify(node, repeat_limit) + assert isinstance(node, Constant) and node.data_type.name in ['int32', 'uint8'] + return node.value diff --git a/python/hidet/ir/functors/type_infer.py b/python/hidet/ir/functors/type_infer.py new file mode 100644 index 0000000..d55a645 --- /dev/null +++ b/python/hidet/ir/functors/type_infer.py @@ -0,0 +1,152 @@ +from typing import List +from hidet.ir.type import ScalarType, TensorType, FuncType +from hidet.ir.expr import BinaryOp, Add, Sub, Multiply, Div, Mod, FloorDiv, Condition, LessThan, Equal, IfThenElse, TensorSlice, Not, Or, And, LessEqual, Let, RightShift, LeftShift, BitwiseNot, BitwiseOr, BitwiseAnd, Neg +from hidet.ir.expr import Var, Constant, TensorElement, Call, Cast +from hidet.ir.dialects.compute import TensorNode, ScalarNode +from hidet.ir.dialects.lowlevel import PointerType, Dereference, Reference, Address, TensorPointerType + +from .base import ExprFunctor +from ..dialects.pattern import AnyExpr + + +def is_bool(tp): + return isinstance(tp, ScalarType) and tp.name == 'bool' + + +class TypeInfer(ExprFunctor): + def visit_Address(self, e: Address): + base_type = self(e.expr) + return PointerType(base_type=base_type) + + def visit_Reference(self, e: Reference): + return self(e.expr) + + def visit_Binary(self, e: BinaryOp): + a_dtype: ScalarType = self.visit(e.a) + b_dtype: ScalarType = self.visit(e.b) + # if not atype or not btype: + # return ScalarType(name=None) + if isinstance(e, (Add, Sub, Multiply, Div, Mod, FloorDiv)): + return ScalarType(max(a_dtype, b_dtype)) + elif isinstance(e, Condition): + return ScalarType('bool') + else: + raise NotImplementedError('Binary op type infer {}'.format(type(e))) + + def visit_Neg(self, e: Neg): + return self(e.a) + + def visit_Add(self, e: Add): + return self.visit_Binary(e) + + def visit_Sub(self, e: Sub): + return self.visit_Binary(e) + + def visit_Multiply(self, e: Multiply): + return self.visit_Binary(e) + + def visit_Div(self, e: Div): + return self.visit_Binary(e) + + def visit_Mod(self, e: Mod): + return self.visit_Binary(e) + + def visit_FloorDiv(self, e: FloorDiv): + return self.visit_Binary(e) + + def visit_LessThan(self, e: LessThan): + return self.visit_Binary(e) + + def visit_Equal(self, e: Equal): + return self.visit_Binary(e) + + def visit_LessEqual(self, e: LessEqual): + return self.visit_Binary(e) + + def visit_And(self, e: And): + return self.visit_Binary(e) + + def visit_Or(self, e: Or): + return self.visit_Binary(e) + + def visit_Not(self, e: Not): + assert is_bool(self.visit(e.a)) + return ScalarType('bool') + + def visit_BitwiseAnd(self, e: BitwiseAnd): + return self.visit(e.a) + + def visit_BitwiseOr(self, e: BitwiseOr): + return self.visit(e.a) + + def visit_BitwiseNot(self, e: BitwiseNot): + return self.visit(e.base) + + def visit_LeftShift(self, e: LeftShift): + return self.visit(e.base) + + def visit_RightShift(self, e: RightShift): + return self.visit(e.base) + + def visit_TensorElement(self, e: TensorElement): + base_type = self.visit(e.base) + if isinstance(base_type, TensorType): + return base_type.scalar_type + elif isinstance(base_type, PointerType): + return base_type.base_type + elif isinstance(base_type, TensorPointerType): + return base_type.tensor_type.scalar_type + else: + raise NotImplementedError() + + def visit_TensorSlice(self, e: TensorSlice): + raise NotImplementedError() + + def visit_IfThenElse(self, e: IfThenElse): + cond_type = self.visit(e.cond) + true_type = self.visit(e.then_expr) + false_type = self.visit(e.else_expr) + assert is_bool(cond_type) + if not (isinstance(true_type, ScalarType) and isinstance(false_type, ScalarType) and true_type.name == false_type.name): + raise ValueError('If-then-else operand 1 and 2 have different types ({} vs {}): {}'.format(true_type, false_type, e)) + return true_type + + def visit_Let(self, e: Let): + self.visit(e.value) + return self.visit(e.body) + + def visit_Call(self, e: Call): + func_var = e.func_var + func_type = func_var.type + if not isinstance(func_type, FuncType): + raise ValueError('Type infer failed, expect a function var "{}" but got variable with type "{}"'.format(func_var, func_type)) + args_type = [self(arg) for arg in e.args] + return func_type.ret_type_on(args_type) + + def visit_Cast(self, e: Cast): + return e.target_type + + def visit_Dereference(self, e: Dereference): + tp = self.visit(e.expr) + assert isinstance(tp, PointerType) + return tp.base_type + + def visit_Var(self, e: Var): + return e.type + + def visit_Constant(self, e: Constant): + return e.data_type + + def visit_ScalarNode(self, e: ScalarNode): + return e.data_type + + def visit_TensorNode(self, e: TensorNode): + return e.data_type + + def visit_AnyExpr(self, e: AnyExpr): + raise ValueError('Can not infer type of an AnyExpr.') + + +def infer_type(expr): + infer = TypeInfer() + return infer(expr) diff --git a/python/hidet/ir/functors/util_functors.py b/python/hidet/ir/functors/util_functors.py new file mode 100644 index 0000000..8cfdd10 --- /dev/null +++ b/python/hidet/ir/functors/util_functors.py @@ -0,0 +1,141 @@ +from typing import Union, Mapping +from hidet.ir.expr import Let +from hidet.ir.func import Function +from hidet.ir.stmt import Stmt, ForStmt, LetStmt +from hidet.ir.dialects.compute import * + +from .base import StmtExprVisitor, StmtExprRewriter, FuncStmtExprVisitor + + +class StmtExprMapRewriter(StmtExprRewriter): + def __init__(self, rmap): + super().__init__() + self.rmap = rmap + + def visit(self, e): + if e not in self.memo: + if e in self.rmap: + self.memo[e] = self.rmap[e] + else: + self.memo[e] = StmtExprRewriter.visit(self, e) + return self.memo[e] + + +class SubStmtExprCollector(FuncStmtExprVisitor): + def __init__(self, expr_types, stop_when_found=False): + super().__init__() + self.expr_types = expr_types + self.stop_when_found = stop_when_found + self.exprs = [] + + def collect(self, e): + self.exprs.clear() + self.visit(e) + return self.exprs + + def visit(self, e): + if e in self.memo: + return self.memo[e] + if isinstance(e, self.expr_types): + self.exprs.append(e) + if self.stop_when_found: + self.memo[e] = None + return + StmtExprVisitor.visit(self, e) + + +class FreeVarCollector(StmtExprVisitor): + def __init__(self): + super().__init__() + self.defined = set() + self.free_vars = set() + + def collect(self, e): + self.defined.clear() + self.visit(e) + return self.free_vars + + def visit_LetStmt(self, stmt: LetStmt): + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + self.visit(bind_value) + self.defined.add(bind_var) + self.visit(stmt.body) + for bind_var in stmt.bind_vars: + self.defined.remove(bind_var) + + def visit_ForStmt(self, stmt: ForStmt): + self.defined.add(stmt.loop_var) + StmtExprVisitor.visit_ForStmt(self, stmt) + self.defined.remove(stmt.loop_var) + + def visit_Var(self, e: Var): + if e not in self.defined: + self.free_vars.add(e) + + +class CloneRewriter(StmtExprRewriter): + def clone(self, obj: Union[Stmt, Expr]): + return self(obj) + + def visit_LetStmt(self, stmt: LetStmt): + bind_vars = [] + bind_values = [] + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + bind_vars.append(Var(bind_var.hint, bind_var.type)) + self.memo[bind_var] = bind_vars[-1] + bind_values.append(self(bind_value)) + return LetStmt(bind_vars, bind_values, self(stmt.body)) + + def visit_Let(self, e: Let): + v = Var(e.var.hint, e.var.type) + self.memo[e.var] = v + return Let(v, self(e.value), self(e.body)) + + +def rewrite(node: Union[Expr, Stmt, tuple], rewrite_map: Mapping[Union[Stmt, Expr], Union[Stmt, Expr]]): + assert isinstance(rewrite_map, dict) + rewriter = StmtExprMapRewriter(rewrite_map) + return rewriter.rewrite(node) + + +def collect(node: Union[Function, Expr, Stmt, list, tuple], node_types, stop_when_found=False) -> list: + """ + Collect sub-nodes in given node with specific types. + + Parameters + ---------- + node: Union[Function, Expr, Stmt, list, tuple] + The root node to start collecting. + node_types: Sequence[Type[Union[Stmt, Expr]]], or Type[Stmt], or Type[Expr] + The node types to collect, can be arbitrary subclass of Expr and Stmt + stop_when_found: bool + When found node of given type, whether to collect the sub-nodes of that node. + + Returns + ------- + ret: List[Node] + The collected nodes. + + """ + if not isinstance(node_types, tuple): + if isinstance(node_types, list): + node_types = tuple(node_types) + elif issubclass(node_types, (Stmt, Expr)): + node_types = (node_types,) + else: + raise ValueError() + if isinstance(node, list): + node = tuple(node) + + collector = SubStmtExprCollector(node_types, stop_when_found) + collected = collector.collect(node) + return collected + + +def clone(node: Union[Stmt, Expr]) -> Union[Stmt, Expr]: + return CloneRewriter()(node) + + +def collect_free_vars(node: Union[Expr, Stmt]): + collector = FreeVarCollector() + return collector.collect(node) diff --git a/python/hidet/ir/layout/__init__.py b/python/hidet/ir/layout/__init__.py new file mode 100644 index 0000000..e7f6d1d --- /dev/null +++ b/python/hidet/ir/layout/__init__.py @@ -0,0 +1,7 @@ +from . import generic + +from .task_layout import TaskLayout, TaskLayoutExpander +from .generic import get_task_layouts, register_task_layout, register_task_layout_generator, TaskLayoutGenerator +from .task_layout import row_major_layout, col_major_layout, full_layout, row_map, col_map, repeat_map, grid_map + +from .data_layout import DataLayout, StridesLayout, RowMajorLayout, ColumnMajorLayout, row_layout, col_layout, local_layout, data_layout diff --git a/python/hidet/ir/layout/data_layout.py b/python/hidet/ir/layout/data_layout.py new file mode 100644 index 0000000..1d56dff --- /dev/null +++ b/python/hidet/ir/layout/data_layout.py @@ -0,0 +1,387 @@ +from collections import OrderedDict +from typing import Sequence, Union, List, Callable, Mapping, Dict, Tuple, Optional + +from hidet import ir +from hidet.ir.node import Node +from hidet.utils import prod + +# typing forward declaration +Expr = 'Expr' +Int = Union['Expr', int] +Bool = Union['Expr', bool] + + +def is_atom(expr: Expr): + from hidet.ir import Constant, Var + return isinstance(expr, (Constant, Var)) + + +def variablize(expr_list: Sequence[Expr], var2value: Dict['Var', Expr]) -> List['Var']: + from hidet.ir import var + out = [] + for expr in expr_list: + if is_atom(expr): + out.append(expr) + else: + v = var('v') + var2value[v] = expr + out.append(v) + return out + + +def concat_let_expr(var2value, body: Expr): + from hidet.ir import Let + for var, value in reversed(var2value.items()): + body = Let(var, value, body) + return body + + +def to_data_layout(obj): + if isinstance(obj, (tuple, list)): + assert all(isinstance(v, int) for v in obj) + return DataLayout.row_major(obj) + elif isinstance(obj, DataLayout): + return obj + else: + raise ValueError('Can not convert {} to a DataLayout, expect a list or tuple of ints'.format(obj)) + + +# data layout +class DataLayout(Node): + def __init__(self, shape=None, size=None): + self.shape: Tuple[Int] = tuple([int(v) if isinstance(v, ir.Constant) else v for v in shape]) if shape is not None else None + self.size: Int = size + + def __call__(self, *args: Int): + return self.serialize(*args) + + def __add__(self, other): + return DataLayout.concat(lhs=self, rhs=other) + + def __radd__(self, other): + return DataLayout.concat(lhs=other, rhs=self) + + def __mul__(self, other): + return DataLayout.product(outer=self, inner=other) + + def const_shape(self) -> List[int]: + return [int(v) for v in self.shape] + + def global2local(self, *args: Int) -> Int: + raise NotImplementedError() + + def global2cond(self, *args: Int) -> Bool: + raise NotImplementedError() + + def serialize(self, *args: Int): + if len(args) == 1 and isinstance(args[0], (tuple, list)): + # support usage such as within_bound([1, 2, 3]) + args = args[0] + assert len(args) == len(self.shape) + # var2value = OrderedDict() + # arg_vars = variablize(args, var2value) + # scalar_index = self.global2local(*arg_vars) + # scalar_index = concat_let_expr(var2value=var2value, body=scalar_index) + scalar_index = self.global2local(*args) + return scalar_index + + def within_bound(self, *args: Int): + if isinstance(args[0], (tuple, list)) and len(args) == 1: + # support usage such as within_bound([1, 2, 3]) + args = args[0] + assert len(args) == len(self.shape) + var2value = OrderedDict() + arg_vars = variablize(args, var2value) + cond = self.global2cond(*arg_vars) + cond = concat_let_expr(var2value=var2value, body=cond) + return cond + + def tile(self, inner_shape: Sequence[Int]): + return TiledDataLayout(base=self, inner_shape=inner_shape) + + def split(self, dim2factor: Mapping[int, Int]): + return SplitDataLayout(base=self, dim2factor=dim2factor) + + def reorder(self, order: Sequence[int]): + return self.fuse(order) + + def fuse(self, dim2fuse: Sequence[Union[Sequence[int], int]]): + return FusedDataLayout(base=self, dim2fuse=dim2fuse) + + def slice_out(self, dims: Sequence[int]): + return SliceOutDataLayout(base=self, dims=dims) + + @staticmethod + def product(outer, inner): + return ProductDataLayout(outer, inner) + + @staticmethod + def concat(lhs, rhs): + lhs = to_data_layout(lhs) + rhs = to_data_layout(rhs) + return ConcatDataLayout(lhs, rhs) + + @staticmethod + def local(shape: Sequence[Int]): + return LocalLayout(shape=shape) + + @staticmethod + def row_major(shape: Sequence[Int]): + return RowMajorLayout(shape) + + @staticmethod + def column_major(shape: Sequence[Int]): + return ColumnMajorLayout(shape) + + +class StridesLayout(DataLayout): + def __init__(self, shape, strides): + super().__init__(shape=shape, + size=StridesLayout.storage_size(shape, strides)) + self.strides: List[Int] = strides + + def global2local(self, *args: Int) -> Int: + return sum(v * self.strides[i] for i, v in enumerate(args)) + + def global2cond(self, *args: Int) -> Bool: + from hidet.ir.expr import And + return And.join_list([v < s for s, v in zip(self.shape, args)]) + + @staticmethod + def storage_size(shape, strides) -> Expr: + # assume the strides are positive, but do not assume the tensor is contiguous. + max_index = sum([(a - 1) * b for a, b in zip(shape, strides)]) + 1 + return ir.functors.simplify(max_index) + + @staticmethod + def from_shape(shape: Sequence[Int], perm: Sequence[int]): + return StridesLayout(shape, StridesLayout.shape2strides(shape, perm)) + + @staticmethod + def shape2strides(shape: Sequence[Int], perm: Sequence[int]): + assert len(shape) == len(perm) + rank = len(shape) + tuples = [[i, p, None] for i, p in zip(range(rank), perm)] + tuples = sorted(tuples, key=lambda t: t[1]) + reordered_shape = [shape[t[0]] for t in tuples] + for i in range(rank): + tuples[i][2] = prod(reordered_shape[i + 1:]) + tuples = sorted(tuples, key=lambda t: t[0]) + strides = [t[2] for t in tuples] + return strides + + +class RowMajorLayout(StridesLayout): + def __init__(self, shape): + super().__init__(shape, StridesLayout.shape2strides(shape, list(range(len(shape))))) + + +class ColumnMajorLayout(StridesLayout): + def __init__(self, shape): + super().__init__(shape, StridesLayout.shape2strides(shape, list(reversed(range(len(shape)))))) + + +class LocalLayout(DataLayout): + def __init__(self, shape): + super().__init__(shape=shape, size=1) + + def global2local(self, *args: Int) -> Int: + return 0 + + def global2cond(self, *args: Int) -> Bool: + from hidet.ir.expr import And + return And.join_list([v < s for s, v in zip(self.shape, args)]) + + +class TiledDataLayout(DataLayout): + def __init__(self, base: DataLayout, inner_shape: Sequence[Int]): + assert len(inner_shape) == len(base.shape) + assert all(b % a == 0 for a, b in zip(inner_shape, base.shape) if isinstance(a, int) and isinstance(b, int)) + self.base = base + self.inner_shape = inner_shape + super().__init__(shape=[b // a for a, b in zip(inner_shape, self.shape)] + list(inner_shape), size=base.size) + + def base_args(self, *args): + outer_args, inner_args = args[:len(args) // 2], args[len(args) // 2:] + return [o * factor + i for factor, o, i in zip(self.inner_shape, outer_args, inner_args)] + + def global2local(self, *args): + return self.base(*self.base_args(args)) + + def global2cond(self, *args): + return self.base.within_bound(*self.base_args(args)) + + +class SplitDataLayout(DataLayout): + """ + 3-dimension tensor with shape [a, b, c] + after split(dim2factor={0: 2, 1: 3}) got + 5-dimension tensor with shape [(a + 1) // 2, 2, (b + 2) // 3, 3, c] + """ + + def __init__(self, base: DataLayout, dim2factor: Mapping[int, Int]): + self.base = base + self.dim2factor = dim2factor + shape = [] + for i, s in enumerate(base.shape): + if i in dim2factor: + factor = dim2factor[i] + outer = (s + factor - 1) // factor + shape.extend([outer, factor]) + else: + shape.append(s) + super().__init__(shape=shape, size=base.size) + + def base_args(self, *args): + merged_args = [] + c = 0 + for i in range(len(self.base.shape)): + if i in self.dim2factor: + outer_idx = args[c] + inner_idx = args[c + 1] + merged_args.append(outer_idx * self.dim2factor[i] + inner_idx) + c += 2 + else: + merged_args.append(args[c]) + c += 1 + return merged_args + + def global2local(self, *args): + return self.base(*self.base_args(*args)) + + def global2cond(self, *args: Int) -> Bool: + return self.base.within_bound(*self.base_args(*args)) + + +class FusedDataLayout(DataLayout): + """ + 3-dimension tensor with shape [a, b, c] + after fuse([2, [1, 0]]) got + 3-dimension tensor with shape [c, b * a] + (i, j, k) of the result data layout will be mapped to (k, j * I + i) of the original data layout + """ + + def __init__(self, base: DataLayout, dim2fuse: Sequence[Union[Sequence[int], int]]): + self.base = base + self.dim2fuse = dim2fuse + covered = [] + shape = [] + self.dims = [] + for i in range(len(dim2fuse)): + item = dim2fuse[i] + if isinstance(item, int): + item = [item] + else: + item = list(item) + self.dims.append(item) + covered.extend(item) + shape.append(prod([base.shape[i] for i in item])) + assert len(covered) == len(base.shape) and len(set(covered)) == len(covered), "missing some dimension or duplicated dimension" + super().__init__(shape=shape, size=base.size) + + def base_args(self, *args: Int): + original_args = [None] * len(self.base.shape) + for i in range(len(self.dims)): + dim_sizes = [self.base.shape[v] for v in self.dims[i]] + for j, dim in enumerate(self.dims[i]): + original_args[dim] = args[i] // prod(dim_sizes[j + 1:]) % dim_sizes[j] + return original_args + + def global2local(self, *args: Int) -> Int: + return self.base(*self.base_args(*args)) + + def global2cond(self, *args: Int) -> Bool: + return self.base.within_bound(*self.base_args(*args)) + + +class SliceOutDataLayout(DataLayout): + """ + 3-dimension tensor with shape [a, b, c] + after cut({0, 2}) got + 1-dimension tensor with shape [b] + """ + + def __init__(self, base: DataLayout, dims: Sequence[int]): + assert all(d < len(base.shape) for d in dims) + self.base = base + self.dims = set(dims) + super().__init__(shape=[s for r, s in enumerate(base.shape) if r not in dims], + size=base.size) # todo: update size + + def base_args(self, *args: Int): + merged_args = [] + c = 0 + for i in range(len(self.base.shape)): + if i in self.dims: + merged_args.append(0) + else: + merged_args.append(args[c]) + c += 1 + return merged_args + + def global2local(self, *args: Int) -> Int: + return self.base(*self.base_args(*args)) + + def global2cond(self, *args: Int) -> Bool: + return self.base.within_bound(*self.base_args(*args)) + + +class ProductDataLayout(DataLayout): + def __init__(self, outer: DataLayout, inner: DataLayout): + assert len(outer.shape) == len(inner.shape) + super().__init__( + shape=[a * b for a, b in zip(outer.shape, inner.shape)], + size=outer.size * inner.size + ) + self.outer = outer + self.inner = inner + + def global2local(self, *args: Int) -> Int: + outer_args = [v // b for v, b in zip(args, self.inner.shape)] + inner_args = [v % b for v, b in zip(args, self.inner.shape)] + return self.outer(*outer_args) * self.inner.size + self.inner(*inner_args) + + def global2cond(self, *args: Int) -> Bool: + from hidet.ir.expr import And + outer_args = [v // b for v, b in zip(args, self.inner.shape)] + inner_args = [v % b for v, b in zip(args, self.inner.shape)] + return And(self.outer.within_bound(*outer_args), self.inner.within_bound(*inner_args)) + + +class ConcatDataLayout(DataLayout): + def __init__(self, lhs: DataLayout, rhs: DataLayout): + super().__init__( + shape=list(lhs.shape) + list(rhs.shape), + size=lhs.size * rhs.size) + self.lhs = lhs + self.rhs = rhs + + def global2local(self, *args: Int) -> Int: + lhs_args = args[:len(self.lhs.shape)] + rhs_args = args[len(self.lhs.shape):] + return self.lhs(*lhs_args) * self.rhs.size + self.rhs(*rhs_args) + + def global2cond(self, *args: Int) -> Bool: + from hidet.ir.expr import And + lhs_args = args[:len(self.lhs.shape)] + rhs_args = args[len(self.lhs.shape):] + return And(self.lhs.within_bound(*lhs_args), self.rhs.within_bound(*rhs_args)) + + +def row_layout(*shape: int): + return DataLayout.row_major(shape) + + +def col_layout(*shape: int): + return DataLayout.column_major(shape) + + +def local_layout(*shape: int): + return DataLayout.local(shape) + + +def data_layout(shape: List[int], perm: Optional[List[int]] = None): + if perm is None: + perm = list(range(len(shape))) + return StridesLayout.from_shape(shape, perm) + diff --git a/python/hidet/ir/layout/generic.py b/python/hidet/ir/layout/generic.py new file mode 100644 index 0000000..1aaa044 --- /dev/null +++ b/python/hidet/ir/layout/generic.py @@ -0,0 +1,192 @@ +from typing import Tuple, List, Optional, Iterable, Union, Sequence, Iterator +from functools import reduce, partial +from itertools import product +import operator +from sympy.ntheory import divisors + +from .task_layout import TaskLayout, Int + + +class TaskLayoutGenerator: + registered = [] + + def get_layouts(self, + num_workers: Optional[int] = None, + task_shape: Optional[Tuple[int, ...]] = None, + rank: Optional[int] = None) -> Iterable[TaskLayout]: + raise NotImplementedError() + + +def register_task_layout(layout: TaskLayout): + TaskLayout.registered.append(layout) + + +def register_task_layout_generator(layout_generator: TaskLayoutGenerator): + TaskLayoutGenerator.registered.append(layout_generator) + + +def get_task_layouts(valid_num_workers: Optional[Union[int, Sequence[int]]] = None, + task_shape: Optional[Sequence[int]] = None, + rank: Optional[int] = None) -> Iterator[TaskLayout]: + if isinstance(valid_num_workers, int): + valid_num_workers = [valid_num_workers] + assert all(isinstance(v, int) for v in valid_num_workers) + if task_shape is not None: + assert all(isinstance(v, int) for v in task_shape) + for idx, layout in enumerate(TaskLayout.registered): + if valid_num_workers is not None and layout.num_workers not in valid_num_workers: + continue + if task_shape is not None: + if tuple(task_shape) != tuple(layout.task_shape): + continue + if rank is not None and len(task_shape) != rank: + continue + yield layout + for layout_generator in TaskLayoutGenerator.registered: + for num_workers in valid_num_workers: + layouts = layout_generator.get_layouts(num_workers, task_shape, rank) + for layout in layouts: + yield layout + + +def decompose_integer(n, num_items): + """ + decompose n into num_items items such that the product of these items equals to n. Order sensitive. + return a list of valid decomposition. E.g., + decompose_integer(n=12, num_items=2) => [(1, 12), (2, 6), (3, 4), (4, 3), (6, 2), (12, 1)] + """ + results = [] + current_result = [None] * num_items + + def helper(remaining, ith): + if ith + 1 == num_items: + current_result[ith] = remaining + results.append(tuple(current_result)) + return + for v in divisors(remaining): + current_result[ith] = v + helper(remaining // v, ith + 1) + + helper(n, 0) + return results + + +class FullLayout(TaskLayoutGenerator): + @classmethod + def get_layouts(cls, + num_workers: Optional[int] = None, + task_shape: Optional[Tuple[int, ...]] = None, + rank: Optional[int] = None) -> Iterable[TaskLayout]: + if num_workers is not None and num_workers != 1: + return + if task_shape is None: + return + yield TaskLayout(1, task_shape, worker2task=partial(cls.worker2task, task_shape=task_shape)) + + @staticmethod + def worker2task(worker_index: Int, task_shape: Tuple[int, ...]) -> List[Tuple[Int, ...]]: + ranges = [range(s) for s in task_shape] + result = list(product(*ranges)) + return result + + @staticmethod + def task2worker(task_index: Tuple[Int, ...], task_shape: Tuple[int, ...]) -> Int: + return 0 + + +class RowMajorLayout(TaskLayoutGenerator): + @classmethod + def get_layouts(cls, + num_workers: Optional[int] = None, + task_shape: Optional[Tuple[int, ...]] = None, + rank: Optional[int] = None) -> Iterable[TaskLayout]: + if num_workers is None and task_shape is None: + return + elif num_workers is None: + num_workers = reduce(operator.mul, task_shape) + task_shapes = [task_shape] + elif task_shape is None: + assert rank is not None + task_shapes = decompose_integer(num_workers, rank) + else: + assert num_workers == reduce(operator.mul, task_shape) + task_shapes = [task_shape] + + for task_shape in task_shapes: + yield TaskLayout(num_workers, task_shape, partial(cls.worker2task, task_shape=task_shape)) + + @staticmethod + def worker2task(worker_index: Int, task_shape: Tuple[Int, ...]) -> List[Tuple[Int, ...]]: + task_index = [] + rank = len(task_shape) + bases = [reduce(operator.mul, task_shape[i + 1:], 1) for i in range(rank)] + for i in range(rank): + task_index.append(worker_index // bases[i] % task_shape[i]) + return [tuple(task_index)] + + @staticmethod + def task2worker(task_index: Tuple[Int, ...], task_shape: Tuple[Int, ...]) -> Int: + worker_index = 0 + rank = len(task_shape) + bases = [reduce(operator.mul, task_shape[i + 1:], 1) for i in range(rank)] + for i in range(rank): + worker_index += task_index[i] * bases[i] + return worker_index + + +class ColumnMajorLayout(TaskLayoutGenerator): + @classmethod + def get_layouts(cls, + num_workers: Optional[int] = None, + task_shape: Optional[Tuple[int, ...]] = None, + rank: Optional[int] = None) -> Iterable[TaskLayout]: + if num_workers is None and task_shape is None: + return + elif num_workers is None: + num_workers = reduce(operator.mul, task_shape) + task_shapes = [task_shape] + elif task_shape is None: + task_shapes = decompose_integer(num_workers, rank) + else: + assert num_workers == reduce(operator.mul, task_shape) + task_shapes = [task_shape] + + for task_shape in task_shapes: + yield TaskLayout(num_workers, task_shape, + partial(cls.worker2task, task_shape=task_shape)) + + @staticmethod + def worker2task(worker_index: Int, task_shape) -> List[Tuple[Int, ...]]: + task_index = [] + rank = len(task_shape) + bases = [reduce(operator.mul, task_shape[:i], 1) for i in range(rank)] + for i in range(rank): + task_index.append(worker_index // bases[i] % task_shape[i]) + return [tuple(task_index)] + + @staticmethod + def task2worker(task_index: Tuple[Int, ...], task_shape) -> Int: + worker_index = 0 + rank = len(task_shape) + bases = [reduce(operator.mul, task_shape[:i], 1) for i in range(rank)] + for i in range(rank): + worker_index += task_index[i] * bases[i] + return worker_index + + +def full_layout(*task_shape) -> TaskLayout: + layouts = list(FullLayout.get_layouts(task_shape=task_shape)) + assert len(layouts) == 1 + return layouts[0] + + +def row_major_layout(*task_shape) -> TaskLayout: + layouts = list(RowMajorLayout.get_layouts(task_shape=task_shape)) + assert len(layouts) == 1 + return layouts[0] + + +def col_major_layout(*task_shape) -> TaskLayout: + layouts = list(ColumnMajorLayout.get_layouts(task_shape=task_shape)) + assert len(layouts) == 1 + return layouts[0] diff --git a/python/hidet/ir/layout/task_layout.py b/python/hidet/ir/layout/task_layout.py new file mode 100644 index 0000000..10c8597 --- /dev/null +++ b/python/hidet/ir/layout/task_layout.py @@ -0,0 +1,235 @@ +from typing import Union, Tuple, List, Type, Dict, Optional, Sequence, Iterator, Callable, Iterable, Mapping +import numpy as np +import itertools +from hidet.utils import prod + +Int = Union['Expr', int] + + +def is_atom(expr): + from hidet.ir import Constant, Var + return isinstance(expr, (Constant, Var)) + + +def var(hint): + from hidet import ir + return ir.var(hint) + + +class TaskLayout: + registered = [] + + def __init__(self, + num_workers: int = None, + task_shape: Tuple[int, ...] = None, + worker2task: Optional[Callable[[Int], List[Tuple[Int, ...]]]] = None): + self.num_workers: int = num_workers + self.task_shape: Tuple[int, ...] = task_shape + self.worker2task: Callable[[Int], List[Tuple[Int]]] = worker2task + if num_workers is not None: + assert isinstance(num_workers, int) + if task_shape is not None: + assert all(isinstance(s, int) for s in task_shape) + + def __call__(self, w: Int) -> List[Tuple[Int, ...]]: + return self.worker2task(w) + + def __mul__(self, other) -> 'TaskLayout': + return ComposedTaskLayout(outer=self, inner=other) + + def __str__(self): + worker_id = np.empty(shape=self.task_shape, dtype=np.int32) + for w in range(self.num_workers): + for task_indices in self.worker2task(w): + worker_id[task_indices] = w + return np.array2string(worker_id) + + @staticmethod + def row_major(task_shape: Sequence[int]): + return GridTaskLayout(task_shape, perm=list(range(len(task_shape)))) + + @staticmethod + def column_major(task_shape: Sequence[int]): + return GridTaskLayout(task_shape, perm=list(reversed(range(len(task_shape))))) + + @staticmethod + def full_layout(task_shape: Sequence[int]): + return FullTaskLayout(task_shape) + + def projection(self, dim2value: Mapping[int, Int]) -> 'TaskLayout': + return ProjectedTaskLayout(base=self, dim2value=dim2value) + + +class FullTaskLayout(TaskLayout): + def __init__(self, task_shape: Sequence[int]): + super().__init__(num_workers=1, task_shape=tuple(task_shape), worker2task=self._worker2task) + + # noinspection PyUnusedLocal + def _worker2task(self, w): + ranges = [range(s) for s in self.task_shape] + return list(itertools.product(*ranges)) + + +class GridTaskLayout(TaskLayout): + def __init__(self, task_shape: Sequence[int], perm: Sequence[int]): + assert len(task_shape) == len(perm) + super().__init__(num_workers=prod(task_shape), task_shape=tuple(task_shape), worker2task=self._worker2task) + self.perm = list(perm) + self.bases = self._get_bases() + + def _get_bases(self): + rank = len(self.perm) + bases: List[Optional[int]] = [None] * rank + s = 1 + for i in reversed(range(rank)): + j = self.perm.index(i) + bases[j] = s + s *= self.task_shape[j] + return bases + + def _worker2task(self, w: Int) -> List[Tuple[Int]]: + task = [] + for mod, b in zip(self.task_shape, self.bases): + task.append((w // b) % mod) + return [tuple(task)] + + +class ProjectedTaskLayout(TaskLayout): + def __init__(self, base: TaskLayout, dim2value: Mapping[int, Int]): + assert all(int(v) == 0 for v in dim2value.values()) + super().__init__(num_workers=base.num_workers, + task_shape=tuple(base.task_shape[i] if i not in dim2value else 1 for i in range(len(base.task_shape))), + worker2task=self._worker2task) + self.base = base + self.dim2value = dim2value + + def _worker2task(self, w: Int) -> List[Tuple[Int]]: + rank = len(self.task_shape) + projected_tasks = [] + for task in self.base(w): + projected_tasks.append(tuple(self.dim2value[i] if i in self.dim2value else task[i] for i in range(rank))) + return projected_tasks + + +class ComposedTaskLayout(TaskLayout): + def __init__(self, outer: TaskLayout, inner: TaskLayout): + assert len(outer.task_shape) == len(inner.task_shape) + super().__init__( + num_workers=outer.num_workers * inner.num_workers, + task_shape=tuple([a * b for a, b in zip(outer.task_shape, inner.task_shape)]), + worker2task=self._worker2task + ) + self.outer = outer + self.inner = inner + + def _worker2task(self, worker_index: Int) -> List[Tuple[Int]]: + outer_worker_index = worker_index // self.inner.num_workers + inner_worker_index = worker_index % self.inner.num_workers + outer_tasks = self.outer.worker2task(outer_worker_index) + inner_tasks = self.inner.worker2task(inner_worker_index) + tasks = [] + for outer_task in outer_tasks: + for inner_task in inner_tasks: + tasks.append(tuple(a * self.inner.task_shape[i] + b for i, (a, b) in enumerate(zip(outer_task, inner_task)))) + return tasks + + +class TaskLayoutExpander: + def __init__(self): + from hidet.ir.stmt import ForStmt, LetStmt + self.stmts: List[Union[LetStmt, ForStmt]] = [] + + def variablize(self, e): + from hidet.ir import LetStmt + if is_atom(e): + return e + else: + v = var('p') + self.stmts.append(LetStmt(v, e)) + return v + + def expand(self, w: Int, task_layout: TaskLayout) -> List[Sequence[Int]]: + vtable = { + FullTaskLayout: self.expand_full, + GridTaskLayout: self.expand_grid, + ComposedTaskLayout: self.expand_composed, + ProjectedTaskLayout: self.expand_projected, + TaskLayout: self.expand_atom, + } + w = self.variablize(w) + # noinspection PyArgumentList + return vtable[task_layout.__class__](w, task_layout) + + def expand_composed(self, w: Int, layout: ComposedTaskLayout): + outer_w = self.variablize(w // layout.inner.num_workers) + inner_w = self.variablize(w % layout.inner.num_workers) + outer_fields = self.expand(outer_w, layout.outer) + inner_fields = self.expand(inner_w, layout.inner) + fields = [] + for outer_field in outer_fields: + scaled_outer_field = [self.variablize(a * b) for a, b in zip(outer_field, layout.inner.task_shape)] + for inner_field in inner_fields: + fields.append(tuple(a + b for a, b in zip(scaled_outer_field, inner_field))) + return fields + + def expand_projected(self, w: Int, layout: ProjectedTaskLayout): + rank = len(layout.task_shape) + base_fields = self.expand(w, layout.base) + projected_fields = [] + for field in base_fields: + projected_fields.append(tuple(layout.dim2value[i] if i in layout.dim2value else field[i] for i in range(rank))) + return projected_fields + + def expand_grid(self, w: Int, layout: GridTaskLayout): + return [[self.variablize(v) for v in layout(w)[0]]] + + def expand_full(self, w: Int, layout: FullTaskLayout): + unroll_limit = 1024 + if prod(layout.task_shape) < unroll_limit: + # unroll automatically + return layout(w) + else: + # do not expand, use for loop + from hidet.ir import var, ForStmt + shape = layout.task_shape + axes = [] + for i, s in enumerate(shape): + axis = var(chr(ord('i') + i)) + self.stmts.append(ForStmt(loop_var=axis, extent=s)) + axes.append(axis) + return [axes] + + @staticmethod + def expand_atom(w: Int, layout: TaskLayout): + return layout(w) + + +def row_major_layout(*task_shape: int): + return GridTaskLayout(task_shape, perm=list(range(len(task_shape)))) + + +def col_major_layout(*task_shape: int): + return GridTaskLayout(task_shape, perm=list(reversed(range(len(task_shape))))) + + +def full_layout(*task_shape: int): + return FullTaskLayout(task_shape) + + +def row_map(*task_shape: int): + return row_major_layout(*task_shape) + + +def col_map(*task_shape: int): + return col_major_layout(*task_shape) + + +def repeat_map(*task_shape: int): + return full_layout(*task_shape) + + +def grid_map(task_shape: List[int], order: Optional[List[int]] = None): + if order is None: + order = list(range(len(task_shape))) + assert len(order) == len(task_shape) + return GridTaskLayout(task_shape, order) diff --git a/python/hidet/ir/node.py b/python/hidet/ir/node.py new file mode 100644 index 0000000..75b1266 --- /dev/null +++ b/python/hidet/ir/node.py @@ -0,0 +1,33 @@ +from typing import Mapping, Type, Any, List + + +class Node: + _dispatch_index = {None: 0} + + def __str__(self): + from hidet.ir.functors.printer import astext + return astext(self) + + def __repr__(self): + return str(self) + + def __int__(self): + return None + + @classmethod + def class_index(cls): + if not hasattr(cls, '_class_index'): + setattr(cls, '_class_index', len(Node._dispatch_index)) + Node._dispatch_index[cls] = getattr(cls, '_class_index') + return getattr(cls, '_class_index') + + @staticmethod + def dispatch_table(mapping: Mapping[Type['Node'], Any]) -> List[Any]: + table = [] + for cls, target in mapping.items(): + idx = cls.class_index() + while idx >= len(table): + table.append(None) + table[idx] = target + return table + diff --git a/python/hidet/ir/primitives/__init__.py b/python/hidet/ir/primitives/__init__.py new file mode 100644 index 0000000..c3dcccc --- /dev/null +++ b/python/hidet/ir/primitives/__init__.py @@ -0,0 +1,8 @@ +from .func import register_primitive_function, is_primitive_function, lookup_primitive_function + +# base primitive functions +from .base import max, min, exp, pow, sqrt, rsqrt, erf, sin, cos, tanh, round, floor, ceil, printf + +# cuda primitive functions and variables +from .cuda import thread_idx, block_idx +from .cuda import syncthreads, syncwarp, lds128, sts128, shfl_sync, shfl_up_sync, shfl_down_sync, shfl_xor_sync, active_mask, set_kernel_max_dynamic_smem_bytes diff --git a/python/hidet/ir/primitives/base.py b/python/hidet/ir/primitives/base.py new file mode 100644 index 0000000..2a7259e --- /dev/null +++ b/python/hidet/ir/primitives/base.py @@ -0,0 +1,4 @@ +# def is_reserved_name(name: str) -> bool: +# from . import func, vars +# # noinspection PyProtectedMember +# return name in func._primitive_functions or name in vars._primitive_variables diff --git a/python/hidet/ir/primitives/base/__init__.py b/python/hidet/ir/primitives/base/__init__.py new file mode 100644 index 0000000..74d2353 --- /dev/null +++ b/python/hidet/ir/primitives/base/__init__.py @@ -0,0 +1 @@ +from .funcs import max, min, exp, pow, sqrt, rsqrt, erf, sin, cos, tanh, round, floor, ceil, printf diff --git a/python/hidet/ir/primitives/base/funcs.py b/python/hidet/ir/primitives/base/funcs.py new file mode 100644 index 0000000..d7e3d6b --- /dev/null +++ b/python/hidet/ir/primitives/base/funcs.py @@ -0,0 +1,161 @@ +from typing import List, Optional, Union +import builtins +import math +from hidet.ir.type import ScalarType +from hidet.ir.expr import Expr, Call, Var, cast +from hidet.ir.stmt import BlackBoxStmt, AsmStmt, ReturnStmt +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from ..func import FuncType, register_primitive_function, is_primitive_function, lookup_primitive_function, registered_primitive_functions +from ..func import primitive_func_pool as pool +from hidet.utils import initialize + +ExprLike = Union[Expr, int, float] + + +def call_base(name: str, args: List[ExprLike]) -> Call: + entry = pool.lookup_by_name(target='base', name=name) + if entry.func_type.type_infer_func is None: + param_types = entry.func_type.param_types + if len(param_types) != len(args): + raise ValueError('Function {} expect {} arguments, got {}.'.format(name, len(param_types), len(args))) + return Call(entry.var, args) + + +def max(a: ExprLike, b: ExprLike) -> Expr: + return call_base('max', [a, b]) + + +def min(a: ExprLike, b: ExprLike) -> Expr: + return call_base('min', [a, b]) + + +def exp(a: ExprLike) -> Expr: + return call_base('exp', [a]) + + +def pow(a: ExprLike, b: ExprLike) -> Expr: + return call_base('pow', [a, b]) + + +def sqrt(a: ExprLike) -> Expr: + return call_base('sqrt', [a]) + + +def rsqrt(a: ExprLike) -> Expr: + return call_base('rsqrt', [a]) + + +def erf(a: ExprLike) -> Expr: + return call_base('erf', [a]) + + +def sin(a: ExprLike) -> Expr: + return call_base('sin', [a]) + + +def cos(a: ExprLike) -> Expr: + return call_base('cos', [a]) + + +def tanh(a: ExprLike) -> Expr: + return call_base('tanh', [a]) + + +def round(a: ExprLike) -> Expr: + return call_base('round', [a]) + + +def floor(a: ExprLike) -> Expr: + return call_base('floor', [a]) + + +def ceil(a: ExprLike) -> Expr: + return call_base('ceil', [a]) + + +def printf(format_string, *args): + """ + usage: + printf(r"%d %d\n", expr_1, expr_2) + """ + arg_string = ', '.join(['{}'] * len(args)) + template_string = f'printf("{format_string}", {arg_string});' + return BlackBoxStmt(template_string, *args) + + +def type_infer_func(arg_types: List[ScalarType]) -> ScalarType: + # level = { + # 'float64': 10, + # 'float32': 9, + # 'bfloat16': 8, + # 'float16': 7, + # + # 'int64': 5, + # 'uint64': 4.5, + # 'int32': 4, + # 'uint32': 3.5, + # 'int16': 3, + # 'uint16': 2.5, + # 'int8': 2, + # 'uint8': 1.5 + # } + # return list(sorted(arg_types, key=lambda a: level[a.name]))[-1] + return builtins.max(arg_types) + + +@initialize() +def register_primitive_functions_generic(): + unary_names = [ + 'neg', 'sin', 'cos', 'tanh', 'exp', 'round', 'floor', 'ceil', 'rsqrt', 'sqrt', 'erf' + ] + binary_names = [ + 'min', 'max', 'pow' + ] + ternary_names = [ + 'fma' + ] + for unary in unary_names: + register_primitive_function('base', unary, FuncType(type_infer_func=type_infer_func), generic=True) + for binary in binary_names: + register_primitive_function('base', binary, FuncType(type_infer_func=type_infer_func), generic=True) + for ternary in ternary_names: + register_primitive_function('base', ternary, FuncType(type_infer_func=type_infer_func), generic=True) + + +@initialize() +def register_primitive_functions_float32(): + unary_names = [ + 'sinf', 'cosf', 'tanhf', 'expf', 'roundf', 'floorf', 'ceilf', 'rsqrtf', 'sqrtf', 'erff' + ] + binary_names = [ + 'fminf', 'fmaxf', 'powf' + ] + ternary_names = [ + 'fmaf' + ] + base2float32 = { + 'sin': 'sinf', + 'cos': 'cosf', + 'tanh': 'tanhf', + 'exp': 'expf', + 'round': 'roundf', + 'floor': 'floorf', + 'ceil': 'ceilf', + 'rsqrt': 'rsqrtf', + 'erf': 'erff', + 'sqrt': 'sqrtf', + 'min': 'fminf', + 'max': 'fmaxf', + 'pow': 'powf', + 'fma': 'fmaf' + } + for unary in unary_names: + register_primitive_function('base', unary, FuncType(param_types=['float32'], ret_type='float32')) + for binary in binary_names: + register_primitive_function('base', binary, FuncType(param_types=['float32', 'float32'], ret_type='float32')) + for ternary in ternary_names: + register_primitive_function('base', ternary, FuncType(param_types=['float32', 'float32', 'float32'], ret_type='float32')) + for base_name, fp32_name in base2float32.items(): + pool.lookup_by_name('base', base_name).dispatch_dtype(dtype='float32', space='base', func_name=fp32_name) + + diff --git a/python/hidet/ir/primitives/cuda/__init__.py b/python/hidet/ir/primitives/cuda/__init__.py new file mode 100644 index 0000000..32e483e --- /dev/null +++ b/python/hidet/ir/primitives/cuda/__init__.py @@ -0,0 +1,6 @@ +from . import float16 +from . import bfloat16 + +from .funcs import syncthreads, syncwarp, lds128, sts128, shfl_sync, shfl_up_sync, shfl_down_sync, shfl_xor_sync, active_mask, set_kernel_max_dynamic_smem_bytes +from .vars import thread_idx, block_idx, is_primitive_variable, get_primitive_variable +from .wmma import wmma_load, wmma_mma, wmma_store diff --git a/python/hidet/ir/primitives/cuda/bfloat16.py b/python/hidet/ir/primitives/cuda/bfloat16.py new file mode 100644 index 0000000..f26c47a --- /dev/null +++ b/python/hidet/ir/primitives/cuda/bfloat16.py @@ -0,0 +1,47 @@ +from hidet.utils import initialize +from ..func import FuncType, register_primitive_function, primitive_func_pool +from .funcs import register_unary_dialect_primitive_function, register_binary_dialect_primitive_function +from hidet.ir.primitives.base.funcs import erf, tanh, pow + + +@initialize() +def register_primitive_functions_bfloat16(): + unary_names = [ + '__hneg', 'hsin', 'hcos', 'hexp', 'hrint', 'hfloor', 'hceil', 'hrsqrt', 'hsqrt', + ] + binary_names = [ + '__hmin', '__hmax', + ] + ternary_names = [ + '__hfma', + ] + base2bfloat16 = { + 'neg': '__hneg', + 'sin': 'hsin', + 'cos': 'hcos', + 'exp': 'hexp', + 'round': 'hrint', + 'floor': 'hceil', + 'ceil': 'hceil', + 'rsqrt': 'hrsqrt', + 'sqrt': 'hsqrt', + 'min': '__hmin', + 'max': '__hmax', + + # cuda c does not provide the following functions, we use f16 -> f32 -> f -> f16 path + 'tanh': 'htanh', + 'erf': 'herf', + 'pow': 'hpow' + } + for unary in unary_names: + register_primitive_function('bfloat16', unary, FuncType(param_types=['bfloat16'], ret_type='bfloat16')) + for binary in binary_names: + register_primitive_function('bfloat16', binary, FuncType(param_types=['bfloat16', 'bfloat16'], ret_type='bfloat16')) + for ternary in ternary_names: + register_primitive_function('bfloat16', ternary, FuncType(param_types=['bfloat16', 'bfloat16', 'bfloat16'], ret_type='bfloat16')) + + register_unary_dialect_primitive_function(space='bfloat16', func_name='htanh', generic_func=tanh, target_dtype='bfloat16', dialect_dtype='float32') + register_unary_dialect_primitive_function(space='bfloat16', func_name='herf', generic_func=erf, target_dtype='bfloat16', dialect_dtype='float32') + register_binary_dialect_primitive_function(space='bfloat16', func_name='hpow', generic_func=pow, target_dtype='bfloat16', dialect_dtype='float32') + for base_name, bf16_name in base2bfloat16.items(): + primitive_func_pool.lookup_by_name('base', base_name).dispatch_dtype(dtype='bfloat16', space='bfloat16', func_name=bf16_name) diff --git a/python/hidet/ir/primitives/cuda/float16.py b/python/hidet/ir/primitives/cuda/float16.py new file mode 100644 index 0000000..79705fd --- /dev/null +++ b/python/hidet/ir/primitives/cuda/float16.py @@ -0,0 +1,47 @@ +from hidet.utils import initialize +from ..func import FuncType, register_primitive_function, primitive_func_pool +from .funcs import register_unary_dialect_primitive_function, register_binary_dialect_primitive_function +from hidet.ir.primitives.base.funcs import erf, tanh, pow + + +@initialize() +def register_primitive_functions_float16(): + unary_names = [ + '__hneg', 'hsin', 'hcos', 'hexp', 'hrint', 'hfloor', 'hceil', 'hrsqrt', 'hsqrt', + ] + binary_names = [ + '__hmin', '__hmax', + ] + ternary_names = [ + '__hfma', + ] + base2float16 = { + 'neg': '__hneg', + 'sin': 'hsin', + 'cos': 'hcos', + 'exp': 'hexp', + 'round': 'hrint', + 'floor': 'hceil', + 'ceil': 'hceil', + 'rsqrt': 'hrsqrt', + 'sqrt': 'hsqrt', + 'min': '__hmin', + 'max': '__hmax', + + # cuda c does not provide the following functions, we use f16 -> f32 -> f -> f16 path + 'tanh': 'htanh', + 'erf': 'herf', + 'pow': 'hpow' + } + for unary in unary_names: + register_primitive_function('float16', unary, FuncType(param_types=['float16'], ret_type='float16')) + for binary in binary_names: + register_primitive_function('float16', binary, FuncType(param_types=['float16', 'float16'], ret_type='float16')) + for ternary in ternary_names: + register_primitive_function('float16', ternary, FuncType(param_types=['float16', 'float16', 'float16'], ret_type='float16')) + + register_unary_dialect_primitive_function(space='float16', func_name='htanh', generic_func=tanh, target_dtype='float16', dialect_dtype='float32') + register_unary_dialect_primitive_function(space='float16', func_name='herf', generic_func=erf, target_dtype='float16', dialect_dtype='float32') + register_binary_dialect_primitive_function(space='float16', func_name='hpow', generic_func=pow, target_dtype='float16', dialect_dtype='float32') + for base_name, fp16_name in base2float16.items(): + primitive_func_pool.lookup_by_name('base', base_name).dispatch_dtype(dtype='float16', space='float16', func_name=fp16_name) diff --git a/python/hidet/ir/primitives/cuda/funcs.py b/python/hidet/ir/primitives/cuda/funcs.py new file mode 100644 index 0000000..fb59e2e --- /dev/null +++ b/python/hidet/ir/primitives/cuda/funcs.py @@ -0,0 +1,141 @@ +from hidet.ir.type import ScalarType +from typing import List, Optional, Union, Tuple + +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.dialects.lowlevel import PointerType, ReferenceType +from hidet.ir.dialects.lowlevel import VoidType +from hidet.ir.expr import Expr, Call, cast +from hidet.ir.expr import Var +from hidet.ir.stmt import AsmStmt, BlackBoxStmt, ReturnStmt +from hidet.ir.type import ScalarType +from hidet.ir.primitives.func import FuncType, register_primitive_function, primitive_func_pool +from hidet.utils import initialize + + +def register_unary_dialect_primitive_function(space, func_name, generic_func, target_dtype: str, dialect_dtype: str): + with FunctionBuilder(func_name, kind='cuda_device', ret_type=ScalarType(target_dtype)) as fb: + # params + x = Var('x', type=ScalarType(target_dtype)) + fb.extend_params([x]) + # body + sb = StmtBuilder() + sb += ReturnStmt(cast(generic_func(cast(x, dialect_dtype)), target_dtype)) + fb.set_body(sb.finish()) + register_primitive_function(space, func_name, fb.get()) + + +def register_binary_dialect_primitive_function(space, func_name, generic_func, target_dtype: str, dialect_dtype: str): + with FunctionBuilder(func_name, kind='cuda_device', ret_type=ScalarType(target_dtype)) as fb: + # params + x = Var('x', type=ScalarType(target_dtype)) + y = Var('y', type=ScalarType(target_dtype)) + fb.extend_params([x, y]) + # body + sb = StmtBuilder() + sb += ReturnStmt(cast(generic_func(cast(x, dialect_dtype), cast(y, dialect_dtype)), target_dtype)) + fb.set_body(sb.finish()) + register_primitive_function(space, func_name, fb.get()) + + +@initialize() +def register_primitive_functions_with_body(): + # lds128 + with FunctionBuilder('lds128', kind='cuda_device') as fb: + # params + regs_vars = [Var(f'reg{i}', ReferenceType(ScalarType('float32'))) for i in range(4)] + smem_addr_var = Var('smem_addr', PointerType(ScalarType('float32'))) + fb.extend_params(regs_vars + [smem_addr_var]) + # body + body = AsmStmt( + r"{" + r" .reg.u64 u64addr;" + r" cvta.to.shared.u64 u64addr, %4;" + r" ld.shared.v4.f32 {%0, %1, %2, %3}, [u64addr];" + r"}", + outputs=[('=f', reg) for reg in regs_vars], + inputs=[('l', smem_addr_var)], + is_volatile=True + ) + fb.set_body(body) + register_primitive_function('cuda', 'lds128', fb.get()) + + # sts128 + with FunctionBuilder('sts128', kind='cuda_device') as fb: + # params + regs_vars = [Var(f'reg{i}', ReferenceType(ScalarType('float32'))) for i in range(4)] + smem_addr_var = Var('smem_addr', PointerType(ScalarType('float32'))) + fb.extend_params(regs_vars + [smem_addr_var]) + # body + body = AsmStmt( + r"{" + r" .reg.u64 u64addr;" + r" cvta.to.shared.u64 u64addr, %0;" + r" st.shared.v4.f32 [u64addr], {%1, %2, %3, %4};" + r"}", + outputs=[], + inputs=[('l', smem_addr_var)] + [('f', reg) for reg in regs_vars], + is_volatile=True + ) + fb.set_body(body) + register_primitive_function('cuda', 'sts128', fb.get()) + + +@initialize() +def register_primitive_functions(): + functions = { + '__syncthreads': FuncType([], VoidType()), + '__syncwarp': FuncType([], VoidType()), + '__activemask': FuncType([], 'int32'), + '__shfl_sync': FuncType(type_infer_func=lambda arg_types: arg_types[1]), # T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize) + '__shfl_up_sync': FuncType(type_infer_func=lambda arg_types: arg_types[1]), + '__shfl_down_sync': FuncType(type_infer_func=lambda arg_types: arg_types[1]), + } + for name, func_type in functions.items(): + register_primitive_function('cuda', name, func_type) + + +def call_cuda(func_name, args: List[Expr]) -> Call: + entry = primitive_func_pool.lookup_by_name('cuda', func_name) + return Call(entry.var, args) + + +def syncthreads() -> Call: + return call_cuda('__syncthreads', []) + + +def syncwarp() -> Call: + return call_cuda('__syncwarp', []) + + +def lds128(reg0, reg1, reg2, reg3, smem_addr) -> Call: + return call_cuda('lds128', [reg0, reg1, reg2, reg3, smem_addr]) + + +def sts128(reg0, reg1, reg2, reg3, smem_addr) -> Call: + return call_cuda('sts128', [reg0, reg1, reg2, reg3, smem_addr]) + + +def shfl_sync(mask, var, src_lane, width=32): + return call_cuda('__shfl_sync', [mask, var, src_lane, width]) + + +def shfl_up_sync(mask, var, delta, width=32): + return call_cuda('__shfl_up_sync', [mask, var, delta, width]) + + +def shfl_down_sync(mask, var, delta, width=32): + return call_cuda('__shfl_down_sync', [mask, var, delta, width]) + + +def shfl_xor_sync(mask, var, lane_mask, width=32): + return call_cuda('__shfl_down_sync', [mask, var, lane_mask, width]) + + +def active_mask(): + return call_cuda('__activemask', []) + + +def set_kernel_max_dynamic_smem_bytes(func, max_dynamic_smem_bytes): + template_string = r'cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});' + raise ValueError('update to use func instead of func_var') + return BlackBoxStmt(template_string, func, max_dynamic_smem_bytes) diff --git a/python/hidet/ir/primitives/cuda/vars.py b/python/hidet/ir/primitives/cuda/vars.py new file mode 100644 index 0000000..9d7fee3 --- /dev/null +++ b/python/hidet/ir/primitives/cuda/vars.py @@ -0,0 +1,39 @@ +from typing import Dict, Optional + +from hidet.ir.expr import Var +from hidet.ir.type import ScalarType + +_primitive_variables: Dict[str, Var] = {} + + +def attach_pool(var): + if '_primitive_variables' not in var.__dict__: + var.__dict__['_primitive_variables'] = _primitive_variables + return var + + +def thread_idx(dim='x') -> Var: + assert dim in ['x', 'y', 'z'] + name = 'threadIdx.{}'.format(dim) + if name not in _primitive_variables: + _primitive_variables[name] = attach_pool(Var(hint=name, type=ScalarType('int32'), name=name)) + return _primitive_variables[name] + + +def block_idx(dim='x') -> Var: + assert dim in ['x', 'y', 'z'] + name = 'blockIdx.{}'.format(dim) + if name not in _primitive_variables: + _primitive_variables[name] = attach_pool(Var(hint=name, type=ScalarType('int32'), name=name)) + return _primitive_variables[name] + + +def is_primitive_variable(name: str) -> bool: + return name in _primitive_variables + + +def get_primitive_variable(name: str) -> Optional[Var]: + if name in _primitive_variables: + return _primitive_variables[name] + else: + return None diff --git a/python/hidet/ir/primitives/cuda/wmma.py b/python/hidet/ir/primitives/cuda/wmma.py new file mode 100644 index 0000000..f4d2444 --- /dev/null +++ b/python/hidet/ir/primitives/cuda/wmma.py @@ -0,0 +1,304 @@ +from collections import namedtuple +from typing import List, Optional, Union, Tuple + +from hidet.ir.builders import FunctionBuilder +from hidet.ir.dialects.lowlevel import PointerType +from hidet.ir.expr import Expr, cast +from hidet.ir.expr import Var +from hidet.ir.primitives.cuda.funcs import call_cuda +from hidet.ir.primitives.func import register_primitive_function +from hidet.ir.stmt import AsmStmt, AssignStmt +from hidet.ir.type import ScalarType +from hidet.utils import initialize + +""" +Documentation of wmma and mma instructions in PTX: +https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions +""" + +dtype_short2long = { + 'f16': 'float16', + 'bf16': 'bfloat16', + 'tf32': 'tfloat32', + 'f32': 'float32' +} +dtype_long2short = { + 'float16': 'f16', + 'bfloat16': 'bf16', + 'tfloat32': 'tf32', + 'float32': 'f32' +} + +WmmaConfig = namedtuple('WmmaConfig', ['shape', 'a_dtype', 'b_dtype', 'c_dtype', 'a_layout', 'b_layout', 'c_layout', 'a_regs', 'b_regs', 'c_regs']) +wmma_configs: List[WmmaConfig] = [] + + +@initialize() +def init_wmma_configs(): + # todo: Add integer, float64, and sub-byte type tensor core wmma instructions when needed. + + # f16 x f16 => f16 or f32 + for shape in [(16, 16, 16), (8, 32, 16), (32, 8, 16)]: + a_dtype = 'f16' + b_dtype = 'f16' + for c_dtype in ['f16', 'f32']: + for a_layout in ['row', 'col']: + for b_layout in ['row', 'col']: + for c_layout in ['row', 'col']: + a_regs = 8 + b_regs = 8 + c_regs = 4 if c_dtype == 'f16' else 8 + config = WmmaConfig(shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs) + wmma_configs.append(config) + + # bf16 x bf16 => f32 + for shape in [(16, 16, 16), (8, 32, 16), (32, 8, 16)]: + a_dtype = 'bf16' + b_dtype = 'bf16' + c_dtype = 'f32' + for a_layout in ['row', 'col']: + for b_layout in ['row', 'col']: + for c_layout in ['row', 'col']: + m, n, k = shape + regs_map = {16: 4, 8: 2, 32: 8} + a_regs = regs_map[m] + b_regs = regs_map[n] + c_regs = 8 + config = WmmaConfig(shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs) + wmma_configs.append(config) + + # tf32 x tf32 => f32 + for shape in [(16, 16, 8)]: + a_dtype = 'tf32' + b_dtype = 'tf32' + c_dtype = 'f32' + for a_layout in ['row', 'col']: + for b_layout in ['row', 'col']: + for c_layout in ['row', 'col']: + a_regs = 4 + b_regs = 4 + c_regs = 8 + config = WmmaConfig(shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs) + wmma_configs.append(config) + + +@initialize() +def register_wmma_load_instructions(): + WmmaLoadConfig = namedtuple('WmmaLoadConfig', ['matrix', 'layout', 'dtype', 'shape', 'num_regs']) + configs = set() + for wmma_config in wmma_configs: + shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs = wmma_config + configs.add(WmmaLoadConfig('a', a_layout, a_dtype, shape, a_regs)) + configs.add(WmmaLoadConfig('b', b_layout, b_dtype, shape, b_regs)) + + for matrix, layout, short_dtype, shape, num_regs in configs: + inst_name = 'wmma.load.{matrix}.sync.aligned.{layout}.{shape}.{dtype}'.format( + matrix=matrix, layout=layout, shape='m{}n{}k{}'.format(*shape), dtype=short_dtype + ) + func_name = inst_name.replace('.', '_') + dtype: ScalarType = ScalarType(dtype_short2long[short_dtype]) + with FunctionBuilder(name=func_name, kind='cuda_device') as fb: + # parameters: dst, src, stride + dst = Var('dst', PointerType(ScalarType('uint32'))) + src = Var('src', PointerType(dtype)) + stride = Var('stride', ScalarType('int32')) + fb.extend_params([dst, src, stride]) + + # body + assert num_regs > 0 + template_sub_strings = [ + inst_name, + '{{{}}},'.format(', '.join([f'%{i}' for i in range(num_regs)])), + '[%{}],'.format(num_regs), + '%{};'.format(num_regs + 1) + ] + fb += AsmStmt( + template_string=' '.join(template_sub_strings), + outputs=[('=r', dst[i]) for i in range(num_regs)], + inputs=[('l', src), ('r', stride)], + is_volatile=False + ) + register_primitive_function(target='cuda', name=func_name, func_or_type=fb.func) + + +@initialize() +def register_wmma_mma_instructions(): + WmmaMmaConfig = namedtuple('WmmaMmaConfig', ['shape', 'a_layout', 'b_layout', 'a_dtype', 'b_dtype', 'c_dtype', 'a_num_regs', 'b_num_regs', 'c_num_regs']) + configs = set() + for wmma_config in wmma_configs: + shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs = wmma_config + configs.add(WmmaMmaConfig(shape, a_layout, b_layout, a_dtype, b_dtype, c_dtype, a_regs, b_regs, c_regs)) + + for shape, a_layout, b_layout, a_dtype, b_dtype, c_dtype, a_num_regs, b_num_regs, c_num_regs in configs: + if a_dtype == 'f16' and b_dtype == 'f16': + inst_name = 'wmma.mma.sync.aligned.{a_layout}.{b_layout}.{shape}.{d_dtype}.{c_dtype}'.format( + a_layout=a_layout, b_layout=b_layout, shape='m{}n{}k{}'.format(*shape), d_dtype=c_dtype, c_dtype=c_dtype + ) + else: + inst_name = 'wmma.mma.sync.aligned.{a_layout}.{b_layout}.{shape}.{d_dtype}.{a_dtype}.{b_dtype}.{c_dtype}'.format( + a_layout=a_layout, b_layout=b_layout, shape='m{}n{}k{}'.format(*shape), d_dtype=c_dtype, a_dtype=a_dtype, b_dtype=b_dtype, c_dtype=c_dtype + ) + func_name = inst_name.replace('.', '_') + uint32_dtype = ScalarType('uint32') + with FunctionBuilder(name=func_name, kind='cuda_device') as fb: + # parameters: a, b, c + a = Var('a', PointerType(uint32_dtype)) + b = Var('b', PointerType(uint32_dtype)) + c = Var('c', PointerType(uint32_dtype)) + fb.extend_params([a, b, c]) + + # body + template_sub_strings = [ + inst_name, + '{{{}}},'.format(', '.join([f'%{i}' for i in range(c_num_regs)])), + '{{{}}},'.format(', '.join([f'%{i}' for i in range(c_num_regs, c_num_regs + a_num_regs)])), + '{{{}}},'.format(', '.join([f'%{i}' for i in range(c_num_regs + a_num_regs, c_num_regs + a_num_regs + b_num_regs)])), + '{{{}}};'.format(', '.join([f'%{i}' for i in range(c_num_regs)])) + ] + template_string = ' '.join(template_sub_strings) + fb += AsmStmt( + template_string=template_string, + outputs=[('+r', c[i]) for i in range(c_num_regs)], + inputs=[('r', a[i]) for i in range(a_num_regs)] + [('r', b[i]) for i in range(b_num_regs)], + is_volatile=False + ) + register_primitive_function(target='cuda', name=func_name, func_or_type=fb.func) + + +@initialize() +def register_wmma_store_instructions(): + WmmaStoreConfig = namedtuple('WmmaStoreConfig', ['shape', 'layout', 'dtype', 'num_regs']) + configs = set() + for wmma_config in wmma_configs: + shape, a_dtype, b_dtype, c_dtype, a_layout, b_layout, c_layout, a_regs, b_regs, c_regs = wmma_config + configs.add(WmmaStoreConfig(shape, c_layout, c_dtype, c_regs)) + + for shape, layout, dtype, num_regs in configs: + inst_name = 'wmma.store.d.sync.aligned.{layout}.{shape}.{dtype}'.format( + layout=layout, shape='m{}n{}k{}'.format(*shape), dtype=dtype + ) + func_name = inst_name.replace('.', '_') + dtype = ScalarType(dtype_short2long[dtype]) + with FunctionBuilder(name=func_name, kind='cuda_device') as fb: + # parameters: dst, src + dst = Var('dst', PointerType(dtype)) + src = Var('src', PointerType(ScalarType('uint32'))) + stride = Var('stride', ScalarType('int32')) + fb.extend_params([dst, src, stride]) + + # body + template_sub_strings = [ + inst_name, + '[%{}],'.format(num_regs), + '{{{}}},'.format(', '.join([f'%{i}' for i in range(num_regs)])), + '%{};'.format(num_regs + 1) + ] + template_string = ' '.join(template_sub_strings) + fb += AsmStmt( + template_string=template_string, + outputs=[], + inputs=[('r', src[i]) for i in range(num_regs)] + [('l', dst)] + [('r', stride)], + is_volatile=False + ) + register_primitive_function(target='cuda', name=func_name, func_or_type=fb.func) + + +def default_stride(matrix: str, layout: str, shape: Tuple[int, int, int]) -> int: + assert matrix in ['a', 'b', 'c'] + assert layout in ['row', 'col'] + m, n, k = shape + matrix_shape = { + 'a': (m, k), + 'b': (k, n), + 'c': (m, n) + } + a, b = matrix_shape[matrix] + return b if layout == 'row' else a + + +def wmma_load( + config: WmmaConfig, + matrix: str, + shape: Tuple[int, int, int], + dtype: Union[str, ScalarType], + reg_addr: Expr, + mem_addr: Expr, + stride: Optional[Union[Expr, int]] = None, + layout: Optional[str] = 'row' +): + func_name = 'wmma.load.{matrix}.sync.aligned.{layout}.{shape}.{dtype}'.format( + matrix=matrix, layout=layout, shape='m{}n{}k{}'.format(*shape), dtype=dtype + ).replace('.', '_') + if stride is None: + stride = default_stride(matrix, layout, shape) + return call_cuda(func_name, args=[reg_addr, mem_addr, stride]) + + +def wmma_load_a( + config: WmmaConfig, + reg_addr: Expr, + mem_addr: Expr, + stride: Optional[Union[Expr, int]] = None, +): + func_name = 'wmma.load.{matrix}.sync.aligned.{layout}.{shape}.{dtype}'.format( + matrix='a', layout=config.a_layout, shape='m{}n{}k{}'.format(*config.shape), dtype=config.a_dtype + ).replace('.', '_') + def_stride = default_stride(matrix='a', layout=config.a_layout, shape=config.shape) + if stride is None: + stride = def_stride + else: + assert stride % def_stride == 0 + return call_cuda(func_name, args=[reg_addr, mem_addr, stride]) + + +def wmma_load_b( + config: WmmaConfig, + reg_addr: Expr, + mem_addr: Expr, + stride: Optional[Union[Expr, int]] = None, +): + func_name = 'wmma.load.{matrix}.sync.aligned.{layout}.{shape}.{dtype}'.format( + matrix='b', layout=config.b_layout, shape='m{}n{}k{}'.format(*config.shape), dtype=config.b_dtype + ).replace('.', '_') + def_stride = default_stride(matrix='b', layout=config.b_layout, shape=config.shape) + if stride is None: + stride = def_stride + else: + assert stride % def_stride == 0 + return call_cuda(func_name, args=[reg_addr, mem_addr, stride]) + + +def wmma_mma( + config: WmmaConfig, + a_regs_addr: Expr, + b_regs_addr: Expr, + c_regs_addr: Expr +): + head_part = 'wmma.mma.sync.aligned.{a_layout}.{b_layout}.{shape}'.format( + a_layout=config.a_layout, b_layout=config.b_layout, shape='m{}n{}k{}'.format(*config.shape) + ) + if config.a_dtype == 'f16' and config.b_dtype == 'f16': + type_part = '.{d_dtype}.{c_dtype}'.format(d_dtype=config.c_dtype, c_dtype=config.c_dtype) + else: + type_part = '.{d_dtype}.{a_dtype}.{b_dtype}.{c_dtype}'.format(d_dtype=config.c_dtype, a_dtype=config.a_dtype, b_dtype=config.b_dtype, c_dtype=config.c_dtype) + func_name = (head_part + type_part).replace('.', '_') + return call_cuda(func_name, args=[a_regs_addr, b_regs_addr, c_regs_addr]) + + +def wmma_store( + config: WmmaConfig, + mem_addr: Expr, + reg_addr: Expr, + stride: Optional[Union[Expr, int]] = None, +): + func_name = 'wmma.store.d.sync.aligned.{layout}.{shape}.{dtype}'.format( + layout=config.c_layout, shape='m{}n{}k{}'.format(*config.shape), dtype=config.c_dtype + ).replace('.', '_') + def_stride = default_stride(matrix='c', layout=config.c_layout, shape=config.shape) + + if stride is None: + stride = def_stride + else: + assert stride % def_stride == 0 + + return call_cuda(func_name, args=[mem_addr, reg_addr, stride]) diff --git a/python/hidet/ir/primitives/func.py b/python/hidet/ir/primitives/func.py new file mode 100644 index 0000000..4f96cb0 --- /dev/null +++ b/python/hidet/ir/primitives/func.py @@ -0,0 +1,130 @@ +from typing import Dict, Union, Optional, List + +from hidet.ir.expr import Var +from hidet.ir.func import Function +from hidet.ir.type import FuncType + + +class PrimitiveFunctionRegistry: + def __init__(self, space: str, name: str, func_type: FuncType, function: Optional[Function] = None, generic: bool = False): + key = '{}.{}'.format(space, name) + self.var = Var(hint=key, type=func_type) + self.space: str = space + self.name: str = name + self.func_type: FuncType = func_type + self.function: Optional[Function] = function + + self.generic = generic + self.dispatch_dtype_rules: Dict[str, str] = {} + + def dispatch_dtype(self, dtype: str, space: str, func_name: str): + if not self.generic: + raise ValueError('Can only dispatch a generic function.') + func_key = '{}.{}'.format(space, func_name) + self.dispatch_dtype_rules[dtype] = func_key + + +class PrimitiveFunctionPool: + def __init__(self): + self.key2func: Dict[str, PrimitiveFunctionRegistry] = {} + + def register(self, space: str, name: str, func_or_type: Union[Function, FuncType], generic): + if isinstance(func_or_type, Function): + registry = PrimitiveFunctionRegistry( + name=name, + func_type=FuncType.from_func(func_or_type), + space=space, + function=func_or_type, + generic=generic + ) + elif isinstance(func_or_type, FuncType): + registry = PrimitiveFunctionRegistry( + name=name, + func_type=func_or_type, + space=space, + function=None, + generic=generic + ) + else: + raise TypeError('Expect a Function or FuncType to register a primitive function, got {}'.format(type(func_or_type))) + key = '{}.{}'.format(space, name) + if key in self.key2func: + raise KeyError('Primitive function {} has already registered.'.format(key)) + self.key2func[key] = registry + return registry + + def lookup(self, func_var: Var) -> PrimitiveFunctionRegistry: + if func_var.hint not in self.key2func: + raise KeyError('Can not find primitive function via variable: {}.'.format(func_var)) + return self.key2func.get(func_var.hint) + + def lookup_by_key(self, key: str) -> PrimitiveFunctionRegistry: + if key not in self.key2func: + raise KeyError('Can not find primitive function with key: {}.'.format(key)) + return self.key2func[key] + + def lookup_by_name(self, target: str, name: str) -> PrimitiveFunctionRegistry: + key = '{}.{}'.format(target, name) + if key not in self.key2func: + candidates = '\n'.join(self.registered_names()[target]) + raise ValueError('Can not find primitive function with target "{}" and name "{}", candidates:\n{}'.format(target, name, candidates)) + return self.key2func[key] + + def registered_names(self) -> Dict[str, List[str]]: + ret = {} + for name in self.key2func: + target, func_name = name.split('.') + if target not in ret: + ret[target] = [] + ret[target].append(func_name) + return ret + + def has_registered(self, key: str) -> bool: + return key in self.key2func + + +primitive_func_pool = PrimitiveFunctionPool() + + +def is_primitive_function(key: str): + return key in primitive_func_pool.key2func + + +def lookup_primitive_function(key: str) -> PrimitiveFunctionRegistry: + return primitive_func_pool.lookup_by_key(key) + + +def registered_primitive_functions() -> List[str]: + return list(primitive_func_pool.key2func.keys()) + + +def register_primitive_function(target, name, func_or_type: Union[Function, FuncType], generic=False) -> PrimitiveFunctionRegistry: + """ + Register a primitive function. + + Parameters + ---------- + target: str + The target device of the primitive function works on. Candidates: 'base', 'cuda', 'cpu'. + 'base' indicates this function is generic to different devices. + 'cuda' indicates this is a primitive function in CUDA programming platform. + 'cpu' indicates this is a primitive function specific in CPU. + + name: str + The name of the primitive function. + + func_or_type: Union[Function, FuncType] + Function definition or function type of the primitive function. + When function type is given, this function is implemented by underlying language (e.g., cuda c). + + generic: bool + Whether this function is a generic function. A generic function will be lowered to a concrete primitive + function according to the calling arguments' type. + + Returns + ------- + ret: PrimitiveFunctionRegistry + The entry of registered primitive function. + """ + return primitive_func_pool.register(target, name, func_or_type, generic) + diff --git a/python/hidet/ir/stmt.py b/python/hidet/ir/stmt.py new file mode 100644 index 0000000..257cab3 --- /dev/null +++ b/python/hidet/ir/stmt.py @@ -0,0 +1,115 @@ +from typing import Sequence, Tuple +from typing import List, Union, Optional +from hidet.ir.node import Node +from hidet.ir.expr import Var, Expr, convert, Constant + + +class Stmt(Node): + pass + + +class EvaluateStmt(Stmt): + def __init__(self, expr): + super().__init__() + self.expr = convert(expr) + + +class DeclareStmt(Stmt): + def __init__(self, var, init: Optional[Expr] = None): + super().__init__() + self.var: Var = var + self.init: Optional[Expr] = init + + +class BufferStoreStmt(Stmt): + def __init__(self, buf, indices, value): + super().__init__() + assert isinstance(indices, (list, tuple)), type(indices) + self.buf = buf + self.indices = convert(indices) + self.value = convert(value) + + +class AssignStmt(Stmt): + def __init__(self, var, value): + super().__init__() + self.var = var + self.value = convert(value) + + +class ReturnStmt(Stmt): + def __init__(self, ret_value: Optional[Expr] = None): + super().__init__() + self.ret_value = ret_value + + +class LetStmt(Stmt): + def __init__(self, bind_vars, bind_values, body=None): + if not isinstance(bind_vars, (list, tuple)): + bind_vars = [bind_vars] + if not isinstance(bind_values, (list, tuple)): + bind_values = [bind_values] + assert len(bind_vars) == len(bind_values) + assert len(bind_vars) > 0 + bind_values = [convert(bind_value) for bind_value in bind_values] + self.bind_vars = bind_vars + self.bind_values = bind_values + self.body = body + + +class ForStmt(Stmt): + DEFAULT_UNROLL_LIMIT = 32 + + def __init__(self, loop_var, extent, unroll: Optional[Union[int, bool]] = None, body=None): + from hidet.ir.functors import simplify + super().__init__() + self.loop_var: Var = loop_var + self.extent = simplify(convert(extent)) + self.unroll = unroll + self.body = body + + +class IfStmt(Stmt): + def __init__(self, cond: Expr, then_body=None, else_body=None): + super().__init__() + self.cond = convert(cond) + self.then_body = then_body + self.else_body = else_body + + +class AssertStmt(Stmt): + def __init__(self, cond: Union[Expr, bool], msg: str): + super().__init__() + self.cond = convert(cond) + self.msg = msg + + +class AsmStmt(Stmt): + def __init__(self, + template_string: str = "", + outputs: Sequence[Tuple[str, Expr]] = (), + inputs: Sequence[Tuple[str, Expr]] = (), + is_volatile=False): + self.template_string = template_string + self.output_labels = [pr[0] for pr in outputs] + self.output_exprs = [pr[1] for pr in outputs] + self.input_labels = [pr[0] for pr in inputs] + self.input_exprs = [pr[1] for pr in inputs] + self.is_volatile = is_volatile + + +class BlackBoxStmt(Stmt): + def __init__(self, template_string: str, *exprs: Sequence[Expr]): + super().__init__() + self.template_string: str = template_string + self.exprs: Tuple[Expr] = convert(exprs) + expect_args_num = self.template_string.count('{}') + assert expect_args_num == len(exprs) + + +class SeqStmt(Stmt): + def __init__(self, seq: List[Stmt]): + super().__init__() + self.seq: Tuple[Stmt] = tuple(seq) + for stmt in seq: + assert isinstance(stmt, Stmt) diff --git a/python/hidet/ir/task.py b/python/hidet/ir/task.py new file mode 100644 index 0000000..df299ff --- /dev/null +++ b/python/hidet/ir/task.py @@ -0,0 +1,257 @@ +from __future__ import annotations +from typing import Any +import copy +import os +import pickle +from typing import Dict, List, Union, Optional, Sequence, Type, Tuple, Callable, TypeVar +from hidet.ir.node import Node +from hidet.ir.expr import Expr, Var, TensorElement, var +from hidet.ir.func import IRModule +from hidet.ir.dialects.compute import TensorNode, ScalarNode, GridCompute + + +class Target: + _supported_targets = ['cuda', 'cpu'] + + def __init__(self, name: str, attrs: List[str]): + if name not in self._supported_targets: + raise ValueError('Does not support target {}, candidates {}.'.format(name, self._supported_targets)) + self.name = name + self.attrs = attrs + + @staticmethod + def from_string(target_string: str) -> Target: + items = target_string.split() + name, attrs = items[0], items[1:] + return Target(name, attrs) + + +class Prologue(Node): + def __init__(self, extra_inputs, indices, value): + self.extra_inputs: List[TensorNode] = extra_inputs + self.indices: List[Var] = indices + self.value: Expr = value + + +class Epilogue(Node): + def __init__(self, extra_inputs, indices, orig_value, value, out_indices, out_tensor): + self.extra_inputs: List[TensorNode] = extra_inputs + self.indices: List[Var] = indices + self.orig_value: Var = orig_value + self.value: Expr = value + self.out_indices: List[Expr] = out_indices + self.out_tensor: TensorNode = out_tensor + + +class TaskContext: + contexts = [] + + def __init__(self, space_level: int = 0, resolve_out_dir: str = None): + self.space_level = space_level + self.resolve_out_dir = resolve_out_dir + + def __enter__(self): + self.contexts.append(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.contexts.pop() + + @staticmethod + def current() -> TaskContext: + return TaskContext.contexts[-1] + + +TaskContext.contexts.append(TaskContext()) # fallback context + + +class InverseMap: + def __init__(self, axes: List[Var], indices: List[Expr]): + from hidet.ir.functors import simplify + self.axes: List[Var] = axes + self.indices: List[Expr] = [simplify(e) for e in indices] + + @staticmethod + def from_lambda(func, num_args=None) -> InverseMap: + num_args = num_args if num_args is not None else func.__code__.co_argcount + axes = [var('v') for v in range(num_args)] + indices = list(func(*axes)) + return InverseMap(axes, indices) + + @staticmethod + def identity(num_args: int) -> InverseMap: + return InverseMap.from_lambda(lambda *indices: list(indices), num_args=num_args) + + def __add__(self, other) -> InverseMap: + from hidet.ir.functors import rewrite + if not isinstance(other, InverseMap): + raise ValueError('Can not concat InverseMap with {}'.format(type(other))) + lhs, rhs = self, other + if len(lhs.indices) != len(rhs.axes): + raise ValueError('Can not concat InverseMap a and b, ' + 'where a has {} indices and b has {} axes'.format(len(lhs.indices), len(rhs.axes))) + rmap = {a: b for a, b in zip(rhs.axes, lhs.indices)} + indices = [rewrite(index_expr, rmap) for index_expr in rhs.indices] + return InverseMap(lhs.axes, indices) + + +class Task(Node): + def __init__(self, name, inputs, outputs, prologues=None, epilogues=None, parameters=None, inverse_map=None, attributes: Optional[Dict] = None): + """ + A Task is a computation definition. + + param_inputs ===========> task_inputs ===================> task_outputs ===========> param_outputs + | prologues task computations epilogues + | ^ ^ + v | | + +-----------+-------------------------------------------------------------+ + + Constraints: + 1. Each task input can have zero or one prologue. + 2. Each task output can have zero or one epilogue. + 3. Prologue and epilogue can only have extra inputs from param inputs. + 4. When a task input has prologue, it should not appear in param input. + 5. When a task output has epilogue, it should not appear in param output. + + + Parameters + ---------- + name: str + The name of the task. Can only contain a-z, A-Z, underscore, and digits. + inputs: List[TensorNode] + The inputs of the task computation. + outputs: List[TensorNode] + The outputs of the task computation. + prologues: Dict[TensorNode, Prologue] + The prologues. + epilogues: Dict[TensorNode, Epilogue] + The epilogues. + parameters: List[TensorNode] + The list of parameters in the final kernel. + inverse_map: Dict[TensorNode, Union[InverseMap, Callable[[Any], Any]]] + If the mapping of input axes to output axes are invertible, then inverse_map contains + the inverse map. It is used to convert a task to epilogue of previous task. + """ + self.attributes: Dict[str, Union[str, float, int, bool]] = attributes if attributes is not None else {} + self.name = name + self.inputs: List[TensorNode] = inputs + self.outputs: List[TensorNode] = outputs + self.prologues: Dict[TensorNode, Prologue] = prologues if prologues else {} + self.epilogues: Dict[TensorNode, Epilogue] = epilogues if epilogues else {} + self.parameters: List[TensorNode] = parameters if parameters else inputs + outputs + + inverse_map = inverse_map if inverse_map else {} + if not isinstance(inverse_map, dict): + raise ValueError('inverse_map should be a dict') + self.inverse_map: Dict[TensorNode, InverseMap] = { + a: (b if isinstance(b, InverseMap) else InverseMap.from_lambda(b)) for a, b in inverse_map.items() + } + + def implement(self, target: Union[Target, str]) -> IRModule: + from hidet.tos.ops.schedules import generic_cuda_schedule, generic_cpu_schedule + if isinstance(target, str): + target = Target.from_string(target) + if target.name == 'cuda': + ret = self.implement_cuda() + if ret is NotImplemented: + ret = generic_cuda_schedule(self) + elif target.name == 'cpu': + ret = self.implement_cpu() + if ret is NotImplemented: + ret = generic_cpu_schedule(self) + else: + raise ValueError() + if not isinstance(ret, IRModule): + raise AssertionError('The task implement function should return an IRModule, but got a {}.'.format(type(ret))) + return ret + + def implement_cuda(self) -> IRModule: + return NotImplemented + + def implement_cpu(self) -> IRModule: + return NotImplemented + + def fast_implement(self, space_level: int) -> bool: + if space_level == 0: + return True + else: + if 'implement_cuda' not in self.__class__.__dict__: + return True + else: + return False + + def copy(self: Task) -> Task: + cls = type(self) + task = object.__new__(cls) + task.name = self.name + task.inputs = self.inputs.copy() + task.outputs = self.outputs.copy() + task.prologues = self.prologues.copy() + task.epilogues = self.epilogues.copy() + task.parameters = self.parameters.copy() + task.inverse_map = self.inverse_map.copy() + for name in self.__dict__: + if name not in task.__dict__: + task.__dict__[name] = copy.copy(self.__dict__[name]) + return task + + def save(self, fname: str): + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + with open(fname, 'wb') as f: + pickle.dump(self, f) + + @staticmethod + def load(fname: str) -> Task: + with open(fname, 'rb') as f: + return pickle.load(f) + + +def is_injective_task(task: Task) -> bool: + """ + Check whether a task is an injective task. A task is injective if and only if there is no reduce compute in + the task. + + Parameters + ---------- + task: Task + The task to check. + + Returns + ------- + ret: bool + Whether the task is injective. + """ + from hidet.ir.functors import collect + scalar_nodes: List[ScalarNode] = collect(task.outputs, ScalarNode, stop_when_found=False) + return all(sn.reduce_compute is None for sn in scalar_nodes) + + +def is_unary_injective_task(task: Task) -> bool: + return len(task.inputs) == 1 and len(task.outputs) == 1 and is_injective_task(task) + + +def is_elementwise_task(task: Task) -> bool: + """ + Check whether a task is an elementwise task. A task is elementwise if and only if it is a unary injective and + invertible. + + Parameters + ---------- + task: Task + The task to check. + + Returns + ------- + ret: bool + Whether the task is elementwise. + """ + return is_unary_injective_task(task) and len(task.inverse_map) > 0 + + +def save_task(task: Task, fname: str): + task.save(fname) + + +def load_task(fname: str) -> Task: + return Task.load(fname) + diff --git a/python/hidet/ir/type.py b/python/hidet/ir/type.py new file mode 100644 index 0000000..25fa92c --- /dev/null +++ b/python/hidet/ir/type.py @@ -0,0 +1,272 @@ +from __future__ import annotations +from typing import Sequence, Optional, Union, List, Tuple, Mapping, Callable, Iterable +import numpy as np + +from hidet import ir +from hidet.ir.node import Node + +# typing forward declaration +Expr = 'Expr' +Int = Union['Expr', int] + + +class TypeNode(Node): + pass + + +# scope +class Scope(Node): + def __init__(self, name): + assert name in ['host', 'global', 'shared', 'register', 'unspecified'] + self.name = name + + +dtype_list = [ + 'int64', + 'float64', + 'int32', + 'uint32', + 'float32', + 'tfloat32', + 'bfloat16', + 'int32', + 'float16', + 'uint8', + 'bool' +] + +float_dtype_rank = {} +for idx, dtype in enumerate(dtype_list): + float_dtype_rank[dtype] = len(dtype_list) - idx + + +class ScalarType(TypeNode): + def __init__(self, name): + if name not in dtype_list: + raise ValueError('Can not recognize data type {}, candidates:\n{}'.format(name, dtype_list)) + self.name = name + + def __eq__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return self.name == other.name + + def __ne__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return self.name != other.name + + def __le__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return float_dtype_rank[self.name] <= float_dtype_rank[other.name] + + def __lt__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return float_dtype_rank[self.name] < float_dtype_rank[other.name] + + def __ge__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return float_dtype_rank[self.name] >= float_dtype_rank[other.name] + + def __gt__(self, other): + if isinstance(other, str): + other = ScalarType(other) + return float_dtype_rank[self.name] > float_dtype_rank[other.name] + + def __hash__(self): + return hash(self.name) + + @staticmethod + def from_numpy_dtype(np_dtype): + if np_dtype == np.float32: + return ScalarType('float32') + elif np_dtype == np.int32: + return ScalarType('int32') + elif np_dtype == np.int64: + return ScalarType('int64') + else: + raise ValueError("Unrecognized numpy data type: '{}'".format(np_dtype)) + + @staticmethod + def float16() -> ScalarType: + return ScalarType('float16') + + @staticmethod + def float32() -> ScalarType: + return ScalarType('float32') + + @staticmethod + def int32() -> ScalarType: + return ScalarType('int32') + + @staticmethod + def int64() -> ScalarType: + return ScalarType('int64') + + @staticmethod + def uint32() -> ScalarType: + return ScalarType('uint32') + + @staticmethod + def uint64() -> ScalarType: + return ScalarType('uint64') + + @staticmethod + def uint8() -> ScalarType: + return ScalarType('uint8') + + def nbytes(self) -> int: + bytes_dict = { + 'float32': 4, + 'tfloat32': 4, + 'bfloat16': 2, + 'float16': 2, + 'int32': 4, + 'uint8': 1, + 'uint32': 4, + 'int64': 8, + 'bool': 1 + } + return bytes_dict[self.name] + + def is_float(self) -> bool: + return self.name in ['float16', 'bfloat16', 'float32', 'float64'] + + def is_integer(self) -> bool: + return self.name in ['bool', 'uint8', 'int32', 'uint32', 'int64'] + + @staticmethod + def resolve_out_dtype(lhs: Union[ScalarType, str], rhs: Union[ScalarType, str]) -> ScalarType: + lhs = ScalarType(lhs) if isinstance(lhs, str) else lhs + rhs = ScalarType(rhs) if isinstance(rhs, str) else rhs + if lhs.is_float() and rhs.is_float(): + nbytes = max(lhs.nbytes(), rhs.nbytes()) + return ScalarType('float{}'.format(nbytes * 8)) + elif lhs.is_integer() and rhs.is_integer(): + nbytes = max(lhs.nbytes(), rhs.nbytes()) + return ScalarType('int{}'.format(nbytes * 8)) + else: + raise NotImplementedError('resolve out dtype for {} and {}'.format(lhs, rhs)) + + +class TensorType(TypeNode): + def __init__(self, + scope: Optional[Scope] = None, + dtype: Optional[ScalarType] = None, + shape: Optional[Tuple[Expr, ...]] = None, + layout: Optional['DataLayout'] = None): + from hidet.ir.layout import DataLayout + self.scope: Scope = scope + self.scalar_type: ScalarType = dtype + self.shape: Tuple[Expr] = shape + self.layout: DataLayout = layout + + def storage_bytes(self) -> Expr: + return self.layout.size * self.scalar_type.nbytes() + + def slice_out(self, dims: Sequence[int]) -> 'TensorType': + layout = self.layout.slice_out(dims) + return tensor_type(self.scope, self.scalar_type, layout=layout) + + def split(self, dim2factor: Mapping[int, Int]) -> 'TensorType': + layout = self.layout.split(dim2factor) + return tensor_type(self.scope, self.scalar_type, layout=layout) + + def reorder(self, order: Sequence[int]): + layout = self.layout.reorder(order) + return tensor_type(self.scope, self.scalar_type, layout=layout) + + def const_shape(self) -> List[int]: + return [int(v) for v in self.shape] + + +TypeLike = Union[str, TypeNode] + + +class FuncType(TypeNode): + def __init__(self, + param_types: Optional[List[TypeLike]] = None, + ret_type: Optional[TypeLike] = None, + type_infer_func: Optional[Callable] = None): # Callable[[a number of TypeNode], TypeNode] + self.param_types = [self._convert_type(tp) for tp in param_types] if param_types is not None else None + self.ret_type = self._convert_type(ret_type) if ret_type is not None else None + self.type_infer_func = type_infer_func + assert not all(v is None for v in [ret_type, type_infer_func]), 'Please provide either a static type or a type infer func' + + def ret_type_on(self, arg_types: List[TypeNode]) -> TypeNode: + if self.ret_type is not None: + # todo: add type checking + return self.ret_type + else: + return self.type_infer_func(arg_types) + + def _convert_type(self, tp: Union[str, TypeNode]): + if isinstance(tp, str): + return ScalarType(tp) + else: + return tp + + @staticmethod + def from_func(func): + return FuncType([param.type for param in func.params], func.ret_type) + + +def scalar_type(type_name): + return ScalarType(type_name) + + +def tensor_type(scope, dtype, shape: Optional[List[Union[int, Expr]]] = None, layout: Optional['DataLayout'] = None): + """ + Construct a tensor type. Shape and layout must be given at least one. + + Parameters + ---------- + scope: str or Scope + The scope of the tensor. Scope can be 'host', 'global', 'shared', and 'local' + + dtype: str or ScalarType + The scalar type of this tensor. + + shape: Optional[List[Union[int, Expr]]] + The shape of the tensor. If not given, the shape in layout will be used. + + layout: Optional[DataLayout] + The layout of the tensor. If not given, the row major layout of given shape will + be used. + + Returns + ------- + ret: TensorType + The constructed tensor type + """ + from hidet.ir.expr import convert, Constant + from hidet.ir.layout import DataLayout, StridesLayout + if isinstance(scope, str): + scope = Scope(scope) + if not isinstance(scope, Scope): + raise ValueError('Tensor type scope expect a "str" or "Scope", but got {}'.format(type(scope))) + if isinstance(dtype, str): + dtype = ScalarType(dtype) + if not isinstance(dtype, ScalarType): + raise ValueError('Scalar type expect a "str" or "ScalarType", but got {}'.format(type(dtype))) + if shape is None and layout is None: + raise ValueError('Tensor type must give either shape or layout') + elif shape is None: + assert isinstance(layout, DataLayout) + shape = layout.shape + elif layout is None: + layout = DataLayout.row_major([int(v) for v in shape]) + else: + assert isinstance(layout, DataLayout) + assert isinstance(shape, (list, tuple)) + for a, b in zip(shape, layout.shape): + assert int(a) == int(b) + shape = convert(shape) + return TensorType(scope, dtype, shape, layout) + + +def max_float_dtype(float_dtypes: Iterable[str]) -> str: + return max(float_dtypes, key=lambda dtype: float_dtype_rank[dtype]) diff --git a/python/hidet/ir/utils/__init__.py b/python/hidet/ir/utils/__init__.py new file mode 100644 index 0000000..da4b7c5 --- /dev/null +++ b/python/hidet/ir/utils/__init__.py @@ -0,0 +1,5 @@ +from . import call_graph +from . import hash_sum +from . import index_transform + +from .index_transform import index_serialize, index_deserialize diff --git a/python/hidet/ir/utils/call_graph.py b/python/hidet/ir/utils/call_graph.py new file mode 100644 index 0000000..181393d --- /dev/null +++ b/python/hidet/ir/utils/call_graph.py @@ -0,0 +1,88 @@ +from typing import Union, List +from collections import defaultdict +from hidet.ir.expr import Call +from hidet.ir.func import IRModule, Function +from hidet.ir.functors import collect +from hidet.ir.primitives import is_primitive_function, lookup_primitive_function + + +class CallGraphNode: + def __init__(self, func): + self.func = func + self.callers = [] + self.callees = [] + + def add_caller(self, caller): + if caller not in self.callers: + self.callers.append(caller) + + def add_callee(self, callee): + if callee not in self.callees: + self.callees.append(callee) + + +class CallGraph: + def __init__(self, ir_module: IRModule): + self.nodes: List[CallGraphNode] = [] + self.func2node = {} + + self.order: List[CallGraphNode] = [] # topological order, from caller to callee + self.reversed_order: List[CallGraphNode] = [] + + for func in ir_module.functions.values(): + node = CallGraphNode(func) + self.func2node[func] = node + self._add_node(node) + + for func in ir_module.functions.values(): + caller = func + for call in collect(func.body, Call): + if is_primitive_function(call.func_var.hint): + entry = lookup_primitive_function(call.func_var.hint) + if entry.function is not None: + if '.' in call.func_var.hint: + target, name = call.func_var.hint.split('.') + else: + name = call.func_var.hint + callee = ir_module.lookup(name) + else: + continue + else: + callee = ir_module.lookup(call.func_var.hint) + self._add_edge(caller, callee) + + self._init_order() + + def _add_node(self, node): + if node not in self.nodes: + self.nodes.append(node) + + def _add_edge(self, caller: Union[Function, CallGraphNode], callee: Union[Function, CallGraphNode]): + if isinstance(caller, Function): + caller = self.func2node[caller] + if isinstance(callee, Function): + callee = self.func2node[callee] + caller.add_callee(callee) + callee.add_caller(caller) + + def _init_order(self): + in_degree = defaultdict(int) + # do not support recursive calling now + for node in self.nodes: + for callee in node.callees: + in_degree[callee] += 1 + qu = [] + for node in self.nodes: + if in_degree[node] == 0: + qu.append(node) + self.order = [] + while len(qu) > 0: + u = qu.pop() + self.order.append(u) + for callee in u.callees: + in_degree[callee] -= 1 + if in_degree[callee] == 0: + qu.append(callee) + + self.reversed_order = list(reversed(self.order)) + diff --git a/python/hidet/ir/utils/hash_sum.py b/python/hidet/ir/utils/hash_sum.py new file mode 100644 index 0000000..2a30c19 --- /dev/null +++ b/python/hidet/ir/utils/hash_sum.py @@ -0,0 +1,38 @@ +from typing import Iterable +from hidet.ir.functors import ExprRewriter +import numpy as np + + +class HashSum: + def __init__(self, obj): + if isinstance(obj, np.ndarray): + self.value = id(obj) + else: + self.value = hash(obj) + self.hashed_obj = obj + + def __str__(self): + return str(self.value % 107) + + def __add__(self, other): + return HashSum((self.value, other)) + + def __iadd__(self, other): + self.value = HashSum((self.value, other.value)).value + return self + + def __and__(self, other): + return HashSum.hash_set([self, other]) + + def __hash__(self): + return self.value + + def __eq__(self, other): + assert isinstance(other, HashSum) + return self.value == other.value + + @staticmethod + def hash_set(objs: Iterable) -> 'HashSum': + return HashSum(tuple(sorted([hash(obj) for obj in objs]))) + + diff --git a/python/hidet/ir/utils/index_transform.py b/python/hidet/ir/utils/index_transform.py new file mode 100644 index 0000000..b00d240 --- /dev/null +++ b/python/hidet/ir/utils/index_transform.py @@ -0,0 +1,28 @@ +from typing import List +from ..expr import Expr, convert + + +def index_serialize(indices: List[Expr], shape: List[int]) -> Expr: + if len(shape) == 0: + return convert(0) + scalar_index: Expr = convert(0) + acc = 1 + for idx_value, extent in reversed(list(zip(indices, shape))): + scalar_index += idx_value * acc + acc = extent + return scalar_index + + +def index_deserialize(scalar_index: Expr, shape: List[int]) -> List[Expr]: + if len(shape) == 0: + return [] + indices = [] + acc = 1 + for r, extent in enumerate(reversed(shape)): + if r < len(shape) - 1: + indices.append(scalar_index // acc % extent) + else: + indices.append(scalar_index // acc) + acc *= extent + return list(reversed(indices)) + diff --git a/python/hidet/libinfo.py b/python/hidet/libinfo.py new file mode 100644 index 0000000..0715af7 --- /dev/null +++ b/python/hidet/libinfo.py @@ -0,0 +1,25 @@ +from typing import List +import os + + +def get_include_dir(): + cur_file = os.path.abspath(__file__) + root = os.path.join(os.path.dirname(cur_file), '..', '..') + include_dir = os.path.join(root, 'include') + return os.path.abspath(include_dir) + + +def get_library_search_dirs() -> List[str]: + cur_file = os.path.abspath(__file__) + root = os.path.join(os.path.dirname(cur_file), '..', '..') + relative_dirs = [ + './lib', + './build/lib', + './build-release/lib', + './build-debug/lib', + ] + return [os.path.abspath(os.path.join(root, relative)) for relative in relative_dirs] + + +if __name__ == '__main__': + print(get_include_dir()) diff --git a/python/hidet/runtime/__init__.py b/python/hidet/runtime/__init__.py new file mode 100644 index 0000000..c3f2f43 --- /dev/null +++ b/python/hidet/runtime/__init__.py @@ -0,0 +1,8 @@ +from . import module +from . import cuda_event + +from .module import CompiledModule, CompiledFunction +from .storage import Storage +from .cuda_event import cuda_event_pool, CudaEventPool, CudaEvent + + diff --git a/python/hidet/runtime/cuda_event.py b/python/hidet/runtime/cuda_event.py new file mode 100644 index 0000000..8345758 --- /dev/null +++ b/python/hidet/runtime/cuda_event.py @@ -0,0 +1,37 @@ +from typing import List +from hidet.ffi import cuda + + +class CudaEvent: + def __init__(self, handle, pool): + self.handle: int = handle + self.pool: CudaEventPool = pool + + def __del__(self): + self.pool.event_handles.append(self.handle) + self.handle = 0 + + def elapsed_time_since(self, start_event) -> float: + return cuda.event_elapsed_time(start_event.handle, self.handle) + + def record_on(self, stream_handle: int = 0): + cuda.event_record(self.handle, stream_handle) + + +class CudaEventPool: + def __init__(self): + self.event_handles: List[int] = [] + + def new_event(self) -> CudaEvent: + if len(self.event_handles) > 0: + return CudaEvent(self.event_handles.pop(), self) + else: + return CudaEvent(cuda.create_event(), self) + + def __del__(self): + while len(self.event_handles) > 0: + handle = self.event_handles.pop() + cuda.destroy_event(handle) + + +cuda_event_pool = CudaEventPool() diff --git a/python/hidet/runtime/cuda_graph.py b/python/hidet/runtime/cuda_graph.py new file mode 100644 index 0000000..670685e --- /dev/null +++ b/python/hidet/runtime/cuda_graph.py @@ -0,0 +1,133 @@ +from typing import List, Optional +import time +from hidet.tos import randn_like, zeros_like +from hidet.ffi import cuda +from hidet.runtime.storage import CudaMemoryPool +from hidet.runtime.cuda_stream import CudaStream +from hidet.tos import Tensor, FlowGraph + + +def dummy_input_like(tensor: Tensor) -> Tensor: + if tensor.dtype == 'float32': + return randn_like(tensor) + elif tensor.dtype in ['int64', 'int32', 'int8', 'uint64', 'uint32', 'uint8']: + return zeros_like(tensor) + else: + raise ValueError('Can not generate dummy input for data type {}'.format(tensor.dtype)) + + +class CudaGraphExec: + def __init__(self, exec_handle: int): + self.exec_handle = exec_handle + + def __del__(self): + cuda.destroy_graph_exec(self.exec_handle) + + def launch(self, stream: Optional[CudaStream] = None): + stream_handle = stream.handle if stream else 0 + cuda.launch_graph_exec(self.exec_handle, stream_handle) + + +class CudaGraphImpl: + def __init__(self): + self.stream = CudaStream() + self.graph_handle: Optional[int] = None + + def __enter__(self): + self.stream.__enter__() + cuda.stream_begin_capture(self.stream.handle) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stream.__exit__(exc_type, exc_val, exc_tb) + if self.graph_handle is not None: + cuda.destroy_graph(self.graph_handle) + self.graph_handle = cuda.stream_end_capture(self.stream.handle) + + def __del__(self): + if self.graph_handle is not None: + cuda.destroy_graph(self.graph_handle) + + def instantiate(self) -> CudaGraphExec: + exec_handle = cuda.instantiate_graph(self.graph_handle) + graph_exec = CudaGraphExec(exec_handle) + return graph_exec + + +class CudaGraph: + def __init__(self, flow_graph: FlowGraph): + flow_graph.update_nodes() + self.flow_graph = flow_graph + self.mem_pool = CudaMemoryPool( + block_size=4096, + max_reserve_size=10 * 1024 ** 3 + ) + self.cuda_graph_impl = CudaGraphImpl() + with self.mem_pool: + self.inputs = [dummy_input_like(tensor) for tensor in flow_graph.inputs] + # run twice to avoid any memory allocation during capturing + self.outputs = flow_graph.forward(*self.inputs) + self.outputs = flow_graph.forward(*self.inputs) + self.mem_pool.storage_device.freeze(True) + with self.cuda_graph_impl: + self.outputs = flow_graph.forward(*self.inputs) + if isinstance(self.outputs, Tensor): + self.outputs = [self.outputs] + # self.mem_pool.storage_device.freeze(False) + self.cuda_graph_exec = self.cuda_graph_impl.instantiate() + + def get_input_tensors(self) -> List[Tensor]: + return self.inputs + + def get_output_tensors(self) -> List[Tensor]: + return self.outputs + + def set_input_tensors(self, input_tensors: List[Tensor]): + if len(input_tensors) != len(self.inputs): + raise ValueError('Expect {} input tensors, got {}'.format(len(self.inputs), len(input_tensors))) + for idx, tensor in enumerate(input_tensors): + self.set_input_tensor(idx, tensor) + + def set_input_tensor(self, idx: int, input_tensor: Tensor): + src = input_tensor + dst = self.inputs[idx] + if src.device != 'cuda': + src = src.cuda() + if src.dtype != dst.dtype: + msg = 'The i-th {} input tensor expect data type {}, but got a tensor with data type {}.'.format(idx, dst.dtype, src.dtype) + raise ValueError(msg) + if any(a != b for a, b in zip(input_tensor.shape, self.inputs[idx].shape)): + msg = 'The i-th {} input tensor expect shape {}, bot got a tensor with shape {}'.format(idx, dst.shape, src.shape) + raise ValueError(msg) + cuda.memcpy_async(src.storage.addr, dst.storage.addr, num_bytes=dst.nbytes, kind=cuda.DeviceToDevice) + + def run_with_inputs(self, inputs: List[Tensor], stream: Optional[CudaStream] = None) -> List[Tensor]: + self.set_input_tensors(inputs) + cuda.device_synchronize() + self.run(stream) + cuda.device_synchronize() + return self.get_output_tensors() + + def run(self, stream: Optional[CudaStream] = None): + self.cuda_graph_exec.launch(stream) + + def profile(self, warmup, number, repeat) -> List[float]: + latency_list = [] + for i in range(warmup): + self.run() + for i in range(repeat): + cuda.device_synchronize() + start = time.time() + for j in range(number): + self.run() + cuda.device_synchronize() + end = time.time() + latency_list.append((end - start) / number) + return latency_list + + def __del__(self): + self.mem_pool.storage_device.freeze(False) + + +def create_cuda_graph(flow_graph: FlowGraph) -> CudaGraph: + exec_ctx = CudaGraph(flow_graph) + return exec_ctx diff --git a/python/hidet/runtime/cuda_stream.py b/python/hidet/runtime/cuda_stream.py new file mode 100644 index 0000000..e7e1fc0 --- /dev/null +++ b/python/hidet/runtime/cuda_stream.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from typing import List +from hidet.ffi import cuda +from hidet.ffi import runtime_api + + +class CudaStream: + stack: List[CudaStream] = [] + + def __init__(self): + self.handle = cuda.create_stream() + + def __enter__(self): + CudaStream.stack.append(self) + runtime_api.set_current_stream(self.handle) + + def __exit__(self, exc_type, exc_val, exc_tb): + CudaStream.stack.pop() + if len(CudaStream.stack) == 0: + # set to default stream + runtime_api.set_current_stream(0) + else: + cur_stream = CudaStream.stack[-1] + runtime_api.set_current_stream(cur_stream.handle) + + def __del__(self): + cuda.destroy_stream(self.handle) + + def synchronize(self): + cuda.stream_synchronize(self.handle) diff --git a/python/hidet/runtime/module.py b/python/hidet/runtime/module.py new file mode 100644 index 0000000..c370f45 --- /dev/null +++ b/python/hidet/runtime/module.py @@ -0,0 +1,25 @@ +from typing import Dict +from hidet.ir.func import Function, IRModule + + +class CompiledModule: + def __init__(self, ir_module, funcs): + self.ir_module: IRModule = ir_module + self.funcs: Dict[str, CompiledFunction] = funcs + + def __getitem__(self, item: str): + return self.funcs[item] + + +class CompiledFunction: + def __init__(self, name, packed_func): + from hidet.ffi import PackedFunc + self.name: str = name + self.packed_func: PackedFunc = packed_func + + def __call__(self, *args): + self.packed_func(*args) + + def profile(self, *args, warmup=1, number=1, repeat=10): + return self.packed_func.profile(*args, warmup=warmup, number=number, repeat=repeat) + diff --git a/python/hidet/runtime/storage.py b/python/hidet/runtime/storage.py new file mode 100644 index 0000000..1eb3b26 --- /dev/null +++ b/python/hidet/runtime/storage.py @@ -0,0 +1,356 @@ +from __future__ import annotations +from typing import Callable, Dict, List, Type +import warnings +from collections import defaultdict +import ctypes +import numpy as np +from hidet.ffi import cuda + + +def nbytes2str(nbytes: int) -> str: + if nbytes > 1024 * 1024: + size = nbytes // 1024 // 1024 + unit = 'MiB' + elif nbytes > 1024: + size = nbytes // 1024 + unit = 'KiB' + else: + size = nbytes + unit = 'Bytes' + return '{} {}'.format(size, unit) + + +class StorageDevice: + def __init__(self): + self.froze = False + + def name(self): + raise NotImplementedError() + + def freeze(self, flag: bool): # when freeze, no allocate or free should happen. used in CudaGraph + self.froze = flag + + def allocate(self, nbytes) -> int: + raise NotImplementedError() + + def free(self, addr): + raise NotImplementedError() + + def allocated_memory(self) -> int: + raise NotImplementedError() + + def peak_allocated_memory(self) -> int: + raise NotImplementedError() + + def free_memory(self) -> int: + raise NotImplementedError() + + def total_memory(self) -> int: + raise NotImplementedError() + + +class CudaStorageDevice(StorageDevice): + def __init__(self): + super().__init__() + self.addr2nbytes = {} + self._peak_allocated_memory = 0 + self._allocated_memory = 0 + + def name(self): + return 'cuda' + + def allocate(self, nbytes): + if self.froze: + raise MemoryError('Should not allocate when the device is frozen.') + + addr = cuda.malloc_async(nbytes) + if addr == 0 and nbytes != 0: + # out of memory + return 0 + self._allocated_memory += nbytes + self._peak_allocated_memory = max(self._peak_allocated_memory, self._allocated_memory) + self.addr2nbytes[addr] = nbytes + return addr + + def free(self, addr): + if self.froze: + raise MemoryError('Should not free when the device is frozen.') + + cuda.free_async(addr) + self._allocated_memory -= self.addr2nbytes.pop(addr) + + def allocated_memory(self) -> int: + return self._allocated_memory + + def peak_allocated_memory(self) -> int: + return self._peak_allocated_memory + + def free_memory(self) -> int: + return cuda.mem_info()[0] + + def total_memory(self) -> int: + return cuda.mem_info()[1] + + +class CpuStorageDevice(StorageDevice): + def __init__(self): + super().__init__() + self.addr2nbytes = {} + self._allocated_memory = 0 + self._peak_allocated_memory = 0 + + def name(self): + return 'cpu' + + def allocate(self, nbytes): + if self.froze: + raise MemoryError('Should not allocate when the device is frozen.') + + addr = cuda.malloc_host(nbytes) + if addr == 0 and nbytes != 0: + return 0 + self._allocated_memory += nbytes + self._peak_allocated_memory = max(self._peak_allocated_memory, self._allocated_memory) + self.addr2nbytes[addr] = nbytes + return addr + + def free(self, addr): + if self.froze: + raise MemoryError('Should not free when the device is frozen.') + + cuda.free_host(addr) + self._allocated_memory -= self.addr2nbytes.pop(addr) + + def allocated_memory(self) -> int: + return self._allocated_memory + + def peak_allocated_memory(self) -> int: + return self._peak_allocated_memory + + def free_memory(self) -> int: + raise NotImplementedError() + + def total_memory(self) -> int: + raise NotImplementedError() + + +class Storage: + + def __init__(self, device, addr, num_bytes, free_handler): + self.device: str = device + self.addr: int = addr + self.num_bytes: int = num_bytes + self.free_handler: Callable[[Storage], None] = free_handler + + def __del__(self): + if self.addr != 0: + self.free_handler(self) + + def __getstate__(self): + raise ValueError() + + def __setstate__(self, state): + raise ValueError() + + def cpu(self): + if self.device == 'cpu': + return self + elif self.device == 'cuda': + host_storage = self.new('cpu', self.num_bytes) + cuda.memcpy_async(src_addr=self.addr, dst_addr=host_storage.addr, num_bytes=self.num_bytes, kind=cuda.DeviceToHost) + return host_storage + else: + raise NotImplementedError() + + def cuda(self): + if self.device == 'cuda': + return self + elif self.device == 'cpu': + cuda_storage = self.new('cuda', self.num_bytes) + cuda.memcpy_async(src_addr=self.addr, dst_addr=cuda_storage.addr, num_bytes=self.num_bytes, kind=cuda.HostToDevice) + return cuda_storage + else: + raise NotImplementedError() + + @staticmethod + def new(device: str, num_bytes: int) -> 'Storage': + if device == 'cpu': + return CpuMemoryPool.current().allocate(nbytes=num_bytes) + elif device == 'cuda': + return CudaMemoryPool.current().allocate(nbytes=num_bytes) + else: + raise ValueError("Unrecognized device '{}', candidates: {}".format(device, ['cpu', 'cuda'])) + + def as_array(self, num_elements: int, dtype: str = 'float32') -> np.ndarray: + """ + Convert to one-dimension numpy array, sharing the underlying storage. + + Parameters + ---------- + num_elements: int + The number of elements in the array. Because the storage may have a larger allocated memory, we can not + infer the desired number of elements. + + dtype: str, default 'float32' + The type of data in this storage. + + Returns + ------- + ret: numpy.ndarray + A numpy ndarray with one dimension that share the same data as the storage. + """ + dtype2ctype = { + 'float32': ctypes.c_float, + 'float16': ctypes.c_uint16, + 'int32': ctypes.c_int32, + 'int64': ctypes.c_int64, + 'bool': ctypes.c_bool + } + dtype2nptype = { + 'float16': np.float16 + } + + if self.device != 'cpu': + raise ValueError('The storage must be cpu storage. Please use .cpu() to convert first.') + buf = (dtype2ctype[dtype] * num_elements).from_address(self.addr) + buf._hidet_storage = self # so this storage will not be freed as long as the buffer not been freed. + assert ctypes.sizeof(buf) <= self.num_bytes, 'Trying to view a storage as a larger array' + with warnings.catch_warnings(): + # temporarily ignore a warning due to python bug. + # See: https://stackoverflow.com/questions/4964101/pep-3118-warning-when-using-ctypes-array-as-numpy-array + warnings.simplefilter('ignore') + array = np.ctypeslib.as_array(buf) + if dtype in dtype2nptype: + # reinterpret the array when needed + array = array.view(dtype2nptype[dtype]) + return array + + +class MemoryPool: + def __init__(self, storage_device: StorageDevice, block_size: int, max_reserve_size: int): + self.storage_device = storage_device + self.block_size: int = block_size + self.max_reserve_size: int = max_reserve_size + + self.reserved_size: int = 0 + self.active_blocks = 0 + self.memory_blocks: Dict[int, List[Storage]] = defaultdict(list) + + def allocate(self, nbytes: int) -> Storage: + allocated = (nbytes + self.block_size - 1) // self.block_size * self.block_size + block_list = self.memory_blocks[allocated] + if len(block_list) > 0: + storage = block_list.pop() + addr = storage.addr + self.reserved_size -= storage.num_bytes + else: + addr = self.storage_device.allocate(allocated) + if addr == 0 and allocated != 0: + # out of memory + self.clear() + addr = self.storage_device.allocate(allocated) + if addr == 0: + raise MemoryError('Can not allocate memory from {} device, total {}, hidet allocated {}, free {}, requesting {}.'.format( + self.storage_device.name(), + nbytes2str(self.storage_device.total_memory()), + nbytes2str(self.storage_device.allocated_memory()), + nbytes2str(self.storage_device.free_memory()), + nbytes2str(allocated) + )) + return Storage( + device=self.storage_device.name(), + addr=addr, + num_bytes=allocated, + free_handler=self.free + ) + + def free(self, storage: Storage): + self.memory_blocks[storage.num_bytes].append(storage) + self.reserved_size += storage.num_bytes + if self.reserved_size > self.max_reserve_size: + self.clear() + + def clear(self): + cuda.device_synchronize() + for block_list in self.memory_blocks.values(): + for storage in block_list: + self.storage_device.free(storage.addr) + storage.addr = 0 + # print('Cleared memory pool, returned {} memory back to {} device'.format( + # nbytes2str(self.reserved_size), self.storage_device.name() + # )) + self.memory_blocks.clear() + self.reserved_size = 0 + + def status(self) -> str: + allocated = self.storage_device.allocated_memory() + peak_allocated = self.storage_device.peak_allocated_memory() + items = [ + ['Allocated', allocated], + ['Peak', peak_allocated], + ['Reserved', self.reserved_size], + ['Active', allocated - self.reserved_size] + ] + lines = [ + 'Status of {} memory pool'.format(self.storage_device.name()), + *['{:>12}: {}'.format(name, nbytes2str(nbytes)) for name, nbytes in items] + ] + return '\n'.join(lines) + + def __str__(self): + return self.status() + + def __del__(self): + self.clear() + + +class CudaMemoryPool(MemoryPool): + stack = [] + + def __init__(self, block_size: int = 4096, max_reserve_size: int = 4 * 1024 ** 3): + super().__init__(CudaStorageDevice(), block_size, max_reserve_size) + + def __enter__(self): + CudaMemoryPool.stack.append(self) + + def __exit__(self, exc_type, exc_value, traceback): + CudaMemoryPool.stack.pop() + + @staticmethod + def current() -> CudaMemoryPool: + return CudaMemoryPool.stack[-1] + + +CudaMemoryPool.stack.append(CudaMemoryPool()) + + +class CpuMemoryPool(MemoryPool): + stack = [] + + def __init__(self, block_size: int = 4096, max_reserve_size: int = 128 * 1024 ** 2): + super().__init__(CpuStorageDevice(), block_size, max_reserve_size) + + def __enter__(self): + CpuMemoryPool.stack.append(self) + + def __exit__(self, exc_type, exc_value, traceback): + CpuMemoryPool.stack.pop() + + @staticmethod + def current() -> CpuMemoryPool: + return CpuMemoryPool.stack[-1] + + +CpuMemoryPool.stack.append(CpuMemoryPool()) + +# cpu_pool = MemoryPool( +# storage_device=CpuStorageDevice(), +# block_size=4 * 1024, # 4 KiB +# max_reserve_size=128 * 1024 * 1024 # 128 MiB +# ) + +# cuda_pool = MemoryPool( +# storage_device=CudaStorageDevice(), +# block_size=4 * 1024, # 4 KiB +# max_reserve_size=3 * 1024 * 1024 * 1024 # 5 GiB +# ) diff --git a/python/hidet/testing/__init__.py b/python/hidet/testing/__init__.py new file mode 100644 index 0000000..75fbb6f --- /dev/null +++ b/python/hidet/testing/__init__.py @@ -0,0 +1,8 @@ +from . import bench +from . import check +from . import tos_models +from . import onnx_models + +from .check import check_unary, check_binary + +from .bench import Conv2dSetting diff --git a/python/hidet/testing/bench.py b/python/hidet/testing/bench.py new file mode 100644 index 0000000..6d0667b --- /dev/null +++ b/python/hidet/testing/bench.py @@ -0,0 +1,85 @@ +from typing import Tuple, List +from collections import OrderedDict +from hidet.utils import prod + + +class Conv2dSetting: + def __init__(self, batch_size, in_channels, image_size, out_channels, kernel, stride, padding): + image_size, kernel, stride, padding = self.normalize(image_size, kernel, stride, padding) + self.batch_size: int = batch_size + self.in_channels: int = in_channels + self.image_size: Tuple[int, int] = image_size + self.out_channels: Tuple[int, int] = out_channels + self.kernel: Tuple[int, int] = kernel + self.stride: Tuple[int, int] = stride + self.padding: Tuple[int, int] = padding + self.output_image_size = tuple([(image_size[i] + 2 * padding[i] - kernel[i]) // stride[i] + 1 for i in range(2)]) + + def __str__(self): + return 'input_{}x{}x{}x{}__kernel_{}x{}_stride_{}x{}_padding_{}x{}_output_{}x{}x{}x{}_flops_{:.0f}'.format( + self.batch_size, self.in_channels, *self.image_size, *self.kernel, *self.stride, + *self.padding, self.batch_size, self.out_channels, *self.output_image_size, self.flops() + ) + + def __repr__(self): + return str(self) + + def flops(self): + return self.batch_size * self.out_channels * prod(self.output_image_size) * self.in_channels * prod(self.kernel) / 10 ** 6 # M FLOPs + + def keys(self) -> List[str]: + return ['n', 'ic', 'h', 'w', 'oc', 'kx', 'ky', 'px', 'py', 'sx', 'sy'] + + def values(self) -> List[int]: + return [self.batch_size, self.in_channels, self.image_size[0], self.image_size[1], self.out_channels, + self.kernel[0], self.kernel[1], self.padding[0], self.padding[1], self.stride[0], self.stride[1]] + + @staticmethod + def normalize(*args): + for arg in args: + if not isinstance(arg, (tuple, list)): + arg = (arg, arg) + yield arg + + @staticmethod + def resnet50_conv2ds(batch_size=1): + workloads = OrderedDict() + workloads[Conv2dSetting(batch_size=batch_size, in_channels=3, image_size=224, out_channels=64, kernel=7, stride=2, padding=3)] = 1 + for image_size, channels, repeat in zip([56, 28, 14, 7], [64, 128, 256, 512], [3, 4, 6, 3]): + if image_size == 56: + lowering_convs = [ + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels, kernel=1, stride=1, padding=0), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels, kernel=3, stride=1, padding=1), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels * 4, kernel=1, stride=1, padding=0), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels * 4, kernel=1, stride=1, padding=0), 1) # skip connection + ] + else: + lowering_convs = [ + (Conv2dSetting(batch_size=batch_size, in_channels=channels * 2, image_size=image_size * 2, out_channels=channels, kernel=1, stride=1, padding=0), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size * 2, out_channels=channels, kernel=3, stride=2, padding=1), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels * 4, kernel=1, stride=1, padding=0), 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels * 2, image_size=image_size * 2, out_channels=channels * 4, kernel=1, stride=2, padding=0), 1) # skip connection + ] + normal_convs = [ + (Conv2dSetting(batch_size=batch_size, in_channels=channels * 4, image_size=image_size, out_channels=channels, kernel=1, stride=1, padding=0), repeat - 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels, kernel=3, stride=1, padding=1), repeat - 1), + (Conv2dSetting(batch_size=batch_size, in_channels=channels, image_size=image_size, out_channels=channels * 4, kernel=1, stride=1, padding=0), repeat - 1), + ] + for conv, r in lowering_convs + normal_convs: + if conv not in workloads: + workloads[conv] = 0 + workloads[conv] += r + return workloads + + def __eq__(self, other): + if len(self.__dict__) != len(other.__dict__): + return False + for k in self.__dict__: + if k not in other.__dict__: + return False + if self.__dict__[k] != other.__dict__[k]: + return False + return True + + def __hash__(self): + return hash((self.batch_size, self.in_channels, self.image_size, self.out_channels, self.kernel, self.stride, self.padding)) diff --git a/python/hidet/testing/check.py b/python/hidet/testing/check.py new file mode 100644 index 0000000..2a1a816 --- /dev/null +++ b/python/hidet/testing/check.py @@ -0,0 +1,21 @@ +from typing import Union + +import numpy as np +import hidet as hi + + +def check_unary(shape, numpy_op, hidet_op, dtype: Union[str, np.dtype] = np.float32, atol=0, rtol=0): + # wrap np.array(...) in case shape = [] + data = np.array(np.random.randn(*shape)).astype(dtype) + numpy_result = numpy_op(data) + hidet_result = hidet_op(hi.array(data).cuda()).cpu().numpy() + np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol) + + +def check_binary(a_shape, b_shape, numpy_op, hidet_op, dtype: Union[str, np.dtype] = np.float32, atol=0.0, rtol=0.0): + a = np.array(np.random.randn(*a_shape)).astype(dtype) + b = np.array(np.random.randn(*b_shape)).astype(dtype) + numpy_result = numpy_op(a, b) + hidet_result = hidet_op(hi.array(a).cuda(), hi.array(b).cuda()).cpu().numpy() + np.testing.assert_allclose(actual=hidet_result, desired=numpy_result, atol=atol, rtol=rtol) + diff --git a/python/hidet/testing/onnx_models/__init__.py b/python/hidet/testing/onnx_models/__init__.py new file mode 100644 index 0000000..0e95f20 --- /dev/null +++ b/python/hidet/testing/onnx_models/__init__.py @@ -0,0 +1 @@ +from .all import get_onnx_model diff --git a/python/hidet/testing/onnx_models/all.py b/python/hidet/testing/onnx_models/all.py new file mode 100644 index 0000000..10c2ae9 --- /dev/null +++ b/python/hidet/testing/onnx_models/all.py @@ -0,0 +1,120 @@ +from typing import Tuple, List +import numpy as np +import hidet +from hidet.tos import Tensor +from hidet.utils import download, hidet_cache_dir, hidet_cache_file +from hidet.utils.transformers_utils import export_transformer_model_as_onnx +from hidet.utils.torch_utils import export_torchvision_model_as_onnx +from .model_blocks import get_bert_block, get_resnet50_block +from .operators import get_onnx_operator + + +def get_onnx_model(name: str, batch_size: int = 1, **kwargs) -> Tuple[str, List[str], List[Tensor]]: + """ + kwargs candidates: + seq_length=128 + """ + if name == 'resnet50': + model_path = hidet_cache_file('onnx', f'{name}.onnx') + export_torchvision_model_as_onnx(model_name=name, output_path=model_path) + input_names = ['data'] + input_tensors = [hidet.randn(shape=[batch_size, 3, 224, 224])] + return model_path, input_names, input_tensors + elif name == 'inception_v3': + model_path = hidet_cache_file('onnx', f'{name}.onnx') + export_torchvision_model_as_onnx(model_name=name, output_path=model_path) + input_names = ['data'] + input_tensors = [hidet.randn(shape=[batch_size, 3, 299, 299])] + return model_path, input_names, input_tensors + elif name == 'mobilenet_v2': + model_path = hidet_cache_file('onnx', f'{name}.onnx') + export_torchvision_model_as_onnx(model_name=name, output_path=model_path) + input_names = ['data'] + input_tensors = [hidet.randn(shape=[batch_size, 3, 224, 224])] + return model_path, input_names, input_tensors + elif name == 'bert': + model_path = hidet_cache_file('onnx', 'bert.onnx') + export_transformer_model_as_onnx( + model_name='bert-base-uncased', + output_path=model_path + ) + vocab_size = 30522 + seq_length = kwargs.get('seq_length', 128) + input_names = [ + 'input_ids', + 'attention_mask', + 'token_type_ids' + ] + input_tensors = [ + hidet.array(np.random.randint(0, vocab_size-1, size=[batch_size, seq_length], dtype=np.int64)), + hidet.ones(shape=[batch_size, seq_length], dtype='int64'), + hidet.zeros(shape=[batch_size, seq_length], dtype='int64') + ] + return model_path, input_names, input_tensors + elif name == 'bart': + model_path = hidet_cache_file('onnx', 'bart.onnx') + export_transformer_model_as_onnx( + model_name='facebook/bart-base', + output_path=model_path + ) + vocab_size = 50265 + seq_length = kwargs.get('seq_length', 128) + input_names = [ + 'input_ids', + 'attention_mask', + 'decoder_input_ids', + 'decoder_attention_mask' + ] + input_tensors = [ + hidet.array(np.random.randint(0, vocab_size-1, size=[batch_size, seq_length], dtype=np.int64)), + hidet.ones(shape=[batch_size, seq_length], dtype='int64'), + hidet.array(np.random.randint(0, vocab_size-1, size=[batch_size, seq_length], dtype=np.int64)), + hidet.ones(shape=[batch_size, seq_length], dtype='int64') + ] + return model_path, input_names, input_tensors + elif name == 'gpt2': + model_path = hidet_cache_file('onnx', 'gpt2.onnx') + export_transformer_model_as_onnx( + model_name='gpt2', + output_path=model_path + ) + vocab_size = 50257 + seq_length = kwargs.get('seq_length', 128) + input_names = [ + 'input_ids', + 'attention_mask', + ] + input_tensors = [ + hidet.array(np.random.randint(0, vocab_size-1, size=[batch_size, seq_length], dtype=np.int64)), + hidet.ones(shape=[batch_size, seq_length], dtype='int64'), + ] + return model_path, input_names, input_tensors + elif name.startswith('resnet50_'): + return get_resnet50_block(name, batch_size=batch_size, **kwargs) + elif name.startswith('bert_'): + return get_bert_block(name, batch_size=batch_size, **kwargs) + elif name.startswith('op_'): + return get_onnx_operator(name, batch_size) + else: + raise NotImplementedError('Can not recognize model {}'.format(name)) + + +if __name__ == '__main__': + names = [ + 'resnet50', + 'inception_v3', + 'mobilenet_v2', + 'bert', + 'bart', + 'gpt2' + ] + configs = { + 'bert': {'seq_length': 512}, + 'bart': {'seq_length': 512}, + 'gpt2': {'seq_length': 512}, + } + for model_name in names: + kwargs = {} + if model_name in configs: + kwargs.update(configs[model_name]) + get_onnx_model(model_name, **kwargs) diff --git a/python/hidet/testing/onnx_models/model_blocks/__init__.py b/python/hidet/testing/onnx_models/model_blocks/__init__.py new file mode 100644 index 0000000..485eb7f --- /dev/null +++ b/python/hidet/testing/onnx_models/model_blocks/__init__.py @@ -0,0 +1,2 @@ +from .bert_blocks import get_bert_block +from .resnet50_blocks import get_resnet50_block diff --git a/python/hidet/testing/onnx_models/model_blocks/bert_blocks.py b/python/hidet/testing/onnx_models/model_blocks/bert_blocks.py new file mode 100644 index 0000000..0d385d0 --- /dev/null +++ b/python/hidet/testing/onnx_models/model_blocks/bert_blocks.py @@ -0,0 +1,455 @@ +import transformers +from typing import Optional, List, Tuple +import os +import tempfile +import math +import torch +import onnx +from torch import nn, Tensor +import hidet +from hidet.utils import hidet_cache_file +from ..utils import export_torch_to_onnx + + +# Acknowledgement: adopted the bert implementation from huggingface transformers package, with some simplification + +class BertConfig: + def __init__(self): + self.vocab_size = 30522 + self.hidden_size = 768 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.max_position_embeddings = 512 + self.intermediate_size = 3072 + self.type_vocab_size = 2 + + +class BertEmbeddings(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + input_ids: Tensor, # [batch_size, seq_length] in [0, vocab_size) + token_type_ids: Optional[Tensor] = None, # [batch_size, seq_length] in [0, type_vocab_size) + position_ids: Optional[Tensor] = None # [batch_size, seq_length] in [0, max_position_embeddings) + ): + batch_size, seq_length = input_ids.shape + + if position_ids is None: + ids = torch.arange(seq_length, dtype=torch.int64).expand((batch_size, -1)) + position_ids = ids + if token_type_ids is None: + token_type_ids = torch.zeros([batch_size, seq_length], dtype=torch.int64) + + input_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = input_embeds + token_type_embeddings + position_embeddings + embeddings = self.layer_norm(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError('Multi head attention expects hidden_size % num_attention_heads == 0, ' + 'got {} and {}'.format(config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + + self.query_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.key_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.value_layer = nn.Linear(config.hidden_size, config.hidden_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + batch_size, seq_length, hidden_size = x.shape + x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) + x = x.permute(0, 2, 1, 3) + return x + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + batch_size, seq_length, hidden_size = hidden_states.shape + query = self.transpose_for_scores(self.query_layer(hidden_states)) + key = self.transpose_for_scores(self.key_layer(hidden_states)) + value = self.transpose_for_scores(self.value_layer(hidden_states)) + attention_scores = torch.matmul(query, torch.transpose(key, -1, -2)) / math.sqrt(self.attention_head_size) + attention_scores = attention_scores + attention_mask + attention_probs = torch.softmax(attention_scores, dim=-1) + context = torch.matmul(attention_probs, value) + context = context.permute(0, 2, 1, 3).reshape([batch_size, seq_length, hidden_size]) + return context + + +class BertSelfAttentionQuery(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.query_layer = nn.Linear(config.hidden_size, config.hidden_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + batch_size, seq_length, hidden_size = x.shape + x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) + x = x.permute(0, 2, 1, 3) + return x + + def forward(self, hidden_states: Tensor): + query = self.transpose_for_scores(self.query_layer(hidden_states)) + return query + + +class BertSelfAttentionQueryKeyValue(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.query_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.key_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.value_layer = nn.Linear(config.hidden_size, config.hidden_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + batch_size, seq_length, hidden_size = x.shape + x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) + x = x.permute(0, 2, 1, 3) + return x + + def forward(self, hidden_states: Tensor): + query = self.transpose_for_scores(self.query_layer(hidden_states)) + key = self.transpose_for_scores(self.key_layer(hidden_states)) + value = self.transpose_for_scores(self.value_layer(hidden_states)) + return [query, key, value] + +class BertSelfAttentionQueryKeyValueV2(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + self.query_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.key_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.value_layer = nn.Linear(config.hidden_size, config.hidden_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + # batch_size, seq_length, hidden_size = x.shape + # x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) + # x = x.permute(0, 2, 1, 3) + return x + + def forward(self, hidden_states: Tensor): + query = self.transpose_for_scores(self.query_layer(hidden_states)) + key = self.transpose_for_scores(self.key_layer(hidden_states)) + value = self.transpose_for_scores(self.value_layer(hidden_states)) + return [query, key, value] + + +class BertSelfAttentionSoftmax(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + + def forward(self, attention_scores: Tensor) -> Tensor: + attention_probs = torch.softmax(attention_scores, dim=-1) + return attention_probs + + +class BertSelfAttentionContext(nn.Module): + def forward(self, attention_probs: Tensor, value: Tensor) -> Tensor: + context = torch.matmul(attention_probs, value) + return context + + +class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states: Tensor, skip_hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layer_norm(hidden_states + skip_hidden_states) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.self_attention = BertSelfAttention(config) + self.output_layer = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + attention_output = self.self_attention(hidden_states, attention_mask) + return self.output_layer(attention_output, hidden_states) + + +# well known as FeedForward +class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.gelu = nn.GELU() + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states: Tensor, skip_hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layer_norm(hidden_states + skip_hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.attention_layer = BertAttention(config) + self.intermediate_layer = BertIntermediate(config) + self.output_layer = BertOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + attention_output = self.attention_layer(hidden_states, attention_mask) + intermediate_output = self.intermediate_layer(attention_output) + layer_output = self.output_layer(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + layers = [] + for _ in range(config.num_hidden_layers): + layers.append(BertLayer(config)) + self.layers = nn.ModuleList(layers) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + for i, layer_module in enumerate(self.layers): + hidden_states = layer_module(hidden_states, attention_mask) + return hidden_states + + +class BertPooler(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: Tensor): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + @staticmethod + def extend_attention_mask(attention_mask: Tensor) -> Tensor: + # attention_mask: [batch_size, seq_length] in {0, 1} + # return: [batch_size, 1, 1, seq_length] in {-10000, 0} + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + return attention_mask + + def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor): + embeds = self.embeddings.forward(input_ids, token_type_ids) + attention_mask = self.extend_attention_mask(attention_mask) + hidden_states = self.encoder.forward(embeds, attention_mask) + pooled_output = self.pooler(hidden_states) + return [hidden_states, pooled_output] + + +def get_bert_block(name: str, batch_size=1, seq_length=128, config: Optional[BertConfig] = None, nocache=False) -> Tuple[str, List[str], List["hidet.Tensor"]]: + if config is None: + config = BertConfig() + hidden_size = config.hidden_size + if name == 'bert_all': + model = BertModel(config) + input_names = [ + 'input_ids', + 'token_type_ids', + 'attention_mask' + ] + inputs = [ + torch.randint(0, config.vocab_size, [batch_size, seq_length], dtype=torch.int64), + torch.zeros([batch_size, seq_length], dtype=torch.int64), + torch.ones([batch_size, seq_length], dtype=torch.int64) + ] + elif name == 'bert_embeddings': + model = BertEmbeddings(config) + input_names = [ + 'input_ids', + 'token_type_ids', + 'position_ids' + ] + inputs = [ + torch.randint(0, config.vocab_size, [batch_size, seq_length], dtype=torch.int64), + torch.zeros([batch_size, seq_length], dtype=torch.int64), + torch.arange(seq_length, dtype=torch.int64).expand(batch_size, seq_length) + ] + elif name == 'bert_encoder': + model = BertEncoder(config) + input_names = [ + 'hidden_states', + 'attention_mask' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + torch.zeros([batch_size, 1, 1, seq_length], dtype=torch.float32) + ] + elif name == 'bert_pooler': + model = BertPooler(config) + input_names = [ + 'hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + ] + elif name == 'bert_layer': + model = BertLayer(config) + input_names = [ + 'hidden_states', + 'attention_mask' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + torch.zeros([batch_size, 1, 1, seq_length], dtype=torch.float32) + ] + elif name == 'bert_attention': + model = BertAttention(config) + input_names = [ + 'hidden_states', + 'attention_mask' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + torch.zeros([batch_size, 1, 1, seq_length], dtype=torch.float32) + ] + elif name == 'bert_intermediate': + model = BertIntermediate(config) + input_names = [ + 'hidden_states', + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + ] + elif name == 'bert_output': + model = BertOutput(config) + input_names = [ + 'hidden_states', + 'skip_hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, config.intermediate_size]), + torch.randn([batch_size, seq_length, hidden_size]), + ] + elif name == 'bert_self_attention': + model = BertSelfAttention(config) + input_names = [ + 'hidden_states', + 'attention_mask' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + torch.zeros([batch_size, 1, 1, seq_length], dtype=torch.float32) + ] + elif name == 'bert_self_output': + model = BertSelfOutput(config) + input_names = [ + 'hidden_states', + 'skip_hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]), + torch.randn([batch_size, seq_length, hidden_size]), + ] + elif name == 'bert_self_at_query': + model = BertSelfAttentionQuery(config) + input_names = [ + 'hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]) + ] + elif name == 'bert_self_at_qkv': + model = BertSelfAttentionQueryKeyValue(config) + input_names = [ + 'hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]) + ] + elif name == 'bert_self_at_qkv_v2': + model = BertSelfAttentionQueryKeyValueV2(config) + input_names = [ + 'hidden_states' + ] + inputs = [ + torch.randn([batch_size, seq_length, hidden_size]) + ] + elif name == 'bert_self_at_softmax': + model = BertSelfAttentionSoftmax(config) + input_names = [ + 'attention_scores' + ] + inputs = [ + torch.randn([batch_size, config.num_attention_heads, seq_length, seq_length]) + ] + elif name == 'bert_self_at_context': + model = BertSelfAttentionContext() + input_names = [ + 'attention_probs', + 'value' + ] + attention_head_size = config.hidden_size // config.num_attention_heads + inputs = [ + torch.randn([batch_size, config.num_attention_heads, seq_length, seq_length]), + torch.randn([batch_size, config.num_attention_heads, seq_length, attention_head_size]), + ] + else: + raise ValueError() + + onnx_path = hidet_cache_file('onnx', 'bert', f'bs{batch_size}_{name}.onnx') + return export_torch_to_onnx( + onnx_path=onnx_path, + model=model, + input_names=input_names, + inputs=inputs, + nocache=nocache + ) + + +if __name__ == '__main__': + for name in [ + 'bert_all', + 'bert_embeddings', + 'bert_encoder', + 'bert_pooler', + 'bert_layer', + 'bert_attention', + 'bert_intermediate', + 'bert_output', + 'bert_self_attention', + 'bert_self_output', + ]: + print(name) + get_bert_block(name) diff --git a/python/hidet/testing/onnx_models/model_blocks/resnet50_blocks.py b/python/hidet/testing/onnx_models/model_blocks/resnet50_blocks.py new file mode 100644 index 0000000..62f84c3 --- /dev/null +++ b/python/hidet/testing/onnx_models/model_blocks/resnet50_blocks.py @@ -0,0 +1,181 @@ +import os.path +from typing import List, Tuple, Optional +from collections import namedtuple, defaultdict +import torch +import tempfile +import onnx +import torchvision.models +import hidet +from torch import nn +from hidet.utils import hidet_cache_file +from ..utils import export_torch_to_onnx + + +class ConvBnRelu(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias): + super().__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) + self.bn = nn.BatchNorm2d(num_features=out_channels) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.bn(self.conv(x))) + return x + + +def conv_bn_relu(batch_size, height, width, in_channels, out_channels, kernel_size, stride, padding, bias=True) -> onnx.ModelProto: + module = ConvBnRelu(in_channels, out_channels, kernel_size, stride, padding, bias) + module.eval() + x = torch.randn([batch_size, in_channels, height, width], dtype=torch.float32) + module(x) + + _, path = tempfile.mkstemp() + + torch.onnx.export(module, + args=x, + f=path, + training=torch.onnx.TrainingMode.PRESERVE, + input_names=['x'], + output_names=['y'], + opset_version=12, + dynamic_axes={ + 'x': {0: 'bs'}, + 'y': {0: 'bs'} + }, + do_constant_folding=False) + onnx.checker.check_model(path) + onnx_model = onnx.load_model(path) + return onnx_model + + +Conv2dConfig = namedtuple('Conv2dConfig', field_names=['batch_size', 'height', 'width', 'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding']) + + +def get_resnet50_configs(batch_size: int = 1) -> List[Conv2dConfig]: + resnet50 = torchvision.models.resnet50() + config_count = defaultdict(int) + + def hook(module: nn.Module, inputs: Tuple[torch.Tensor]): + if isinstance(module, nn.Conv2d): + c = module + x = inputs[0] + w = module.weight + assert isinstance(x, torch.Tensor) + config = Conv2dConfig( + batch_size=x.size(0), height=x.size(2), width=x.size(3), in_channels=x.size(1), + out_channels=w.size(0), kernel_size=(w.size(2), w.size(3)), stride=c.stride, padding=c.padding + ) + config_count[config] += 1 + # print(config) + + def register_hook(module: nn.Module): + module.register_forward_pre_hook(hook) + + resnet50.apply(register_hook) + resnet50(torch.randn(batch_size, 3, 224, 224)) + # for a, b in config_count.items(): + # print(b, a) + # print(a, b) + # as of Python 3.6, the order of dict keys is the insertion order in CPython. + return list(config_count.keys()) + + +def print_implicit_gemm_workloads(configs: List[Conv2dConfig] = None): + if configs is None: + configs = get_resnet50_configs() + for idx, config in enumerate(configs): + n, c, h, w = config.batch_size, config.in_channels, config.height, config.width + oc = config.out_channels + kx, ky = config.kernel_size + px, py = config.padding + sx, sy = config.stride + oh, ow = (h + px * 2 - kx) // sx + 1, (w + py * 2 - ky) // sy + 1 + m_size = n * oh * ow + n_size = oc + k_size = kx * ky * c + print(m_size, n_size, k_size) + + +def conv_bn_relu_onnx_path(idx: int) -> str: + path = hidet.utils.hidet_cache_file('onnx', f'conv_{idx}.onnx') + if not os.path.exists(path): + export_conv_bn_relu() + if not os.path.exists(path): + raise ValueError('failed generate onnx model') + return path + + +def conv_bn_relu_input_shape(bs: int, idx: int) -> List[int]: + shapes = { + 0: [3, 224, 224], + 1: [64, 56, 56], + 2: [64, 56, 56], + 3: [64, 56, 56], + 4: [256, 56, 56], + 5: [256, 56, 56], + 6: [128, 56, 56], + 7: [128, 28, 28], + 8: [256, 56, 56], + 9: [512, 28, 28], + 10: [128, 28, 28], + 11: [512, 28, 28], + 12: [256, 28, 28], + 13: [256, 14, 14], + 14: [512, 28, 28], + 15: [1024, 14, 14], + 16: [256, 14, 14], + 17: [1024, 14, 14], + 18: [512, 14, 14], + 19: [512, 7, 7], + 20: [1024, 14, 14], + 21: [2048, 7, 7], + 22: [512, 7, 7], + } + return [bs] + shapes[idx] + + +def get_resnet50_block(name: str, batch_size=1, nocache=False) -> Tuple[str, List[str], List["hidet.Tensor"]]: + a, b, c = name.split('_') # resnet50_conv_0 to resnet50_conv_22 + conv_idx = int(c) + configs = get_resnet50_configs(batch_size) + config = configs[conv_idx] + x_shape = conv_bn_relu_input_shape(batch_size, conv_idx) + model = ConvBnRelu(in_channels=config.in_channels, out_channels=config.out_channels, kernel_size=config.kernel_size, stride=config.stride, padding=config.padding, bias=True) + + x = torch.randn(x_shape) + return export_torch_to_onnx( + onnx_path=hidet_cache_file('onnx', 'resnet50', f'{name}.onnx'), + model=model, + input_names=['x'], + inputs=[x], + nocache=nocache + ) + + +if __name__ == '__main__': + for name in [ + 'resnet50_conv_0', + 'resnet50_conv_1', + 'resnet50_conv_2', + 'resnet50_conv_3', + 'resnet50_conv_4', + 'resnet50_conv_5', + 'resnet50_conv_6', + 'resnet50_conv_7', + 'resnet50_conv_8', + 'resnet50_conv_9', + 'resnet50_conv_10', + 'resnet50_conv_11', + 'resnet50_conv_12', + 'resnet50_conv_13', + 'resnet50_conv_14', + 'resnet50_conv_15', + 'resnet50_conv_16', + 'resnet50_conv_17', + 'resnet50_conv_18', + 'resnet50_conv_19', + 'resnet50_conv_20', + 'resnet50_conv_21', + 'resnet50_conv_22', + ]: + get_resnet50_block(name) diff --git a/python/hidet/testing/onnx_models/operators.py b/python/hidet/testing/onnx_models/operators.py new file mode 100644 index 0000000..ae65826 --- /dev/null +++ b/python/hidet/testing/onnx_models/operators.py @@ -0,0 +1,93 @@ +from typing import List, Tuple +from .utils import export_torch_to_onnx +import torch +import hidet +from hidet.utils import hidet_cache_file +from torch import nn + + +class ReduceSum(nn.Module): + def __init__(self, dims: List[int], keepdim=True): + super().__init__() + self.dims = dims + self.keepdim = keepdim + + def forward(self, x: torch.Tensor): + return x.sum(self.dims, self.keepdim) + + +class Matmul(nn.Module): + def __init__(self, layout: str): + super().__init__() + assert layout in ['NN', 'NT', 'TN', 'TT'] + self.layout = layout + + def forward(self, x: torch.Tensor, y: torch.Tensor): + if self.layout[0] == 'T': + x = torch.transpose(x, -1, -2) + if self.layout[1] == 'T': + y = torch.transpose(y, -1, -2) + return torch.matmul(x, y) + + +def get_onnx_operator(name: str, batch_size=1) -> Tuple[str, List[str], List["hidet.Tensor"]]: + onnx_path = hidet_cache_file('onnx', 'op', f'{name}.onnx') + if name.startswith('op_sum_'): + a, b, c = name.split('_') # op_sum_0 + op_idx = int(c) + idx_2_configs = { + 0: [[batch_size, 8, 128, 768], [1], False], + 1: [[batch_size, 8, 128, 768], [3], False], + } + shape, dims, keepdim = idx_2_configs[op_idx] + return export_torch_to_onnx( + onnx_path=onnx_path, + model=ReduceSum(dims=dims, keepdim=keepdim), + input_names=['x'], + inputs=[torch.randn(shape)], + ) + elif name.startswith('op_resnet50_conv'): + a, b, c, d = name.split('_') + op_idx = int(d) + idx_2_configs = { + 2: [[batch_size, 256, 28, 28], 256, 3, 1, 2], + } + x_shape, out_channels, kernel, padding, strides = idx_2_configs[op_idx] + return export_torch_to_onnx( + onnx_path=onnx_path, + model=nn.Conv2d(in_channels=x_shape[1], out_channels=out_channels, kernel_size=kernel, stride=strides, padding=padding, bias=False), + input_names=['x'], + inputs=[torch.randn(x_shape)] + ) + elif name.startswith('op_matmul_'): # like 'op_matmul_nn_0' + a, b, layout, idx = name.split('_') + layout = str(layout).upper() + workloads = { + 0: [batch_size, 128, 128, 64], + 1: [batch_size, 128, 768, 2304], + 2: [batch_size, 128, 768, 2304], + 3: [batch_size, 128, 768, 2304], + 4: [batch_size, 2048, 2048, 2048], + 5: [batch_size, 2039, 2039, 2039], + 6: [batch_size, 2047, 2047, 2047], + 7: [batch_size, 2046, 2046, 2046], + 8: [batch_size, 2045, 2045, 2045], + 9: [batch_size, 2044, 2044, 2044], + 10: [batch_size, 2043, 2043, 2043], + 11: [batch_size, 2042, 2042, 2042], + } + batch_size, m_size, n_size, k_size = workloads[int(idx)] + x = torch.randn([batch_size, m_size, k_size]) + y = torch.randn([batch_size, k_size, n_size]) + if layout[0] == 'T': + x = torch.transpose(x, -1, -2) + if layout[1] == 'T': + y = torch.transpose(y, -1, -2) + return export_torch_to_onnx( + onnx_path=onnx_path, + model=Matmul(layout), + input_names=['x', 'y'], + inputs=[x, y], + ) + else: + raise ValueError('') diff --git a/python/hidet/testing/onnx_models/utils.py b/python/hidet/testing/onnx_models/utils.py new file mode 100644 index 0000000..b768acf --- /dev/null +++ b/python/hidet/testing/onnx_models/utils.py @@ -0,0 +1,37 @@ +from typing import List +import tempfile +import os +import torch +import onnx +import hidet +from torch import nn + + +def export_torch_to_onnx( + onnx_path: str, + model: nn.Module, + input_names: List[str], + inputs: List[torch.Tensor], + nocache=False +): + # onnx_path = hidet_cache_file('onnx', 'bert', f'{name}.onnx') + if nocache and os.path.exists(onnx_path): + os.remove(onnx_path) + if not os.path.exists(onnx_path): + model.eval() + model(*inputs) + _, path = tempfile.mkstemp() + torch.onnx.export(model, + args=tuple(inputs), + f=path, + training=torch.onnx.TrainingMode.PRESERVE, + input_names=input_names, + opset_version=12, + do_constant_folding=True) + dirname = os.path.dirname(onnx_path) + os.makedirs(dirname, exist_ok=True) + os.rename(path, onnx_path) + onnx.checker.check_model(onnx_path) + + hidet_inputs = [hidet.array(torch_tensor.numpy()).cuda() for torch_tensor in inputs] + return onnx_path, input_names, hidet_inputs diff --git a/python/hidet/testing/torch_models/__init__.py b/python/hidet/testing/torch_models/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/python/hidet/testing/torch_models/__init__.py @@ -0,0 +1 @@ + diff --git a/python/hidet/testing/torch_models/all.py b/python/hidet/testing/torch_models/all.py new file mode 100644 index 0000000..9a3a8e6 --- /dev/null +++ b/python/hidet/testing/torch_models/all.py @@ -0,0 +1,67 @@ +from typing import Tuple, List, Dict +import torch +import torchvision +import transformers +from torch import nn + + +def get_torch_model(name: str, batch_size: int = 1, **kwargs) -> Tuple[nn.Module, Dict[str, torch.Tensor]]: + if name == 'resnet50': + model = torchvision.models.resnet50(pretrained=True).eval().cuda() + inputs = { + 'x': torch.randn([batch_size, 3, 224, 224]).cuda() + } + return model, inputs + elif name == 'inception_v3': + model = torchvision.models.inception_v3(pretrained=True).eval().cuda() + model.eval() + inputs = { + 'x': torch.randn([batch_size, 3, 299, 299]).cuda() + } + return model, inputs + elif name == 'mobilenet_v2': + model = torchvision.models.mobilenet_v2(pretrained=True).eval().cuda() + inputs = { + 'x': torch.randn([batch_size, 3, 224, 224]).cuda() + } + return model, inputs + elif name == 'bert': + config = transformers.BertConfig() + model = transformers.BertModel(config).eval().cuda() + model.eval() + vocab_size = 30522 + seq_length = kwargs.get('seq_length', 128) + inputs = { + 'input_ids': torch.randint(0, vocab_size - 1, size=[batch_size, seq_length]).cuda(), + 'attention_mask': torch.ones(size=[batch_size, seq_length], dtype=torch.int64).cuda(), + 'token_type_ids': torch.zeros(size=[batch_size, seq_length], dtype=torch.int64).cuda() + } + return model, inputs + elif name == 'gpt2': + config = transformers.GPT2Config() + model = transformers.GPT2Model(config).eval().cuda() + model.eval() + vocab_size = 50257 + seq_length = kwargs.get('seq_length', 128) + inputs = { + 'input_ids': torch.randint(0, vocab_size - 1, size=[batch_size, seq_length]).cuda(), + 'attention_mask': torch.ones(size=[batch_size, seq_length], dtype=torch.int64).cuda(), + } + return model, inputs + else: + raise ValueError('Can not recognize model: {}'.format(name)) + + +if __name__ == '__main__': + from time import time + for name in ['resnet50', 'inception_v3', 'mobilenet_v2', 'bert', 'gpt2']: + model, inputs = get_torch_model(name) + outputs = model(**inputs) + repeats = 10 + torch.cuda.synchronize() + t1 = time() + for t in range(repeats): + outputs = model(**inputs) + torch.cuda.synchronize() + t2 = time() + print('{} {:.1f}'.format(name, (t2 - t1) / repeats * 1000.0)) diff --git a/python/hidet/testing/tos_models/__init__.py b/python/hidet/testing/tos_models/__init__.py new file mode 100644 index 0000000..ef895c1 --- /dev/null +++ b/python/hidet/testing/tos_models/__init__.py @@ -0,0 +1,7 @@ +from . import bert +from . import resnet +from . import inception + +from .resnet import Bottleneck, BasicBlock, ResNet +from .inception import InceptionHead, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionTail, InceptionV3 +from .bert import BertModel diff --git a/python/hidet/testing/tos_models/bert.py b/python/hidet/testing/tos_models/bert.py new file mode 100644 index 0000000..47bacdf --- /dev/null +++ b/python/hidet/testing/tos_models/bert.py @@ -0,0 +1,208 @@ +import transformers +from typing import Optional +import math +import numpy as np +from hidet.tos import randn, zeros, array, ones +from hidet.tos import nn, Tensor +from hidet.tos import ops + + +# Acknowledgement: adopted the bert implementation from huggingface transformers package, with some simplification + +class BertConfig: + def __init__(self): + self.vocab_size = 30522 + self.hidden_size = 768 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.max_position_embeddings = 512 + self.intermediate_size = 3072 + self.type_vocab_size = 2 + + +class BertEmbeddings(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + input_ids: Tensor, # [batch_size, seq_length] in [0, vocab_size) + token_type_ids: Optional[Tensor] = None, # [batch_size, seq_length] in [0, type_vocab_size) + position_ids: Optional[Tensor] = None # [batch_size, seq_length] in [0, max_position_embeddings) + ): + batch_size, seq_length = input_ids.shape + + if position_ids is None: + ids = array(np.arange(seq_length).astype(np.int64)) + ids = ops.tile(ops.unsqueeze(ids, [0]), repeats=[batch_size, 1]) + position_ids = ids + if token_type_ids is None: + token_type_ids = zeros([batch_size, seq_length], dtype='int64') + + input_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + position_embeddings = self.position_embeddings(position_ids) + embeddings = input_embeds + token_type_embeddings + position_embeddings + embeddings = self.layer_norm(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError('Multi head attention expects hidden_size % num_attention_heads == 0, ' + 'got {} and {}'.format(config.hidden_size, config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = config.hidden_size // config.num_attention_heads + + self.query_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.key_layer = nn.Linear(config.hidden_size, config.hidden_size) + self.value_layer = nn.Linear(config.hidden_size, config.hidden_size) + + def transpose_for_scores(self, x: Tensor) -> Tensor: + batch_size, seq_length, hidden_size = x.shape + x = x.reshape([batch_size, seq_length, self.num_attention_heads, self.attention_head_size]) + x = x.rearrange([[0, 2], [1], [3]]) + return x + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + batch_size, seq_length, hidden_size = hidden_states.shape + query = self.transpose_for_scores(self.query_layer(hidden_states)) + key = self.transpose_for_scores(self.key_layer(hidden_states)) + value = self.transpose_for_scores(self.value_layer(hidden_states)) + attention_scores = ops.matmul(query, key.transpose([-1, -2])) / math.sqrt(self.attention_head_size) + attention_scores = attention_scores + attention_mask + attention_probs = ops.softmax(attention_scores, axis=-1) + context = ops.matmul(attention_probs, value) + context = context.reshape([batch_size, self.num_attention_heads, seq_length, self.attention_head_size]) + context = context.rearrange([[0], [2], [1, 3]]) + return context + + +class BertSelfOutput(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states: Tensor, skip_hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layer_norm(hidden_states + skip_hidden_states) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.self_attention = BertSelfAttention(config) + self.output_layer = BertSelfOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + attention_output = self.self_attention(hidden_states, attention_mask) + return self.output_layer(attention_output, hidden_states) + + +# well known as FeedForward +class BertIntermediate(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.gelu = nn.Gelu() + + def forward(self, hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size) + + def forward(self, hidden_states: Tensor, skip_hidden_states: Tensor) -> Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.layer_norm(hidden_states + skip_hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.attention_layer = BertAttention(config) + self.intermediate_layer = BertIntermediate(config) + self.output_layer = BertOutput(config) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: + attention_output = self.attention_layer(hidden_states, attention_mask) + intermediate_output = self.intermediate_layer(attention_output) + layer_output = self.output_layer(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + layers = [] + for _ in range(config.num_hidden_layers): + layers.append(BertLayer(config)) + self.layers = nn.ModuleList(layers) + + def forward(self, hidden_states: Tensor, attention_mask: Tensor): + for i, layer_module in enumerate(self.layers.submodules.values()): + hidden_states = layer_module(hidden_states, attention_mask) + return hidden_states + + +class BertPooler(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: Tensor): + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + + def extend_attention_mask(self, attention_mask: Tensor) -> Tensor: + # attention_mask: [batch_size, seq_length] in {0, 1} + # return: [batch_size * num_attention_heads, 1, seq_length] in {-10000, 0} + attention_mask = attention_mask.unsqueeze([1]) + attention_mask = ops.tile(attention_mask, repeats=[self.config.num_attention_heads, 1, 1]) + attention_mask = (1.0 - attention_mask) * -10000.0 + return attention_mask + + def forward(self, input_ids: Tensor, token_type_ids: Tensor, attention_mask: Tensor): + embeds = self.embeddings.forward(input_ids, token_type_ids) + attention_mask = self.extend_attention_mask(attention_mask) + hidden_states = self.encoder.forward(embeds, attention_mask) + pooled_output = self.pooler(hidden_states) + return [hidden_states, pooled_output] + + +def bert(batch_size=1, seq_length=128): + config = BertConfig() + model = BertModel(config) + input_ids = array(np.random.randint(0, config.vocab_size, size=[batch_size, seq_length])) + token_type_ids = zeros([batch_size, seq_length], dtype='int64') + attention_mask = ones([batch_size, seq_length], dtype='int64') + return model, [input_ids, token_type_ids, attention_mask] diff --git a/python/hidet/testing/tos_models/inception.py b/python/hidet/testing/tos_models/inception.py new file mode 100644 index 0000000..3f6dc4d --- /dev/null +++ b/python/hidet/testing/tos_models/inception.py @@ -0,0 +1,283 @@ +from typing import Tuple, List, Union +import hidet +from hidet.tos import nn, ops, Tensor + +# import torchvision.models.inception + +# Acknowledgement: the model definitions are adopted from torchvision.models.inception + +Ints = Union[int, List[int], Tuple[int]] + + +class BasicConv2d(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: Ints, padding: Ints = 0, stride: Ints = 1, groups: int = 1) -> None: + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding, stride, groups) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv(x) + x = self.bn(x) + return ops.relu(x) + + +class InceptionHead(nn.Module): + def __init__(self): + super().__init__() + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, padding=1) + self.max_pool_1 = nn.MaxPool2d(kernel_size=3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3) + self.max_pool_2 = nn.MaxPool2d(kernel_size=3, stride=2) + + def forward(self, x): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.max_pool_1(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.max_pool_2(x) + return x + + +class InceptionA(nn.Module): + def __init__(self, in_channels, pool_features: int): + super().__init__() + self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) + + self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) + self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) + + self.branch3x3_1 = BasicConv2d(in_channels, 64, kernel_size=1) + self.branch3x3_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) + self.branch3x3_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) + + self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + branch1x1 = self.branch1x1(x) + branch5x5 = self.branch5x5_2(self.branch5x5_1(x)) + branch3x3 = self.branch3x3_3(self.branch3x3_2(self.branch3x3_1(x))) + branch_pool = self.branch_pool(ops.avg_pool2d(x, kernel=3, stride=1, padding=1)) + outputs = [branch1x1, branch5x5, branch3x3, branch_pool] + return ops.concat(outputs, axis=1) + + +class InceptionB(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) + self.branch3x3_p1 = BasicConv2d(in_channels, 64, kernel_size=1) + self.branch3x3_p2 = BasicConv2d(64, 96, kernel_size=3, padding=1) + self.branch3x3_p3 = BasicConv2d(96, 96, kernel_size=3, stride=2) + + def forward(self, x): + branch3x3 = self.branch3x3(x) + + branch3x3_p = self.branch3x3_p1(x) + branch3x3_p = self.branch3x3_p2(branch3x3_p) + branch3x3_p = self.branch3x3_p3(branch3x3_p) + + branch_pool = ops.max_pool2d(x, kernel=3, stride=2, padding=0) + + return ops.concat([branch3x3, branch3x3_p, branch_pool], axis=1) + + +class InceptionC(nn.Module): + def __init__(self, in_channels: int, channels_7x7: int): + super().__init__() + self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) + + c7 = channels_7x7 + self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) + self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=[1, 7], padding=[0, 3]) + self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=[7, 1], padding=[3, 0]) + + self.branch7x7_dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) + self.branch7x7_dbl_2 = BasicConv2d(c7, c7, kernel_size=[7, 1], padding=[3, 0]) + self.branch7x7_dbl_3 = BasicConv2d(c7, c7, kernel_size=[1, 7], padding=[0, 3]) + self.branch7x7_dbl_4 = BasicConv2d(c7, c7, kernel_size=[7, 1], padding=[3, 0]) + self.branch7x7_dbl_5 = BasicConv2d(c7, 192, kernel_size=[1, 7], padding=[0, 3]) + + self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + + branch7x7_dbl = self.branch7x7_dbl_1(x) + branch7x7_dbl = self.branch7x7_dbl_2(branch7x7_dbl) + branch7x7_dbl = self.branch7x7_dbl_3(branch7x7_dbl) + branch7x7_dbl = self.branch7x7_dbl_4(branch7x7_dbl) + branch7x7_dbl = self.branch7x7_dbl_5(branch7x7_dbl) + + branch_pool = ops.avg_pool2d(x, kernel=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + return ops.concat([ + branch1x1, + branch7x7, + branch7x7_dbl, + branch_pool + ], axis=1) + + +class InceptionD(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) + self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) + + self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=[1, 7], padding=[0, 3]) + self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=[7, 1], padding=[3, 0]) + self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) + + def forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + + branch_pool = ops.max_pool2d(x, kernel=3, stride=2, padding=0) + return ops.concat([branch3x3, branch7x7x3, branch_pool], axis=1) + + +class InceptionE(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) + self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) + self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=[1, 3], padding=[0, 1]) + self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=[3, 1], padding=[1, 0]) + + self.branch3x3_dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) + self.branch3x3_dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) + self.branch3x3_dbl_3a = BasicConv2d(384, 384, kernel_size=[1, 3], padding=[0, 1]) + self.branch3x3_dbl_3b = BasicConv2d(384, 384, kernel_size=[3, 1], padding=[1, 0]) + + self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) + + def forward(self, x): + branch1x1 = self.branch1x1(x) + + branch3x3 = self.branch3x3_1(x) + branch3x3 = [ + self.branch3x3_2a(branch3x3), + self.branch3x3_2b(branch3x3) + ] + branch3x3 = ops.concat(branch3x3, axis=1) + + branch3x3_dbl = self.branch3x3_dbl_1(x) + branch3x3_dbl = self.branch3x3_dbl_2(branch3x3_dbl) + branch3x3_dbl = [ + self.branch3x3_dbl_3a(branch3x3_dbl), + self.branch3x3_dbl_3b(branch3x3_dbl) + ] + branch3x3_dbl = ops.concat(branch3x3_dbl, axis=1) + + branch_pool = ops.avg_pool2d(x, kernel=3, stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + + return ops.concat([branch1x1, branch3x3, branch3x3_dbl, branch_pool], axis=1) + + +class InceptionTail(nn.Module): + def __init__(self, num_classes=1000): + super().__init__() + self.fc = nn.Linear(2048, num_classes) + + def forward(self, x: Tensor): + x = ops.avg_pool2d(x, kernel=x.shape[2:], stride=1, padding=0) + x = ops.squeeze(x, dims=[2, 3]) + x = self.fc(x) + return x + + +class InceptionV3(nn.Module): + def __init__(self): + super().__init__() + self.blocks = nn.Sequential( + InceptionHead(), + InceptionA(in_channels=192, pool_features=32), + InceptionA(in_channels=256, pool_features=64), + InceptionA(in_channels=288, pool_features=64), + InceptionB(in_channels=288), + InceptionC(in_channels=768, channels_7x7=128), + InceptionC(in_channels=768, channels_7x7=160), + InceptionC(in_channels=768, channels_7x7=160), + InceptionC(in_channels=768, channels_7x7=192), + InceptionD(in_channels=768), + InceptionE(in_channels=1280), + InceptionE(in_channels=2048), + InceptionTail() + ) + + def forward(self, x): + return self.blocks(x) + + +def basic_conv2d(batch_size, in_channels, height, width, out_channels, kernel_size, padding, stride, groups) -> Tuple[nn.Module, List[Tensor]]: + inputs = [hidet.randn([batch_size, in_channels, height, width])] + model = BasicConv2d(in_channels, out_channels, kernel_size, padding, stride, groups) + return model, inputs + + +def inception_head(batch_size=1): + inputs = [hidet.randn([batch_size, 3, 299, 299])] + model = InceptionHead() + return model, inputs + + +def inception_a(in_channels: int = 192, pool_features: int = 32, batch_size=1): + assert (in_channels, pool_features) in [(192, 32), (256, 64), (288, 64)] + inputs = [hidet.randn([batch_size, in_channels, 35, 35])] + model = InceptionA(in_channels, pool_features) + return model, inputs + + +def inception_b(batch_size=1): + inputs = [hidet.randn([batch_size, 288, 35, 35])] + model = InceptionB(in_channels=288) + return model, inputs + + +def inception_c(in_channels=768, channels_7x7=128, batch_size=1): + assert (in_channels, channels_7x7) in [(768, 128), (768, 160), (768, 160), (768, 192)] + inputs = [hidet.randn([batch_size, in_channels, 17, 17])] + model = InceptionC(in_channels, channels_7x7=channels_7x7) + return model, inputs + + +def inception_d(in_channels=768, batch_size=1): + inputs = [hidet.randn([batch_size, in_channels, 17, 17])] + model = InceptionD(in_channels) + return model, inputs + + +def inception_e(in_channels=1280, batch_size=1): + assert in_channels in [1280, 2048] + inputs = [hidet.randn([batch_size, in_channels, 8, 8])] + model = InceptionE(in_channels) + return model, inputs + + +def inception_tail(batch_size=1): + inputs = [hidet.randn([batch_size, 2048, 8, 8])] + model = InceptionTail() + return model, inputs + + +def inception_v3(batch_size=1): + inputs = [hidet.randn([batch_size, 3, 299, 299])] + model = InceptionV3() + return model, inputs diff --git a/python/hidet/testing/tos_models/resnet.py b/python/hidet/testing/tos_models/resnet.py new file mode 100644 index 0000000..b8735b7 --- /dev/null +++ b/python/hidet/testing/tos_models/resnet.py @@ -0,0 +1,136 @@ +from typing import Type, Union, List, Callable, Any + +import hidet.utils +from hidet.tos.modules import nn + + +def conv1x1(in_channels, out_channels, stride=1) -> nn.Conv2d: + return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride) + + +def conv3x3(in_channels, out_channels, stride=1) -> nn.Conv2d: + return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1 + ): + super().__init__() + self.conv1 = conv3x3(in_channels, channels, stride) + self.bn1 = nn.BatchNorm2d(channels) + self.conv2 = conv3x3(channels, channels) + self.bn2 = nn.BatchNorm2d(channels) + self.relu = nn.Relu() + if in_channels != channels * self.expansion or stride != 1: + self.skip = nn.Sequential( + conv1x1(in_channels=in_channels, out_channels=channels * self.expansion, stride=stride), + nn.BatchNorm2d(channels * self.expansion) + ) + else: + self.skip = (lambda x: x) + + def forward(self, x): + out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) + out = self.relu(out + self.skip(x)) + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1 + ): + super().__init__() + expansion = 4 + self.conv1 = conv1x1(in_channels, channels) + self.bn1 = nn.BatchNorm2d(channels) + self.conv2 = conv3x3(channels, channels, stride) + self.bn2 = nn.BatchNorm2d(channels) + self.conv3 = conv1x1(channels, channels * expansion) + self.bn3 = nn.BatchNorm2d(channels * expansion) + self.relu = nn.Relu() + if in_channels != channels * expansion or stride != 1: + self.skip = nn.Sequential( + conv1x1(in_channels=in_channels, out_channels=channels * self.expansion, stride=stride), + nn.BatchNorm2d(channels * self.expansion) + ) + else: + self.skip = (lambda x: x) + + def forward(self, x): + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.relu(self.bn3(self.conv3(out)) + self.skip(x)) + return out + + +class ResNet(nn.Module): + def __init__(self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000): + super().__init__() + self.in_channels = 64 + self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(self.in_channels) + self.relu = nn.Relu() + self.max_pool = nn.MaxPool2d(kernel_size=7, stride=2, padding=3) + self.layer1 = self.make_layer(block, 64, layers[0]) + self.layer2 = self.make_layer(block, 128, layers[1], stride=2) + self.layer3 = self.make_layer(block, 256, layers[2], stride=2) + self.layer4 = self.make_layer(block, 512, layers[3], stride=2) + self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + def make_layer(self, + block: Type[Union[BasicBlock, Bottleneck]], + channels: int, + blocks: int, + stride: int = 1): + layers = [] + for i in range(blocks): + if i == 0: + layers.append(block(self.in_channels, channels, stride)) + self.in_channels = channels * block.expansion + else: + layers.append(block(self.in_channels, channels, stride=1)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.max_pool(self.relu(self.bn1(self.conv1(x)))) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avg_pool(x) + x = x.squeeze(dims=(2, 3)) + x = self.fc(x) + return x + + +def resnet18(): + return ResNet(block=BasicBlock, layers=[2, 2, 2, 2]) + + +def resnet34(): + return ResNet(block=BasicBlock, layers=[3, 4, 6, 3]) + + +def resnet50(): + return ResNet(block=Bottleneck, layers=[3, 4, 6, 3]) + + +def resnet101(): + return ResNet(block=Bottleneck, layers=[3, 4, 23, 3]) + + +def resnet152(): + return ResNet(block=Bottleneck, layers=[3, 8, 36, 3]) diff --git a/python/hidet/tos/__init__.py b/python/hidet/tos/__init__.py new file mode 100644 index 0000000..1bf23d8 --- /dev/null +++ b/python/hidet/tos/__init__.py @@ -0,0 +1,21 @@ +from . import tensor +from . import operator +from . import module +from . import modules +from . import ops +from . import ir +from . import frontend + +from .tensor import Tensor +from .operator import Operator +from .module import Module +from .ir import FlowGraph +from .transforms import GraphPass, PassContext + +from .tensor import array, randn, empty, zeros, ones, symbol, randn_like, empty_like, zeros_like, ones_like, symbol_like +from .tensor import full, full_like +from .operator import space_level, get_space_level +from .ir import trace_from, load_graph, save_graph +from .transforms import optimize +from .modules import nn +from .jit import jit diff --git a/python/hidet/tos/common.py b/python/hidet/tos/common.py new file mode 100644 index 0000000..4d17932 --- /dev/null +++ b/python/hidet/tos/common.py @@ -0,0 +1,5 @@ +def normalize(v, num=2): + if isinstance(v, (list, tuple)): + return v + else: + return [v for _ in range(num)] diff --git a/python/hidet/tos/frontend/__init__.py b/python/hidet/tos/frontend/__init__.py new file mode 100644 index 0000000..49cc982 --- /dev/null +++ b/python/hidet/tos/frontend/__init__.py @@ -0,0 +1,4 @@ +from . import onnx_utils +from . import torch_utils + +from .onnx_utils import from_onnx diff --git a/python/hidet/tos/frontend/onnx_utils.py b/python/hidet/tos/frontend/onnx_utils.py new file mode 100644 index 0000000..d6772ef --- /dev/null +++ b/python/hidet/tos/frontend/onnx_utils.py @@ -0,0 +1,844 @@ +from typing import List, Union, Sequence, Optional, Dict, Callable +from collections import defaultdict +import os +import numpy as np +import hidet +from hidet.tos.modules import nn +from hidet.tos import ops +from hidet.tos.tensor import Tensor, from_numpy, randn +from hidet.utils import line_profile, prod + +""" +Please refers to https://github.com/onnx/onnx/blob/main/docs/Operators.md for operator definition when adding new operators. +Please refers to https://github.com/onnx/onnx/blob/main/onnx/onnx.proto for proto structure of onnx format. +""" + + +class OnnxOperator: + def __init__(self, node, opset: int = 11): + """ + Parameters + ---------- + node: onnx.NodeProto + """ + import onnx.numpy_helper + self.node = node + self.opset = opset + self.input_names = [name for name in node.input] + self.output_names = [name for name in node.output] + self.attrs = {} + for attr in node.attribute: + if attr.type == 1: # float + v = attr.f + elif attr.type == 2: # int + v = attr.i + elif attr.type == 3: # string + v = attr.s.decode('utf-8') + elif attr.type == 4: # tensor + v = from_numpy(onnx.numpy_helper.to_array(tensor=attr.t)).cuda() + elif attr.type == 6: # floats + v = list(attr.floats) + elif attr.type == 7: # ints + v = list(attr.ints) + elif attr.type == 8: # strings + v = [s.decode('utf-8') for s in attr.strings] + else: + raise ValueError('Can not recognize type id {} of attribute {}'.format(attr.type, attr.name)) + self.attrs[attr.name] = v + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + for opset in range(self.opset, 0, -1): + run_func: Callable[[List[Tensor]], List[Tensor]] = getattr(self, 'run_v{}'.format(opset)) + outs = run_func(inputs) + if outs is NotImplemented: + continue + else: + return outs + raise ValueError('Can not dispatch operator {} in opset {}'.format(self.__class__.__name__, self.opset)) + + def run_v1(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v2(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v3(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v4(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v5(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v6(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v7(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v8(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v9(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v10(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v12(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + def run_v13(self, inputs: List[Tensor]) -> List[Tensor]: + return NotImplemented + + @staticmethod + def tensor2list(tensor: Tensor) -> Union[List, int, float]: + return tensor.cpu().numpy().tolist() + + @staticmethod + def optional_inputs(inputs: List[Tensor], requires: List[bool]) -> List[Union[Tensor, None]]: + diff = len(requires) - len(inputs) + assert diff >= 0, 'Onnx get {} inputs but expect at most {}.'.format(len(inputs), len(requires)) + ret: List[Union[Tensor, None]] = [] + ret += inputs + ret += [None for _ in range(diff)] + for i, (t, r) in enumerate(zip(ret, requires)): + if t is None and r: + raise 'The {}th input is required.'.format(i) + return ret + + +class OnnxConv(OnnxOperator): + def run_v1(self, inputs: List[Tensor]) -> List[Tensor]: + padding = self.attrs.get('pads', [0, 0, 0, 0]) + strides = self.attrs.get('strides', [1, 1]) + groups = self.attrs.get('group', 1) + if len(inputs) == 2: + x, w = inputs + bias = None + else: + x, w, bias = inputs + x = ops.pad(x, ops.utils.normalize_padding(padding)) + output = ops.conv2d(x, w, stride=strides, groups=groups) + if bias is not None: + bias = ops.unsqueeze(bias, [0, 2, 3]) + output = output + bias + return [output] + + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + return self.run_v1(inputs) + + +class OnnxBatchNormalization(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.epsilon: float = self.attrs.get('epsilon', 1e-5) + self.momentum: float = self.attrs.get('momentum', 0.9) + self.training_mode: int = self.attrs.get('training_mode', 0) + assert self.training_mode == 0, 'BatchNorm in training mode occurs, currently, hidet does not support training.' + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, scale, bias, running_mean, running_var = inputs + y = ops.batch_norm_infer(x, running_mean=running_mean, running_var=running_var, epsilon=self.epsilon, axis=1) + return [y * scale.unsqueeze([0, 2, 3]) + bias.unsqueeze([0, 2, 3])] + + +class OnnxRelu(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.relu(inputs[0])] + + +class OnnxSin(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.sin(inputs[0])] + + +class OnnxCos(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.cos(inputs[0])] + + +class OnnxPow(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, y = inputs + return [ops.pow(x, y)] + + +class OnnxDiv(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, y = inputs + return [ops.divide(x, y)] + + +class OnnxSqrt(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.sqrt(inputs[0])] + + +class OnnxErf(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.erf(inputs[0])] + + +class OnnxTanh(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.tanh(inputs[0])] + + +class OnnxMaxPool(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.kernel_size = list(self.attrs.get('kernel_shape')) + self.padding = list(self.attrs.get('pads', [0, 0, 0, 0])) + self.strides = list(self.attrs.get('strides')) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.max_pool2d(inputs[0], self.kernel_size, self.strides, self.padding)] + + +class OnnxReduceMean(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.dims = self.attrs.get('axes') + self.keep_dim = self.attrs.get('keepdims', 1) == 1 + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.reduce_mean(inputs[0], self.dims, self.keep_dim)] + + +class OnnxSqueezeOp(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.dims = self.attrs.get('axes', None) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + data = inputs[0] + if self.dims is None: + # squeeze all dimensions with extent 1 + dims = [i for i, dim in enumerate(data.shape) if dim == 1] + else: + dims = list(self.dims) + return [ops.squeeze(inputs[0], dims)] + + +class OnnxAdd(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [inputs[0] + inputs[1]] + + +class OnnxSub(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [inputs[0] - inputs[1]] + + +class OnnxMul(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [inputs[0] * inputs[1]] + + +class OnnxMatMul(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + a, b = inputs + assert len(a.shape) >= 2 and len(b.shape) >= 2 + if len(a.shape) == 2 and len(b.shape) == 2: + return [ops.matmul(a, b)] + else: + prefix_shape = hidet.tos.ops.definitions.arithmatic.broadcast_shape(a.shape[:-2], b.shape[:-2]) + a = ops.broadcast(a, prefix_shape + a.shape[-2:]) + b = ops.broadcast(b, prefix_shape + b.shape[-2:]) + a = ops.flatten(a, end_dim=-2) # [B, M, K] + b = ops.flatten(b, end_dim=-2) # [B, K, N] + c = ops.matmul(a, b) # [B, M, N] + c_expect_shape = prefix_shape + [a.shape[-2], b.shape[-1]] + c = c.reshape(c_expect_shape) + return [c] + + +class OnnxSoftmax(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.axis = self.attrs.get('axis') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.softmax(inputs[0], self.axis)] + + +class OnnxGlobalAveragePool(OnnxOperator): + def __init__(self, node): + super().__init__(node) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, = inputs + n, c, h, w = x.shape + return [ops.avg_pool2d(x, kernel=(h, w), stride=(1, 1), padding=(0, 0))] + + +class OnnxFlatten(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.axis = self.attrs.get('axis', 1) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + rank = len(x.shape) + axis = (self.axis + rank) % rank + dims = list(range(rank)) + return [ops.rearrange(x, plan=[dims[:axis], dims[axis:]])] + + +class OnnxUnsqueeze(OnnxOperator): + def __init__(self, node): + super().__init__(node) + + def run_v1(self, inputs: List[Tensor]) -> List[Tensor]: + axes = self.attrs['axes'] # in [-output_rank, output_rank - 1] + x = inputs[0] + rank = len(x.shape) + len(axes) + axes = [(axis + rank) % rank for axis in axes] + return [ops.unsqueeze(x, axes)] + + def run_v13(self, inputs: List[Tensor]) -> List[Tensor]: + x, axes = inputs + axes = self.tensor2list(axes) + rank = len(x.shape) + len(axes) + axes = [(axis + rank) % rank for axis in axes] + return [ops.unsqueeze(x, axes)] + + +class OnnxReshape(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.allow_zero = self.attrs.get('allowzero', 0) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, shape = inputs + shape = self.tensor2list(shape) + return [ops.reshape(x, shape)] + + +class OnnxTranspose(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.perm = self.attrs.get('perm', None) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + perm = self.perm if self.perm else list(reversed(range(len(x.shape)))) + return [ops.transpose(x, perm)] + + +class OnnxConcat(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.axis = self.attrs.get('axis') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.concat(inputs, self.axis)] + + +class OnnxArgMax(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return inputs + # raise NotImplementedError('ArgMax') + + +class OnnxGemm(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.alpha = self.attrs.get('alpha', 1.0) + self.beta = self.attrs.get('beta', 0.0) + self.trans_a = self.attrs.get('transA', 0) + self.trans_b = self.attrs.get('transB', 0) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + a, b = inputs[:2] + c = inputs[2] if len(inputs) > 2 else None + if self.trans_a == 1: + a = ops.rearrange(a, plan=[[1], [0]]) + if self.trans_b == 1: + b = ops.rearrange(b, plan=[[1], [0]]) + assert a.shape[1] == b.shape[0] + d = ops.matmul(a, b) + if self.alpha != 1.0: + d = d * self.alpha + if c and self.beta != 0.0: + d = d + c * self.beta + return [d] + + +class OnnxCast(OnnxOperator): + code2dtype = { + 1: 'float32', + 2: 'uint8', + 3: 'int8', + 4: 'uint16', + 5: 'int16', + 6: 'int32', + 7: 'int64', + 8: 'string', + 9: 'bool', + 10: 'float16', + 11: 'double', + 12: 'uint32', + 13: 'uint64', + 14: 'complex64', + 15: 'complex128', + 16: 'bfloat16', + } + + def __init__(self, node): + super().__init__(node) + self.to = self.attrs.get('to') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + dtype = self.code2dtype[self.to] + return [ops.cast(x, dtype)] + + +class OnnxShape(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.start = self.attrs.get('start', 0) + self.end: Optional[int] = self.attrs.get('end', None) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + rank = len(x.shape) + start = self.start + rank if self.start < 0 else self.start + if self.end is not None: + end = self.end + rank if self.end < 0 else self.end + else: + end = rank + start = max(min(start, rank), 0) + end = max(min(end, rank), 0) + return [hidet.array(x.shape[start:end]).cuda()] + + +class OnnxConstant(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.value = self.attrs.get('value') + if self.value is None: + raise NotImplementedError('Currently, only support Tensor constant in onnx importer') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + assert len(inputs) == 0 + return [self.value] + + +class OnnxGather(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.axis = self.attrs.get('axis', 0) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + data, indices = inputs + return [ops.take(data, indices, self.axis)] + + +class OnnxSlice(OnnxOperator): + def __init__(self, node): + super().__init__(node) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + data, starts, ends = inputs[:3] + axes = inputs[3] if len(inputs) > 3 else None + steps = inputs[4] if len(inputs) > 4 else None + starts = self.tensor2list(starts) + ends = self.tensor2list(ends) + axes = self.tensor2list(axes) if axes else None + steps = self.tensor2list(steps) if steps else None + return [ops.strided_slice(data, starts, ends, axes, steps)] + + +class OnnxSigmoid(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + return [ops.sigmoid(inputs[0])] + + +class OnnxInstanceNormalization(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.epsilon = self.attrs.get('epsilon', 1e-5) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, scale, bias = inputs + rank = len(x.shape) + dims = [0] + list(range(2, rank)) + scale = ops.unsqueeze(scale, dims) # [1, C, D1, ...] + bias = ops.unsqueeze(bias, dims) # [1, C, D1, ...] + return [ops.instance_norm(x, self.epsilon) * scale + bias] + + +class OnnxConstantOfShape(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.value = self.attrs.get('value') + if self.value is None: + self.value = hidet.zeros([1], dtype='float32') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + shape = inputs[0].cpu().numpy().tolist() + assert all(v >= 0 for v in shape) + return [ops.broadcast(self.value, shape)] + + +class OnnxPad(OnnxOperator): + def run_v2(self, inputs: List[Tensor]) -> List[Tensor]: + data = inputs[0] + mode = self.attrs.get('mode', 'constant') + pads = self.attrs.get('pads') + value = self.attrs.get('value', 0.0) + return [ops.pad(data, pads, mode, value)] + + def run_v13(self, inputs: List[Tensor]) -> List[Tensor]: + mode = self.attrs.get('mode', 'constant') + data, pads = inputs[:2] + value = self.tensor2list(inputs[2]) if len(inputs) > 2 else 0.0 + pads = self.tensor2list(pads) + return [ops.pad(data, pads, mode, value)] + + +class OnnxResize(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.coordinate_transformation_mode = self.attrs.get('coordinate_transformation_mode', 'half_pixel') + self.cubic_coeff_a = self.attrs.get('cubic_coeff_a', -0.75) + self.exclude_outside = self.attrs.get('exclude_outside', 0) + self.extrapolation_value = self.attrs.get('extrapolation_value', 0.0) + self.mode = self.attrs.get('mode', 'nearest') + self.nearest_mode = self.attrs.get('nearest_mode', 'round_prefer_floor') + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x, roi, scales, sizes = self.optional_inputs(inputs, requires=[True, False, False, False]) + if roi is not None: + roi = self.tensor2list(roi) + target_size = None + if scales is not None: + scales = self.tensor2list(scales) + assert len(x.shape) == len(scales) + target_size = [int(a * b) for a, b in zip(x.shape, scales)] + if sizes is not None: + sizes = self.tensor2list(sizes) + target_size = [int(v) for v in sizes] + if target_size is None: + raise ValueError('Resize operator in onnx must give either scales or sizes.') + if len(x.shape) == 4: + if not (target_size[0] == x.shape[0] and target_size[1] == x.shape[1]): + raise ValueError('Unsupported resize on batch and channel dimension.') + return [ops.resize2d(x, target_size[2:], self.mode, self.coordinate_transformation_mode, self.nearest_mode, + roi, self.cubic_coeff_a, self.exclude_outside, self.extrapolation_value)] + else: + raise NotImplementedError('Current only support 2d resize, got x {}.'.format(x.shape)) + + +class OnnxExpand(OnnxOperator): + def run_v8(self, inputs: List[Tensor]) -> List[Tensor]: + data, new_shape = inputs + new_shape = self.tensor2list(new_shape) + new_shape = hidet.tos.ops.definitions.arithmatic.broadcast_shape(data.shape, new_shape) + return [ops.broadcast(data, new_shape)] + + +class OnnxRange(OnnxOperator): + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + start, limit, delta = [self.tensor2list(t) for t in inputs] + array = np.arange(start=start, stop=limit, step=delta) + array = hidet.array(array).cuda().cast(dtype=inputs[0].dtype) + return [array] + + +class OnnxTile(OnnxOperator): + def run(self, inputs: List[Tensor]) -> List[Tensor]: + data, repeats = inputs + repeats = self.tensor2list(repeats) + return [ops.tile(data, repeats)] + + +class OnnxAveragePool(OnnxOperator): + def __init__(self, node): + super().__init__(node) + self.auto_pad = self.attrs.get('auto_pad', 'NOTSET') + self.ceil_mode = self.attrs.get('ceil_mode', 0) + self.count_include_pad = self.attrs.get('count_include_pad', 0) + self.kernel_shape = self.attrs.get('kernel_shape') + self.pads = self.attrs.get('pads') + self.strides = self.attrs.get('strides') + if self.auto_pad != 'NOTSET' or self.ceil_mode != 0 or self.count_include_pad != 0: + raise NotImplementedError(self) + + def run(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + if len(x.shape) != 4: + raise NotImplementedError('Currently only support 2-d avg pooling') + x = ops.avg_pool2d(x, self.kernel_shape, self.strides, self.pads) + return [x] + + +class OnnxClip(OnnxOperator): + def run_v1(self, inputs: List[Tensor]) -> List[Tensor]: + raise NotImplementedError() + + def run_v6(self, inputs: List[Tensor]) -> List[Tensor]: + x = inputs[0] + min_value = self.attrs.get('min', None) + max_value = self.attrs.get('max', None) + x = ops.clip(x, min_value, max_value) + return [x] + + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + raise NotImplementedError() + + def run_v12(self, inputs: List[Tensor]) -> List[Tensor]: + raise NotImplementedError() + + +class OnnxEqual(OnnxOperator): + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + a, b = inputs + return [ops.equal(a, b)] + + +class OnnxLess(OnnxOperator): + def run_v9(self, inputs: List[Tensor]) -> List[Tensor]: + a, b = inputs + return [ops.less(a, b)] + + +class OnnxWhere(OnnxOperator): + def run_v9(self, inputs: List[Tensor]) -> List[Tensor]: + cond, a, b = inputs + return [ops.where(cond, a, b)] + + +class OnnxSplit(OnnxOperator): + def run_v2(self, inputs: List[Tensor]) -> List[Tensor]: + axis = self.attrs.get('axis', 0) + parts = self.attrs['split'] + data = inputs[0] + return ops.split(data, axis, parts) + + def run_v13(self, inputs: List[Tensor]) -> List[Tensor]: + axis = self.attrs.get('axis', 0) + data, parts = inputs + parts = self.tensor2list(parts) + return ops.split(data, axis, parts) + + +class OnnxReduceSum(OnnxOperator): + def run_v1(self, inputs: List[Tensor]) -> List[Tensor]: + axes = self.attrs['axes'] + keepdims = self.attrs.get('keepdims', True) + data = inputs[0] + return [ops.reduce_sum(data, dims=axes, keep_dim=keepdims)] + + def run_v11(self, inputs: List[Tensor]) -> List[Tensor]: + return self.run_v1(inputs) + + def run_v13(self, inputs: List[Tensor]) -> List[Tensor]: + raise NotImplementedError() + + +def dispatch(node, opset: int = 11) -> OnnxOperator: + dispatch_table = { + 'Conv': OnnxConv, + 'Relu': OnnxRelu, + 'Pow': OnnxPow, + 'Div': OnnxDiv, + 'Sqrt': OnnxSqrt, + 'Erf': OnnxErf, + 'Tanh': OnnxTanh, + 'MaxPool': OnnxMaxPool, + 'ReduceMean': OnnxReduceMean, + 'Squeeze': OnnxSqueezeOp, + 'Add': OnnxAdd, + 'Sub': OnnxSub, + 'Mul': OnnxMul, + 'MatMul': OnnxMatMul, + 'Softmax': OnnxSoftmax, + 'ArgMax': OnnxArgMax, + 'BatchNormalization': OnnxBatchNormalization, + 'GlobalAveragePool': OnnxGlobalAveragePool, + 'Flatten': OnnxFlatten, + 'Unsqueeze': OnnxUnsqueeze, + 'Concat': OnnxConcat, + 'Cast': OnnxCast, + 'Constant': OnnxConstant, + 'Reshape': OnnxReshape, + 'Shape': OnnxShape, + 'Gemm': OnnxGemm, + 'Gather': OnnxGather, + 'Slice': OnnxSlice, + 'Transpose': OnnxTranspose, + 'Sin': OnnxSin, + 'Cos': OnnxCos, + 'Sigmoid': OnnxSigmoid, + 'InstanceNormalization': OnnxInstanceNormalization, + 'ConstantOfShape': OnnxConstantOfShape, + 'Pad': OnnxPad, + 'Resize': OnnxResize, + 'Expand': OnnxExpand, + 'Range': OnnxRange, + 'Tile': OnnxTile, + 'AveragePool': OnnxAveragePool, + 'Clip': OnnxClip, + 'Equal': OnnxEqual, + 'Less': OnnxLess, + 'Where': OnnxWhere, + 'Split': OnnxSplit, + 'ReduceSum': OnnxReduceSum, + } + op_type = node.op_type + if op_type not in dispatch_table: + raise NotImplementedError("Operator '{}' (opset {}) from onnx has not been supported yet.".format(op_type, opset)) + op = dispatch_table[op_type](node) + op.opset = opset + return op + + +def run_trt(node: OnnxOperator, inputs: List[Tensor]) -> List[Tensor]: + import onnx + from onnx.helper import make_value_info, make_tensor_type_proto + from onnx import TensorProto + import onnxruntime + hidet_outputs = node.run(inputs) + dtype_map = { + 'float32': TensorProto.FLOAT, + 'int64': TensorProto.INT64, + 'bool': TensorProto.BOOL + } + inputs_value_info = [ + make_value_info( + name=name, + type_proto=make_tensor_type_proto( + elem_type=dtype_map[tensor.dtype], + shape=tensor.shape + ) + ) for name, tensor in zip(node.input_names, inputs) + ] + outputs_value_info = [ + make_value_info( + name=name, + type_proto=make_tensor_type_proto( + elem_type=dtype_map[tensor.dtype], + shape=tensor.shape + ) + ) for name, tensor in zip(node.output_names, hidet_outputs) + ] + graph = onnx.helper.make_graph( + nodes=[node.node], + name='test', + inputs=inputs_value_info, + outputs=outputs_value_info + ) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", node.opset)]) + # print(model) + onnx.checker.check_model(model) + # serialized_model = onnx._serialize(model) + serialized_model = model.SerializeToString() + session = onnxruntime.InferenceSession(serialized_model, providers=['CPUExecutionProvider']) + outputs = session.run(node.output_names, input_feed={ + name: tensor.cpu().numpy() for name, tensor in zip(node.input_names, inputs) + }) + return [hidet.array(output).cuda() for output in outputs] + + +class OnnxModule(nn.Module): + def __init__(self, model): + """ + Parameters + ---------- + model: onnx.ModelProto + """ + super().__init__() + import onnx.numpy_helper + import onnx.external_data_helper + graph = model.graph + self.name: str = graph.name + self.model = model + for param in graph.initializer: + numpy_array = onnx.numpy_helper.to_array(tensor=param) + self.parameters[param.name] = from_numpy(numpy_array).cuda() + self.input_names: List[str] = [input.name for input in graph.input if input.name not in self.parameters] + self.output_names: List[str] = [output.name for output in graph.output] + self.opset = [opset_import.version for opset_import in model.opset_import] + assert len(self.opset) == 1 + self.operators: List[OnnxOperator] = [dispatch(node, opset=self.opset[0]) for node in graph.node] + self.usage_count: Dict[str, int] = self.count_usage() + + def forward(self, *args): + name2tensor = {} + assert len(args) == len(self.input_names) + # parameters + for name, param in self.parameters.items(): + name2tensor[name] = param + # inputs + for name, input in zip(self.input_names, args): + name2tensor[name] = input + # run nodes + usage_count = self.usage_count.copy() + for operator in self.operators: + inputs = [name2tensor[name] for name in operator.input_names] + outputs = operator.run(inputs) + + check = False + if check: + outputs_trt = run_trt(operator, inputs) + for a, b in zip(outputs, outputs_trt): + try: + np.testing.assert_allclose(a.cpu().numpy(), b.cpu().numpy(), atol=1e-3, rtol=1e-3) + except AssertionError as e: + print('Operator check failed: {:>20}'.format(operator.node.name)) + # print('{}'.format(', '.join(out.signature() for out in outputs))) + raise e + + assert len(outputs) == len(operator.output_names) + for name, tensor in zip(operator.output_names, outputs): + name2tensor[name] = tensor + for name in operator.input_names: + usage_count[name] -= 1 + if usage_count[name] == 0: + # free memory + del name2tensor[name] + # put outputs + results = [name2tensor[name] for name in self.output_names] + if len(results) == 1: + return results[0] + else: + return results + + def count_usage(self): + usage_count = defaultdict(int) + for op in self.operators: + for input_name in op.input_names: + usage_count[input_name] += 1 + for graph_output_name in self.output_names: + usage_count[graph_output_name] += 1 + return usage_count + + +def from_onnx(model: Union[str, 'onnx.ModelProto']) -> OnnxModule: + """ + Load an onnx model to hidet.tos.nn.Module. + + Parameters + ---------- + model: Union[str, onnx.ModelProto] + The path or model proto of given onnx model. + + Returns + ------- + ret: OnnxModule + The loaded model. + """ + import onnx + if isinstance(model, str): + model = os.path.expanduser(model) + model = onnx.load_model(model, load_external_data=False) + onnx.checker.check_model(model, full_check=True) + return OnnxModule(model) diff --git a/python/hidet/tos/frontend/torch_utils.py b/python/hidet/tos/frontend/torch_utils.py new file mode 100644 index 0000000..e69de29 diff --git a/python/hidet/tos/ir/__init__.py b/python/hidet/tos/ir/__init__.py new file mode 100644 index 0000000..9514c23 --- /dev/null +++ b/python/hidet/tos/ir/__init__.py @@ -0,0 +1,5 @@ +from . import graph +from . import functors + +from .graph import FlowGraph, Tensor, Operator, trace_from, load_graph, save_graph +from .functors import GraphRewriter, GraphVisitor diff --git a/python/hidet/tos/ir/functors.py b/python/hidet/tos/ir/functors.py new file mode 100644 index 0000000..cadd0e1 --- /dev/null +++ b/python/hidet/tos/ir/functors.py @@ -0,0 +1,160 @@ +from typing import Union, Type, Dict, List, Tuple, Optional +from collections import defaultdict +from hidet.tos.ir.graph import FlowGraph, Operator, Tensor +from hidet.utils import same_list + + +class GraphVisitor: + def __init__(self): + self.memo = {} + + def __call__(self, obj): + return self.visit(obj) + + def visit(self, obj: Union[FlowGraph, Operator, Tensor, list, tuple]): + key = obj if not isinstance(obj, list) else id(obj) + if self.memo is not None and obj in self.memo: + return self.memo[key] + if isinstance(obj, FlowGraph): + self.visit_FlowGraph(obj) + elif isinstance(obj, Operator): + self.visit_Operator(obj) + elif isinstance(obj, Tensor): + self.visit_Tensor(obj) + elif isinstance(obj, (list, tuple)): + self.visit_Sequence(obj) + else: + raise ValueError(type(obj)) + if self.memo is not None: + self.memo[key] = None + + def visit_FlowGraph(self, graph: FlowGraph): + for output in graph.outputs: + self(output) + + def visit_Operator(self, op: Operator): + for input in op.inputs: + self(input) + + def visit_Tensor(self, tensor: Tensor): + if tensor.trace is None: + return tensor + self(tensor.trace[0]) + + def visit_Sequence(self, seq: Union[list, tuple]): + for obj in seq: + self(obj) + + +class GraphRewriter: + def __init__(self): + self.memo = {} + + def __call__(self, obj): + return self.visit(obj) + + def visit(self, obj: Union[FlowGraph, Operator, Tensor, list, tuple]): + key = obj if not isinstance(obj, list) else id(obj) + if self.memo is not None and obj in self.memo: + return self.memo[key] + if isinstance(obj, FlowGraph): + ret = self.visit_FlowGraph(obj) + elif isinstance(obj, Operator): + ret = self.visit_Operator(obj) + elif isinstance(obj, Tensor): + ret = self.visit_Tensor(obj) + elif isinstance(obj, (list, tuple)): + ret = self.visit_Sequence(obj) + else: + raise ValueError(type(obj)) + if self.memo is not None: + self.memo[key] = ret + return ret + + def visit_FlowGraph(self, graph: FlowGraph): + outputs = [self.visit(output) for output in graph.outputs] + if same_list(outputs, graph.outputs): + return graph + else: + return FlowGraph(outputs, graph.inputs) + + def visit_Operator(self, op: Operator): + inputs = [self(input) for input in op.inputs] + if same_list(inputs, op.inputs): + return + else: + updated_outputs = op.clone(inputs) + for original, updated in zip(op.outputs, updated_outputs): + self.memo[original] = updated + + def visit_Tensor(self, tensor: Tensor): + if tensor.trace is None: + # input + return tensor + self(tensor.trace[0]) + if tensor in self.memo: + # the operator has been updated + return self.memo[tensor] + else: + return tensor + + def visit_Sequence(self, seq: Union[list, tuple]): + return seq.__class__([self(obj) for obj in seq]) + + +class GraphCloneRewriter(GraphRewriter): + def visit_FlowGraph(self, graph: FlowGraph): + outputs = [self.visit(output) for output in graph.outputs] + return FlowGraph(outputs, graph.inputs) + + def visit_Operator(self, op: Operator): + inputs = [self(x) for x in op.inputs] + updated_outputs = op.clone(inputs) + for original, updated in zip(op.outputs, updated_outputs): + self.memo[original] = updated + + def visit_Tensor(self, tensor: Tensor): + if tensor.trace is None: + # keep the input tensor the same + return tensor + else: + self(tensor.trace[0]) + return self.memo[tensor] + + +class GraphUsageAnalyzer(GraphVisitor): + def __init__(self): + super().__init__() + self.usage: Dict[Tensor, List[Tuple[Optional[Operator], int]]] = defaultdict(list) + + def analyze(self, graph: FlowGraph): + self.usage = defaultdict(list) + self.visit(graph) + return self.usage + + def visit_FlowGraph(self, graph: FlowGraph): + for idx, output in enumerate(graph.outputs): + self(output) + self.usage[output].append((None, idx)) + GraphVisitor.visit_FlowGraph(self, graph) + + def visit_Operator(self, op: Operator): + for idx, input in enumerate(op.inputs): + self.usage[input].append((op, idx)) + GraphVisitor.visit_Operator(self, op) + + +def analyze_usage(graph: FlowGraph) -> Dict[Tensor, List[Tuple[Operator, int]]]: + analyzer = GraphUsageAnalyzer() + return analyzer.analyze(graph) + + +def clone(graph: FlowGraph): + return GraphCloneRewriter().visit(graph) + + +def graph_collect(obj: Union[FlowGraph, Operator, Tensor], cls: Type[Union[Operator, Tensor]]): + visitor = GraphVisitor() + visitor.visit(obj) + return [v for v in visitor.memo if isinstance(v, cls)] + diff --git a/python/hidet/tos/ir/graph.py b/python/hidet/tos/ir/graph.py new file mode 100644 index 0000000..298852f --- /dev/null +++ b/python/hidet/tos/ir/graph.py @@ -0,0 +1,279 @@ +from __future__ import annotations +from typing import List, Union, Dict, Set, Optional, Tuple +import os +import pickle +import warnings +from collections import defaultdict + +import hidet.tos.operator +from hidet.tos.tensor import Tensor +from hidet.tos.operator import Operator +from hidet.utils import tracer +from hidet.utils.doc import Doc, NewLine, Text, doc_join +from hidet.utils.namer import Namer + + +class FlowGraph: + def __init__(self, outputs: List[Tensor], inputs=None, nodes=None): + self.outputs: List[Tensor] = outputs + self.inputs: Optional[List[Tensor]] = inputs + self.nodes: Optional[List[Operator]] = nodes + self.usage_count: Optional[Dict[Tensor, int]] = None + + def __call__(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: + return self.forward(*inputs) + + def __str__(self): + if any(v is None for v in [self.inputs, self.nodes, self.usage_count]): + self.update_nodes() + namer = Namer() + + def get_tensor_sig(x: Tensor) -> Doc: + return Text(x.dtype) + '[' + doc_join([str(v) for v in x.shape], ', ') + ']' + + def get_attr_repr(value: Union[float, int, bool, str, list, tuple]) -> Doc: + if isinstance(value, (float, int, bool)): + return Text(str(value)) + elif isinstance(value, str): + return Text('"{}"'.format(value)) + elif isinstance(value, list): + return '[' + doc_join([get_attr_repr(v) for v in value], ', ') + ']' + elif isinstance(value, tuple): + return '(' + doc_join([get_attr_repr(v) for v in value], ', ') + ')' + else: + raise ValueError(value) + + param_docs = [] + for x in self.inputs: + name = namer(x) + param_docs.append(Text(name) + ': ' + get_tensor_sig(x)) + + # head + head_doc = 'Graph(' + doc_join(param_docs, ', ') + ')' + + # body + body_doc = Doc() + for op in self.nodes: + # const inputs + for x in op.inputs: + if x not in namer.obj_name: + assert x.storage is not None + body_doc += NewLine() + namer.get_name(x, hint='c') + ' = ' + 'Constant(' + get_tensor_sig(x) + ')' + outputs = op.outputs + if len(outputs) > 1: + raise NotImplementedError() + output: Tensor = outputs[0] + line_doc = Doc() + line_doc += namer(output) + ' = ' + line_doc += op.name + ('*' if len(op.task.prologues) + len(op.task.epilogues) > 0 else '') + '(' + line_doc += doc_join([namer(x) for x in op.inputs], sep=', ') + if op.attrs: + line_doc += ', ' + doc_join([Text(name) + '=' + get_attr_repr(value) for name, value in op.attrs.items()], ', ') + line_doc += ')' + line_doc += ' # ' + get_tensor_sig(output) + body_doc += NewLine() + line_doc + + # return statement + body_doc += NewLine() + Text('return ') + doc_join([namer(x) for x in self.outputs], ', ') + + graph_doc = head_doc + '{' + body_doc.indent() + NewLine() + '}' + return str(graph_doc) + + def build(self): + tasks = [] + tunable_tasks = [] + task_keys = set() + space_level = hidet.get_space_level() + for node in self.nodes: + if node.task_func is None: + # if space_level == 0 or 'implement_cuda' not in node.task.__class__.__dict__: + task_key = hash(str(node.task)) + if task_key in task_keys: + continue + task_keys.add(task_key) + if node.task.fast_implement(space_level): + tasks.append(node.task) + else: + tunable_tasks.append(node.task) + hidet.driver.build_batch_task(tasks, space_level, parallel=True) + hidet.driver.build_batch_task(tunable_tasks, space_level, parallel=False) + + def forward(self, *inputs: Tensor) -> Union[List[Tensor], Tensor]: + if any(v is None for v in [self.inputs, self.nodes, self.usage_count]): + self.update_nodes() + + self.build() + + if len(inputs) != len(self.inputs): + raise ValueError('FlowGraph expects {} inputs, but got {}.'.format(len(self.inputs), len(inputs))) + for idx, tensor in enumerate(inputs): + if tensor.storage is None: + raise ValueError('FlowGraph expects all input tensors are non-symbolic, ' + 'but the input {} ({}) is a symbol tensor.'.format(idx, tensor.signature())) + usage_count = self.usage_count.copy() + tensor_map: Dict[Tensor, Tensor] = {} + for st, at in zip(self.inputs, inputs): + tensor_map[st] = at + for node in self.nodes: + # prepare node inputs + node_inputs = [] + for node_input in node.inputs: + if node_input.storage is None: + # symbolic input + node_inputs.append(tensor_map[node_input]) + usage_count[node_input] -= 1 + if usage_count[node_input] == 0: + # free the memory + del tensor_map[node_input] + else: + # constant input + node_inputs.append(node_input) + # run node + args = {f'input_{idx}': f'{tensor.dtype}{tensor.shape}' for idx, tensor in enumerate(node.inputs)} + args.update(node.attrs) + with tracer.profile(node.name, category='op', args=args, trace_cuda=True): + node_outputs = node.imperative_run(node_inputs) + for st, at in zip(node.outputs, node_outputs): + tensor_map[st] = at + ret = [tensor_map[st] for st in self.outputs] + return ret[0] if len(ret) == 1 else ret + + def save(self, fname: str): + # before save, clear the packed func cache because ctypes object can not be pickled + for node in self.nodes: + node.task_func = None + self.usage_count, self.nodes = None, None + + dirname = os.path.dirname(fname) + os.makedirs(dirname, exist_ok=True) + # save to a temporary file first, in case pickle fails. + with open(fname + '.temp', 'wb') as f: + pickle.dump(self, f) + os.rename(fname + '.temp', fname) + + @staticmethod + def load(fname: str) -> FlowGraph: + with open(fname, 'rb') as f: + ret = pickle.load(f) + if not isinstance(ret, FlowGraph): + raise TypeError('Expect to load FlowGraph, got {}'.format(type(ret))) + ret.update_nodes() + return ret + + def update_nodes(self): + inputs, self.nodes, self.usage_count = self._analyze(self.outputs) + if self.inputs: + if len(inputs) != len(self.inputs): + raise ValueError('Found {} symbol inputs, but {} given'.format(len(inputs), len(self.inputs))) + if any(a not in self.inputs for a in inputs): + raise ValueError('There is a symbol tensor not given in inputs') + else: + if len(inputs) > 1: + warnings.warn('There are {} symbol inputs traced, ' + 'but the inputs has not given to specify the order.'.format(len(inputs))) + self.inputs = inputs + return self + + def cuda_graph(self): + from hidet.runtime.cuda_graph import create_cuda_graph + return create_cuda_graph(self) + + @staticmethod + def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]: + inputs = [] + nodes: List[Operator] = [] + # find out all nodes + all_nodes: Set[Operator] = set() + + def find_all_nodes(u: Operator): + all_nodes.add(u) + for it in u.inputs: + if it.op is None: + continue + v: Operator = it.op + if v not in all_nodes: + find_all_nodes(v) + for ot in outputs: + if ot.trace: + find_all_nodes(ot.op) + + # topological sort + out_degree: Dict[Operator, int] = {u: 0 for u in all_nodes} + for u in all_nodes: + for it in u.inputs: + if it.op is None: + continue + out_degree[it.op] += 1 + for u in outputs: + if u.op: + out_degree[u.op] += 1 + + stack: List[Operator] = [] + for u in outputs: + if u.op: + out_degree[u.op] -= 1 + if out_degree[u.op] == 0: + stack.append(u.op) + while len(stack) > 0: + op = stack.pop() + nodes.append(op) + for it in op.inputs: + if it.op is None: + if it.storage is None and it not in inputs: + # input + inputs.append(it) + else: + out_degree[it.op] -= 1 + if out_degree[it.op] == 0: + stack.append(it.op) + nodes = list(reversed(nodes)) + assert len(nodes) == len(all_nodes), 'all_nodes {} topo_order {}'.format(len(all_nodes), len(nodes)) + + # tensor usage count + usage_count: Dict[Tensor, int] = defaultdict(int) + for op in all_nodes: + for inp in op.inputs: + usage_count[inp] += 1 + for graph_output in outputs: + usage_count[graph_output] += 1 + + return inputs, nodes, usage_count + + +def trace_from(tensor: Union[Tensor, List[Tensor]], inputs: Optional[Union[Tensor, List[Tensor]]] = None) -> FlowGraph: + """ + Trace the flow graph given the output tensor(s). + + Parameters + ---------- + tensor: Tensor or List[Tensor] + The output tensor(s) that we trace from. + inputs: Optional, Tensor or List[Tensor] + The inputs of the flow graph. When there is only a single symbol tensor in the flow graph, it is + optional. When there are multiple inputs, this is required to specify the input order. + + Returns + ------- + ret: FlowGraph + The flow graph that outputs the given input tensor(s). + """ + if isinstance(tensor, Tensor): + if tensor.trace is None: + raise ValueError('trace_from expects symbol tensor(s).') + outputs = [tensor] + else: + outputs = list(tensor) + if inputs is not None: + if isinstance(inputs, Tensor): + inputs = [inputs] + else: + inputs = list(inputs) + return FlowGraph(outputs, inputs).update_nodes() + + +def save_graph(graph: FlowGraph, fname: str): + graph.save(fname) + + +def load_graph(fname: str) -> FlowGraph: + return FlowGraph.load(fname) diff --git a/python/hidet/tos/jit.py b/python/hidet/tos/jit.py new file mode 100644 index 0000000..c4c11e9 --- /dev/null +++ b/python/hidet/tos/jit.py @@ -0,0 +1,144 @@ +from typing import Optional, Callable, Dict, List, Union +import numpy as np +import os +import time +import functools +import inspect +import hidet +from hidet.tos import Tensor +from hidet.tos.ir.graph import FlowGraph +from hidet.tos.tensor import symbol_like +from hidet.ffi import cuda + + +def get_type_repr(value): + import numpy as np + from hidet.tos import Tensor + + if isinstance(value, (str, int, float)): + return str(type(value).__name__) + elif isinstance(value, list): + items = [get_type_repr(v) for v in value] + return '[{}]'.format(', '.join(items)) + elif isinstance(value, tuple): + items = [get_type_repr(v) for v in value] + return '({})'.format(', '.join(items)) + elif isinstance(value, dict): + for v in value.keys(): + if not isinstance(v, str): + raise TypeError('Only support str as dict key, got {}'.format(type(v))) + keys = list(v for v in value.keys()) + items = [get_type_repr(v) for v in value.values()] + return '{{{}}}'.format(', '.join('{}: {}'.format(k, v) for k, v in zip(keys, items))) + elif isinstance(value, Tensor): + shape_repr = ', '.join(str(v) for v in value.shape) + return '{}[{}]'.format(value.dtype, shape_repr) + elif isinstance(value, np.ndarray): + shape_repr = ', '.join(str(v) for v in value.shape) + return 'np.{}[{}]'.format(value.dtype, shape_repr) + else: + raise TypeError('Does not support type {} for jit.'.format(type(value))) + + +def get_bind_repr(bind: inspect.BoundArguments) -> str: + items = [] + for name, value in bind.arguments: + items += '{}: {}'.format(name, get_type_repr(value)) + return 'BindRepr({})'.format(', '.join(items)) + + +class JitGraph: + # todo: use inspect package to support more wide range input and outputs + def __init__( + self, + func: Callable, + opt: bool = False, + parallel_k: str = 'default', + save_ir_dir: Optional[str] = './outs', + mma: str = 'wmma_tf32_f32', + ): + self.func: Callable = func + self.cached_graph: Dict[str, FlowGraph] = {} + + self.parallel_k = parallel_k + self.opt = opt + self.save_ir_dir = os.path.join(save_ir_dir, func.__name__) + self.mma = mma + + def __str__(self): + items = [] + for args_repr, graph in self.cached_graph.items(): + items.extend([args_repr, ' => ', str(graph), '\n']) + return ''.join(items) + + @staticmethod + def args_representation(*args): + for arg in args: + if not isinstance(arg, Tensor): + raise NotImplementedError('Currently only support Tensor argument, got {}.'.format(type(arg))) + + args_repr = get_type_repr(args) + return args_repr + + def flow_graph_for(self, *args) -> FlowGraph: + args_repr = self.args_representation(*args) + + if args_repr not in self.cached_graph: + symbol_inputs = [symbol_like(arg) for arg in args] + symbol_outputs = self.func(*symbol_inputs) + graph = hidet.trace_from(symbol_outputs, inputs=symbol_inputs) + if self.opt: + with hidet.tos.PassContext() as ctx: + ctx.save_graph_instrument(self.save_ir_dir) + ctx.set_mma(self.mma) + if self.parallel_k == 'default': + ctx.set_parallel_k(default=True) + elif self.parallel_k == 'disabled': + ctx.set_parallel_k(disabled=True) + else: + ctx.set_parallel_k(nparts=int(self.parallel_k)) + graph = hidet.tos.optimize(graph) + self.cached_graph[args_repr] = graph + graph: FlowGraph = self.cached_graph[args_repr] + return graph + + def __call__(self, *args): + graph = self.flow_graph_for(*args) + return graph(*args) + + def benchmark(self, *args, warmup=10, number=10, repeat=10, median=True) -> Union[float, List[float]]: + graph = self.flow_graph_for(*args) + cuda_graph = graph.cuda_graph() + cuda_graph.set_input_tensors(args) + + results = [] + for i in range(warmup): + cuda_graph.run() + cuda.device_synchronize() + for i in range(repeat): + cuda.device_synchronize() + start_time = time.time() + for j in range(number): + cuda_graph.run() + cuda.device_synchronize() + end_time = time.time() + results.append((end_time - start_time) * 1000 / number) + + if median: + return float(np.median(results)) + else: + return results + + +def jit(opt=False, save_ir_dir='./outs', parallel_k='default', mma='simt'): + def decorator(func): + jit_graph = JitGraph( + func=func, + opt=opt, + parallel_k=parallel_k, + save_ir_dir=save_ir_dir, + mma=mma + ) + return jit_graph + + return decorator diff --git a/python/hidet/tos/module.py b/python/hidet/tos/module.py new file mode 100644 index 0000000..8ad2774 --- /dev/null +++ b/python/hidet/tos/module.py @@ -0,0 +1,62 @@ +from typing import Optional +from collections import OrderedDict +from hidet.tos.tensor import Tensor + + +class Module: + def __init__(self): + self.name = None + self.parameters: OrderedDict[str, Optional[Tensor]] = OrderedDict() + self.submodules: OrderedDict[str, Optional[Module]] = OrderedDict() + + def __setattr__(self, key, value): + parameters = self.__dict__.get('parameters') + submodules = self.__dict__.get('submodules') + if isinstance(value, Tensor): + value.name = key + self.parameters[key] = value + elif isinstance(value, Module): + value.name = '{}.{}'.format(self.name, key) if self.name else key + self.submodules[key] = value + elif parameters and submodules and value is None and (key in parameters or key in submodules): + if key in self.parameters: + self.parameters[key] = value + if key in self.submodules: + self.submodules[key] = value + else: + super().__setattr__(key, value) + cnt = sum([1 for collection in [parameters, submodules, self.__dict__] if collection and key in collection]) + assert cnt <= 1, 'duplicated definition of {}'.format(key) + + def __getattr__(self, item): + if item in self.parameters: + return self.parameters[item] + if item in self.submodules: + return self.submodules[item] + raise AttributeError(item) + + def __str__(self): + lines = [] + args_lines = self.extra_str().split('\n') + lines.extend([line for line in args_lines if len(line) > 0]) + for key, submodule in self.submodules.items(): + substr = str(submodule) + sub_lines = substr.split('\n') + sub_lines[0] = '({}): {}'.format(key, sub_lines[0]) + lines.extend(sub_lines) + indent = 2 + name = self.__class__.__name__ + if len(lines) <= 1: + return '{}({})'.format(name, '\n'.join(lines)) + else: + lines = [' ' * indent + line for line in lines] + return '{}(\n{}\n)'.format(name, '\n'.join(lines)) + + def __call__(self, *args): + return self.forward(*args) + + def extra_str(self) -> str: + return '' + + def forward(self, *args): + raise NotImplementedError() diff --git a/python/hidet/tos/modules/__init__.py b/python/hidet/tos/modules/__init__.py new file mode 100644 index 0000000..d972f26 --- /dev/null +++ b/python/hidet/tos/modules/__init__.py @@ -0,0 +1,2 @@ +from . import container +from . import nn diff --git a/python/hidet/tos/modules/container.py b/python/hidet/tos/modules/container.py new file mode 100644 index 0000000..110861e --- /dev/null +++ b/python/hidet/tos/modules/container.py @@ -0,0 +1,30 @@ +from __future__ import annotations +from typing import Optional, Iterable +from collections import OrderedDict +from hidet.tos.module import Module + + +class Sequential(Module): + def __init__(self, *args): + super().__init__() + if len(args) == 1 and isinstance(args[0], OrderedDict): + for key, module in args[0].items(): + self.__setattr__(key, module) + else: + for idx, module in enumerate(args): + self.__setattr__(str(idx), module) + + def forward(self, x): + for module in self.submodules.values(): + x = module(x) + return x + + +class ModuleList(Module): + def __init__(self, modules: Iterable[Module] = None): + super().__init__() + for idx, module in enumerate(modules): + self.submodules[str(idx)] = module + + def forward(self, *args): + raise ValueError('Should not forward ModuleList.') diff --git a/python/hidet/tos/modules/nn.py b/python/hidet/tos/modules/nn.py new file mode 100644 index 0000000..3d1c65c --- /dev/null +++ b/python/hidet/tos/modules/nn.py @@ -0,0 +1,151 @@ +from typing import Optional, Union, List +import math +from hidet.tos import ops +from hidet.tos.common import normalize +from hidet.tos.module import Module, Tensor +from hidet.tos.tensor import randn, zeros, ones +from hidet.tos.modules.container import Sequential, ModuleList + + +class Conv2d(Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, groups=1): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel = normalize(kernel_size) + self.padding = normalize(padding) + self.stride = normalize(stride) + self.groups = groups + self.weight = randn(shape=[out_channels, in_channels, *self.kernel], dtype='float32', stddev=1.0 / math.sqrt(out_channels)) + + def extra_str(self) -> str: + return 'in_channels={}, out_channels={}, kernel_size={}, stride={}, padding={}'.format(self.in_channels, self.out_channels, self.kernel, self.stride, self.padding) + + def forward(self, x): + x = ops.pad(x, ops.utils.normalize_padding(self.padding)) + return ops.conv2d(x, self.weight, self.stride, self.groups) + + +class BatchNorm2d(Module): + def __init__(self, num_features, eps=1e-5): + super().__init__() + self.eps = eps + self.running_mean = zeros(shape=[num_features]) + self.running_var = ones(shape=[num_features]) + + def extra_str(self) -> str: + return 'eps={}'.format(self.eps) + + def forward(self, x: Tensor): + return ops.batch_norm_infer(x, self.running_mean, self.running_var, self.eps) + + +class Linear(Module): + def __init__(self, in_features, out_features, bias: bool = True): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = randn(shape=[in_features, out_features], stddev=1.0 / math.sqrt(in_features)) + if bias: + self.bias = zeros(shape=[out_features]) + else: + self.bias = None + + def extra_str(self) -> str: + return 'in_features={}, out_features={}'.format(self.in_features, self.out_features) + + def forward(self, x: Tensor) -> Tensor: + return ops.matmul(x, self.weight) + self.bias + + +class Relu(Module): + def forward(self, x): + return ops.relu(x) + + +class MaxPool2d(Module): + def __init__(self, kernel_size, stride=1, padding=0): + super().__init__() + self.kernel = kernel_size + self.stride = stride + self.padding = padding + + def extra_str(self) -> str: + return 'kernel_size={}, stride={}, padding={}'.format(self.kernel, self.stride, self.padding) + + def forward(self, x): + return ops.max_pool2d(x, self.kernel, self.stride, self.padding) + + +class AvgPool2d(Module): + def __init__(self, kernel_size, stride, padding): + super().__init__() + self.kernel = kernel_size + self.stride = stride + self.padding = padding + + def extra_str(self) -> str: + return 'kernel_size={}, stride={}, padding={}'.format(self.kernel, self.stride, self.padding) + + def forward(self, x): + return ops.avg_pool2d(x, self.kernel, self.stride, self.padding) + + +class AdaptiveAvgPool2d(Module): + def __init__(self, output_size): + super().__init__() + self.output_size = normalize(output_size) + assert tuple(self.output_size) == (1, 1), 'current only support this' + + def extra_str(self) -> str: + return 'output_size={}'.format(self.output_size) + + def forward(self, x: Tensor) -> Tensor: + n, c, h, w = x.shape + return ops.avg_pool2d(x, kernel=(h, w), stride=(1, 1), padding=(0, 0)) + + +class Embedding(Module): + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = randn(shape=[num_embeddings, embedding_dim], dtype='float32', mean=0.0, stddev=1.0) + + def forward(self, indices: Tensor) -> Tensor: + return ops.take(self.weight, indices, axis=0) + + +class LayerNorm(Module): + def __init__(self, normalized_shape: Union[int, List[int]], eps: float = 1e-5, elementwise_affine: bool = True): + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if elementwise_affine: + self.weight = ones(normalized_shape) + self.bias = zeros(normalized_shape) + else: + self.weight = None + self.bias = None + + def forward(self, x: Tensor) -> Tensor: + x = ops.layer_norm(x) + if self.weight: + x = x * self.weight + if self.bias: + x = x + self.bias + return x + + +class Gelu(Module): + def forward(self, x): + return x * (ops.erf(x * (1.0 / 1.4142135381698608)) + 1.0) * 0.5 + + +class Tanh(Module): + def forward(self, x): + return ops.tanh(x) + diff --git a/python/hidet/tos/operator.py b/python/hidet/tos/operator.py new file mode 100644 index 0000000..5776c4a --- /dev/null +++ b/python/hidet/tos/operator.py @@ -0,0 +1,147 @@ +from typing import List, Optional, Dict, Any, Iterable, Tuple, Union +from collections import defaultdict + +from hidet.ir.task import Task +from hidet.runtime import CompiledFunction +from hidet.driver import build_task +from hidet.tos.tensor import empty, empty_like, Tensor + + +def trim_op_ending(name: str): + return name[:-2] if name.endswith('Op') else name + + +class Operator: + _current_space_level = 0 + _use_cache = True + + _task_cache: Dict[int, Dict[str, CompiledFunction]] = defaultdict(dict) + + def __init__( + self, + inputs: List[Tensor], + task: Optional[Task], + outputs: Optional[List[Tensor]] = None, + name: Optional[str] = None, + attributes: Optional[Dict[str, Any]] = None): + self.inputs: List[Tensor] = inputs + self.task: Optional[Task] = task + self.attrs: Dict[str, Any] = attributes if attributes is not None else {} + self.outputs: Optional[List[Tensor]] = outputs + self.name = name if name else trim_op_ending(self.__class__.__name__) + + assert all(isinstance(v, Tensor) for v in inputs) + + # cache + self.task_func: Optional[CompiledFunction] = None + + def __str__(self): + arguments = ['{}: {}{}'.format(i, t.dtype, t.shape) for i, t in enumerate(self.inputs)] + attributes = ['{}={}'.format(name, str(value)) for name, value in self.attrs.items()] + return '{}({})'.format(self.name, ', '.join(arguments + attributes)) + + def __dir__(self) -> Iterable[str]: + return ['task', 'inputs', 'outputs', 'attributes', 'name'] + list(self.attrs) + + def run(self) -> List[Tensor]: + if all(t.storage is not None for t in self.inputs): + return self.imperative_run(self.inputs) + else: + self.outputs = self.lazy_run() + return self.outputs + + def get_output(self, idx: int) -> Tensor: + if self.outputs is None: + outputs = self.run() + else: + outputs = self.outputs + return outputs[idx] + + def imperative_run(self, inputs: List[Tensor]) -> List[Tensor]: + if self.task_func is None: + task_string = str(self.task) + level = self._current_space_level + if task_string in self._task_cache[level]: + self.task_func = self._task_cache[level][task_string] + else: + self.task_func = build_task(self.task, space_level=self._current_space_level, use_cache=self._use_cache) + self._task_cache[level][task_string] = self.task_func + assert len(inputs) + len(self.task.outputs) == len(self.task.parameters) + output_types = [output.data_type for output in self.task.parameters[-len(self.task.outputs):]] + outputs = [empty(shape=type.const_shape(), dtype=type.scalar_type.name, device='cuda', layout=type.layout) for type in output_types] + self.task_func(*inputs, *outputs) + return outputs + + def lazy_run(self) -> List[Tensor]: + output_types = [output.data_type for output in self.task.parameters[-len(self.task.outputs):]] + outputs = [Tensor(shape=type.const_shape(), dtype=type.scalar_type.name, device='cuda', storage=None, layout=type.layout, trace=(self, i)) for i, type in enumerate(output_types)] + return outputs + + def reforward(self, inputs: List[Tensor], update_attributes: Optional[Dict[str, Any]] = None) -> List[Tensor]: + cls = self.__class__ + if not isinstance(self, Operator) or cls is Operator: + raise ValueError('Can only reforward operator whose class is a proper class of Operator. Please use .clone') + attributes = self.attrs.copy() + if update_attributes is not None: + attributes.update(update_attributes) + return cls(*inputs, **attributes).run() + + def clone(self, inputs: List[Tensor], update_attributes: Optional[Dict[str, Any]] = None) -> List[Tensor]: + cls = self.__class__ + attributes = self.attrs.copy() + if update_attributes is not None: + attributes.update(update_attributes) + + new_op = cls.__new__(cls) + new_op.name = self.name + new_op.inputs = inputs + new_op.task = self.task + new_op.attrs = attributes + new_op.outputs = new_op.run() + new_op.task_func = None + return new_op.outputs + + def latency(self, warmup=3, number=20, repeat=5, median=True) -> Union[List[float], float]: + from hidet.ffi import cuda + from time import time + import numpy as np + dummy_inputs = [] + for x in self.inputs: + if x.storage is not None: + dummy_inputs.append(x) + else: + if x.dtype in ['float32', 'float16', 'bfloat16']: + dummy_inputs.append(empty_like(x)) + else: + raise ValueError('Can not generate dummpy input for dtype {}'.format(x.dtype)) + output_types = [output.data_type for output in self.task.parameters[-len(self.task.outputs):]] + outputs = [empty(shape=type.const_shape(), dtype=type.scalar_type.name, device='cuda', layout=type.layout) for type in output_types] + + self.imperative_run(dummy_inputs) + for t in range(warmup): + self.task_func(*dummy_inputs, *outputs) + cuda.device_synchronize() + results = [] + for i in range(repeat): + cuda.device_synchronize() + t1 = time() + for j in range(number): + self.task_func(*dummy_inputs, *outputs) + cuda.device_synchronize() + t2 = time() + results.append((t2 - t1) / number) + if median: + return float(np.median(results)) + return results + + +def space_level(level=0): + Operator._current_space_level = level + + +def get_space_level() -> int: + return Operator._current_space_level + + +def cache_operator(use_cache=True): + Operator._use_cache = use_cache diff --git a/python/hidet/tos/ops/__init__.py b/python/hidet/tos/ops/__init__.py new file mode 100644 index 0000000..cc1b3b3 --- /dev/null +++ b/python/hidet/tos/ops/__init__.py @@ -0,0 +1,17 @@ +from . import definitions + +from .definitions.conv2d import conv2d, conv2d_winograd, conv2d_gemm, conv2d_gemm_image_transform +from .definitions.matmul import matmul, parallel_k_batched_matmul +from .definitions.pool import max_pool2d, avg_pool2d +from .definitions.softmax import softmax +from .definitions.activation import relu, sigmoid, clip, relu6 +from .definitions.norm import batch_norm_infer, instance_norm, layer_norm +from .definitions.image import resize2d +from .definitions.arithmatic import add, sub, multiply, divide, neg, sqrt, rsqrt, sin, cos, pow, erf, tanh, equal, less, where, square +from .definitions.reduce import reduce_mean, reduce_sum, reduce_var +from .definitions.transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape, transpose, broadcast, pad, tile, split, conv_pad +from .definitions.special import barrier + +from .definitions import utils + +from . import schedules diff --git a/python/hidet/tos/ops/definitions/__init__.py b/python/hidet/tos/ops/definitions/__init__.py new file mode 100644 index 0000000..483e21f --- /dev/null +++ b/python/hidet/tos/ops/definitions/__init__.py @@ -0,0 +1,22 @@ +from .conv2d import conv2d, conv2d_winograd, conv2d_gemm +from .conv2d import conv2d_gemm_image_transform, conv2d_gemm_filter_transform, conv2d_gemm_inverse_transform +from .conv2d import conv2d_winograd_image_transform, conv2d_winograd_filter_transform, conv2d_winograd_inverse_transform + +from .matmul import matmul +from .pool import max_pool2d, avg_pool2d +from .softmax import softmax +from .activation import relu, sigmoid, relu6, clip +from .norm import batch_norm_infer, instance_norm +from .image import resize2d +from .arithmatic import add, sub, multiply, divide, neg, sqrt, rsqrt, equal, less, where +from .reduce import reduce_mean +from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, split +from .special import barrier + +from .matmul import MatmulOp +from .conv2d import Conv2dOp +from .arithmatic import ErfOp, PowOp, AddOp, SubOp, MultiplyOp, DivideOp, EqualOp, WhereOp +from .reduce import ReduceSumOp, ReduceMeanOp +from .transform import PadOp + +from . import utils diff --git a/python/hidet/tos/ops/definitions/activation.py b/python/hidet/tos/ops/definitions/activation.py new file mode 100644 index 0000000..6f8bbdc --- /dev/null +++ b/python/hidet/tos/ops/definitions/activation.py @@ -0,0 +1,54 @@ +from typing import Optional +import math +from hidet.ir import primitives as prim, convert +from hidet.ir.expr import const_like + +from .utils import Tensor +from .arithmatic import UnaryElementwiseOp, erf, tanh, cube + + +class ReluOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: prim.max(v, const_like(0.0, v)), name='relu') + + +class SigmoidOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: const_like(1.0, v) / (const_like(1.0, v) + prim.exp(-v)), name='sigmoid') + + +class ClipOp(UnaryElementwiseOp): + def __init__(self, x, min_val: Optional[float] = None, max_val: Optional[float] = None): + def op(v): + if min_val is not None: + v = prim.max(v, const_like(min_val, v)) + if max_val is not None: + v = prim.min(v, const_like(max_val, v)) + return v + + super().__init__(x, op=op, name='clip') + + +class GeluOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: const_like(0.5, v) * v * (const_like(1.0, v) + prim.erf(v * const_like(1 / math.sqrt(2), v))), name='gelu') + + +def relu(x) -> Tensor: + return ReluOp(x).get_output(0) + + +def sigmoid(x: Tensor) -> Tensor: + return SigmoidOp(x).get_output(0) + + +def clip(x: Tensor, min_val: Optional[float], max_val: Optional[float]) -> Tensor: + return ClipOp(x, min_val, max_val).get_output(0) + + +def relu6(x: Tensor) -> Tensor: + return clip(x, 0.0, 6.0) + + +def gelu(x: Tensor) -> Tensor: + return GeluOp(x).get_output(0) diff --git a/python/hidet/tos/ops/definitions/arithmatic.py b/python/hidet/tos/ops/definitions/arithmatic.py new file mode 100644 index 0000000..c937002 --- /dev/null +++ b/python/hidet/tos/ops/definitions/arithmatic.py @@ -0,0 +1,360 @@ +from typing import List, Callable, Any, Union, Type, Optional, Dict + +import operator +from hidet.ir import primitives +from hidet.ir import expr +from hidet.ir.expr import const_like +from hidet.utils import prod +from .utils import Task, Operator, Tensor, TensorNode, InverseMap, compute, input_like +from hidet.tos.tensor import convert + + +def broadcast_shape(x_shape: List[int], y_shape: List[int]) -> List[int]: + """ + Broadcast two shapes with the same rule as numpy. + Please refer to https://numpy.org/doc/stable/user/basics.broadcasting.html for details. + """ + orig_shapes = x_shape, y_shape + while len(x_shape) < len(y_shape): + x_shape = [1] + x_shape + while len(y_shape) < len(x_shape): + y_shape = [1] + y_shape + result_shape = [] + for p, q in zip(x_shape, y_shape): + if p != q and p != 1 and q != 1: + raise ValueError('can not broadcast two arrays with shape {} and {}'.format(orig_shapes[0], orig_shapes[1])) + result_shape.append(max(p, q)) + return result_shape + + +class UnaryElementwiseTask(Task): + def __init__(self, name: str, x: TensorNode, op: Callable[[Any], Any]): + shape = x.const_shape() + y = compute( + name='y', + shape=shape, + fcompute=lambda *indices: op(x.__getitem__(indices)), + scope='global' + ) + super().__init__( + name=name, + inputs=[x], + outputs=[y], + inverse_map={ + x: InverseMap.from_lambda(lambda *indices: list(indices), num_args=len(x.data_type.shape)) + } + ) + + +def broadcast_indices(indices, shape, out_shape): + # used to support broadcast + pad_dim = len(out_shape) - len(shape) + indices = list(indices[pad_dim:]) + for idx, dim in enumerate(shape): + if int(dim) == 1: + indices[idx] = 0 + return indices + + +class BinaryElementwiseTask(Task): + def __init__(self, name: str, x: TensorNode, y: TensorNode, op: Callable[[Any, Any], Any]): + x_shape = x.const_shape() + y_shape = y.const_shape() + z_shape = broadcast_shape(x_shape, y_shape) + + z = compute( + name='z', + shape=z_shape, + fcompute=lambda *indices: op(x[broadcast_indices(indices, x_shape, z_shape)], y[broadcast_indices(indices, y_shape, z_shape)]), + scope='global' + ) + + super().__init__( + name=name, + inputs=[x, y], + outputs=[z], + inverse_map={v: InverseMap.identity(len(v_shape)) for v, v_shape + in zip([x, y], [x_shape, y_shape]) if prod(v_shape) == prod(z_shape)} + ) + + +class WhereTask(Task): + def __init__(self, cond: TensorNode, x: TensorNode, y: TensorNode): + cond_shape = cond.const_shape() + x_shape = x.const_shape() + y_shape = y.const_shape() + z_shape = broadcast_shape(cond_shape, broadcast_shape(x_shape, y_shape)) + + z = compute( + name='z', + shape=z_shape, + fcompute=lambda *indices: expr.if_then_else( + cond=cond[broadcast_indices(indices, cond_shape, z_shape)], + then_expr=x[broadcast_indices(indices, x_shape, z_shape)], + else_expr=y[broadcast_indices(indices, y_shape, z_shape)] + ) + ) + + super().__init__( + name='where', + inputs=[cond, x, y], + outputs=[z], + inverse_map={v: InverseMap.identity(len(v_shape)) for v, v_shape + in zip([cond, x, y], [cond_shape, x_shape, y_shape]) if prod(v_shape) == prod(z_shape)} + ) + + +class UnaryElementwiseOp(Operator): + def __init__(self, x: Tensor, op, name: str, attributes: Optional[Dict[str, Any]] = None): + super().__init__( + inputs=[x], + task=UnaryElementwiseTask(name, input_like(x, 'x'), op=op), + attributes=attributes + ) + + +class BinaryElementwiseOp(Operator): + def __init__(self, x: Tensor, y: Tensor, op, name: str): + super().__init__( + inputs=[x, y], + task=BinaryElementwiseTask(name, input_like(x, 'x'), input_like(y, 'y'), op=op) + ) + + +class AddScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: v + const_like(scalar, v), attributes={'scalar': scalar}, name='adds') + + +class SubScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: v - const_like(scalar, v), attributes={'scalar': scalar}, name='subs') + + +class RSubScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: const_like(scalar, v) - v, attributes={'scalar': scalar}, name='rsubs') + + +class MultiplyScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: v * const_like(scalar, v), attributes={'scalar': scalar}, name='muls') + + +class DivideScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: v / const_like(scalar, v), attributes={'scalar': scalar}, name='divs') + + +class RDivideScalarOp(UnaryElementwiseOp): + def __init__(self, x: Tensor, scalar: Union[float, int]): + super().__init__(x, op=lambda v: const_like(scalar, v) / v, attributes={'scalar': scalar}, name='rdivs') + + +class SqrtOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: primitives.sqrt(v), name='sqrt') + + +class ErfOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: primitives.erf(v), name='erf') + + +class TanhOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: primitives.tanh(v), name='erf') + + +class RsqrtOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: primitives.rsqrt(v), name='rsqrt') + + +class PowOp(BinaryElementwiseOp): + def __init__(self, x, y): + super().__init__(x, y, op=lambda x, y: primitives.pow(x, y), name='pow') + + +class NegOp(UnaryElementwiseOp): + def __init__(self, x): + super().__init__(x, op=lambda v: -v, name='neg') + + +class AddOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, op=lambda a, b: a + b, name='add') + + +class SubOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, op=lambda a, b: a - b, name='sub') + + +class MultiplyOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, op=lambda a, b: a * b, name='mul') + + +class DivideOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, op=lambda a, b: a / b, name='div') + + +class SinOp(UnaryElementwiseOp): + def __init__(self, x: Tensor): + super().__init__(x, op=lambda a: primitives.sin(a), name='sin') + + +class CosOp(UnaryElementwiseOp): + def __init__(self, x: Tensor): + super().__init__(x, op=lambda a: primitives.cos(a), name='cos') + + +class SquareOp(UnaryElementwiseOp): + def __init__(self, x: Tensor): + super().__init__(x, op=lambda a: a * a, name='square') + + +class CubeOp(UnaryElementwiseOp): + def __init__(self, x: Tensor): + super().__init__(x, op=lambda a: a * a * a, name='cube') + + +class EqualOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, lambda a, b: expr.Equal(a, b), name='equal') + + +class LessOp(BinaryElementwiseOp): + def __init__(self, x: Tensor, y: Tensor): + super().__init__(x, y, lambda a, b: a < b, name='less') + + +class WhereOp(Operator): + def __init__(self, cond: Tensor, x: Tensor, y: Tensor): + super().__init__( + inputs=[cond, x, y], + task=WhereTask(input_like(cond, 'cond'), input_like(x, 'x'), input_like(y, 'y')), + name='where' + ) + + +PythonScalar = Union[float, int] + + +def binary_arithmatic( + x: Union[Tensor, float, int], + y: Union[Tensor, float, int], + tensor_scalar_op, + scalar_tensor_op, + tensor_tensor_op +) -> Union[Tensor, float, int]: + if not (isinstance(x, (Tensor, float, int)) and isinstance(y, (Tensor, float, int))): + raise ValueError('Only support add/sub/mul/div between hidet.Tensor, float, and int. got {} and {}'.format(type(x), type(y))) + if isinstance(x, (float, int)): + x = convert(x) + if isinstance(y, (float, int)): + y = convert(y) + x_scalar = len(x.shape) == 0 and x.storage is not None + y_scalar = len(y.shape) == 0 and y.storage is not None + if x_scalar and y_scalar: + return tensor_tensor_op(x, y) + elif y_scalar: + return tensor_scalar_op(x, y.scalar()) + elif x_scalar: + return scalar_tensor_op(x.scalar(), y) + else: + return tensor_tensor_op(x, y) + + +def add(x: Union[Tensor, float, int], y: Union[Tensor, float, int]) -> Tensor: + return binary_arithmatic( + x, y, + lambda a, b: AddScalarOp(a, b).get_output(0), + lambda a, b: AddScalarOp(b, a).get_output(0), + lambda a, b: AddOp(a, b).get_output(0) + ) + + +def sub(x: Union[Tensor, float, int], y: Union[Tensor, float, int]) -> Tensor: + return binary_arithmatic( + x, y, + lambda a, b: SubScalarOp(a, b).get_output(0), + lambda a, b: RSubScalarOp(b, a).get_output(0), + lambda a, b: SubOp(a, b).get_output(0) + ) + + +def multiply(x: Union[Tensor, float, int], y: Union[Tensor, float, int]) -> Tensor: + return binary_arithmatic( + x, y, + lambda a, b: MultiplyScalarOp(a, b).get_output(0), + lambda a, b: MultiplyScalarOp(b, a).get_output(0), + lambda a, b: MultiplyOp(a, b).get_output(0) + ) + + +def divide(x: Union[Tensor, float, int], y: Union[Tensor, float, int]) -> Tensor: + return binary_arithmatic( + x, y, + lambda a, b: DivideScalarOp(a, b).get_output(0), + lambda a, b: RDivideScalarOp(b, a).get_output(0), + lambda a, b: DivideOp(a, b).get_output(0) + ) + + +def sqrt(x: Tensor) -> Tensor: + return SqrtOp(x).get_output(0) + + +def tanh(x: Tensor) -> Tensor: + return TanhOp(x).get_output(0) + + +def pow(x: Tensor, y: Tensor) -> Tensor: + return PowOp(x, y).get_output(0) + + +def erf(x: Tensor) -> Tensor: + return ErfOp(x).get_output(0) + + +def rsqrt(x: Tensor) -> Tensor: + return RsqrtOp(x).get_output(0) + + +def neg(x: Tensor) -> Tensor: + return NegOp(x).get_output(0) + + +def sin(x: Tensor) -> Tensor: + return SinOp(x).get_output(0) + + +def cos(x: Tensor) -> Tensor: + return CosOp(x).get_output(0) + + +def square(x: Tensor) -> Tensor: + return SquareOp(x).get_output(0) + + +def cube(x: Tensor) -> Tensor: + return CubeOp(x).get_output(0) + + +def equal(x: Tensor, y: Tensor) -> Tensor: + if x.dtype != y.dtype: + raise ValueError('Can only compare tensors with the same dtype, but got {} and {}'.format(x.dtype, y.dtype)) + return EqualOp(x, y).get_output(0) + + +def less(x: Tensor, y: Tensor) -> Tensor: + return LessOp(x, y).get_output(0) + + +def where(cond: Tensor, x: Tensor, y: Tensor) -> Tensor: + if cond.dtype != 'bool': + raise ValueError('The condition tensor must have dtype "bool", but got {}'.format(cond.dtype)) + return WhereOp(cond, x, y).get_output(0) diff --git a/python/hidet/tos/ops/definitions/conv2d/__init__.py b/python/hidet/tos/ops/definitions/conv2d/__init__.py new file mode 100644 index 0000000..737923a --- /dev/null +++ b/python/hidet/tos/ops/definitions/conv2d/__init__.py @@ -0,0 +1,7 @@ +from .conv2d import conv2d +from .conv2d import Conv2dOp +from .conv2d_winograd import conv2d_winograd, conv2d_winograd_image_transform, conv2d_winograd_filter_transform, conv2d_winograd_inverse_transform +from .conv2d_winograd import Conv2dWinogradInverseTransformOp, Conv2dWinogradFilterTransformOp, Conv2dWinogradImageTransformOp +from .conv2d_gemm import conv2d_gemm, conv2d_gemm_image_transform, conv2d_gemm_filter_transform, conv2d_gemm_inverse_transform +from .conv2d_gemm import Conv2dGemmImageTransformOp + diff --git a/python/hidet/tos/ops/definitions/conv2d/conv2d.py b/python/hidet/tos/ops/definitions/conv2d/conv2d.py new file mode 100644 index 0000000..c8a277c --- /dev/null +++ b/python/hidet/tos/ops/definitions/conv2d/conv2d.py @@ -0,0 +1,48 @@ +from typing import List, Union +from hidet.tos.ops.definitions.utils import Task, Operator, Tensor, compute, input_like, TensorNode, normalize_kernel, normalize_stride, normalize_padding, reduce + + +class Conv2dTask(Task): + def __init__(self, data: TensorNode, weight: TensorNode, stride: List[int], groups: int): + n, c, h, w = data.const_shape() + oc, wc, kx, ky = weight.const_shape() + sx, sy = stride + p, q = (h - kx) // sx + 1, (w - ky) // sy + 1 + if c % groups != 0 or oc % groups != 0: + raise ValueError('Conv2d expect the in_channels % groups == 0 and out_channels % groups == 0, \n' + 'but got in_channels, out_channels, groups: {}, {}, {}'.format(c, oc, groups)) + if wc * groups != c: + raise ValueError('Conv2d expect the weight has shape [out_channels, in_channels / groups, kx, ky], \n' + 'but got weight shape {}, in_channels {} and groups {}'.format([oc, wc, kx, ky], c, groups)) + out_group_size = oc // groups + output = compute( + name='out', + shape=[n, oc, p, q], + fcompute=lambda ni, oci, pi, qi: reduce( + shape=[wc, kx, ky], + fcompute=lambda wci, kxi, kyi: data[ni, (oci // out_group_size) * wc + wci, pi * sx + kxi, qi * sy + kyi] * weight[oci, wci, kxi, kyi], + reduce_type='sum' + ) + ) + super().__init__( + name='conv2d', + inputs=[data, weight], + outputs=[output], + ) + + +class Conv2dOp(Operator): + def __init__(self, x: Tensor, w: Tensor, stride: List[int], groups: int): + stride = normalize_stride(stride) + super().__init__( + inputs=[x, w], + task=Conv2dTask(input_like(x, 'x'), input_like(w, 'w'), stride, groups), + attributes={ + 'stride': stride, + 'groups': groups + } + ) + + +def conv2d(data: Tensor, weight: Tensor, stride: Union[int, List[int]], groups: int = 1) -> Tensor: + return Conv2dOp(data, weight, stride, groups).get_output(0) diff --git a/python/hidet/tos/ops/definitions/conv2d/conv2d_gemm.py b/python/hidet/tos/ops/definitions/conv2d/conv2d_gemm.py new file mode 100644 index 0000000..cd935c8 --- /dev/null +++ b/python/hidet/tos/ops/definitions/conv2d/conv2d_gemm.py @@ -0,0 +1,149 @@ +from typing import List + +from hidet.tos.ops.definitions.matmul.matmul import matmul +from hidet.tos.ops.definitions.utils import Task, Operator, Tensor, compute, input_like, TensorNode +from hidet.tos.ops.definitions.utils import normalize_kernel, normalize_stride +from .utils import infer_conv2d_shape + + +class Conv2dGemmImageTransformTask(Task): + def __init__(self, x: TensorNode, kernel: List[int], stride: List[int], groups: int): + n, c, h, w = x.const_shape() + kx, ky = kernel + sx, sy = stride + p, q = (h - kx) // sx + 1, (w - ky) // sy + 1 + if c % groups != 0: + raise ValueError('Conv2d expect in_channels % groups == 0, but got in_channels {} and groups {}'.format(c, groups)) + gc = c // groups # group channels + gemm_x = compute( + name='gemm_x', + shape=[groups, n * p * q, gc * kx * ky], + fcompute=lambda g, i, k: x[i // (p * q), g * gc + k // (kx * ky), i // q % p * sx + k // ky % kx, i % q * sy + k % ky], + scope=x.data_type.scope + ) + super().__init__( + name='conv2d_gemm_image_transform', + inputs=[x], + outputs=[gemm_x], + ) + + +# class Conv2dGemmFilterTransformTask(Task): +# def __init__(self, w: TensorNode, groups: int): +# oc, c, kx, ky = w.const_shape() +# if oc % groups != 0 or c % groups != 0: +# raise ValueError('Conv2d expects in_channels % groups == 0, out_channels % groups == 0, got {}, {}, {}'.format(c, oc, groups)) +# ogc = oc // groups # out group channels +# gemm_w = compute( +# name='gemm_w', +# shape=[groups, c * kx * ky, ogc], +# fcompute=lambda g, k, j: w[g * ogc + j, k // (kx * ky), k // ky % kx, k % ky], +# scope=w.data_type.scope +# ) +# super().__init__( +# name='conv2d_gemm_filter_transform', +# inputs=[w], +# outputs=[gemm_w] +# ) +# +# +# class Conv2dGemmInverseTransformTask(Task): +# def __init__(self, gemm_y: TensorNode, out_shape: List[int]): +# n, oc, p, q = out_shape +# y_shape = gemm_y.const_shape() # [groups, n * p * q, ogc] +# groups = y_shape[0] +# +# assert y_shape[-1] * y_shape[0] == oc +# ogc = oc // groups +# +# if tuple(y_shape) != (groups, n * p * q, ogc): +# raise ValueError('Conv2d gemm inverse transform expect input with shape {}, got {}'.format( +# (groups, n * p * q, ogc), gemm_y.const_shape())) +# +# y = compute( +# name='y', +# shape=[n, oc, p, q], +# fcompute=lambda i, j, r, s: gemm_y[j // ogc, i * (p * q) + r * q + s, j % ogc], +# scope=gemm_y.data_type.scope +# ) +# super().__init__( +# name='conv2d_gemm_inverse_transform', +# inputs=[gemm_y], +# outputs=[y], +# inverse_map={ +# gemm_y: lambda i, j: [i // (p * q), j, i // q % p, i % q] +# } +# ) + + +class Conv2dGemmImageTransformOp(Operator): + def __init__(self, x: Tensor, kernel, stride, groups): + kernel = normalize_kernel(kernel) + stride = normalize_stride(stride) + super().__init__( + inputs=[x], + task=Conv2dGemmImageTransformTask(input_like(x, 'x'), kernel, stride, groups), + attributes={ + 'kernel': kernel, + 'stride': stride + } + + ) + + +# +# class Conv2dGemmFilterTransformOp(Operator): +# def __init__(self, w: Tensor, groups): +# super().__init__( +# inputs=[w], +# task=Conv2dGemmFilterTransformTask(input_like(w, 'w'), groups) +# ) +# +# +# class Conv2dGemmInverseTransformOp(Operator): +# def __init__(self, gemm_y: Tensor, out_shape: List[int]): +# if len(out_shape) != 4: +# raise ValueError('Output shape expect with length 4, got {}'.format(out_shape)) +# super().__init__( +# inputs=[gemm_y], +# task=Conv2dGemmInverseTransformTask(input_like(gemm_y, 'gemm_y'), out_shape) +# ) +# + + +def conv2d_gemm_image_transform(x: Tensor, kernel: List[int], stride: List[int], groups: int = 1) -> Tensor: + return Conv2dGemmImageTransformOp(x, kernel, stride, groups).get_output(0) + + +def conv2d_gemm_filter_transform(w: Tensor, groups: int = 1) -> Tensor: + # weight shape: [oc, c, kx, ky] + # output shape: [groups, c * kx * ky, ogc] where ogc = oc // groups + oc, c, kx, ky = w.shape + if oc % groups != 0: + raise ValueError('invalid conv2d groups {} for out channels {}'.format(groups, oc)) + ogc = oc // groups + w = w.reshape([groups, ogc, c, kx, ky]) # [groups, ogc, c, kx, ky] + w = w.rearrange([[0], [2, 3, 4], [1]]) # [groups, c * kx * ky, ogc] + return w + + +def conv2d_gemm_inverse_transform(gemm_y: Tensor, out_height, out_width) -> Tensor: + # gemm_y shape: [groups, n * p * q, ogc] + # output shape: [n, oc, p, q] where oc = groups * ogc + p, q = out_height, out_width + groups, npq, ogc = gemm_y.shape + assert npq % (p * q) == 0 + n = npq // (p * q) + y = gemm_y.reshape([groups, n, p, q, ogc]) + y = y.rearrange([[1], [0, 4], [2], [3]]) + return y + + +def conv2d_gemm(data: Tensor, weight: Tensor, stride, groups: int = 1) -> Tensor: + gemm_x = conv2d_gemm_image_transform(data, kernel=weight.shape[2:], stride=stride, groups=groups) + gemm_w = conv2d_gemm_filter_transform(weight, groups=groups) + gemm_y = matmul(gemm_x, gemm_w) + + y_shape = infer_conv2d_shape(data.shape, weight.shape, stride, groups) + y = conv2d_gemm_inverse_transform(gemm_y, out_height=y_shape[2], out_width=y_shape[3]) + return y diff --git a/python/hidet/tos/ops/definitions/conv2d/conv2d_winograd.py b/python/hidet/tos/ops/definitions/conv2d/conv2d_winograd.py new file mode 100644 index 0000000..edf7e51 --- /dev/null +++ b/python/hidet/tos/ops/definitions/conv2d/conv2d_winograd.py @@ -0,0 +1,216 @@ +from functools import lru_cache +from typing import List, Tuple + +import numpy as np + +from hidet.ir.expr import const_tensor, Constant, cast +from hidet.tos.ops.definitions.matmul.matmul import matmul +from hidet.tos.ops.definitions.transform import flatten, reshape +from hidet.tos.ops.definitions.utils import Tensor, Operator, Task, TensorNode, input_like, compute, reduce, normalize_kernel + +""" +Winograd convolution, see https://arxiv.org/pdf/1509.09308.pdf +""" + + +@lru_cache(maxsize=32) +def winograd_transform_matrices(m: int, r: int) -> Tuple[Constant, Constant, Constant]: + if m == 2 and r == 3: + G = np.array( + [[1, 0, 0], + [1 / 2, 1 / 2, 1 / 2], + [1 / 2, -1 / 2, 1 / 2], + [0, 0, 1]] + ).astype(np.float32) + BT = np.array( + [[1, 0, -1, 0], + [0, 1, 1, 0], + [0, -1, 1, 0], + [0, 1, 0, -1]] + ).astype(np.float32) + AT = np.array( + [[1, 1, 1, 0], + [0, 1, -1, -1]] + ).astype(np.float32) + return const_tensor(G), const_tensor(BT), const_tensor(AT) + else: + raise NotImplementedError('winograd transform matrices: m = {}, r = {}'.format(m, r)) + + +class Conv2dWinogradImageTransformTask(Task): + def __init__(self, x: TensorNode, kernel: List[int], ms: List[int]): + assert len(kernel) == 2 and len(x.const_shape()) == 4 + n, c, h, w = x.const_shape() + rx, ry = kernel + mx, my = ms # output size per tile + oh, ow = h - rx + 1, w - ry + 1 # output size of image + nh, nw = (oh + mx - 1) // mx, (ow + my - 1) // my # number of tiles on each image dimension + p = n * nh * nw # number of tiles per channel + alpha_x, alpha_y = mx + rx - 1, my + ry - 1 + tile = compute( + name='tile', + shape=[c, p, alpha_x, alpha_y], + fcompute=lambda cc, pp, ax, ay: x[pp // (nh * nw), cc, (pp // nw) % nh * mx + ax, pp % nw * my + ay] + ) + BH = winograd_transform_matrices(mx, rx)[1] + BW = winograd_transform_matrices(my, ry)[1] + dtype = x.data_type.scalar_type + y = compute( + name='y', + shape=[alpha_x, alpha_y, c, p], + fcompute=lambda ax, ay, cc, pp: reduce( + shape=[alpha_x, alpha_y], + fcompute=lambda kx, ky: cast(BH[ax, kx], dtype) * tile[cc, pp, kx, ky] * cast(BW[ay, ky], dtype), + reduce_type='sum' + ) + ) + super().__init__( + name='conv2d_winograd_image_transform', + inputs=[x], + outputs=[y] + ) + + +class Conv2dWinogradFilterTransformTask(Task): + def __init__(self, w: TensorNode, ms: List[int]): + assert len(w.const_shape()) == 4 + oc, c, rx, ry = w.const_shape() + mx, my = ms + alpha_x, alpha_y = mx + rx - 1, my + ry - 1 + GH = winograd_transform_matrices(mx, rx)[0] + GW = winograd_transform_matrices(my, ry)[0] + dtype = w.data_type.scalar_type + y = compute( + name='y', + shape=[alpha_x, alpha_y, oc, c], + fcompute=lambda ax, ay, occ, cc: reduce( + shape=[rx, ry], + fcompute=lambda kx, ky: cast(GH[ax, kx], dtype) * w[occ, cc, kx, ky] * cast(GW[ay, ky], dtype), + reduce_type='sum' + ) + ) + super().__init__( + name='conv2d_winograd_filter_transform', + inputs=[w], + outputs=[y] + ) + + +class Conv2dWinogradInverseTransformTask(Task): + def __init__(self, y: TensorNode, input_shape, kernel, ms): + assert len(y.const_shape()) == 4 + alpha_x, alpha_y, oc, p = y.const_shape() + n, c, h, w = input_shape + rx, ry = kernel + mx, my = ms + oh, ow = h - rx + 1, w - ry + 1 # output size of image + nh, nw = (oh + mx - 1) // mx, (ow + my - 1) // my # number of tiles on each image dimension + AH = winograd_transform_matrices(mx, rx)[2] + AW = winograd_transform_matrices(my, ry)[2] + dtype = y.data_type.scalar_type + inverse = compute( + name='inverse', + shape=[mx, my, oc, p], + fcompute=lambda mxx, myy, occ, pp: reduce( + shape=[alpha_x, alpha_y], + fcompute=lambda kx, ky: cast(AH[mxx, kx], dtype) * y[kx, ky, occ, pp] * cast(AW[myy, ky], dtype), + reduce_type='sum' + ) + ) + output = compute( + name='output', + shape=[n, oc, oh, ow], + fcompute=lambda nn, occ, ohh, oww: inverse[ohh % mx, oww % my, occ, nn * (nh * nw) + (ohh // mx) * nw + (oww // my)], + ) + super().__init__( + name='conv2d_winograd_inverse_transform', + inputs=[y], + outputs=[output] + ) + + +class Conv2dWinogradImageTransformOp(Operator): + def __init__(self, x: Tensor, kernel, ms): + if len(x.shape) != 4: + raise NotImplementedError('Current only support winograd conv2d') + kernel = normalize_kernel(kernel, dim=2) + assert len(ms) == 2 + super().__init__( + inputs=[x], + task=Conv2dWinogradImageTransformTask(input_like(x, 'x'), kernel, ms), + attributes={ + 'kernel': kernel, + 'ms': ms + } + + ) + + +class Conv2dWinogradFilterTransformOp(Operator): + def __init__(self, w: Tensor, ms): + assert len(ms) == 2 + super().__init__( + inputs=[w], + task=Conv2dWinogradFilterTransformTask(input_like(w, 'w'), ms), + attributes={ + 'ms': ms + } + + ) + + +class Conv2dWinogradInverseTransformOp(Operator): + def __init__(self, y: Tensor, input_shape, kernel, ms): + kernel = normalize_kernel(kernel, dim=2) + super().__init__( + inputs=[y], + task=Conv2dWinogradInverseTransformTask(input_like(y, 'y'), input_shape, kernel, ms), + attributes={ + 'input_shape': input_shape, + 'kernel': kernel, + 'ms': ms + } + + ) + + +def conv2d_winograd_image_transform(x: Tensor, kernel, ms) -> Tensor: + return Conv2dWinogradImageTransformOp(x, kernel, ms).get_output(0) + + +def conv2d_winograd_filter_transform(w: Tensor, ms) -> Tensor: + return Conv2dWinogradFilterTransformOp(w, ms).get_output(0) + + +def conv2d_winograd_inverse_transform(y: Tensor, input_shape, kernel, ms) -> Tensor: + return Conv2dWinogradInverseTransformOp(y, input_shape, kernel, ms).get_output(0) + + +def conv2d_winograd(x: Tensor, w: Tensor) -> Tensor: + assert len(x.shape) == 4 and len(w.shape) == 4 and x.shape[1] == w.shape[1] + r2m = { + 1: 1, + 3: 2 + } + for k in w.shape[2:]: + if k not in r2m: + raise NotImplementedError('Winograd convolution for kernel size {} has not been supported yet.'.format(k)) + + input_shape = x.shape + kernel = w.shape[2:] + ms = [r2m[r] for r in kernel] + alpha = [r + m - 1 for r, m in zip(kernel, ms)] + + # winograd transform + x = conv2d_winograd_image_transform(x, kernel, ms) # [alpha_x, alpha_y, ci, p] + w = conv2d_winograd_filter_transform(w, ms) # [alpha_x, alpha_y, co, ci] + + # product + x = flatten(x, start_dim=0, end_dim=2) # [alpha_x * alpha_y, ci, p] + w = flatten(w, start_dim=0, end_dim=2) # [alpha_x * alpha_y, co, ci] + y = matmul(w, x) # [alpha_x * alpha_y, co, p] + y = reshape(y, [alpha[0], alpha[1], y.shape[1], y.shape[2]]) # [alpha_x, alpha_y, co, p] + + # winograd inverse transform + y = conv2d_winograd_inverse_transform(y, input_shape, kernel, ms) # [n, oc, oh, ow] + return y diff --git a/python/hidet/tos/ops/definitions/conv2d/utils.py b/python/hidet/tos/ops/definitions/conv2d/utils.py new file mode 100644 index 0000000..536edc3 --- /dev/null +++ b/python/hidet/tos/ops/definitions/conv2d/utils.py @@ -0,0 +1,16 @@ +from typing import List, Union +from ..utils import normalize_stride + + +def infer_conv2d_shape(x_shape: List[int], w_shape: List[int], strides: Union[int, List[int]], groups: int) -> List[int]: + n, c, h, w = x_shape + oc, gc, kx, ky = w_shape + sx, sy = normalize_stride(strides) + if gc * groups != c: + raise ValueError('Conv2d: x has {} input channels, w has {} group channels, and groups={}'.format(c, gc, groups)) + if oc % groups != 0: + raise ValueError('Conv2d expects out_channels % groups == 0, got out_channels {} and groups {}'.format(oc, groups)) + p, q = (h - kx) // sx + 1, (w - ky) // sy + 1 + return [n, oc, p, q] + + diff --git a/python/hidet/tos/ops/definitions/image.py b/python/hidet/tos/ops/definitions/image.py new file mode 100644 index 0000000..8e39797 --- /dev/null +++ b/python/hidet/tos/ops/definitions/image.py @@ -0,0 +1,146 @@ +from typing import Optional, List + +from hidet.ir.expr import Expr, if_then_else, convert, cast, And +from hidet.ir import primitives as prim +from .utils import Task, Operator, Tensor, TensorNode, compute, input_like + + +# Acknowledgement: take TVM resize topi implementation as a reference + + +def get_origin_index(x: Expr, image_width: int, target_width: int, coordinate_transformation_mode: str) -> Expr: + scale = image_width / target_width + func_map = { + 'half_pixel': + lambda x: (x + 0.5) * scale - 0.5, + 'align_corners': + lambda x: x * ((image_width - 1) / (target_width - 1)), + 'asymmetric': + lambda x: x * scale, + 'pytorch_half_pixel': + lambda x: (x + 0.5) * scale if target_width > 1 else convert(0.0), + 'tf_half_pixel_for_nn': + lambda x: (x + 0.5) * scale + } + if coordinate_transformation_mode not in func_map: + raise ValueError('Unsupported coordinate transformation mode: {}, candidates: {}.'.format( + coordinate_transformation_mode, func_map.keys() + )) + return func_map[coordinate_transformation_mode](x) + + +def get_closest_index(x: Expr, rounding_method: str) -> Expr: + func_map = { + 'rounding_method': + lambda x: cast(prim.round(x), 'int32'), + 'round_prefer_floor': + lambda x: cast(prim.ceil(x - 0.5), 'int32'), + 'round_prefer_ceil': + lambda x: cast(prim.floor(x + 0.5), 'int32'), + 'floor': + lambda x: cast(prim.floor(x + 1e-5), 'int32'), # add epsilon (1e-5) to prevent gpu rounding error + 'ceil': + lambda x: cast(prim.ceil(x - 1e-5), 'int32') # sub epsilon (1e-5) to prevent gpu rounding error + } + if rounding_method not in func_map: + raise ValueError('Unsupported rounding_method: {}, candidates: {}'.format(rounding_method, func_map.keys())) + return func_map[rounding_method](x) + + +def get_2d_pixel(data: TensorNode, n, c, h, w) -> Expr: + height, width = data.const_shape()[2:] + h = prim.max(0, prim.min(height, h)) + w = prim.max(0, prim.min(width, w)) + return data[n, c, h, w] + + +def linear_interpolate(a, b, ratio): + return a * (1.0 - ratio) + b * ratio + + +def resize2d_nchw_compute(data: TensorNode, size: List[int], method: str, coordinate_transformation_mode, rounding_method, + roi, cubic_alpha, cubic_exclude, extrapolation_value): + image_size = data.const_shape()[2:] + target_size = size + + def fmap(n, c, h, w): + h = get_origin_index(h, image_size[0], target_size[0], coordinate_transformation_mode) + w = get_origin_index(w, image_size[1], target_size[1], coordinate_transformation_mode) + if method == 'nearest': + h = get_closest_index(h, rounding_method) + w = get_closest_index(w, rounding_method) + value = get_2d_pixel(data, n, c, h, w) + elif method == 'linear': + h_int = cast(prim.floor(h), 'int32') + w_int = cast(prim.floor(w), 'int32') + h_ratio = h - h_int + w_ratio = w - w_int + pixels = [[get_2d_pixel(data, n, c, h_int + i, w_int + j) for j in range(2)] for i in range(2)] + top = linear_interpolate(*pixels[0], w_ratio) + bottom = linear_interpolate(*pixels[1], w_ratio) + value = linear_interpolate(top, bottom, h_ratio) + elif method == 'cubic': + raise NotImplementedError(method) + else: + raise ValueError('Unsupported scaling method: {}, candidates: {}'.format( + method, ['nearest', 'linear', 'cubic'] + )) + if coordinate_transformation_mode == 'tf_half_pixel_for_nn': + value = if_then_else(And.join(0 <= h, h < image_size[0], 0 <= w, w < image_size[1]), value, extrapolation_value) + return value + + output_shape = data.const_shape()[:2] + list(target_size) + out = compute( + 'out', + shape=output_shape, + fcompute=fmap, + scope=data.data_type.scope + ) + return out + + +class Resize2dTask(Task): + def __init__(self, data: TensorNode, size: List[int], method: str, coordinate_transformation_mode, rounding_method, + roi, cubic_alpha, cubic_exclude, extrapolation_value): + out = resize2d_nchw_compute(data, size, method, coordinate_transformation_mode, rounding_method, roi, cubic_alpha, cubic_exclude, extrapolation_value) + super().__init__( + name='resize2d', + inputs=[data], + outputs=[out] + ) + + +class Resize2dOp(Operator): + supported_methods = ['nearest', 'linear', 'cubic'] + supported_coord_trans_mode = ['half_pixel', 'align_corners', 'asymmetric', 'pytorch_half_pixel', 'tf_half_pixel_for_nn', 'tf_crop_and_resize'] + supported_rounding_methods = ['round', 'floor', 'ceil'] + + def __init__(self, data, size: List[int], method: str, coordinate_transformation_mode: str, rounding_method: str, + roi: Optional, cubic_alpha: Optional, cubic_exclude: Optional, extrapolation_value: Optional): + if method not in self.supported_methods: + raise ValueError("Resize only support methods: {}, but got {}.".format(self.supported_methods, method)) + if coordinate_transformation_mode not in self.supported_coord_trans_mode: + raise ValueError("Resize only support coordinate transformation modes: {}, but got {}.".format( + self.supported_coord_trans_mode, coordinate_transformation_mode)) + if method == 'nearest' and rounding_method not in self.supported_rounding_methods: + raise ValueError("Resize only support rounding methods: {}, but got {}.".format(self.supported_rounding_methods, rounding_method)) + if len(size) != 2: + raise ValueError('Resize2d expect size has 2 elements (height, width), got {}'.format(size)) + + super().__init__( + inputs=[data], + task=Resize2dTask(input_like(data, 'data'), size, method, coordinate_transformation_mode, rounding_method, roi, cubic_alpha, cubic_exclude, extrapolation_value), + attributes={ + 'method': method, + 'coordinate_transformation_mode': coordinate_transformation_mode, + 'rounding_method': rounding_method, + 'roi': roi, + 'cubic_alpha': cubic_alpha, + 'cubic_exclude': cubic_exclude, + 'extrapolation_value': extrapolation_value + } + ) + + +def resize2d(data: Tensor, size: List[int], method: str, coordinate_transformation_mode: str, rounding_method: str, roi: Optional, cubic_alpha: Optional, cubic_exclude: Optional, extrapolation_value: Optional) -> Tensor: + return Resize2dOp(data, size, method, coordinate_transformation_mode, rounding_method, roi, cubic_alpha, cubic_exclude, extrapolation_value).get_output(0) diff --git a/python/hidet/tos/ops/definitions/matmul/__init__.py b/python/hidet/tos/ops/definitions/matmul/__init__.py new file mode 100644 index 0000000..9d9832d --- /dev/null +++ b/python/hidet/tos/ops/definitions/matmul/__init__.py @@ -0,0 +1,4 @@ +from .matmul import matmul +from .matmul import MatmulOp +from .parallel_k_matmul import parallel_k_batched_matmul + diff --git a/python/hidet/tos/ops/definitions/matmul/matmul.py b/python/hidet/tos/ops/definitions/matmul/matmul.py new file mode 100644 index 0000000..e1bf5a0 --- /dev/null +++ b/python/hidet/tos/ops/definitions/matmul/matmul.py @@ -0,0 +1,170 @@ +from hidet.ir.func import IRModule +from hidet.tos.ops.definitions.utils import Task, Operator, Tensor, TensorNode, compute, reduce, input_like +from hidet.tos.ops.definitions.arithmatic import broadcast_shape +from hidet.tos.ops.definitions.transform import broadcast, unsqueeze, transpose +from hidet.ffi import cuda + + +class MatmulTask(Task): + def __init__(self, a: TensorNode, b: TensorNode, mma: str = 'simt', ta=False, tb=False, tc=False): + batch_size, m_size, k_size = a.const_shape() + batch_size, k_size, n_size = b.const_shape() + self.batch_size: int = batch_size + self.m_size: int = m_size + self.k_size: int = k_size + self.n_size: int = n_size + self.mma = mma + c = compute( + name='c', + shape=[batch_size, m_size, n_size], + fcompute=lambda r, i, j: reduce( + shape=[k_size], + fcompute=lambda k: a[r, i, k] * b[r, k, j], + reduce_type='sum' + ), + scope='global' + ) + super().__init__( + name='matmul', + inputs=[a, b], + outputs=[c], + attributes={ + 'batch_size': batch_size, + 'm_size': m_size, + 'n_size': n_size, + 'k_size': k_size, + 'mma': mma, + 'ta': ta, + 'tb': tb, + 'tc': tc + } + ) + + def implement_cuda(self) -> IRModule: + from hidet.tos.ops.schedules.cuda.matmul import batched_matmul_cuda_schedule_default, batched_matmul_cuda_schedule_wmma + if self.mma == 'simt' or self.mma == 'default': + return batched_matmul_cuda_schedule_default(self) + elif self.mma.startswith('wmma'): + return batched_matmul_cuda_schedule_wmma(self) + else: + raise ValueError('Can not recognize mma type {}, candidates: {}'.format(self.mma, ['simt', 'wmma'])) + + def fast_implement(self, space_level: int) -> bool: + return space_level == 0 + + +class MatmulOp(Operator): + def __init__(self, a: Tensor, b: Tensor, algo, mma: str = 'simt', ta=False, tb=False, tc=False): + if not (len(a.shape) == len(b.shape) == 3 and a.shape[0] == b.shape[0] and a.shape[2] == b.shape[1]): + raise ValueError('Matrix multiplication expect tensor A and B with shape [B, M, K] and [B, K, N]' + + ', got {} and {}'.format(a.shape, b.shape)) + task = MatmulTask(input_like(a, 'a'), input_like(b, 'b'), mma, ta, tb, tc) + super().__init__( + inputs=[a, b], + task=task, + attributes={ + 'algo': algo, + 'mma': mma, + 'ta': ta, + 'tb': tb, + 'tc': tc + } + ) + + +def matmul(a: Tensor, b: Tensor, algo: str = 'default', mma: str = 'default', ta=False, tb=False, tc=False) -> Tensor: + """ + Batched matrix multiplication. + + Parameters + ---------- + a: Tensor + The lhs operand with shape [batch_size, m_size, k_size]. + + b: Tensor + The rhs operand with shape [batch_size, k_size, n_size]. + + algo: str + The algorithm to use. There are two algorithms: + - 'direct': + Direct matrix multiplication. + - 'parallel_k': + Matrix multiplication also parallel on k dimension. + - 'default': + Choose one of above algorithms automatically. + + mma: str + The matrix-multiplication-accumulate (mma) in warp level: + - 'simt': + Use cuda core to do the warp-level mma (simt stands for single-instruction-multiple-threads). + - 'wmma_f16_f16', 'wmma_f16_f32', 'wmma_bf16_f32', 'wmma_tf32_f32': + Use warp level matrix multiplication accumulate instruction. Tensor core is used in these instructions. + Here 'wmma_ta_tb' indicates the matrix a and b will be converted into data type ta, do the computation + and accumulated with data type tb in the underlying kernel. + - 'wmma': + Choose one of wmma instruction automatically. + - 'default': + Choose in 'simt', 'wmma' automatically. + + See also: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions + + ta: bool + Whether to transpose matrix A. + + tb: bool + Whether to transpose matrix B. + + tc: bool + Whether to transpose matrix C. + + Returns + ------- + c: Tensor + The result tensor of matrix multiplication. + """ + if len(a.shape) < 2 or len(b.shape) < 2: + raise ValueError('Current only support matrix multiplication with two matrices whose rank >= 2.') + + if len(a.shape) == 2 and len(b.shape) == 2: + if ta: + aa = a.transpose([-1, -2]).barrier().transpose([-1, -2]) + else: + aa = a + if tb: + bb = b.transpose([-1, -2]).barrier().transpose([-1, -2]) + else: + bb = b + aa = aa.unsqueeze(0) + bb = bb.unsqueeze(0) + cc = batched_matmul(aa, bb, algo, mma, ta, tb, tc) + cc = cc.squeeze(0) + if tc: + cc = cc.transpose([-1, -2]).barrier().transpose([-1, -2]) + return cc + else: + if ta: + aa = a.transpose([-1, -2]).barrier().transpose([-1, -2]) + else: + aa = a + if tb: + bb = b.transpose([-1, -2]).barrier().transpose([-1, -2]) + else: + bb = b + stack_shape = broadcast_shape(aa.shape[:-2], bb.shape[:-2]) + aa = broadcast(aa, shape=stack_shape + a.shape[-2:]).flatten(end_dim=-2) + bb = broadcast(bb, shape=stack_shape + b.shape[-2:]).flatten(end_dim=-2) + cc = batched_matmul(aa, bb, algo, mma, ta, tb, tc) + if tc: + cc = cc.transpose([-1, -2]).barrier().transpose([-1, -2]) + return cc + + +def batched_matmul(a: Tensor, b: Tensor, algo: str = 'default', mma: str = 'default', ta=True, tb=False, tc=False) -> Tensor: + mma_candidates = ['default', 'simt', 'wmma', 'wmma_f16_f16', 'wmma_f16_f32', 'wmma_bf16_f32', 'wmma_tf32_f32'] + algo_candidates = ['default', 'direct', 'parallel_k'] + if mma not in mma_candidates: + raise ValueError('Can not recognize mma {}, candidates: {}'.format(mma, mma_candidates)) + if algo not in algo_candidates: + raise ValueError('Can not recognize algorithm {}, candidates: {}'.format(algo, algo_candidates)) + return MatmulOp(a, b, algo, mma, ta, tb, tc).get_output(0) diff --git a/python/hidet/tos/ops/definitions/matmul/parallel_k_matmul.py b/python/hidet/tos/ops/definitions/matmul/parallel_k_matmul.py new file mode 100644 index 0000000..2d62c9f --- /dev/null +++ b/python/hidet/tos/ops/definitions/matmul/parallel_k_matmul.py @@ -0,0 +1,121 @@ +import warnings +from .matmul import matmul, Tensor, cuda +from hidet.utils.py import gcd, factor +from hidet.utils import hidet_cache_file + + +def parallel_k_nparts(batch_size, m_size, n_size, k_size) -> int: + predefined_rules = { + # some predefined rules used in important models + (1, 128, 3072, 768): 3, + (1, 128, 2304, 768): 3, + (1, 128, 768, 768): 6, + (1, 289, 192, 1120): 10, + (1, 64, 384, 1152): 12, + (1, 289, 192, 1344): 14, + (1, 49, 160, 960): 15, + (1, 49, 320, 960): 15, + (1, 196, 64, 384): 16, + (1, 196, 256, 2304): 16, + (1, 289, 128, 768): 16, + (1, 289, 96, 864): 16, + (1, 289, 128, 896): 16, + (1, 128, 768, 3072): 16, + } + if k_size < 384: + # for small k, the overhead of splitting k is too large compared with the benefits + return 1 + elif m_size == 1 or n_size == 1: + return 16 + elif (batch_size, m_size, n_size, k_size) in predefined_rules: + return predefined_rules[(batch_size, m_size, n_size, k_size)] + else: + # we hope to run multiple waves of thread blocks (e.g., 5) + estimate_thread_blocks = batch_size * ((m_size + 63) // 64) * ((n_size + 63) // 64) + num_multi_processors = cuda.device_property(cuda.PropertyMultiProcessorCount) + if estimate_thread_blocks * 8 <= num_multi_processors * 5: + nparts = 8 + elif estimate_thread_blocks * 4 <= num_multi_processors * 5: + nparts = 4 + elif estimate_thread_blocks * 2 <= num_multi_processors * 5: + nparts = 2 + else: + nparts = 1 + return nparts + + +def parallel_k_batched_matmul(a: Tensor, b: Tensor, mma: str = 'default', nparts=None) -> Tensor: + k_size = a.shape[-1] + batch_size, m_size, n_size = a.shape[0], a.shape[1], b.shape[2] + + if nparts is None: + nparts = parallel_k_nparts(batch_size, m_size, n_size, k_size) + + nparts = gcd(nparts, k_size) + + # print('parallel_k of batched matmul {}x{}x{}x{} used factor {}'.format(batch_size, m_size, n_size, k_size, nparts)) + + if nparts == 1: + # warnings.warn('Parallel k matmul use nparts=1, fall back to direct matmul.') + return matmul(a, b, algo='direct', mma=mma) + else: + a = a.reshape([batch_size, m_size, nparts, k_size // nparts]).rearrange([[0, 2], [1], [3]]) # [batch_size * nparts, m_size, k_size // nparts] + b = b.reshape([batch_size, nparts, k_size // nparts, n_size]).rearrange([[0, 1], [2], [3]]) # [batch_size * nparts, k_size // nparts, n_size] + c = matmul(a, b, algo='direct', mma=mma).reshape([batch_size, nparts, m_size, n_size]).sum(1) + return c + + # if nparts is None: + # if use_parallel_k(batch_size, m_size, n_size, k_size): + # if nparts is None: + # nparts = parallel_k_nparts(batch_size, m_size, n_size, k_size) + # else: + # nparts = gcd(nparts, k_size) + # if nparts == 1: + # return batched_matmul(a, b, algo='direct', mma=mma) + # else: + # else: + # warnings.warn('Please use use_parallel_k to check whether we should use parallel_k matmul first. Falling back to direct algorithm.') + # return batched_matmul(a, b, algo='direct', mma=mma) + # pass + # else: + # + + +def parallel_k_batched_matmul_search(a: Tensor, b: Tensor, mma: str = 'default') -> Tensor: + import numpy as np + k_size = a.shape[-1] + batch_size, m_size, n_size = a.shape[0], a.shape[1], b.shape[2] + + factors = [v for v in factor(k_size) if v <= 16] + + best_nparts = None + best_nparts_latency = 1e9 + + if len(factors) > 1: + # print('searching batch_matmul {}x{}x{}x{} parallel k factors: {}'.format(batch_size, m_size, n_size, k_size, factors)) + candidate_latencies = [] + for nparts in factors: + num_trials = 1000 + if nparts == 1: + c = matmul(a, b, algo='direct', mma=mma) + latency = float(np.median(c.op.latency(number=num_trials))) + if latency < best_nparts_latency: + best_nparts = nparts + best_nparts_latency = latency + candidate_latencies.append(latency) + else: + aa = a.reshape([batch_size, m_size, nparts, k_size // nparts]).rearrange([[0, 2], [1], [3]]) # [batch_size * nparts, m_size, k_size // nparts] + bb = b.reshape([batch_size, nparts, k_size // nparts, n_size]).rearrange([[0, 1], [2], [3]]) # [batch_size * nparts, k_size // nparts, n_size] + cc = matmul(aa, bb, algo='direct', mma=mma) + c1 = cc.reshape([batch_size, nparts, m_size, n_size]) + c2 = c1.sum(1) + latency = cc.op.latency(number=num_trials) + c2.op.latency(number=num_trials) + if latency < best_nparts_latency: + best_nparts = nparts + best_nparts_latency = latency + candidate_latencies.append(latency) + # print('candidate latencies: {}, choose factor {}'.format(['{:.3f}'.format(v * 1000) for v in candidate_latencies], best_nparts)) + else: + assert len(factors) == 1 + best_nparts = factors[0] + return parallel_k_batched_matmul(a, b, mma, best_nparts) diff --git a/python/hidet/tos/ops/definitions/norm.py b/python/hidet/tos/ops/definitions/norm.py new file mode 100644 index 0000000..1133c21 --- /dev/null +++ b/python/hidet/tos/ops/definitions/norm.py @@ -0,0 +1,63 @@ +from typing import List +from .utils import Tensor +from .arithmatic import sqrt, square +from .reduce import reduce_mean, reduce_var + + +def normalize(x: Tensor, dims: List[int], epsilon: float = 1e-5) -> Tensor: + x = x - x.mean(dims, keep_dim=True) + variance = square(x).mean(dims, keep_dim=True) + return x * (variance + epsilon).rsqrt() + + +def batch_norm_infer(x: Tensor, running_mean: Tensor, running_var: Tensor, epsilon=1e-5, axis=1) -> Tensor: + assert len(x.shape) == 4 and axis == 1 + assert len(running_mean.shape) == 1 and len(running_var.shape) == 1 + assert x.shape[1] == running_mean.shape[0] == running_var.shape[0] + running_mean = running_mean.unsqueeze([0, 2, 3]) # [1, c, 1, 1] + running_var = running_var.unsqueeze([0, 2, 3]) + return (x - running_mean) * (running_var + epsilon).rsqrt() + + +def instance_norm(x: Tensor, axis: int = 1, epsilon: float = 1e-5) -> Tensor: + """ + Instance norm. + + Parameters + ---------- + x: Tensor + The data to be normalized. + axis: int + The axis of channel dimension. + epsilon: float + The epsilon added to variance. + + Returns + ------- + ret: Tensor + The normalized tensor. + """ + dims = [dim for dim in range(2, len(x.shape)) if dim != axis] + return normalize(x, dims=dims, epsilon=epsilon) + + +def layer_norm(x: Tensor, num_last_dims: int = 1, epsilon: float = 1e-5) -> Tensor: + """ + Layer norm. + + Parameters + ---------- + x: Tensor + The data to be normalized. + num_last_dims: int + The number of dimensions to be normalized, starting from the end dimension of x. + epsilon: float + The epsilon added to variance. + + Returns + ------- + ret: Tensor + The normalized tensor. + """ + dims = list(range(len(x.shape) - num_last_dims, len(x.shape))) + return normalize(x, dims=dims, epsilon=epsilon) diff --git a/python/hidet/tos/ops/definitions/pool.py b/python/hidet/tos/ops/definitions/pool.py new file mode 100644 index 0000000..f730e9f --- /dev/null +++ b/python/hidet/tos/ops/definitions/pool.py @@ -0,0 +1,82 @@ +from typing import Union, Sequence +from hidet.ir.expr import convert + +from .utils import Task, Operator, Tensor, TensorNode, compute, reduce, inline_compute, input_like, normalize_stride, normalize_kernel, normalize_padding + + +class Pool2dTask(Task): + def __init__(self, x: TensorNode, kernel, strides, padding, reduce_type: str): + assert reduce_type in ['max', 'avg'] + kernel = normalize_kernel(kernel) + strides = normalize_stride(strides) + padding = normalize_padding(padding) + batch_size, channels, height, width = x.const_shape() + out_height = (height + padding[0] + padding[2] - kernel[0]) // strides[0] + 1 + out_width = (width + padding[1] + padding[3] - kernel[1]) // strides[1] + 1 + pad_value = convert(0.0 if reduce_type == 'avg' else -1e30, dtype=x.data_type.scalar_type) + pad = compute( + name='pad', + shape=[batch_size, channels, height + 2 * padding[0], width + 2 * padding[1]], + fcompute=lambda n, c, h, w: x.protect_read(indices=[n, c, h - padding[0], w - padding[1]], default_value=pad_value), + scope=x.data_type.scope + ) + y = compute( + name='y', + shape=[batch_size, channels, out_height, out_width], + fcompute=lambda n, c, h, w: reduce( + shape=[kernel[0], kernel[1]], + fcompute=lambda rx, ry: pad[n, c, h * strides[0] + rx, w * strides[1] + ry], + reduce_type=reduce_type + ), + scope='global' + ) + y = inline_compute(y) + super().__init__( + name='{}_pool2d'.format(reduce_type), + inputs=[x], + outputs=[y] + ) + + +class MaxPool2dOp(Operator): + def __init__(self, + input: Tensor, + kernel: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]] + ): + super().__init__( + inputs=[input], + task=Pool2dTask(input_like(input, 'x'), kernel, stride, padding, reduce_type='max'), + attributes={ + 'kernel': kernel, + 'stride': stride, + 'padding': padding + } + ) + + +class AvgPool2dOp(Operator): + def __init__(self, + input: Tensor, + kernel: Union[int, Sequence[int]], + stride: Union[int, Sequence[int]], + padding: Union[int, Sequence[int]] + ): + super().__init__( + inputs=[input], + task=Pool2dTask(input_like(input, 'x'), kernel, stride, padding, reduce_type='avg'), + attributes={ + 'kernel': kernel, + 'stride': stride, + 'padding': padding + } + ) + + +def max_pool2d(input: Tensor, kernel, stride, padding) -> Tensor: + return MaxPool2dOp(input, kernel, stride, padding).get_output(0) + + +def avg_pool2d(input: Tensor, kernel, stride, padding) -> Tensor: + return AvgPool2dOp(input, kernel, stride, padding).get_output(0) diff --git a/python/hidet/tos/ops/definitions/reduce.py b/python/hidet/tos/ops/definitions/reduce.py new file mode 100644 index 0000000..c49f9d3 --- /dev/null +++ b/python/hidet/tos/ops/definitions/reduce.py @@ -0,0 +1,111 @@ +from typing import List, Union + +from .arithmatic import square +from .utils import Task, Operator, Tensor, TensorNode, IRModule, compute, reduce, input_like, normalize_dim + + +class ReduceTask(Task): + def __init__(self, x: TensorNode, dims: List[int], keep_dim: bool, reduce_type: str, accumulate_dtype: str = 'float32'): + x_shape = x.const_shape() + y_shape = [] + for i in range(len(x_shape)): + if i in dims: + if keep_dim: + y_shape.append(1) + else: + y_shape.append(x_shape[i]) + + def fcompute(*indices): + def reduce_fcompute(*reduce_indices): + x_indices = [] + p = 0 + q = 0 + for i in range(len(x_shape)): + if i not in dims: + x_indices.append(indices[p]) + p += 1 + else: + x_indices.append(reduce_indices[q]) + q += 1 + if keep_dim: + p += 1 + assert p == len(indices) and q == len(reduce_indices) + return x[x_indices] + + reduce_shape = [x_shape[i] for i in dims] + return reduce(shape=reduce_shape, fcompute=reduce_fcompute, + reduce_type=reduce_type, accumulate_dtype=accumulate_dtype) + + y = compute(name='y', shape=y_shape, fcompute=fcompute, scope='global') + + self.dims: List[int] = dims + self.keep_dim: bool = keep_dim + self.reduce_type: str = reduce_type + + super().__init__( + name='reduce_{}'.format(reduce_type), + inputs=[x], + outputs=[y], + attributes={ + 'dims': dims, + 'keep_dim': keep_dim, + 'reduce_type': reduce_type, + 'accumulate_dtype': accumulate_dtype + } + ) + + def implement_cuda(self) -> IRModule: + from ..schedules import cuda_schedule_reduce_by_default, cuda_schedule_reduce_by_warp_reduce + rank = len(self.inputs[0].const_shape()) + if rank - 1 in self.dims: + # reduce over last dimension + return cuda_schedule_reduce_by_warp_reduce(self) + else: + # last dimension has not been reduced + return cuda_schedule_reduce_by_default(self) + + def fast_implement(self, space_level: int) -> bool: + return True + + +class ReduceMeanOp(Operator): + def __init__(self, x: Tensor, dims: List[int], keep_dim: bool = False): + dims = normalize_dim(dims, rank=len(x.shape)) + super().__init__( + inputs=[x], + task=ReduceTask(input_like(x, 'x'), dims, keep_dim, 'avg'), + attributes={ + 'dims': dims, + 'keep_dim': keep_dim + } + ) + + +class ReduceSumOp(Operator): + def __init__(self, x: Tensor, dims: List[int], keep_dim: bool = False): + dims = normalize_dim(dims, rank=len(x.shape)) + super().__init__( + inputs=[x], + task=ReduceTask(input_like(x, 'x'), dims, keep_dim, 'sum'), + attributes={ + 'dims': dims, + 'keep_dim': keep_dim + } + ) + + +def reduce_mean(x: Tensor, dims: Union[int, List[int]], keep_dim: bool = False) -> Tensor: + if isinstance(dims, int): + dims = [dims] + return ReduceMeanOp(x, dims, keep_dim).get_output(0) + + +def reduce_sum(x: Tensor, dims: Union[int, List[int]], keep_dim: bool = False) -> Tensor: + if isinstance(dims, int): + dims = [dims] + return ReduceSumOp(x, dims, keep_dim).get_output(0) + + +def reduce_var(x: Tensor, dims: Union[int, List[int]], keep_dim: bool = False) -> Tensor: + x = x - x.mean(dims=dims, keep_dim=True) + return square(x).mean(dims=dims, keep_dim=keep_dim) diff --git a/python/hidet/tos/ops/definitions/softmax.py b/python/hidet/tos/ops/definitions/softmax.py new file mode 100644 index 0000000..106eedf --- /dev/null +++ b/python/hidet/tos/ops/definitions/softmax.py @@ -0,0 +1,79 @@ +from hidet.ir.func import IRModule +from .utils import Task, Operator, Tensor, TensorNode, compute, input_like, normalize_dim, reduce +from hidet.ir import primitives as prim + + +class SoftmaxTask(Task): + def __init__(self, x: TensorNode, axis: int): + self.x_shape = x.const_shape() + self.axis = axis + + shape = x.const_shape() + axis_extent = shape[axis] + reduced_shape = shape[:axis] + shape[axis+1:] + + # max value + max_value = compute( + name='max_value', + shape=reduced_shape, + fcompute=lambda *indices: reduce( + shape=[axis_extent], + fcompute=lambda k: x[indices[:axis] + (k,) + indices[axis:]], + reduce_type='max' + ) + ) + + # exp + exp_value = compute( + name='exp_value', + shape=shape, + fcompute=lambda *indices: prim.exp(x[indices] - max_value[indices[:axis] + indices[axis+1:]]) + ) + + # sum + sum_value = compute( + name='sum_value', + shape=reduced_shape, + fcompute=lambda *indices: reduce( + shape=[axis_extent], + fcompute=lambda k: exp_value[indices[:axis] + (k,) + indices[axis:]], + reduce_type='sum' + ) + ) + + # out + out = compute( + name='out', + shape=shape, + fcompute=lambda *indices: exp_value[indices] / sum_value[indices[:axis] + indices[axis+1:]] + ) + super().__init__( + name='softmax', + inputs=[x], + outputs=[out] + ) + + def implement_cuda(self) -> IRModule: + from hidet.tos.ops.schedules import softmax_cuda_schedule + return softmax_cuda_schedule(self) + + def fast_implement(self, space_level: int) -> bool: + return True + + +class SoftmaxOp(Operator): + def __init__(self, + x: Tensor, + axis: int = 1): + axis = normalize_dim(axis, len(x.shape)) + super().__init__( + inputs=[x], + task=SoftmaxTask(input_like(x, 'x'), axis), + attributes={ + 'axis': axis + } + ) + + +def softmax(x: Tensor, axis=1) -> Tensor: + return SoftmaxOp(x, axis).get_output(0) diff --git a/python/hidet/tos/ops/definitions/special.py b/python/hidet/tos/ops/definitions/special.py new file mode 100644 index 0000000..aafe343 --- /dev/null +++ b/python/hidet/tos/ops/definitions/special.py @@ -0,0 +1,45 @@ +from .utils import Task, Operator, Tensor, TensorNode, compute, input_like + + +# todo: add GraphInput and GraphOutput special operators here. + +class BarrierTask(Task): + def __init__(self, x: TensorNode): + y = compute( + name='y', + shape=x.const_shape(), + fcompute=lambda *indices: x[indices] + ) + super().__init__( + name='barrier', + inputs=[x], + outputs=[y] + ) + + +class BarrierOp(Operator): + def __init__(self, x: Tensor): + super().__init__( + inputs=[x], + task=BarrierTask(input_like(x, 'x')), + ) + + +def barrier(x: Tensor) -> Tensor: + """ + Barrier operator is an identity operator and return the same tensor as input. During + graph-level optimizations, this operator prevents the fusion of producer and consumer + of the input tensor and output tensor, respectively. This operator will be eliminated + at the end of graph-level optimizations. + + Parameters + ---------- + x: Tensor + The input tensor. + + Returns + ------- + y: Tensor + The output tensor. + """ + return BarrierOp(x).get_output(0) diff --git a/python/hidet/tos/ops/definitions/transform.py b/python/hidet/tos/ops/definitions/transform.py new file mode 100644 index 0000000..51cf522 --- /dev/null +++ b/python/hidet/tos/ops/definitions/transform.py @@ -0,0 +1,721 @@ +from typing import List, Optional, Union, Sequence +import functools + +from hidet.ir.expr import Expr, And, if_then_else, convert +from hidet.ir.layout import DataLayout, RowMajorLayout, ColumnMajorLayout +from hidet.ir.utils import index_deserialize, index_serialize +from hidet.utils import prod +from .utils import Task, InverseMap, Operator, Tensor, TensorNode, compute, input_like, normalize_dim + + +def same_shape(shape_a: List[int], shape_b: List[int]) -> bool: + return len(shape_a) == len(shape_b) and all(a == b for a, b in zip(shape_a, shape_b)) + + +class ReshapeTask(Task): + def __init__(self, x: TensorNode, y_shape: List[int]): + x_shape = x.const_shape() + if not prod(x_shape) == prod(y_shape): + raise ValueError('Can not reshape {} to {} because they have different number ' + 'of elements: {} vs {}'.format(x_shape, y_shape, prod(x_shape), prod(y_shape))) + if not isinstance(x.data_type.layout, RowMajorLayout): + raise NotImplementedError('currently, only support row major layout. Please use ' + '.contiguous() to transfer the given tensor into row major layout first.') + + def index_map(dst_indices, src_shape, dst_shape): + src_groups = [] + dst_groups = [] + i, j = 0, 0 + while i < len(src_shape) and j < len(dst_shape): + src_group = [i] + dst_group = [j] + x_size, y_size = src_shape[i], dst_shape[j] + i += 1 + j += 1 + while x_size != y_size: + if x_size < y_size: + x_size *= src_shape[i] + src_group.append(i) + i += 1 + else: + y_size *= dst_shape[j] + dst_group.append(j) + j += 1 + src_groups.append(src_group) + dst_groups.append(dst_group) + if i < len(src_shape): + assert prod(src_shape[i:]) == 1 + src_groups.append(list(range(i, len(src_shape)))) + dst_groups.append([]) + if j < len(dst_shape): + assert prod(dst_shape[j:]) == 1 + src_groups.append([]) + dst_groups.append(list(range(j, len(dst_shape)))) + src_indices = [] + for src_group, dst_group in zip(src_groups, dst_groups): + x_group_shape = [src_shape[r] for r in src_group] + y_group_shape = [dst_shape[r] for r in dst_group] + y_group_indices = [dst_indices[r] for r in dst_group] + x_group_indices = index_deserialize(index_serialize(y_group_indices, y_group_shape), x_group_shape) + src_indices.extend(x_group_indices) + assert len(src_indices) == len(src_shape) + return src_indices + + def inverse_map(*x_indices): + return index_map(x_indices, src_shape=y_shape, dst_shape=x_shape) + + y = compute( + name='y', + shape=y_shape, + fcompute=lambda *indices: x[index_map(indices, src_shape=x_shape, dst_shape=y_shape)], + scope='global', + ) + super().__init__( + name='reshape', + inputs=[x], + outputs=[y], + inverse_map={x: InverseMap.from_lambda(inverse_map, num_args=len(x_shape))} + ) + + +class RearrangeTask(Task): + def __init__(self, x: TensorNode, plan: List[List[int]]): + x_shape = x.const_shape() + y_shape = [prod([x_shape[i] for i in dims]) for dims in plan] + + def index_split(total_index, dim_sizes: List[int]) -> List: + bases = [prod(dim_sizes[i + 1:]) for i in range(len(dim_sizes))] + return [(total_index // base) % dim for dim, base in zip(dim_sizes, bases)] + + def fcompute(*y_indices): + x_indices = [None for _ in range(len(x_shape))] + for i, y_index in enumerate(y_indices): + dims = plan[i] + if len(dims) == 0: + # this new dimension has size 1 + continue + else: + split_indices = index_split(total_index=y_index, dim_sizes=[x_shape[k] for k in dims]) + for j, x_index in zip(dims, split_indices): + x_indices[j] = x_index + for i, x_index in enumerate(x_indices): + if x_index is None: + if x_shape[i] != 1: + msg = 'Rearrange plan {} on tensor {} leave non-one dimension {} not been accessed'.format(plan, x_shape, i) + raise ValueError(msg) + else: + x_indices[i] = 0 + return x[x_indices] + + y = compute( + name='y', + shape=y_shape, + fcompute=fcompute + ) + + def inverse_map(*x_indices): + y_indices = [] + for dims in plan: + cnt = convert(0) + for dim in dims: + cnt = cnt * x_shape[dim] + x_indices[dim] + y_indices.append(cnt) + return y_indices + + super().__init__( + name='rearrange', + inputs=[x], + outputs=[y], + inverse_map={x: InverseMap.from_lambda(inverse_map, len(x_shape))} + ) + + +class ConcatTask(Task): + def __init__(self, inputs: List[TensorNode], axis: int): + shapes = [t.const_shape() for t in inputs] + n = len(shapes) + assert n > 0 + for i in range(1, n): + if len(shapes[0]) != len(shapes[i]): + raise ValueError('Concat: all shapes must have the same rank, got {}'.format(shapes)) + if any(a != b for j, (a, b) in enumerate(zip(shapes[0], shapes[i])) if j != axis): + raise ValueError('Concat: all tensors must have the same shape except axis dimension, got {}, axis {}'.format(shapes, axis)) + rank = len(shapes[0]) + out_shape = [shapes[0][i] if i != axis else sum(shapes[j][i] for j in range(n)) for i in range(rank)] + + def fmap(*indices): + pre_sum = [sum([shapes[j][axis] for j in range(i)]) for i in range(n + 1)] + value = inputs[-1][indices[:axis] + (indices[axis] - pre_sum[-2],) + indices[axis + 1:]] + for i, input in reversed(list(zip(range(n - 1), inputs[:n - 1]))): + input_i_value = inputs[i][indices[:axis] + (indices[axis] - pre_sum[i],) + indices[axis + 1:]] + value = if_then_else(indices[axis] < pre_sum[i + 1], input_i_value, value) + return value + + out = compute( + name='out', + shape=out_shape, + fcompute=lambda *indices: fmap(*indices) + ) + + super().__init__( + name='concat', + inputs=inputs, + outputs=[out] + ) + + +class TakeTask(Task): + def __init__(self, data: TensorNode, indices: TensorNode, axis=0): + data_shape = data.const_shape() + indices_shape = indices.const_shape() + output_shape = data_shape[:axis] + indices_shape + data_shape[axis + 1:] + assert 0 <= axis < len(data_shape) + + def fmap(*output_indices): + indices_indices = output_indices[axis: axis + len(indices_shape)] + data_indices = output_indices[:axis] + (indices[indices_indices],) + output_indices[axis + len(indices_shape):] + return data[data_indices] + + output = compute( + name='output', + shape=output_shape, + fcompute=lambda *output_indices: fmap(*output_indices), + scope='global' + ) + super().__init__( + name='take', + inputs=[data, indices], + outputs=[output] + ) + + +class StridedSliceTask(Task): + def __init__(self, data: TensorNode, starts: List[Optional[int]], ends: List[Optional[int]], axes: List[int], strides: List[int]): + assert len(starts) == len(ends) == len(axes) == len(strides) + if len(axes) != len(set(axes)): + raise ValueError('Duplicated axes in slice, axes: {}'.format(axes)) + data_shape = data.const_shape() + output_shape = list(data_shape) + axis2info = {} + for axis, start, end, stride in zip(axes, starts, ends, strides): + if stride == 0: + raise NotImplementedError('Stride can not be 0 in slicing: starts {} ends {} axes {} strides {}.'.format(starts, ends, axes, strides)) + if stride > 0: + output_shape[axis] = (end - start + stride - 1) // stride + else: + output_shape[axis] = (start - end + (-stride) - 1) // (-stride) + if output_shape[axis] <= 0: + raise NotImplementedError('Slice result can not be: starts {} ends {} axes {} strides {}'.format(starts, ends, axes, strides)) + axis2info[axis] = (start, end, stride) + + def fmap(indices): + data_indices = [] + for axis, index in enumerate(indices): + if axis in axis2info: + start, end, stride = axis2info[axis] + data_indices.append(start + index * stride) + else: + data_indices.append(index) + return data[data_indices] + + out = compute( + 'out', + shape=output_shape, + fcompute=lambda *indices: fmap(indices), + scope=data.data_type.scope + ) + super().__init__( + name='slice', + inputs=[data], + outputs=[out] + ) + + +def can_broadcast(src_shape: List[int], dst_shape: List[int]) -> bool: + if len(dst_shape) < len(src_shape): + return False + src_shape = [1 for _ in range(len(dst_shape) - len(src_shape))] + src_shape + for a, b in zip(src_shape, dst_shape): + if a != 1 and a != b: + return False + return True + + +class BroadcastTask(Task): + def __init__(self, data: TensorNode, shape: List[int]): + data_shape = data.const_shape() + if not can_broadcast(data_shape, shape): + raise ValueError('Can not broadcast a tensor with shape {} to {}'.format(data_shape, shape)) + + def fmap(*indices): + expanded = len(shape) - len(data_shape) + indices = indices[expanded:] + indices = [v if data_shape[i] != 1 else 0 for i, v in enumerate(indices)] + return data[indices] + + out = compute( + 'out', + shape=shape, + fcompute=fmap, + scope=data.data_type.scope + ) + super().__init__( + name='broadcast', + inputs=[data], + outputs=[out] + ) + + +class PadTask(Task): + def __init__(self, data: TensorNode, pads: List[int], value: float): + shape = data.const_shape() + rank = len(shape) + assert rank * 2 == len(pads) + out_shape = [a + b + c for a, b, c in zip(pads[:rank], shape, pads[rank:])] + + value = convert(value, dtype=data.data_type.scalar_type.name) + + def fmap(*indices): + indices = [idx - beg for idx, beg in zip(indices, pads[:rank])] + cond = And.join_list([And(0 <= idx, idx < shape[i]) for i, idx in enumerate(indices)]) + return if_then_else(cond, data[indices], value) + + out = compute( + 'out', + shape=out_shape, + fcompute=fmap, + scope=data.data_type.scope + ) + super().__init__( + name='pad', + inputs=[data], + outputs=[out] + ) + + +class TileTask(Task): + def __init__(self, data: TensorNode, repeats: List[int]): + shape = data.const_shape() + assert len(shape) == len(repeats) + out_shape = [a * b for a, b in zip(shape, repeats)] + + def fmap(*indices): + indices = [idx % shape[i] for i, idx in enumerate(indices)] + return data[indices] + + out = compute( + name='out', + shape=out_shape, + fcompute=fmap, + scope=data.data_type.scope, + ) + super().__init__( + name='tile', + inputs=[data], + outputs=[out] + ) + + +class ReshapeOp(Operator): + def __init__(self, x: Tensor, shape): + shape = self.normalize_shape(x.shape, shape) + task = ReshapeTask(input_like(x, 'x'), shape) + super().__init__( + inputs=[x], + task=task, + attributes={ + 'shape': shape + } + + ) + + @staticmethod + def normalize_shape(origin_shape: List[int], shape: List[int]): + # [1, 3, 224, 224], [1, -1, 224, 0] => [1, 3, 224, 224] + shape = list(shape) + for i in range(len(shape)): + if shape[i] == 0: + if i >= len(origin_shape): + raise ValueError('0 is used outside original shape: origin {} target {}'.format(origin_shape, shape)) + shape[i] = origin_shape[i] + size = prod(origin_shape) + cnt = sum([1 for v in shape if v == -1]) + if cnt == 0: + if prod(shape) != size: + raise ValueError('Reshape: given shape has different size with input tensor: shape {} and size {}'.format(shape, size)) + return shape + elif cnt == 1: + remain_size = prod([v for v in shape if v != -1]) + if size % remain_size != 0: + raise ValueError('Given shape is incompatible with input tensor: shape {} and size {}'.format(shape, size)) + return [v if v != -1 else size // remain_size for v in shape] + else: + raise ValueError('Can not infer the shape when there are multiple -1: {}'.format(shape)) + + +class RearrangeOp(Operator): + def __init__(self, x: Tensor, plan: List[List[int]]): + super().__init__( + inputs=[x], + task=RearrangeTask(input_like(x, 'x'), plan=plan), + attributes={ + 'plan': plan + } + + ) + + +class SqueezeOp(Operator): + def __init__(self, x: Tensor, dims: List[int]): + super().__init__( + inputs=[x], + task=RearrangeTask(input_like(x, 'x'), plan=[[i] for i in range(len(x.shape)) if i not in dims]), + attributes={ + 'dims': dims + } + + ) + + def imperative_run(self, inputs: Optional[List[Tensor]] = None) -> List[Tensor]: + x = inputs[0] if inputs else self.inputs[0] + if isinstance(x.layout, (RowMajorLayout, ColumnMajorLayout)): + shape = self.task.outputs[0].const_shape() + layout = x.layout.__class__(shape) + return [Tensor(shape=shape, dtype=x.dtype, device=x.device, storage=x.storage, layout=layout, trace=None)] + else: + return Operator.imperative_run(self, inputs) + + +class UnsqueezeOp(Operator): + def __init__(self, x: Tensor, dims: List[int]): + dims = list(dims) + plan = [] + c = 0 + for i in range(len(x.shape) + len(dims)): + if i in dims: + plan.append([]) + else: + plan.append([c]) + c += 1 + assert c == len(x.shape) + super().__init__( + inputs=[x], + task=RearrangeTask(input_like(x, 'x'), plan=plan), + attributes={ + 'dims': dims + } + + ) + + def imperative_run(self, inputs: Optional[List[Tensor]] = None) -> List[Tensor]: + x = inputs[0] if inputs else self.inputs[0] + if isinstance(x.layout, (RowMajorLayout, ColumnMajorLayout)): + shape = self.task.outputs[0].const_shape() + layout = x.layout.__class__(shape) + return [Tensor(shape=shape, dtype=x.dtype, device=x.device, storage=x.storage, layout=layout, trace=None)] + else: + return Operator.imperative_run(self, inputs) + + +class FlattenOp(Operator): + def __init__(self, x: Tensor, start_dim: int, end_dim: int): + rank = len(x.shape) + start_dim = normalize_dim(start_dim, rank) + end_dim = normalize_dim(end_dim, rank) + assert 0 <= start_dim < end_dim <= rank + dims = list(range(len(x.shape))) + plan = [[v] for v in dims[:start_dim]] + [dims[start_dim: end_dim]] + [[v] for v in dims[end_dim:]] + super().__init__( + inputs=[x], + task=RearrangeTask(input_like(x, 'x'), plan=plan), + attributes={ + 'start_dim': start_dim, + 'end_dim': end_dim + } + + ) + + +class TransposeOp(Operator): + def __init__(self, x: Tensor, axes: Optional[List[int]] = None): + if axes and len(axes) != len(x.shape): + raise ValueError('Transpose tensor with shape {} expect a permutation of axes with length {}, got {}'.format(x.shape, len(x.shape), axes)) + if axes is None: + axes = list(reversed(range(len(x.shape)))) + plan = [[v] for v in axes] + super().__init__( + inputs=[x], + task=RearrangeTask(input_like(x, 'x'), plan), + attributes={ + 'axes': axes + } + + ) + + +class CastOp(Operator): + def __init__(self, x: Tensor, dtype: str): + from hidet.ir.expr import Cast + from .arithmatic import UnaryElementwiseTask + super().__init__( + inputs=[x], + task=UnaryElementwiseTask('cast', input_like(x, 'x'), op=lambda v: Cast(v, dtype)), + attributes={ + 'dtype': dtype + } + + ) + + +class ConcatOp(Operator): + def __init__(self, *tensors: Tensor, axis: int): + tensors = list(tensors) + if len(tensors) == 0: + raise ValueError('Concat requires at least one tensor, 0 given.') + axis = normalize_dim(axis, len(tensors[0].shape)) + super().__init__( + inputs=tensors, + task=ConcatTask([input_like(tensor, 'x{}'.format(idx)) for idx, tensor in enumerate(tensors)], axis=axis), + attributes={ + 'axis': axis + } + ) + + +class TakeOp(Operator): + def __init__(self, data: Tensor, indices: Tensor, axis: int): + super().__init__( + inputs=[data, indices], + task=TakeTask(input_like(data, 'data'), input_like(indices, 'indices'), axis=axis), + attributes={ + 'axis': axis + } + ) + + +class StridedSliceOp(Operator): + def __init__(self, data: Tensor, starts: List[int], ends: List[int], axes: Optional[List[int]] = None, strides: Optional[List[int]] = None): + starts, ends, axes, strides = self.normalize(data.shape, starts, ends, axes, strides) + task = StridedSliceTask(input_like(data, 'data'), starts, ends, axes, strides) + super().__init__( + inputs=[data], + task=task, + attributes={ + 'starts': starts, + 'ends': ends, + 'axes': axes, + 'strides': strides + } + + ) + + @staticmethod + def normalize(shape, starts, ends, axes: Optional[List[int]], strides: Optional[List[int]]): + # follow: https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice to normalize + rank = len(shape) + if axes is None: + axes = [i for i in range(len(starts))] + axes = normalize_dim(axes, rank) + if strides is None: + strides = [1 for _ in range(len(starts))] + shape = [shape[i] for i in axes] + assert len(shape) == len(starts) == len(ends) == len(axes) == len(strides) + for i in range(len(axes)): + starts[i] = starts[i] + shape[i] if starts[i] < 0 else starts[i] + ends[i] = ends[i] + shape[i] if ends[i] < 0 else ends[i] + if strides[i] > 0: + starts[i] = max(0, min(shape[i], starts[i])) + ends[i] = max(0, min(shape[i], ends[i])) + else: + starts[i] = max(0, min(shape[i] - 1, starts[i])) + ends[i] = max(-1, min(shape[i] - 1, ends[i])) + return starts, ends, axes, strides + + +class BroadcastOp(Operator): + def __init__(self, data: Tensor, shape: List[int]): + super().__init__( + inputs=[data], + task=BroadcastTask(input_like(data, 'data'), shape), + attributes={ + 'shape': shape + } + ) + + +class PadOp(Operator): + def __init__(self, data: Tensor, pads: List[int], mode: str = 'constant', value: float = 0.0): + if len(pads) < len(data.shape) * 2: + assert len(pads) % 2 == 0, 'The pads must have even number of elements.' + half = len(pads) // 2 + extra = [0 for _ in range(len(data.shape) - half)] + pads = extra + pads[:half] + extra + pads[half:] + if mode != 'constant': + raise NotImplementedError("Padding mode '{}' has not been implemented yet.".format(mode)) + super().__init__( + inputs=[data], + task=PadTask(input_like(data, 'data'), pads, value), + attributes={ + 'pads': pads, + 'mode': mode, + 'value': value + } + + ) + + +class TileOp(Operator): + def __init__(self, data: Tensor, repeats: List[int]): + if len(repeats) != len(data.shape): + raise ValueError("The length of 'repeats' parameter of Tile operator expects to have the " + "same length as data shape. shape: {}, repeats: {}".format(data.shape, repeats)) + super().__init__( + inputs=[data], + task=TileTask(input_like(data, 'data'), repeats), + attributes={ + 'repeats': repeats + } + ) + + +def reshape(x: Tensor, shape) -> Tensor: + if same_shape(x.shape, shape): + return x + return ReshapeOp(x, shape).get_output(0) + + +def rearrange(x: Tensor, plan: List[List[int]]) -> Tensor: + """ + Rearrange a tensor. This task is a general task of squeeze, unsqueeze, flatten, and perm. + + Parameters + ---------- + x: Tensor + The input tensor. + + plan: List[List[int]] + The rearrange plan. + + Returns + ------- + ret: Tensor + The task to conduct rearrangement. + + Examples + -------- + squeeze([1, 1, 2, 3], dims=[0, 1]) = rearrange([1, 1, 2, 3], plan=[[2], [3]]) => Tensor([2, 3]) + unsqueeze([2, 3], dims=[0, 1]) = rearrange([2, 3], plan=[[], [], [0], [1]]) => Tensor([1, 1, 2, 3]) + flatten([2, 3, 4, 5], start_dim=1, end_dim=2) = rearrange([2, 3, 4, 5], plan=[[0], [1, 2], [3]]) => Tensor([2, 12, 5]) + """ + if not isinstance(plan, (list, tuple)) or any(not isinstance(v, (list, tuple)) for v in plan): + raise ValueError('plan should be List[List[int]], but got: {}'.format(plan)) + return RearrangeOp(x, plan).get_output(0) + + +def squeeze(x: Tensor, dims: Union[int, Sequence[int]]) -> Tensor: + if isinstance(dims, int): + dims = [dims] + if len(dims) == 0: + return x + return SqueezeOp(x, dims).get_output(0) + + +def unsqueeze(x: Tensor, dims: Union[int, Sequence[int]]) -> Tensor: + if isinstance(dims, int): + dims = [dims] + if len(dims) == 0: + return x + return UnsqueezeOp(x, dims).get_output(0) + + +def flatten(x: Tensor, start_dim=0, end_dim=None) -> Tensor: + start_dim = normalize_dim(start_dim, len(x.shape)) + end_dim = normalize_dim(end_dim, len(x.shape)) + if start_dim + 1 >= end_dim: + return x + return FlattenOp(x, start_dim, end_dim).get_output(0) + + +def transpose(x: Tensor, axes: Optional[List[int]] = None) -> Tensor: + rank = len(x.shape) + if axes is None: + axes = list(reversed(range(rank))) + axes = [normalize_dim(dim, rank) for dim in axes] + dims = [] + i = 0 + for j in range(rank): + if j in axes: + dims.append(axes[i]) + i += 1 + else: + dims.append(j) + return TransposeOp(x, dims).get_output(0) + + +def concat(tensors: List[Tensor], axis: int) -> Tensor: + return ConcatOp(*tensors, axis=axis).get_output(0) + + +def cast(x: Tensor, dtype: str) -> Tensor: + if x.dtype == dtype: + return x + return CastOp(x, dtype).get_output(0) + + +def take(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor: + return TakeOp(data, indices, axis).get_output(0) + + +def strided_slice(data: Tensor, starts: List[int], ends: List[int], axes: Optional[List[int]] = None, strides: Optional[List[int]] = None) -> Tensor: + return StridedSliceOp(data, starts, ends, axes, strides).get_output(0) + + +def broadcast(data: Tensor, shape) -> Tensor: + if same_shape(data.shape, shape): + return data + return BroadcastOp(data, shape).get_output(0) + + +def pad(data: Tensor, pads: List[int], mode: str = 'constant', value: float = 0.0) -> Tensor: + if all(p == 0 for p in pads): + return data + return PadOp(data, pads, mode, value).get_output(0) + + +def conv_pad(data: Tensor, pads: List[int]) -> Tensor: + from .utils import normalize_padding + pads = normalize_padding(pads, dim=len(data.shape) - 2) + return pad(data, pads) + + +def tile(data: Tensor, repeats: List[int]) -> Tensor: + """ + Tile a tensor. See https://numpy.org/doc/stable/reference/generated/numpy.tile.html. + + Parameters + ---------- + data: Tensor + The input tensor to be tiled. + repeats: List[int] + A list of integers to represent the number of repeats for each dimension. + Must have len(repeats) == len(data.shape). + + Returns + ------- + ret: Tensor + The tiled tensor, with shape [a * b for a, b in zip(data.shape, repeats)]. + """ + return TileOp(data, repeats).get_output(0) + + +def split(data: Tensor, axis: int, parts: List[int]) -> List[Tensor]: + axis = normalize_dim(axis, len(data.shape)) + if sum(parts) != data.shape[axis]: + raise ValueError('split operator expects the sum(parts) parameter equals the the extent of given axis' + ', but got shape {}, axis {} and parts {}'.format(data.shape, axis, parts)) + outputs = [] + for i in range(len(parts)): + start = sum(parts[:i]) + end = start + parts[i] + outputs.append(strided_slice(data, starts=[start], ends=[end], axes=[axis], strides=[1])) + return outputs diff --git a/python/hidet/tos/ops/definitions/utils.py b/python/hidet/tos/ops/definitions/utils.py new file mode 100644 index 0000000..a12b514 --- /dev/null +++ b/python/hidet/tos/ops/definitions/utils.py @@ -0,0 +1,97 @@ +from typing import Tuple, List, Union, Sequence, Optional +from hidet.ir.layout import DataLayout +from hidet.ir.expr import Var +from hidet.ir.type import TensorType, tensor_type, ScalarType +from hidet.ir.task import Task, InverseMap +from hidet.ir.func import IRModule +from hidet.tos.operator import Operator, Tensor +from hidet.ir.dialects.compute import TensorNode, tensor_input, compute, reduce + +from hidet.ir.functors import inline_compute + + +def input_like(tensor: Tensor, name: str) -> TensorNode: + # todo: make scope and device consistent + device2scope = { + 'cuda': 'global', + 'cpu': 'host' + } + return tensor_input(name, tensor.dtype, tensor.shape, device2scope[tensor.device], tensor.layout) + + +def normalize_stride(stride: Union[int, Sequence[int]], dim=2) -> List[int]: + if isinstance(stride, int): + return [stride for _ in range(dim)] + elif isinstance(stride, (list, tuple)): + if len(stride) == 1: + return stride * dim + elif len(stride) == dim: + return stride + raise ValueError('Stride must be an integer or a list of integer with length 1 or {}, but got {}'.format(dim, stride)) + + +def normalize_kernel(kernel: Union[int, Sequence[int]], dim=2) -> List[int]: + if isinstance(kernel, int): + return [kernel for _ in range(dim)] + elif isinstance(kernel, (list, tuple)): + if len(kernel) == 1: + return kernel * dim + elif len(kernel) == dim: + return kernel + raise ValueError('Kernel size must be an integer or a list of integer with length 1 or {}, but got {}'.format(dim, kernel)) + + +def normalize_padding(padding: Union[int, Sequence[int]], dim=2) -> List[int]: + if isinstance(padding, int): + return [padding for _ in range(dim * 2)] + elif isinstance(padding, (list, tuple)): + if len(padding) == 1: + return list(padding * (2 * dim)) + elif len(padding) == dim: + return list(padding + padding) + elif len(padding) == dim * 2: + return list(padding) + raise ValueError('Padding must be an integer or a list of integer with length 1, {}, or {}, but got {}'.format(dim, dim * 2, padding)) + + +def normalize_dim(dim: Optional[Union[int, Sequence[int]]], rank: int) -> Union[int, List[int]]: + """ + normalize a dim from [-rank, rank] or None to [0, rank]. + """ + if isinstance(dim, (list, tuple)): + return [normalize_dim(d, rank) for d in dim] + else: + original_dim = dim + if dim is None: + dim = rank + if dim < 0: + dim += rank + if not (0 <= dim <= rank): + raise ValueError('Given dim {} is not a valid dim for rank {}'.format(original_dim, rank)) + return dim + + +def normalize_index(index: Optional[int], dim_size, default) -> int: + """ + normalize an index from [-oo, oo] or None to [0, dim_size] + """ + if index is None: + return default + elif index < 0: + return max(index + dim_size, 0) + elif 0 <= index <= dim_size: + return index + else: + return dim_size + + +def resolve_out_dtype(input_dtypes: List[Union[ScalarType, str]]) -> str: + if len(input_dtypes) == 0: + raise ValueError('Expect at least one input dtype to resolve the output dtype.') + def combine(lhs_dtype, rhs_dtype) -> str: + pass + out_dtype = input_dtypes[0] + for input_dtype in input_dtypes[1:]: + out_dtype = ScalarType.resolve_out_dtype(out_dtype, input_dtype) + return out_dtype.name + diff --git a/python/hidet/tos/ops/schedules/__init__.py b/python/hidet/tos/ops/schedules/__init__.py new file mode 100644 index 0000000..d3635bf --- /dev/null +++ b/python/hidet/tos/ops/schedules/__init__.py @@ -0,0 +1,8 @@ +from . import cpu +from . import cuda + +from .cpu import generic_cpu_schedule +from .cuda import generic_cuda_schedule + +from .cuda.softmax import softmax_cuda_schedule +from .cuda.reduce import cuda_schedule_reduce_by_default, cuda_schedule_reduce_by_warp_reduce diff --git a/python/hidet/tos/ops/schedules/common.py b/python/hidet/tos/ops/schedules/common.py new file mode 100644 index 0000000..c33393d --- /dev/null +++ b/python/hidet/tos/ops/schedules/common.py @@ -0,0 +1,233 @@ +from __future__ import annotations +from typing import Mapping + +from hidet.ir.dialects.compute import TensorNode, ScalarNode +from hidet.ir.builders import StmtBuilder +from hidet.ir.expr import * +from hidet.ir.functors import infer_type, ExprRewriter, rewrite +from hidet.ir.stmt import ForStmt, BufferStoreStmt, AssignStmt +from hidet.ir.task import Task +from hidet.utils import prod + + +class NotSupportedError(Exception): + def __init__(self, obj: object, msg: str = ""): + self.obj = obj + self.msg = msg + + +class Schedule: + def keys(self) -> List[Tuple[str, Union[int, float, str]]]: + raise NotImplementedError() + + def derived_keys(self) -> List[Tuple[str, Union[int, float, str]]]: + raise NotImplementedError() + + +class LoopExpander(ExprRewriter): + def __init__(self, input_map): + super().__init__() + self.sb = StmtBuilder() + self.input_map = input_map + self.new_buffer_map = {} + + def expand(self, e): + value = self.visit(e) + return self.sb.finish(), value, self.new_buffer_map + + # def visit_TensorInput(self, e: TensorNode): + # return self.input_map[e] + # + # def visit_ScalarInput(self, e: ScalarNode): + # return self.input_map[e] + # + + def visit_TensorNode(self, e: TensorNode): + if e.grid_compute is None: + # input tensor + return self.input_map[e] + grid_compute = e.grid_compute + # declare output buffer when needed + if e in self.input_map: + buf = self.input_map[e] + else: + buf = Var(e.name, e.data_type) + self.new_buffer_map[e] = buf + + shape, axes, value = grid_compute.shape, grid_compute.axes, grid_compute.value + # tensor compute loops + for i in range(len(shape)): + self.sb.enter_body(ForStmt(axes[i], shape[i])) + + # at the innermost loop body + expr = self.visit(grid_compute.value) + self.sb.append(BufferStoreStmt(buf, axes, expr)) + + # exit loop scope + for i in range(len(shape)): + self.sb.exit_body() + + return buf + + def visit_ScalarNode(self, e: ScalarNode): + if e.reduce_compute is None: + # input scalar + return self.input_map[e] + + rc = e.reduce_compute + shape, axes, value = rc.shape, rc.axes, rc.value + # declare accumulator + acc = scalar_var(e.name, infer_type(value)) + self.new_buffer_map[e] = acc + + # init accumulator + self.sb += AssignStmt(acc, rc.init_const(rc.reduce_type, e.data_type.name)) + + # reduction loops + for i in range(len(shape)): + self.sb.enter_body(ForStmt(axes[i], shape[i])) + + # at the innermost loop body + expr = self.visit(value) + self.sb += AssignStmt(acc, rc.combine(rc.reduce_type, acc, expr)) + + # exit loop scope + for i in range(len(shape)): + self.sb.exit_body() + + # finalize + acc = rc.finalize(rc.reduce_type, acc, prod(shape)) + + # if e is in the input buffer, we should write it back + if e in self.input_map: + input_var = self.input_map[e] + self.sb += AssignStmt(input_var, acc) + + return acc + + +def expand_loop(expr: Expr, input_map: Mapping[Union[ScalarNode, TensorNode], Var]): + """ + Generate statements to calculate the expression. + + The expression may contain TensorCompute and ReduceCompute sub-expressions. + After expand, the stmt will not have ScalarInput, TensorInput, TensorCompute and ReduceCompute anymore. + + The returned new_buffer_map is a mapping from ReduceCompute and TensorCompute sub-expressions to + new allocated buffers used to conduct the computation. + + For example, the following expr: + compute([3, 3], (i, j) -> reduce_sum(A[i, k] * B[k, j], axis=k)) where k = axis(3) + will be expanded to + for i in range(3): + for j in range(3): + s = 0 + for k in range(3): + s += A[i, k] * B[k, j] + C[i, j] = s + + If C is in input_map, then the mapped var is used directly. Otherwise, a new tensor var is created to store the results + and returned in new_buffer_map. We only reuse tensor in input_map. + """ + expander = LoopExpander(input_map) + stmt, value, new_buffer_map = expander.expand(expr) + return stmt, value, new_buffer_map + + +class VirtualTensor: + """ + A virtual tensor map index to a value + VirtualTensor can be used to abstract an expression to a tensor. + Support indexing and slicing. + + For example, considering this expression: 0 <= i && i < 32 ? A[i] : 0.0, we can construct a + virtual tensor A = VirtualTensor(fmap=lambda i: 0<=i && i<32 ? A[i] : 0.0); + Then we can access A[i] and slice A[1:]. + """ + + def __init__(self, fmap): + self.fmap = fmap + + def __getitem__(self, item): + if not isinstance(item, (list, tuple)): + item = [item] + if any(isinstance(v, slice) for v in item): + starts = [] + indices = [] + for v in item: + if isinstance(v, slice): + starts.append(v.start if v.start else 0) + indices.append(None) + else: + starts.append(None) + indices.append(v) + + def fmap(*slice_indices): + assert len(indices) == len([v for v in starts if v is not None]) + orig_indices = [] + cur = 0 + for i in range(len(starts)): + if starts[i] is not None: + orig_indices.append(slice_indices[cur] + starts[i]) + cur += 1 + else: + orig_indices.append(indices[i]) + return self.__getitem__(orig_indices) + + return VirtualTensor(fmap) + else: + return self.fmap(*item) + + @staticmethod + def from_indexed_value(indices: Sequence[Var], value: Expr) -> VirtualTensor: + def fmap(*actual_indices): + if len(actual_indices) != len(indices): + raise ValueError('Expect {} number of indices, got {}.'.format(len(indices), len(actual_indices))) + return rewrite(value, {a: b for a, b in zip(indices, actual_indices)}) + + return VirtualTensor(fmap) + + +def params_from_task(task: Task) -> List[Var]: + return [Var(param.name, param.data_type) for param in task.inputs + task.outputs] + + +# def params_from_task(task: Task) -> List[Var]: +# return [Var(param.name, param.data_type) for param in task.parameters] +# +# +# def inputs_from_task(task: Task, params: List[Var]) -> List[Union[VirtualTensor, Var]]: +# inputs = [] +# param2var = {param: var for param, var in zip(task.parameters, params)} +# for input in task.inputs: +# if input in task.prologues: +# prologue = task.prologues[input] +# value = rewrite(prologue.value, param2var) +# inputs.append(VirtualTensor.from_indexed_value(prologue.indices, value)) +# else: +# assert input in param2var +# inputs.append(param2var[input]) +# return inputs +# +# +# def outputs_from_task(task: Task, params: List[Var]) -> List[Var]: +# outputs = [] +# param2var = {param: var for param, var in zip(task.parameters, params)} +# for output in task.outputs: +# assert output in param2var +# outputs.append(param2var[output]) +# return outputs +# +# +# def write_output(buf: Var, indices: List[Var], value: Expr, task: Task, params: List[Var]) -> BufferStoreStmt: +# param2var = {param: var for param, var in zip(task.parameters, params)} +# var2param = {var: param for param, var in zip(task.parameters, params)} +# param = var2param[buf] +# if param in task.epilogues: +# epilogue = task.epilogues[param] +# rmap = param2var +# rmap.update({a: b for a, b in zip(epilogue.indices, indices)}) +# value = rewrite(epilogue.value, rmap) +# return BufferStoreStmt(buf, indices, value) +# else: +# return BufferStoreStmt(buf, indices, value) diff --git a/python/hidet/tos/ops/schedules/cpu/__init__.py b/python/hidet/tos/ops/schedules/cpu/__init__.py new file mode 100644 index 0000000..646a339 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cpu/__init__.py @@ -0,0 +1 @@ +from .generic_cpu import generic_cpu_schedule diff --git a/python/hidet/tos/ops/schedules/cpu/generic_cpu.py b/python/hidet/tos/ops/schedules/cpu/generic_cpu.py new file mode 100644 index 0000000..b783b5a --- /dev/null +++ b/python/hidet/tos/ops/schedules/cpu/generic_cpu.py @@ -0,0 +1,17 @@ +from hidet.tos.ops.schedules.common import expand_loop +from hidet.ir.dialects.lowlevel import VoidType +from hidet.ir.expr import Var +from hidet.ir.func import IRModule, Function +from hidet.ir.task import Task + + +def generic_cpu_schedule(task: Task, space_level: int = 0) -> IRModule: + assert len(task.outputs) == 0 + func_param_vars = [Var(param.name, param.data_type) for param in task.parameters] + input_map = {p: v for p, v in zip(task.parameters, func_param_vars)} + body, _, new_buffer_map = expand_loop(task.outputs[0], input_map) + func_locals = list(new_buffer_map.values()) + func = Function(task.name + '.host', kind='host_kernel', params=func_param_vars, body=body, ret_type=VoidType(), + local_vars=func_locals, local_const_vars=[]) + module = IRModule({func.name: func}) + return module diff --git a/python/hidet/tos/ops/schedules/cuda/__init__.py b/python/hidet/tos/ops/schedules/cuda/__init__.py new file mode 100644 index 0000000..9c4e5c9 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/__init__.py @@ -0,0 +1,5 @@ +from . import matmul + +from .generic_cuda import generic_cuda_schedule +from .softmax import softmax_cuda_schedule +from .reduce import cuda_schedule_reduce_by_default diff --git a/python/hidet/tos/ops/schedules/cuda/common.py b/python/hidet/tos/ops/schedules/cuda/common.py new file mode 100644 index 0000000..9620c2c --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/common.py @@ -0,0 +1,125 @@ +from typing import List, Optional, Sequence, Tuple +import warnings +import os +from hidet.ir.func import IRModule +from hidet.ir.builders import StmtBuilder +from hidet.ir.primitives import active_mask, shfl_down_sync, shfl_sync +from hidet.ir.stmt import AssignStmt, Stmt +from hidet.ir.task import Task +from hidet.utils import gcd, prod +from hidet.ir.layout import TaskLayout, DataLayout, row_map, repeat_map, grid_map, row_layout, local_layout +from hidet.tos.ops.schedules.common import NotSupportedError + + +def warp_reduce(v, op) -> Stmt: + """ + Reduce over the threads in a warp. + + Parameters + ---------- + v: Var + The value to reduce. It must be a variable. + op: + An binary operator to represent the reducing operator, must be communicative and associative. + + Returns + ------- + ret: Stmt + A block statement to finish the reduction. After reduction, the value in each thread in the warp + has the reduced value. + """ + sb = StmtBuilder() + with sb.let('mask', active_mask()) as mask: + for delta in [16, 8, 4, 2, 1]: + sb += AssignStmt(v, op(v, shfl_down_sync(mask, v, delta=delta))) + sb += AssignStmt(v, shfl_sync(mask, v, src_lane=0)) + return sb.finish() + + +def _get_shapes(task_shape: Sequence[int], num_workers=32, perm: Optional[Sequence[int]] = None) -> Tuple[List[int], List[int]]: + rank = len(task_shape) + + if prod(task_shape) % num_workers != 0: + raise NotSupportedError('Number of workers must be a divisor of total number of tasks, ' + 'task shape {} and number workers {}.'.format(task_shape, num_workers)) + + # can not have duplicated dimensions + if len(set(perm)) != len(perm): + raise NotSupportedError('Duplicated ranks in perm: {}'.format(perm)) + + if len(perm) != rank: + raise NotSupportedError('Length of perm {} does not match task_shape {}'.format(perm, task_shape)) + + if len(set(perm) - set(range(rank))) != 0: + raise NotSupportedError('perm should be a permutation of {}, got {}'.format(list(range(rank)), perm)) + + # first fill in grid_shape + grid_shape = [0 for _ in range(rank)] + for i in reversed(range(rank)): + dim = perm.index(i) + factor = gcd(task_shape[dim], num_workers) + grid_shape[dim] = factor + num_workers //= factor + assert num_workers == 1 + + # remaining tasks are repeated by workers + repeat_shape = [] + for dim in range(rank): + repeat_shape.append(task_shape[dim] // grid_shape[dim]) + + return grid_shape, repeat_shape + + +def get_task_map(task_shape: Sequence[int], num_workers=32, perm: Sequence[int] = None) -> TaskLayout: + """ + Get a task map that maps a collection of workers to a task domain with given shape. The returned + task map is composed of repeat shape and grid shape. We first determine the size of each dimension + in the grid shape, then fill the repeat shape accordingly. + + It follows the following steps to construct the task map. + 1. Normalize the order of dimensions. The last dimension in the order is continuous regarding + worker index. + 2. Following the order of dimension, determine the grid shape. + 2. Fill the repeat shape. + + Parameters + ---------- + task_shape: Sequence[int] + The shape of the task domain. + num_workers: int + The number of workers. + perm: Optional[Sequence[int]] + todo: finish this. + + Returns + ------- + ret: TaskLayout + The task mapping that maps given number of workers to given task domain. + + Examples + -------- + + >>> get_task_map([4, 4], num_workers=2, perm=[0, 1]) + [[0 1 0 1] + [0 1 0 1] + [0 1 0 1] + [0 1 0 1]] + + >>> get_task_map([4, 4], num_workers=2, perm=[1, 0]) + [[0 0 0 0] + [1 1 1 1] + [0 0 0 0] + [1 1 1 1]] + """ + grid_shape, repeat_shape = _get_shapes(task_shape, num_workers, perm) + + task_map = repeat_map(*repeat_shape) * grid_map(grid_shape, order=perm) + return task_map + + +def get_transfer_task_map(task_shape: Sequence[int], num_workers=32, order: Optional[Sequence[int]] = None) -> Tuple[TaskLayout, DataLayout]: + grid_shape, repeat_shape = _get_shapes(task_shape, num_workers, order) + + task_map = repeat_map(*repeat_shape) * grid_map(grid_shape, order=order) + data_layout = row_layout(*repeat_shape) * local_layout(*grid_shape) + return task_map, data_layout diff --git a/python/hidet/tos/ops/schedules/cuda/generic_cuda.py b/python/hidet/tos/ops/schedules/cuda/generic_cuda.py new file mode 100644 index 0000000..7b761d5 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/generic_cuda.py @@ -0,0 +1,47 @@ +from hidet.tos.ops.schedules.common import expand_loop +from hidet.ir import IRModule +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.dialects.compute import TensorNode +from hidet.ir.expr import Var +from hidet.ir.functors import rewrite +from hidet.ir.layout import TaskLayout +from hidet.ir.primitives import block_idx, thread_idx +from hidet.ir.stmt import BufferStoreStmt +from hidet.ir.task import Task +from hidet.ir.functors import inline_compute + +from ..common import params_from_task + + +def generic_cuda_schedule(task: Task) -> IRModule: + computation: TensorNode = inline_compute(task.outputs[0], reduce_limit=16) + block_size = 512 + task_shape = computation.const_shape() + task_layout = TaskLayout.row_major(task_shape) + num_blocks = (task_layout.num_workers + block_size - 1) // block_size + + with FunctionBuilder(name=task.name + '_grid', grid_dim=num_blocks, block_dim=block_size, kind='cuda_kernel', label='generic implementer') as fb: + # params + params = params_from_task(task) + param_map = {param: var for param, var in zip(task.inputs + task.outputs, params)} + fb.extend_params(params) + scalar_value = rewrite(computation.grid_compute.value, param_map) # replace TensorInput to function parameter + assert len(task.outputs) == 1 + out = param_map[task.outputs[0]] + # body + sb = StmtBuilder() + worker_idx = block_idx() * block_size + thread_idx() + with sb.if_then(worker_idx < task_layout.num_workers): + with sb.for_task(worker_index=worker_idx, task_layout=task_layout) as tasks: + buffer_map = {} + for axes_values in tasks: + remap = {axis: value for axis, value in zip(computation.grid_compute.axes, axes_values)} + stmt, value, new_buffer_map = expand_loop(rewrite(scalar_value, remap), input_map=buffer_map) + buffer_map.update(new_buffer_map) + sb += stmt + sb += BufferStoreStmt(out, axes_values, value) + fb.extend_local_vars(list(buffer_map.values())) + fb.set_body(sb.finish()) + func = fb.get() + return IRModule(funcs={func.name: func}, task=task) + diff --git a/python/hidet/tos/ops/schedules/cuda/matmul/__init__.py b/python/hidet/tos/ops/schedules/cuda/matmul/__init__.py new file mode 100644 index 0000000..134aa66 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/matmul/__init__.py @@ -0,0 +1,3 @@ +from .bmm import batched_matmul_cuda_schedule_default +from .bmm_wb import batched_matmul_cuda_schedule_wb +from .bmm_wmma import batched_matmul_cuda_schedule_wmma diff --git a/python/hidet/tos/ops/schedules/cuda/matmul/bmm.py b/python/hidet/tos/ops/schedules/cuda/matmul/bmm.py new file mode 100644 index 0000000..6cf2b9c --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/matmul/bmm.py @@ -0,0 +1,391 @@ +from typing import List, Tuple, Union, Optional + +import os +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.dialects.lowlevel import TensorPointerType, PointerType +from hidet.ir.expr import Var, And, Equal, Cast, if_then_else, convert, Expr +from hidet.ir.func import IRModule +from hidet.ir.functors import simplify_to_int +from hidet.ir.layout import TaskLayout, DataLayout, StridesLayout +from hidet.ir.primitives import syncthreads, thread_idx, block_idx +from hidet.ir.stmt import AssignStmt, BufferStoreStmt, IfStmt +from hidet.ir.type import scalar_type, tensor_type, ScalarType +from hidet.ir.task import TaskContext +from hidet.utils import cuda +from hidet.tos.ops.definitions.matmul.matmul import MatmulTask +from hidet.tos.ops.schedules.resolve import resolve_ir_modules +from hidet.tos.ops.schedules.common import params_from_task, Schedule, NotSupportedError + + +""" +pseudo code of matmul with double buffering +========= +assume block_k % task_k == 0 and warp_k % block_k == 0 +gmem[0] -> smem[0] +sync +smem[0, 0] -> regs[0] +sync +for k0 in range(task_k / block_k - 1): + for k1 in range(block_k / warp_k): + if k1 == 0: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + gmem[k0 + 1] -> smem[(k0 + 1) % 2] + regs[k1 % 2] -> acc regs + elif 0 < k1 < block_k / warp_k - 1: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + else k1 == block_k / warp_k - 1: + sync + smem[(k0 + 1) % 2, 0] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs +k0 = task_k / block_k - 1 +for k1 in range(block_k / warp_k): + if k1 == 0: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + elif 0 < k1 < block_k / warp_k - 1: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + else k1 == block_k / warp_k - 1: + regs[k1 % 2] -> acc regs +sync +write back +""" + + +class MatmulSchedule(Schedule): + def __init__( + self, + block_warps_k=8, + warp_k=1, + block_warps=(4, 2), + warp_outer=(2, 2), + atom_layout=TaskLayout.row_major([4, 8]), + atom_layout_name='row_4x8', + warp_inner=(4, 4), + dtype='float32' + ): + self.block_warps_k = block_warps_k + self.warp_k = warp_k + self.block_warps = block_warps + self.warp_outer = warp_outer + self.atom_layout = atom_layout + self.warp_inner = warp_inner + self.atom_layout_name = atom_layout_name + + # sanity check + row_major = TaskLayout.row_major + full_layout = TaskLayout.full_layout + warp_outer_layout = full_layout(warp_outer) + warp_inner_layout = full_layout(warp_inner) + warp_layout = warp_outer_layout * atom_layout * warp_inner_layout + block_warps_layout = row_major(block_warps) + block_layout = block_warps_layout * warp_layout + block_k = block_warps_k * warp_k + atom_shape = atom_layout.task_shape + block_shape = block_layout.task_shape + warp_size = 32 + block_size = block_layout.num_workers + self.check(atom_layout.num_workers == 32, "atom layout should have exactly 32 workers, corresponding to 32 threads in a warp") + self.check(block_warps_k % 2 == 0, "double buffering requires that block_k/warp_k is divisible by 2") + if block_k <= warp_size: + self.check(warp_size % block_k == 0, f"transfer from gmem to smem requires block_k ({block_k}) is divisible by warp_size ({warp_size})") + self.check(block_shape[0] % (block_size // block_k) == 0 and block_shape[1] % (block_size // block_k) == 0, + f"transfer of matrix A/B from gmem to regs requirement. block_shape ({block_shape}) block_size ({block_size}) block_k ({block_k}) block_size / block_k ({block_size / block_k})") + else: + self.check(block_k % warp_size == 0, "transfer from gmem to smem requires warp_size is divisible by block_k") + raise NotSupportedError(self, "Will support later") + + # derived data layouts + local_layout = DataLayout.local + row_major = DataLayout.row_major + col_major = DataLayout.column_major + self.regs_a_layout = local_layout((block_warps[0], 1)) * col_major((warp_outer[0], warp_k)) * local_layout((atom_shape[0], 1)) * row_major((warp_inner[0], 1)) + self.regs_b_layout = local_layout((1, block_warps[1])) * row_major((warp_k, warp_outer[1])) * local_layout((1, atom_shape[1])) * row_major((1, warp_inner[1])) + self.regs_c_layout = local_layout(block_warps) * row_major(warp_outer) * local_layout(atom_shape) * row_major(warp_inner) + if block_k <= warp_size: + self.regs_a_ldg_layout = local_layout((block_size // block_k, block_k)) * row_major((block_shape[0] // (block_size // block_k), 1)) + self.regs_b_ldg_layout = row_major((1, block_shape[1] // (block_size // block_k))) * local_layout((block_k, block_size // block_k)) + else: + raise NotSupportedError(self) + reserved_regs = 48 # number of reserved registers for intermediate results + used_num_regs_per_thread = self.regs_a_layout.size + self.regs_b_layout.size + self.regs_c_layout.size + self.regs_a_ldg_layout.size + self.regs_b_ldg_layout.size + reserved_regs + used_num_regs_per_thread = (used_num_regs_per_thread + 7) // 8 * 8 # the number of registers allocated to each thread is a multiple of 8. + resident_blocks = cuda.max_num_regs_per_sm() // (used_num_regs_per_thread * block_size) + + max_smem_bytes_per_block = min(cuda.max_smem_bytes_per_sm() // resident_blocks, cuda.max_smem_bytes_per_block()) // 128 * 128 + + # derived task layouts + row_major = TaskLayout.row_major + full_layout = TaskLayout.full_layout + self.block_warps_layout = block_warps_layout + self.warp_layout = warp_layout + self.block_layout = block_layout + if block_k <= warp_size: + lines = block_size // block_k + self.a_g2s_layout = row_major([lines, block_k]) * full_layout([block_shape[0] // lines, 1]) + self.b_g2s_layout = full_layout([1, block_shape[1] // lines]) * row_major([block_k, lines]) + else: + raise NotSupportedError(self) + self.a_s2r_layout = (self.block_warps_layout * full_layout([warp_outer[0], warp_k]) * atom_layout * full_layout([warp_inner[0], warp_k])).projection({1: 0}) + self.b_s2r_layout = (self.block_warps_layout * full_layout([warp_k, warp_outer[1]]) * atom_layout * full_layout([warp_k, warp_inner[1]])).projection({0: 0}) + + # derived constants + used_smem_bytes_per_block = (block_shape[0] + block_shape[1]) * block_k * 2 * ScalarType(dtype).nbytes() # 2 for double buffering, 4 for number of bytes per float32 + self.check(used_smem_bytes_per_block <= max_smem_bytes_per_block, f"Used shared memory ({used_smem_bytes_per_block} bytes) exceeded the maximum ({max_smem_bytes_per_block} bytes)") + self.block_size = block_size + self.block_shape = block_layout.task_shape + self.block_k = block_k + self.warp_shape = warp_layout.task_shape + self.warp_k = warp_k + self.used_num_regs_per_thread = used_num_regs_per_thread + self.used_smem_bytes_per_block = used_smem_bytes_per_block + # we muse use dynamic shared memory when we use more than 48 KiBytes shared memory + # see Appendix 'Compute Capability' in CUDA C Programming Guide + self.use_dynamic_smem = (used_smem_bytes_per_block > 48 * 1024) + self.min_thread_blocks = resident_blocks + + self.check(used_num_regs_per_thread <= cuda.max_num_regs_per_thread(), f'register used {used_num_regs_per_thread} exceeds maximum {cuda.max_num_regs_per_thread()}') + self.check(used_num_regs_per_thread * block_size <= cuda.max_num_regs_per_block(), f'echo block can only have {cuda.max_num_regs_per_block()} registers, but this schedule requires {used_num_regs_per_thread * block_size} registers') + + def keys(self) -> List[Tuple[str, Union[int, float, str]]]: + return [ + ('bwx', self.block_warps[0]), + ('bwy', self.block_warps[1]), + ('wox', self.warp_outer[0]), + ('woy', self.warp_outer[1]), + ('atom', self.atom_layout_name), + ('wix', self.warp_inner[0]), + ('wiy', self.warp_inner[1]), + ('bk', self.block_k), + ('wk', self.warp_k), + ('mtb', self.min_thread_blocks) + ] + + def derived_keys(self) -> List[Tuple[str, Union[int, float, str]]]: + return [ + ('bx', self.block_shape[0]), + ('by', self.block_shape[1]), + ('regs', self.used_num_regs_per_thread), + ('smem', self.used_smem_bytes_per_block), + ] + + def __str__(self): + return 'overall_{}x{}x{}_blcok_warps_{}x{}_outer_{}_{}_middle_{}x{}_inner_{}x{}_warpk_{}_atom_{}_min_blocks_{}'.format( + *self.block_layout.task_shape, self.block_warps_k * self.warp_k, *self.block_warps, *self.warp_outer, *self.atom_layout.task_shape, *self.warp_inner, + self.warp_k, self.atom_layout_name, self.min_thread_blocks + ) + + def check(self, cond, msg: str = ""): + if not cond: + raise NotSupportedError(self, msg) + + @staticmethod + def schedules(space_level: int = 0): + settings = [] + if space_level == 0: + settings.append(MatmulSchedule()) + elif space_level == 1: + for inner_m, inner_n in [[4, 4], [4, 8], [8, 4]]: + for outer_m, outer_n in [[1, 1], [1, 2], [2, 1], [2, 2]]: + for block_warps_k, warp_k in [[8, 1]]: + for block_warps_m, block_warps_n in [[1, 1], [1, 2], [2, 2], [2, 4]]: + for name, atom_layout in [('row_4x8', TaskLayout.row_major((4, 8)))]: + try: + settings.append(MatmulSchedule( + block_warps_k=block_warps_k, + warp_k=warp_k, + block_warps=[block_warps_m, block_warps_n], + warp_outer=[outer_m, outer_n], + atom_layout=atom_layout, + atom_layout_name=name, + warp_inner=[inner_m, inner_n] + )) + except NotSupportedError as e: + pass + elif space_level == 2: + grid = TaskLayout.row_major + for inner_m, inner_n in [[4, 4]]: + for outer_m, outer_n in [[1, 1], [1, 2], [2, 1], [2, 2], [1, 3], [3, 1], [2, 3], [3, 2], [3, 3]]: + for block_warps_k, warp_k in [[4, 1], [8, 1]]: + for block_warps_m, block_warps_n in [[1, 1], [1, 2], [2, 1], [2, 2], [2, 4], [4, 2]]: + for name, atom_layout in [ + ('row_4x8', grid((4, 8))), + ('custom_4x8', grid((2, 1)) * grid((1, 8)) * grid((2, 1))), + ('row_2x16', grid((2, 16))), + # ('row_1x32', grid((1, 32))), + ]: + try: + settings.append(MatmulSchedule( + block_warps_k=block_warps_k, + warp_k=warp_k, + block_warps=[block_warps_m, block_warps_n], + warp_outer=[outer_m, outer_n], + atom_layout=atom_layout, + atom_layout_name=name, + warp_inner=[inner_m, inner_n] + )) + except NotSupportedError as e: + # print() + # print(e.msg) + pass + else: + raise NotImplementedError() + return settings + + +def batched_matmul_cuda_schedule_default(task: MatmulTask) -> IRModule: + ctx = TaskContext.current() + all_schedules = MatmulSchedule.schedules(space_level=ctx.space_level) + default_resolve_out_dir = os.path.join('./outs/resolve', task.name, 'batched_matmul_default_{}x{}x{}x{}'.format(task.batch_size, task.m_size, task.k_size, task.n_size)) + resolve_out_dir = ctx.resolve_out_dir if ctx.resolve_out_dir else default_resolve_out_dir + ir_modules = [] + for schedule in all_schedules: + ir_modules.append(batched_matmul_cuda_with_given_schedule(task, schedule)) + return resolve_ir_modules( + ir_modules=ir_modules, + schedules=all_schedules, + output_dir=resolve_out_dir, + parallel=True, + verbose=True + ) + + +def batched_matmul_cuda_with_given_schedule(task: MatmulTask, schedule: MatmulSchedule) -> IRModule: + ir_module = IRModule(task=task) + sch = schedule + + a_dtype = task.inputs[0].data_type.scalar_type + b_dtype = task.inputs[1].data_type.scalar_type + c_dtype = task.outputs[0].data_type.scalar_type + + batch_size = task.batch_size + m_size, k_size, n_size = task.m_size, task.k_size, task.n_size + + m_tile_size, n_tile_size = sch.block_shape + m_tiles = (m_size + m_tile_size - 1) // m_tile_size + n_tiles = (n_size + n_tile_size - 1) // n_tile_size + grid_blocks_layout: TaskLayout = TaskLayout.row_major([m_tiles, n_tiles]) + + # define function + with FunctionBuilder( + name=task.name + '.grid', + kind='cuda_kernel', + grid_dim=(grid_blocks_layout.num_workers, batch_size), + block_dim=sch.block_size, + dynamic_smem_bytes=sch.used_smem_bytes_per_block if sch.use_dynamic_smem else 0, + min_blocks=sch.min_thread_blocks, + label=str(sch)) as fb: + sb = StmtBuilder() + + # declare params + params = params_from_task(task) + gmem_a, gmem_b, gmem_c = params + fb.extend_params(params) + + # declare local variables + smem_a = Var('smem_a', TensorPointerType('shared', a_dtype, layout=StridesLayout.from_shape([2, sch.block_shape[0], sch.block_k], perm=[0, 2, 1]))) + smem_b = Var('smem_b', TensorPointerType('shared', b_dtype, layout=StridesLayout.from_shape([2, sch.block_k, sch.block_shape[1]], perm=[0, 1, 2]))) + if sch.use_dynamic_smem: + # 'extern __shared__ uint8_t smem_storage[];' in c code + smem_storage = Var('smem_storage', PointerType(base_type=scalar_type('uint8'), specifiers=['extern', '__shared__'], use_bracket=True)) + else: + smem_storage = Var('smem_storage', tensor_type('shared', dtype='uint8', shape=[sch.used_smem_bytes_per_block])) + smem_a_bytes = simplify_to_int(smem_a.type.tensor_type.storage_bytes()) + fb.extend_local_vars([smem_a, smem_b, smem_storage]) + sb += AssignStmt(smem_a, Cast(~smem_storage[0], PointerType(a_dtype))) + sb += AssignStmt(smem_b, Cast(~(smem_storage[smem_a_bytes]), PointerType(b_dtype))) + + # declare a, b, c registers + regs_a = Var('regs_A', tensor_type('register', a_dtype, layout=[2] + schedule.regs_a_layout)) + regs_b = Var('regs_B', tensor_type('register', b_dtype, layout=[2] + schedule.regs_b_layout)) + regs_c = Var('regs_C', tensor_type('register', c_dtype, layout=schedule.regs_c_layout)) + regs_a_ldg = Var('regs_A_ldg', tensor_type(scope='register', dtype=a_dtype, layout=schedule.regs_a_ldg_layout)) + regs_b_ldg = Var('regs_B_ldg', tensor_type(scope='register', dtype=b_dtype, layout=schedule.regs_b_ldg_layout)) + fb.extend_local_vars([regs_a, regs_b, regs_c, regs_a_ldg, regs_b_ldg]) + + a_default_value = convert(0.0, a_dtype) + b_default_value = convert(0.0, b_dtype) + acc_default_value = convert(0.0, c_dtype) + + with sb.lets(['bi', 'bj'], grid_blocks_layout(block_idx())[0]) as (bi, bj): + block_k_tiles = (k_size + sch.block_k - 1) // sch.block_k + first_k_tile = k_size - (block_k_tiles - 1) * sch.block_k + block_offset = [idx * dim for idx, dim in zip([bi, bj], sch.block_shape)] + # transfer first tile + sb += copy(gmem_a[block_idx('y'), block_offset[0]:, :], regs_a_ldg, schedule.a_g2s_layout, src_predicate=lambda i, k: And.join(block_offset[0] + i < m_size, k < first_k_tile), default_value=a_default_value) + sb += copy(regs_a_ldg, smem_a[0], layout=schedule.a_g2s_layout) + sb += copy(gmem_b[block_idx('y'), :, block_offset[1]:], regs_b_ldg, schedule.b_g2s_layout, src_predicate=lambda k, j: And.join(k < first_k_tile, block_offset[1] + j < n_size), default_value=b_default_value) + sb += copy(regs_b_ldg, smem_b[0], layout=schedule.b_g2s_layout) + sb += syncthreads() + sb += copy(smem_a[0], regs_a[0], schedule.a_s2r_layout) + sb += copy(smem_b[0], regs_b[0], schedule.b_s2r_layout) + sb += syncthreads() + # init regs c + sb += init(regs_c, acc_default_value, schedule.block_layout) + with sb.for_loop('k0', block_k_tiles - 1) as k0: + block_offset_k = k0 * sch.block_k + first_k_tile + with sb.for_loop('k1', sch.block_warps_k) as k1: + with sb.if_then(Equal(k1, sch.block_warps_k - 1)): + sb += copy(regs_a_ldg, smem_a[(k0 + 1) % 2], schedule.a_g2s_layout) + sb += copy(regs_b_ldg, smem_b[(k0 + 1) % 2], schedule.b_g2s_layout) + sb += syncthreads() + sb += copy(smem_a[(k0 + 1) % 2], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[(k0 + 1) % 2], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + with sb.otherwise(): + sb += copy(smem_a[k0 % 2, :, k1 + 1:], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[k0 % 2, k1 + 1:, :], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + with sb.if_then(Equal(k1, 0)): + sb += copy(gmem_a[block_idx('y'), block_offset[0]:, block_offset_k:], regs_a_ldg, schedule.a_g2s_layout, src_predicate=lambda i, _: block_offset[0] + i < m_size, default_value=a_default_value) + sb += copy(gmem_b[block_idx('y'), block_offset_k:, block_offset[1]:], regs_b_ldg, schedule.b_g2s_layout, src_predicate=lambda _, j: block_offset[1] + j < n_size, default_value=b_default_value) + sb += mma(regs_a[k1 % 2], regs_b[k1 % 2], regs_c, schedule) + with sb.let('block_k_tile', block_k_tiles - 1) as k0: + with sb.for_loop('warp_k_tile', sch.block_warps_k) as k1: + with sb.if_then(k1 < sch.block_warps_k - 1): + sb += copy(smem_a[k0 % 2, :, k1 + 1:], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[k0 % 2, k1 + 1:, :], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + sb += mma(regs_a[k1 % 2], regs_b[k1 % 2], regs_c, schedule) + sb += copy(src=regs_c, dst=gmem_c[block_idx('y'), block_offset[0]:, block_offset[1]:], layout=schedule.block_layout, + dst_predicate=lambda i, j: And(block_offset[0] + i < m_size, block_offset[1] + j < n_size)) + # set body + fb.set_body(sb.finish()) + + func = fb.get() + ir_module.add(func.name, func) + return ir_module + + +def init(dst, init_value, layout): + sb = StmtBuilder() + for indices in layout(thread_idx()): + sb += BufferStoreStmt(dst, indices, init_value) + return sb.finish() + + +def copy(src, dst, layout, src_predicate=None, dst_predicate=None, default_value: Optional[Union[Expr, float]] = 0.0): + sb = StmtBuilder() + for indices in layout(thread_idx()): + value = src.__getitem__(indices) + if src_predicate: + value = if_then_else(src_predicate(*indices), value, default_value) + stmt = BufferStoreStmt(dst, indices, value) + if dst_predicate: + stmt = IfStmt(dst_predicate(*indices), stmt) + sb += stmt + return sb.finish() + + +def mma(a, b, c, schedule): + layout = schedule.block_layout + sb = StmtBuilder() + for i, j in layout(thread_idx()): + for k in range(schedule.warp_k): + sb += BufferStoreStmt(c, [i, j], c[i, j] + a[i, k] * b[k, j]) + return sb.finish() + + +if __name__ == '__main__': + schedules = MatmulSchedule.schedules(space_level=2) + # print(len(schedules)) + for sch in schedules: + print(sch) diff --git a/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wb.py b/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wb.py new file mode 100644 index 0000000..f8dce6c --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wb.py @@ -0,0 +1,427 @@ +import itertools +from typing import List, Tuple, Union, Optional + +import os +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.dialects.lowlevel import TensorPointerType, PointerType +from hidet.ir.expr import Var, And, Equal, Cast, if_then_else, convert, Expr +from hidet.ir.func import IRModule +from hidet.ir.functors import simplify_to_int +from hidet.ir.layout import TaskLayout, DataLayout, StridesLayout +from hidet.ir.primitives import syncthreads, thread_idx, block_idx +from hidet.ir.stmt import AssignStmt, BufferStoreStmt, IfStmt +from hidet.ir.type import scalar_type, tensor_type +from hidet.ir.task import TaskContext +from hidet.utils import cuda, factor, prod +from hidet.tos.ops.definitions.matmul.matmul import MatmulTask +from hidet.tos.ops.schedules.resolve import resolve_ir_modules +from hidet.tos.ops.schedules.common import params_from_task, Schedule, NotSupportedError + + +""" +pseudo code of matmul with double buffering +========= +assume block_k % task_k == 0 and warp_k % block_k == 0 +gmem[0] -> smem[0] +sync +smem[0, 0] -> regs[0] +sync +for k0 in range(task_k / block_k - 1): + for k1 in range(block_k / warp_k): + if k1 == 0: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + gmem[k0 + 1] -> smem[(k0 + 1) % 2] + regs[k1 % 2] -> acc regs + elif 0 < k1 < block_k / warp_k - 1: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + else k1 == block_k / warp_k - 1: + sync + smem[(k0 + 1) % 2, 0] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs +k0 = task_k / block_k - 1 +for k1 in range(block_k / warp_k): + if k1 == 0: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + elif 0 < k1 < block_k / warp_k - 1: + smem[k0 % 2, k1+1] -> regs[(k1 + 1) % 2] + regs[k1 % 2] -> acc regs + else k1 == block_k / warp_k - 1: + regs[k1 % 2] -> acc regs +sync +write back +""" + + +class CustomTaskLayout(TaskLayout): + def __init__(self): + super().__init__(num_workers=32, task_shape=(4, 8), worker2task=self._work2task) + + @staticmethod + def _work2task(w): + return [(w // 16 * 2 + w % 2, w // 2 % 8)] + + +class MatmulSchedule(Schedule): + def __init__(self, + block_warps_k=8, + warp_k=1, + block_warps=(4, 2), + warp_outer=(2, 2), + atom_layout=CustomTaskLayout(), + atom_layout_name='custom_4x8', + warp_inner=(4, 4)): + self.block_warps_k = block_warps_k + self.warp_k = warp_k + self.block_warps = block_warps + self.warp_outer = warp_outer + self.atom_layout = atom_layout + self.warp_inner = warp_inner + self.atom_layout_name = atom_layout_name + + # sanity check + row_major = TaskLayout.row_major + full_layout = TaskLayout.full_layout + warp_outer_layout = full_layout(warp_outer) + warp_inner_layout = full_layout(warp_inner) + warp_layout = warp_outer_layout * atom_layout * warp_inner_layout + block_warps_layout = row_major(block_warps) + block_layout = block_warps_layout * warp_layout + block_k = block_warps_k * warp_k + atom_shape = atom_layout.task_shape + block_shape = block_layout.task_shape + warp_size = 32 + block_size = block_layout.num_workers + self.check(atom_layout.num_workers == 32, "atom layout should have exactly 32 workers, corresponding to 32 threads in a warp") + self.check(block_warps_k % 2 == 0, "double buffering requires that block_k/warp_k is divisible by 2") + if block_k <= warp_size: + self.check(warp_size % block_k == 0, f"transfer from gmem to smem requires block_k ({block_k}) is divisible by warp_size ({warp_size})") + # todo: consider removing the following two constraints by adding bound-checking in source-template + self.check(block_shape[0] % (block_size // block_k) == 0 and block_shape[1] % (block_size // block_k) == 0, + f"transfer of matrix A/B from gmem to regs requirement. block_shape ({block_shape}) block_size ({block_size}) block_k ({block_k}) block_size / block_k ({block_size / block_k})") + else: + self.check(block_k % warp_size == 0, "transfer from gmem to smem requires warp_size is divisible by block_k") + raise NotSupportedError(self, "Will support later") + + # derived data layouts + local_layout = DataLayout.local + row_major = DataLayout.row_major + col_major = DataLayout.column_major + self.regs_a_layout = local_layout((block_warps[0], 1)) * col_major((warp_outer[0], warp_k)) * local_layout((atom_shape[0], 1)) * row_major((warp_inner[0], 1)) + self.regs_b_layout = local_layout((1, block_warps[1])) * row_major((warp_k, warp_outer[1])) * local_layout((1, atom_shape[1])) * row_major((1, warp_inner[1])) + self.regs_c_layout = local_layout(block_warps) * row_major(warp_outer) * local_layout(atom_shape) * row_major(warp_inner) + if block_k <= warp_size: + self.regs_a_ldg_layout = local_layout((block_size // block_k, block_k)) * row_major((block_shape[0] // (block_size // block_k), 1)) + self.regs_b_ldg_layout = row_major((1, block_shape[1] // (block_size // block_k))) * local_layout((block_k, block_size // block_k)) + else: + raise NotSupportedError(self) + reserved_regs = 16 + used_num_regs_per_thread = self.regs_a_layout.size + self.regs_b_layout.size + self.regs_c_layout.size + self.regs_a_ldg_layout.size + self.regs_b_ldg_layout.size + reserved_regs + used_num_regs_per_thread = (used_num_regs_per_thread + 7) // 8 * 8 # the number of registers allocated to each thread is a multiple of 8. + self.check(used_num_regs_per_thread <= cuda.max_num_regs_per_thread(), + f'register used {used_num_regs_per_thread} exceeds maximum {cuda.max_num_regs_per_thread()}') + self.check(used_num_regs_per_thread * block_size <= cuda.max_num_regs_per_block(), + f'echo block can only have {cuda.max_num_regs_per_block()} registers, but this schedule requires {used_num_regs_per_thread * block_size} registers') + resident_blocks = cuda.max_num_regs_per_sm() // (used_num_regs_per_thread * block_size) + + max_smem_bytes_per_block = min(cuda.max_smem_bytes_per_sm() // resident_blocks, cuda.max_smem_bytes_per_block()) // 128 * 128 + + # derived task layouts + row_major = TaskLayout.row_major + full_layout = TaskLayout.full_layout + self.block_warps_layout = block_warps_layout + self.warp_layout = warp_layout + self.block_layout = block_layout + if block_k <= warp_size: + lines = block_size // block_k + self.a_g2s_layout = row_major([lines, block_k]) * full_layout([block_shape[0] // lines, 1]) + self.b_g2s_layout = full_layout([1, block_shape[1] // lines]) * row_major([block_k, lines]) + else: + raise NotSupportedError(self) + self.a_s2r_layout = (self.block_warps_layout * full_layout([warp_outer[0], warp_k]) * atom_layout * full_layout([warp_inner[0], warp_k])).projection({1: 0}) + self.b_s2r_layout = (self.block_warps_layout * full_layout([warp_k, warp_outer[1]]) * atom_layout * full_layout([warp_k, warp_inner[1]])).projection({0: 0}) + + pairs = [] + for a, b in itertools.product(factor(warp_outer[0]), factor(warp_outer[1])): + used_smem_bytes = prod((block_warps_layout * full_layout([a, b]) * atom_layout * warp_inner_layout).task_shape) * 4 # 4 types per float32, todo: update when support other data type + if used_smem_bytes > max_smem_bytes_per_block: + continue + pairs.append((a, b)) + self.check(len(pairs) > 0, "Can not find a write-back config") + pair = max(pairs, key=lambda p: p[0] * p[1]) + self.c_warp_r2s_layout = full_layout(pair) * atom_layout * warp_inner_layout + c_wb_shape = self.c_warp_r2s_layout.task_shape + if warp_size <= c_wb_shape[1]: + self.check(c_wb_shape[1] % warp_size == 0, f"C write back alignment requirement, warp_size = {warp_size}, c_wb_shape = {c_wb_shape}") + self.c_warp_s2g_layout = full_layout([c_wb_shape[0], c_wb_shape[1] // warp_size]) * row_major([1, warp_size]) + else: + self.check(warp_size % c_wb_shape[1] == 0 and c_wb_shape[0] % (warp_size // c_wb_shape[1]), f"C write back alignment requirement, warp_size = {warp_size}, c_wb_shape = {c_wb_shape}") + lines = warp_size // c_wb_shape[1] + self.c_warp_s2g_layout = full_layout([c_wb_shape[0] // lines, 1]) * row_major([lines, c_wb_shape[1]]) + + # derived constants + used_smem_bytes_per_block = max((block_shape[0] + block_shape[1]) * block_k * 2 * 4, # 2 for double buffering, 4 for number of bytes per float32 + prod((block_warps_layout * self.c_warp_r2s_layout).task_shape) * 4) # 4 for number of bytes per float32 + self.check(used_smem_bytes_per_block <= max_smem_bytes_per_block, f"Used shared memory ({used_smem_bytes_per_block} bytes) exceeded the maximum ({max_smem_bytes_per_block} bytes)") + self.block_size = block_size + self.block_shape = block_layout.task_shape + self.block_k = block_k + self.warp_shape = warp_layout.task_shape + self.warp_k = warp_k + self.c_wb_outer = [a // b for a, b in zip(warp_outer, pair)] + self.c_wb_shape = c_wb_shape + self.used_num_regs_per_thread = used_num_regs_per_thread + self.used_smem_bytes_per_block = used_smem_bytes_per_block + # self.used_smem_bytes_per_block = 2048 * 4 + # we muse use dynamic shared memory when we use more than 48 KiBytes shared memory + # see Appendix 'Compute Capability' in CUDA C Programming Guide + self.use_dynamic_smem = (used_smem_bytes_per_block > 48 * 1024) + self.min_thread_blocks = resident_blocks + # self.use_dynamic_smem = False + + def keys(self) -> List[Tuple[str, Union[int, float, str]]]: + return [ + ('bwx', self.block_warps[0]), + ('bwy', self.block_warps[1]), + ('wox', self.warp_outer[0]), + ('woy', self.warp_outer[1]), + ('atom', self.atom_layout_name), + ('wix', self.warp_inner[0]), + ('wiy', self.warp_inner[1]), + ('bk', self.block_k), + ('wk', self.warp_k), + ('mtb', self.min_thread_blocks) + ] + + def derived_keys(self) -> List[Tuple[str, Union[int, float, str]]]: + return [ + ('bx', self.block_shape[0]), + ('by', self.block_shape[1]), + ('regs', self.used_num_regs_per_thread), + ('smem', self.used_smem_bytes_per_block), + ] + + def __str__(self): + return 'overall_{}x{}x{}_blcok_warps_{}x{}_outer_{}_{}_middle_{}x{}_inner_{}x{}_warpk_{}_atom_{}_min_blocks_{}'.format( + *self.block_layout.task_shape, self.block_warps_k * self.warp_k, *self.block_warps, *self.warp_outer, *self.atom_layout.task_shape, *self.warp_inner, + self.warp_k, self.atom_layout_name, self.min_thread_blocks + ) + + def check(self, cond, msg: str = ""): + if not cond: + raise NotSupportedError(self, msg) + + @staticmethod + def schedules(space_level: int = 0): + settings = [] + if space_level == 0: + settings.append(MatmulSchedule()) + elif space_level == 1: + for inner_m, inner_n in [[4, 4], [4, 8], [8, 4]]: + for outer_m, outer_n in [[1, 1], [1, 2], [2, 1], [2, 2]]: + for block_warps_k, warp_k in [[8, 1]]: + for block_warps_m, block_warps_n in [[1, 1], [1, 2], [2, 2], [2, 4]]: + for name, atom_layout in [('row_4x8', TaskLayout.row_major((4, 8)))]: + try: + settings.append(MatmulSchedule( + block_warps_k=block_warps_k, + warp_k=warp_k, + block_warps=[block_warps_m, block_warps_n], + warp_outer=[outer_m, outer_n], + atom_layout=atom_layout, + atom_layout_name=name, + warp_inner=[inner_m, inner_n] + )) + except NotSupportedError as e: + pass + elif space_level == 2: + for inner_m, inner_n in [[4, 4]]: + for outer_m, outer_n in [[1, 1], [1, 2], [2, 1], [2, 2], [1, 3], [3, 1], [2, 3], [3, 2], [3, 3]]: + for block_warps_k, warp_k in [[4, 1], [8, 1]]: + for block_warps_m, block_warps_n in [[1, 1], [1, 2], [2, 1], [2, 2], [2, 4], [4, 2]]: + for name, atom_layout in [('row_4x8', TaskLayout.row_major((4, 8))), ('custom_4x8', CustomTaskLayout())]: + try: + settings.append(MatmulSchedule( + block_warps_k=block_warps_k, + warp_k=warp_k, + block_warps=[block_warps_m, block_warps_n], + warp_outer=[outer_m, outer_n], + atom_layout=atom_layout, + atom_layout_name=name, + warp_inner=[inner_m, inner_n] + )) + except NotSupportedError as e: + print() + print(e.obj) + print(e.msg) + pass + else: + raise NotImplementedError() + return settings + + +def batched_matmul_cuda_schedule_wb(task: MatmulTask) -> IRModule: + ctx = TaskContext.current() + schedules = MatmulSchedule.schedules(space_level=ctx.space_level) + default_resolve_out_dir = os.path.join('./outs/resolve', task.name, 'batched_matmul_{}x{}x{}x{}'.format(task.batch_size, task.m_size, task.k_size, task.n_size)) + resolve_out_dir = ctx.resolve_out_dir if ctx.resolve_out_dir else default_resolve_out_dir + ir_modules = [] + for schedule in schedules: + ir_modules.append(batched_matmul_cuda_with_given_schedule(task, schedule)) + return resolve_ir_modules( + ir_modules=ir_modules, + schedules=schedules, + output_dir=resolve_out_dir, + parallel=True, + verbose=True + ) + + +def batched_matmul_cuda_with_given_schedule(task: MatmulTask, schedule: MatmulSchedule) -> IRModule: + ir_module = IRModule(task=task) + sch = schedule + + a_dtype = task.inputs[0].data_type.scalar_type + b_dtype = task.inputs[1].data_type.scalar_type + c_dtype = task.outputs[0].data_type.scalar_type + + batch_size = task.batch_size + m_size, k_size, n_size = task.m_size, task.k_size, task.n_size + + m_tile_size, n_tile_size = sch.block_shape + m_tiles = (m_size + m_tile_size - 1) // m_tile_size + n_tiles = (n_size + n_tile_size - 1) // n_tile_size + grid_blocks_layout: TaskLayout = TaskLayout.row_major([m_tiles, n_tiles]) + + # define function + with FunctionBuilder( + name=task.name + '.grid', + kind='cuda_kernel', + grid_dim=(grid_blocks_layout.num_workers, batch_size), + block_dim=sch.block_size, + dynamic_smem_bytes=sch.used_smem_bytes_per_block if sch.use_dynamic_smem else 0, + min_blocks=sch.min_thread_blocks, + label=str(sch)) as fb: + sb = StmtBuilder() + + # declare params + params = params_from_task(task) + gmem_a, gmem_b, gmem_c = params + fb.extend_params(params) + + # declare local variables + smem_a = Var('smem_a', TensorPointerType('shared', a_dtype, layout=StridesLayout.from_shape([2, sch.block_shape[0], sch.block_k], perm=[0, 2, 1]))) + smem_b = Var('smem_b', TensorPointerType('shared', b_dtype, layout=StridesLayout.from_shape([2, sch.block_k, sch.block_shape[1]], perm=[0, 1, 2]))) + smem_c = Var('smem_c', TensorPointerType('shared', c_dtype, layout=StridesLayout.row_major((sch.block_warps_layout * sch.c_warp_s2g_layout).task_shape))) + if sch.use_dynamic_smem: + # 'extern __shared__ uint8_t smem_storage[];' in c code + smem_storage = Var('smem_storage', PointerType(base_type=scalar_type('uint8'), specifiers=['extern', '__shared__'], use_bracket=True)) + else: + smem_storage = Var('smem_storage', tensor_type('shared', dtype='uint8', shape=[sch.used_smem_bytes_per_block])) + smem_A_bytes = simplify_to_int(smem_a.type.tensor_type.storage_bytes()) + fb.extend_local_vars([smem_a, smem_b, smem_c, smem_storage]) + sb += AssignStmt(smem_a, Cast(~smem_storage[0], PointerType(a_dtype))) + sb += AssignStmt(smem_b, Cast(~(smem_storage[smem_A_bytes]), PointerType(b_dtype))) + sb += AssignStmt(smem_c, Cast(~(smem_storage[0]), PointerType(c_dtype))) + + # declare a, b, c registers + regs_a = Var('regs_A', tensor_type('register', a_dtype, layout=StridesLayout.row_major([2]) + schedule.regs_a_layout)) + regs_b = Var('regs_B', tensor_type('register', b_dtype, layout=StridesLayout.row_major([2]) + schedule.regs_b_layout)) + regs_c = Var('regs_C', tensor_type('register', c_dtype, layout=schedule.regs_c_layout)) + regs_a_ldg = Var('regs_A_ldg', tensor_type(scope='register', dtype=a_dtype, layout=schedule.regs_a_ldg_layout)) + regs_b_ldg = Var('regs_B_ldg', tensor_type(scope='register', dtype=b_dtype, layout=schedule.regs_b_ldg_layout)) + fb.extend_local_vars([regs_a, regs_b, regs_c, regs_a_ldg, regs_b_ldg]) + + a_default_value = convert(0.0, a_dtype) + b_default_value = convert(0.0, b_dtype) + acc_default_value = convert(0.0, c_dtype) + + with sb.lets(['bi', 'bj'], grid_blocks_layout(block_idx())[0]) as (bi, bj): + block_k_tiles = (k_size + sch.block_k - 1) // sch.block_k + first_k_tile = k_size - (block_k_tiles - 1) * sch.block_k + block_offset = [idx * dim for idx, dim in zip([bi, bj], sch.block_shape)] + # transfer first tile + sb += copy(gmem_a[block_idx('y'), block_offset[0]:, :], regs_a_ldg, schedule.a_g2s_layout, src_predicate=lambda i, k: And.join(block_offset[0] + i < m_size, k < first_k_tile), default_value=a_default_value) + sb += copy(regs_a_ldg, smem_a[0], layout=schedule.a_g2s_layout) + sb += copy(gmem_b[block_idx('y'), :, block_offset[1]:], regs_b_ldg, schedule.b_g2s_layout, src_predicate=lambda k, j: And.join(k < first_k_tile, block_offset[1] + j < n_size), default_value=b_default_value) + sb += copy(regs_b_ldg, smem_b[0], layout=schedule.b_g2s_layout) + sb += syncthreads() + sb += copy(smem_a[0], regs_a[0], schedule.a_s2r_layout) + sb += copy(smem_b[0], regs_b[0], schedule.b_s2r_layout) + sb += syncthreads() + # init regs c + sb += init(regs_c, acc_default_value, schedule.block_layout) + with sb.for_loop('k0', block_k_tiles - 1) as k0: + block_offset_k = k0 * sch.block_k + first_k_tile + with sb.for_loop('k1', sch.block_warps_k) as k1: + with sb.if_then(Equal(k1, sch.block_warps_k - 1)): + sb += copy(regs_a_ldg, smem_a[(k0 + 1) % 2], schedule.a_g2s_layout) + sb += copy(regs_b_ldg, smem_b[(k0 + 1) % 2], schedule.b_g2s_layout) + sb += syncthreads() + sb += copy(smem_a[(k0 + 1) % 2], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[(k0 + 1) % 2], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + with sb.otherwise(): + sb += copy(smem_a[k0 % 2, :, k1 + 1:], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[k0 % 2, k1 + 1:, :], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + with sb.if_then(Equal(k1, 0)): + sb += copy(gmem_a[block_idx('y'), block_offset[0]:, block_offset_k:], regs_a_ldg, schedule.a_g2s_layout, src_predicate=lambda i, _: block_offset[0] + i < m_size, default_value=a_default_value) + sb += copy(gmem_b[block_idx('y'), block_offset_k:, block_offset[1]:], regs_b_ldg, schedule.b_g2s_layout, src_predicate=lambda _, j: block_offset[1] + j < n_size, default_value=b_default_value) + sb += mma(regs_a[k1 % 2], regs_b[k1 % 2], regs_c, schedule) + with sb.let('block_k_tile', block_k_tiles - 1) as k0: + with sb.for_loop('warp_k_tile', sch.block_warps_k) as k1: + with sb.if_then(k1 < sch.block_warps_k - 1): + sb += copy(smem_a[k0 % 2, :, k1 + 1:], regs_a[(k1 + 1) % 2], schedule.a_s2r_layout) + sb += copy(smem_b[k0 % 2, k1 + 1:, :], regs_b[(k1 + 1) % 2], schedule.b_s2r_layout) + sb += mma(regs_a[k1 % 2], regs_b[k1 % 2], regs_c, schedule) + with sb.for_loop('i', sch.c_wb_outer[0]) as i: + with sb.for_loop('j', sch.c_wb_outer[1]) as j: + warp_indices = sch.block_warps_layout(thread_idx() // 32)[0] + regs_warp_offset = [wid * wdim + pid * pdim for wid, wdim, pid, pdim in zip(warp_indices, sch.warp_layout.task_shape, [i, j], sch.c_wb_shape)] + smem_warp_offset = [idx * dim for idx, dim in zip(warp_indices, sch.c_wb_shape)] + gmem_warp_offset = [bo + ro for bo, ro in zip(block_offset, regs_warp_offset)] + sb += syncthreads() + sb += copy(src=regs_c[regs_warp_offset[0]:, regs_warp_offset[1]:], dst=smem_c[smem_warp_offset[0]:, smem_warp_offset[1]:], layout=schedule.c_warp_r2s_layout) + sb += syncthreads() + sb += copy(src=smem_c[smem_warp_offset[0]:, smem_warp_offset[1]:], dst=gmem_c[block_idx('y'), gmem_warp_offset[0]:, gmem_warp_offset[1]:], layout=schedule.c_warp_s2g_layout, + dst_predicate=lambda ii, jj: And.join(gmem_warp_offset[0] + ii < m_size, gmem_warp_offset[1] + jj < n_size)) + # set body + fb.set_body(sb.finish()) + + func = fb.get() + ir_module.add(func.name, func) + return ir_module + + +def init(dst, init_value, layout): + sb = StmtBuilder() + for indices in layout(thread_idx()): + sb += BufferStoreStmt(dst, indices, init_value) + return sb.finish() + + +def copy(src, dst, layout, src_predicate=None, dst_predicate=None, default_value: Optional[Union[Expr, float]] = 0.0): + sb = StmtBuilder() + for indices in layout(thread_idx()): + value = src.__getitem__(indices) + if src_predicate: + value = if_then_else(src_predicate(*indices), value, default_value) + stmt = BufferStoreStmt(dst, indices, value) + if dst_predicate: + stmt = IfStmt(dst_predicate(*indices), stmt) + sb += stmt + return sb.finish() + + +def mma(a, b, c, schedule): + layout = schedule.block_layout + sb = StmtBuilder() + for i, j in layout(thread_idx()): + for k in range(schedule.warp_k): + sb += BufferStoreStmt(c, [i, j], c[i, j] + a[i, k] * b[k, j]) + return sb.finish() + + +if __name__ == '__main__': + schedules = MatmulSchedule.schedules(space_level=1) + print(len(schedules)) diff --git a/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wmma.py b/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wmma.py new file mode 100644 index 0000000..a307266 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/matmul/bmm_wmma.py @@ -0,0 +1,11 @@ +import os +from typing import List, Tuple, Union, Optional + +from hidet.ir.func import IRModule +from hidet.tos.ops.definitions.matmul.matmul import MatmulTask + + +def batched_matmul_cuda_schedule_wmma(task: MatmulTask) -> IRModule: + raise NotImplementedError() + + diff --git a/python/hidet/tos/ops/schedules/cuda/reduce.py b/python/hidet/tos/ops/schedules/cuda/reduce.py new file mode 100644 index 0000000..4d0a2ae --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/reduce.py @@ -0,0 +1,165 @@ +import functools +from typing import List + +from hidet.ir import IRModule +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.expr import scalar_var, if_then_else, tensor_var, const_like, convert, Expr, And, cast +from hidet.ir.layout import TaskLayout +from hidet.ir.primitives import block_idx, thread_idx +from hidet.ir.dialects.compute import ReduceCompute +from hidet.ir.stmt import AssignStmt, BufferStoreStmt +from hidet.ir.utils import index_deserialize +from hidet.tos.ops.definitions.reduce import ReduceTask +from hidet.tos.ops.schedules.common import params_from_task +from .common import warp_reduce +from hidet.utils import prod + + +def merge_indices(grid_indices: List[Expr], reduce_indices: List[Expr], reduce_dims: List[int]) -> List[Expr]: + indices = [] + grid_indices = list(reversed(grid_indices)) + reduce_indices = list(reversed(reduce_indices)) + for i in range(len(grid_indices) + len(reduce_indices)): + if i in reduce_dims: + indices.append(reduce_indices.pop()) + else: + indices.append(grid_indices.pop()) + return indices + + +def cuda_schedule_reduce_by_warp_reduce(task: ReduceTask) -> IRModule: + x, y = task.inputs[0], task.outputs[0] + + shape: List[int] = x.const_shape() + dims = task.dims + + grid_shape = [v for i, v in enumerate(shape) if i not in dims] + reduce_shape = [shape[i] for i in dims] + + grid_layout = TaskLayout.row_major(task_shape=grid_shape) + + warp_size = 32 + reduce_extent = prod(reduce_shape) + warp_extent = (reduce_extent + warp_size - 1) // warp_size + block_layout = TaskLayout.full_layout([warp_extent]) * TaskLayout.row_major([warp_size]) + + x_dtype = task.inputs[0].data_type.scalar_type + accumulate_dtype = task.attributes['accumulate_dtype'] + + with FunctionBuilder( + name=task.name + '_grid', + kind='cuda_kernel', + grid_dim=grid_layout.num_workers, + block_dim=block_layout.num_workers, + label='reduce schedule' + ) as fb: + # params + params = params_from_task(task) + x, y = params + fb.extend_params(params) + + # local variables + rv = scalar_var('rv', accumulate_dtype) # rv stands for reduce value + fb.extend_local_vars([rv]) + + # get reduce functors + reduce_type = task.reduce_type + init_value = ReduceCompute.init_const(reduce_type=reduce_type, data_type=accumulate_dtype) + combine = functools.partial(ReduceCompute.combine, reduce_type) + finalize = functools.partial(ReduceCompute.finalize, reduce_type) + + # body + sb = StmtBuilder() + grid_indices = grid_layout.worker2task(block_idx())[0] + + # get the reduced value along reduce dimensions + sb += AssignStmt(rv, init_value) + for r, in block_layout.worker2task(thread_idx()): + with sb.if_then(r < reduce_extent): + reduce_indices = index_deserialize(r, shape=reduce_shape) + input_indices = merge_indices(grid_indices, reduce_indices, reduce_dims=task.dims) + sb += AssignStmt(rv, combine(rv, x[input_indices])) + + sb += warp_reduce(rv, op=combine) + sb += AssignStmt(rv, finalize(acc=rv, size=reduce_extent)) + + # write back + for r, in block_layout.worker2task(thread_idx()): + with sb.if_then(r < reduce_extent): + reduce_indices = index_deserialize(r, shape=reduce_shape) + with sb.if_then(And.join_list([reduce_index.equals(0) for reduce_index in reduce_indices])): + reduce_indices = [convert(0) for _ in task.dims] + if task.keep_dim: + output_indices = merge_indices(grid_indices, reduce_indices, reduce_dims=task.dims) + else: + output_indices = grid_indices + sb += BufferStoreStmt(y, output_indices, cast(rv, x_dtype)) + + fb.set_body(sb.finish()) + func = fb.get() + return IRModule(funcs={func.name: func}, task=task) + + +def cuda_schedule_reduce_by_default(task: ReduceTask) -> IRModule: + x, y = task.inputs[0], task.outputs[0] + + shape: List[int] = x.const_shape() + dims = task.dims + + remain_shape = [v for i, v in enumerate(shape) if i not in dims] + reduce_shape = [shape[i] for i in dims] + reduce_extent = prod(reduce_shape) + + block_size = 256 + remain_layout = TaskLayout.row_major(remain_shape) + reduce_layout = TaskLayout.full_layout(reduce_shape) + + grid_size = (remain_layout.num_workers + block_size - 1) // block_size + + x_dtype = task.inputs[0].data_type.scalar_type + accumulate_dtype = task.attributes['accumulate_dtype'] + + with FunctionBuilder( + name=task.name + '_grid', + kind='cuda_kernel', + grid_dim=grid_size, + block_dim=block_size, + label='reduce schedule' + ) as fb: + # params + params = params_from_task(task) + x, y = params + fb.extend_params(params) + + # local variables + rv = scalar_var('rv', accumulate_dtype) # rv stands for reduce value + fb.extend_local_vars([rv]) + + # get reduce functors + reduce_type = task.reduce_type + init_value = ReduceCompute.init_const(reduce_type=reduce_type, data_type=accumulate_dtype) + combine = functools.partial(ReduceCompute.combine, reduce_type) + finalize = functools.partial(ReduceCompute.finalize, reduce_type) + + # body + sb = StmtBuilder() + remain_indices = remain_layout.worker2task(thread_idx() + block_idx() * block_size)[0] + with sb.if_then(And.join_list([remain_index < remain_shape[i] for i, remain_index in enumerate(remain_indices)])): + # get the reduced value along reduce dimensions + sb += AssignStmt(rv, init_value) + for reduce_indices in reduce_layout.worker2task(0): + input_indices = merge_indices(remain_indices, reduce_indices, reduce_dims=task.dims) + sb += AssignStmt(rv, combine(rv, x[input_indices])) + sb += AssignStmt(rv, finalize(acc=rv, size=reduce_extent)) + + # write back + reduce_indices = [convert(0) for _ in reduce_shape] + if task.keep_dim: + output_indices = merge_indices(remain_indices, reduce_indices, reduce_dims=task.dims) + else: + output_indices = remain_indices + sb += BufferStoreStmt(y, output_indices, cast(rv, x_dtype)) + + fb.set_body(sb.finish()) + func = fb.get() + return IRModule(funcs={func.name: func}, task=task) diff --git a/python/hidet/tos/ops/schedules/cuda/softmax.py b/python/hidet/tos/ops/schedules/cuda/softmax.py new file mode 100644 index 0000000..686aaf2 --- /dev/null +++ b/python/hidet/tos/ops/schedules/cuda/softmax.py @@ -0,0 +1,75 @@ +from typing import List + +from hidet.ir import IRModule +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.ir.expr import scalar_var, if_then_else, tensor_var, const_like, convert +from hidet.ir.layout import TaskLayout +from hidet.ir.primitives import block_idx, thread_idx +from hidet.ir import primitives as prim +from hidet.ir.stmt import AssignStmt, BufferStoreStmt +from hidet.tos.ops.definitions.softmax import SoftmaxTask +from hidet.tos.ops.schedules.common import params_from_task +from .common import warp_reduce + + +def softmax_cuda_schedule(task: SoftmaxTask) -> IRModule: + shape: List[int] = task.x_shape + axis = task.axis + + other_shape = shape[:axis] + shape[axis+1:] + grid_layout = TaskLayout.row_major(task_shape=other_shape) + + warp_size = 32 + reduce_extent = shape[axis] + outer_extent = (reduce_extent + warp_size - 1) // warp_size + block_layout = TaskLayout.full_layout([outer_extent]) * TaskLayout.row_major([warp_size]) + + x_dtype = task.inputs[0].data_type.scalar_type + + with FunctionBuilder( + name=task.name + '_grid', + kind='cuda_kernel', + grid_dim=grid_layout.num_workers, + block_dim=block_layout.num_workers, + label='softmax schedule' + ) as fb: + # params + params = params_from_task(task) + x, y = params + fb.extend_params(params) + + # local variables + buf = tensor_var('buf', shape=[outer_extent], scope='register', dtype=x_dtype) + rv = scalar_var('rv', x_dtype) # rv stands for reduce value + fb.extend_local_vars([rv, buf]) + + # body + sb = StmtBuilder() + + # get the max value along c dimension + sb += AssignStmt(rv, convert(-1e30, x_dtype)) + other_indices = grid_layout.worker2task(block_idx())[0] + for r, in block_layout.worker2task(thread_idx()): + with sb.if_then(r < reduce_extent): + sb += BufferStoreStmt(buf, [r], x[other_indices[:axis] + (r,) + other_indices[axis:]]) + sb += AssignStmt(rv, prim.max(rv, buf[r])) + sb += warp_reduce(rv, prim.max) + + # calculate exp(v-max) + for r, in block_layout.worker2task(thread_idx()): + sb += AssignStmt(buf[r], prim.exp(buf[r] - rv)) + + # calculate sum(exp(v-max)) + sb += AssignStmt(rv, convert(0.0, x_dtype)) + for r, in block_layout.worker2task(thread_idx()): + sb += AssignStmt(rv, rv + if_then_else(r < reduce_extent, buf[r], convert(0.0, x_dtype))) + sb += warp_reduce(rv, lambda a, b: a + b) + + # calculate exp(v-max) / sum(exp(vv-max)) + for r, in block_layout.worker2task(thread_idx()): + with sb.if_then(r < reduce_extent): + sb += BufferStoreStmt(y, other_indices[:axis] + (r,) + other_indices[axis:], buf[r] / rv) + + fb.set_body(sb.finish()) + func = fb.get() + return IRModule(funcs={func.name: func}, task=task) diff --git a/python/hidet/tos/ops/schedules/resolve.py b/python/hidet/tos/ops/schedules/resolve.py new file mode 100644 index 0000000..c638de9 --- /dev/null +++ b/python/hidet/tos/ops/schedules/resolve.py @@ -0,0 +1,126 @@ +import os +import time +from typing import List +import numpy as np + +from hidet.ir.type import TensorType +from hidet.ir.expr import Constant +from hidet.ir.func import IRModule +from hidet.ir.task import Task +from hidet.utils import TableBuilder, strict_zip +from hidet.tos.tensor import randn, zeros, ones, Tensor +from hidet.backend import BuildInstance, batch_build_ir_modules +from .common import Schedule + + +def dummy_inputs_from_task(task: Task) -> List[Tensor]: + """ + Create dummy inputs values for given task. + + Parameters + ---------- + task: Task + The task to generate dummy inputs for. + + Returns + ------- + ret: List[Tensor] + The dummy input tensors. + """ + inputs = [] + for idx, param in enumerate(task.parameters): + param_type = param.data_type + + if not isinstance(param_type, TensorType): + raise ValueError('Currently, only support create dummy scalar inputs.') + if any(not isinstance(s, Constant) for s in param_type.shape): + raise ValueError('Currently, only support create dummy values for static tensor inputs.') + dtype = param_type.scalar_type.name + scope = param_type.scope.name + shape = [int(s) for s in param_type.shape] + scope2device = { + 'global': 'cuda', + 'host': 'cpu' + } + device = scope2device[scope] + if dtype in ['float32', 'float16', 'bfloat16']: + x = randn(shape, dtype, device=device, layout=param_type.layout) + elif dtype in ['int64', 'int32', 'int8', 'uint64', 'uint32', 'uint8']: + x = zeros(shape, dtype, device=device, layout=param_type.layout) + elif dtype == 'bool': + x = ones(shape, dtype, device=device, layout=param_type.layout) + else: + raise ValueError('Currently do not support generate random array for data type {}'.format(dtype)) + inputs.append(x) + return inputs + + +def resolve_ir_modules(ir_modules: List[IRModule], schedules: List[Schedule], output_dir: str, parallel: bool = True, verbose: bool = True) -> IRModule: + """ + Resolve the ir modules of the same task by comparing the latency of each kernel. + + Parameters + ---------- + ir_modules: List[IRModule] + The ir modules to resolve. + schedules: List[Schedule] + The schedules corresponding to each ir module. The order of schedules must be consistent with ir modules'. + output_dir: str + The output directory to store the summary and lowered source code of each ir module. + parallel: bool + Whether to parallelize the building. Default True. + verbose: bool + Whether to show the progress of parallel building. + Returns + ------- + ret: IRModule + The best ir module we can find. + """ + if len(ir_modules) == 0: + raise ValueError('Require at least one ir module.') + if len(ir_modules) == 1: + return ir_modules[0] + if len(schedules) != len(ir_modules): + raise ValueError('The number of ir modules and schedules does not match.') + if any(ir_module.task != ir_modules[0].task for ir_module in ir_modules): + raise ValueError('Require all ir modules are from the same task.') + build_instances = [BuildInstance(ir_module=ir_module, + output_dir=os.path.join(output_dir, 'resolve', str(idx)), + keep_ir=False, + nvcc_keep=False, + verbose=False) for idx, ir_module in enumerate(ir_modules)] + compiled_funcs = batch_build_ir_modules(build_instances, parallel=parallel, verbose=verbose) + dummy_inputs = dummy_inputs_from_task(ir_modules[0].task) + best_latency = 1e30 + best_ir_module = None + latencies = [] + time.sleep(5.0) + i = 0 + for ir_module, compiled_func in strict_zip(ir_modules, compiled_funcs): + # print(schedules[i]) + i += 1 + if compiled_func: + repeat_latency = compiled_func.profile(*dummy_inputs, warmup=5, number=10, repeat=3) + latency = float(np.median(repeat_latency)) + else: + # this ir module failed in building, skip + latency = 1e30 + # print(latency) + latencies.append(latency) + if best_latency > latency: + best_latency = latency + best_ir_module = ir_module + if best_ir_module is None: + raise ValueError('All ir modules are failed in building.') + + with TableBuilder(headers=['idx'] + [v[0] for v in (schedules[0].keys() + schedules[0].derived_keys())] + ['latency']) as tb: + rows = [] + for idx, (schedule, latency) in enumerate(zip(schedules, latencies)): + row = [idx] + [v[1] for v in schedule.keys() + schedule.derived_keys()] + [latency] + rows.append(row) + rows = sorted(rows, key=lambda v: v[-1]) + for row in rows: + tb += row + with open(os.path.join(output_dir, 'summary.txt'), 'w') as f: + f.write(str(tb)) + return best_ir_module diff --git a/python/hidet/tos/tensor.py b/python/hidet/tos/tensor.py new file mode 100644 index 0000000..a180c81 --- /dev/null +++ b/python/hidet/tos/tensor.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import ctypes +from functools import partial +from typing import List, Optional, Tuple, Sequence, Union + +import numpy as np + +from hidet.ffi import cuda, cuda_kernels +from hidet.ir.layout import DataLayout +from hidet.ir.layout.data_layout import RowMajorLayout +from hidet.runtime import Storage +from hidet.utils import prod + + +def convert(v): + if isinstance(v, (float, int)): + dtype_map = { + float: 'float32', + int: 'int64' + } + return full(shape=[1], fill_value=v, dtype=dtype_map[type(v)]) + elif isinstance(v, Tensor): + return v + else: + raise NotImplementedError() + + +class Tensor: + def __init__(self, + shape: Sequence[int], + dtype: str, + device: str, + storage: Optional[Storage], + layout: DataLayout = None, + trace: Optional[Tuple['Operator', int]] = None): + from hidet.tos.operator import Operator + self.shape = [int(v) for v in shape] + self.dtype = str(dtype) + self.device = device + self.storage = storage + self.layout = layout if layout else DataLayout.row_major(shape) + self.trace: Optional[Tuple[Operator, int]] = trace + + def __neg__(self) -> Tensor: + from .ops import neg + return neg(self) + + def __add__(self, other) -> Tensor: + from .ops import add + return add(self, other) + + def __radd__(self, other): + from .ops import add + return add(other, self) + + def __sub__(self, other) -> Tensor: + from .ops import sub + return sub(self, other) + + def __rsub__(self, other): + from .ops import sub + return sub(other, self) + + def __mul__(self, other) -> Tensor: + from .ops import multiply + return multiply(self, other) + + def __rmul__(self, other): + from .ops import multiply + return multiply(other, self) + + def __truediv__(self, other) -> Tensor: + from .ops import divide + return divide(self, other) + + def __str__(self): + head = self.signature() + if self.storage: + array_str = str(self.cpu().numpy()) + return '{}\n{}'.format(head, array_str) + else: + return head + ' with empty storage' + + def __getitem__(self, item): + from hidet.tos.ops import strided_slice + if not isinstance(item, tuple): + item = tuple([item]) + rank = len(self.shape) + if all(not isinstance(v, slice) for v in item) and len(item) == rank: + # element access + return strided_slice(self, starts=list(item), ends=[v + 1 for v in item]).numpy().flatten()[0] + else: + while len(item) < rank: + item = item + (slice(None, None, None),) + starts, ends, steps = [], [], [] + squeeze_dims = [] + for dim, v in enumerate(item): + if isinstance(v, int): + squeeze_dims.append(dim) + starts.append(v) + ends.append(v + 1) + steps.append(1) + else: + assert isinstance(v, slice) + starts.append(v.start if v.start is not None else 0) + ends.append(v.stop if v.stop is not None else self.shape[dim]) + steps.append(v.step if v.step is not None else 1) + sliced = strided_slice(self, starts, ends, strides=steps).squeeze(squeeze_dims) + return sliced + + def __iter__(self): + raise TypeError('hidet.Tensor does not support iteration.') + + def __getstate__(self): + if self.storage: + data = self.detach().numpy() + else: + data = None + + return { + 'shape': self.shape, + 'dtype': self.dtype, + 'device': self.device, + 'data': data, + 'layout': self.layout, + 'trace': self.trace + } + + def __setstate__(self, state): + data = state['data'] + if data is not None: + assert isinstance(data, np.ndarray) + tensor = from_numpy(data) + if state['device'] == 'cuda': + tensor = tensor.cuda() + storage = tensor.storage + else: + storage = None + + self.shape = state['shape'] + self.dtype = state['dtype'] + self.device = state['device'] + self.storage = storage + self.layout = state['layout'] + self.trace = state['trace'] + + def signature(self) -> str: + return "Tensor(shape={}, dtype='{}', device='{}')".format(self.shape, self.dtype, self.device) + + @property + def nbytes(self): + return prod(self.shape) * dtype_bytes(self.dtype) + + @property + def op(self): + return self.trace[0] if self.trace else None + + def scalar(self) -> Union[float, int]: + if len(self.shape) != 0: + raise ValueError('Can not convert a Tensor with shape {} to a scalar.'.format(self.shape)) + value = self.numpy().tolist() + assert isinstance(value, (int, float)) + return value + + def contiguous(self): + if isinstance(self.layout, RowMajorLayout): + return self + return self.reshape(self.shape) + + def reshape(self, shape: Sequence[int]): + from .ops import reshape + return reshape(self, shape) + + def squeeze(self, dims: Union[int, Sequence[int]]): + from .ops import squeeze + return squeeze(self, dims) + + def unsqueeze(self, dims: Union[int, Sequence[int]]): + from .ops import unsqueeze + return unsqueeze(self, dims) + + def rearrange(self, plan: List[List[int]]): + from .ops import rearrange + return rearrange(self, plan) + + def flatten(self, start_dim=0, end_dim=None): + from .ops import flatten + return flatten(self, start_dim, end_dim) + + def transpose(self, axes: Optional[Sequence[int]]): + from .ops import transpose + return transpose(self, axes) + + def barrier(self) -> Tensor: + from .ops import barrier + return barrier(self) + + def sum(self, dims: Union[int, List[int]], keep_dim: bool = False): + from .ops import reduce_sum + return reduce_sum(self, dims=dims, keep_dim=keep_dim) + + def mean(self, dims: Union[int, List[int]], keep_dim: bool = False): + from .ops import reduce_mean + return reduce_mean(self, dims=dims, keep_dim=keep_dim) + + def rsqrt(self): + from .ops import rsqrt + return rsqrt(self) + + def cast(self, dtype): + from .ops import cast + return cast(self, dtype) + + def cpu(self): + if self.device == 'cpu': + return self + else: + if self.trace is None: + return Tensor(self.shape, self.dtype, 'cpu', self.storage.cpu() if self.storage else None, self.layout) + else: + raise ValueError('Please use .detach() to detach a trace variable first.') + + def cuda(self): + if self.device == 'cuda': + return self + else: + if self.trace is None: + return Tensor(self.shape, self.dtype, 'cuda', self.storage.cuda() if self.storage else None, self.layout) + else: + raise ValueError('Please use .detach() to detach a trace variable first.') + + def detach(self): + if self.trace is None: + return self + else: + return Tensor( + shape=self.shape, + dtype=self.dtype, + device=self.device, + storage=self.storage, + layout=self.layout, + trace=None + ) + + def numpy(self) -> np.ndarray: + if self.device != 'cpu': + return self.cpu().numpy() + # convert if this tensor is not in row major layout + storage = self.contiguous().storage + + # because numpy does not support bfloat16, we convert it into float32 + if self.dtype == 'bfloat16': + return self.cast('float32').numpy() + else: + array = storage.as_array(num_elements=prod(self.shape), dtype=self.dtype) + return array.reshape(self.shape) + + +def dtype_bytes(dtype: str): + bytes_dict = { + 'float32': 4, + 'bfloat16': 2, + 'float16': 2, + 'int32': 4, + 'int64': 8, + 'uint8': 1, + 'bool': 1 + } + return bytes_dict[dtype] + + +def empty(shape: Sequence[int], dtype: str = 'float32', device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + num_bytes = prod(shape) * dtype_bytes(dtype) + storage = Storage.new(device, num_bytes) + return Tensor(shape, dtype, device, storage, layout) + + +def symbol(shape: Sequence[int], dtype: str = 'float32', device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + return Tensor(shape, dtype, device, None, layout) + + +def zeros(shape: Sequence[int], dtype: str = 'float32', device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + tensor = empty(shape, dtype, device, layout) + cuda.memset_async(tensor.storage.addr, tensor.nbytes, value=0) + return tensor + + +def ones(shape: Sequence[int], dtype: str = 'float32', device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + value_map = { + 'float32': 1.0, + 'int32': 1, + 'int64': 1 + } + if dtype in value_map: + return full(shape, value_map[dtype], dtype, device, layout) + else: + if dtype in ['float16', 'bool']: + f32_tensor = ones(shape, 'float32', device, layout) + return f32_tensor.cast(dtype) + else: + raise NotImplementedError('Not implemented ones for dtype {}, please create a float32 tensor and cast to this type'.format(dtype)) + + +def full(shape: Sequence[int], fill_value, dtype: str = 'float32', device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + tensor = empty(shape, dtype, device, layout) + cuda_kernels.fill_value(tensor.storage.addr, tensor.nbytes, value=fill_value, dtype=dtype) + return tensor + + +def randn(shape: Sequence[int], dtype: str = 'float32', mean: float = 0.0, stddev: float = 1.0, device: str = 'cuda', layout: Optional[DataLayout] = None) -> Tensor: + tensor = empty(shape, dtype, device, layout) + if dtype == 'float32': + cuda.generate_normal(tensor.storage.addr, num_elements=prod(tensor.shape), mean=mean, stddev=stddev) + else: + float32_tensor = randn_like(tensor, dtype='float32') + return float32_tensor.cast(dtype=dtype) + # raise NotImplementedError('Currently do not support generate random array for data type {}'.format(dtype)) + return tensor + + +def _tensor_like(constructor, data, shape, dtype, device, layout): + shape = data.shape if shape is None else shape + dtype = data.dtype if dtype is None else dtype + device = data.device if device is None else device + layout = data.layout if layout is None else layout + return constructor(shape=shape, dtype=dtype, device=device, layout=layout) + + +def empty_like(data: Tensor, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(empty, data, shape, dtype, device, layout) + + +def symbol_like(data: Tensor, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(symbol, data, shape, dtype, device, layout) + + +def zeros_like(data: Tensor, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(zeros, data, shape, dtype, device, layout) + + +def ones_like(data: Tensor, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(ones, data, shape, dtype, device, layout) + + +def full_like(data: Tensor, fill_value, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(partial(full, fill_value=fill_value), data, shape, dtype, device, layout) + + +def randn_like(data: Tensor, shape: Optional[Sequence[int]] = None, dtype: Optional[str] = None, device: Optional[str] = None, layout: Optional[DataLayout] = None) -> Tensor: + return _tensor_like(randn, data, shape, dtype, device, layout) + + +def void_pointer_to_uint64(p): + ret = ctypes.cast(ctypes.addressof(p), ctypes.POINTER(ctypes.c_uint64)).contents + return ret.value + + +def from_numpy(array: np.ndarray) -> Tensor: + dtype_convert = { + np.dtype(np.float32): 'float32', + np.dtype(np.int64): 'int64', + np.dtype(np.int32): 'int32', + np.dtype(np.float16): 'float16', + np.dtype(np.bool): 'bool', + np.dtype(np.uint8): 'uint8' + } + if array.dtype not in dtype_convert: + raise NotImplementedError("Do not support convert np.ndarray with data type '{}'.".format(array.dtype)) + tensor = empty(shape=array.shape, dtype=dtype_convert[array.dtype], device='cpu') + cuda.memcpy_async(src_addr=void_pointer_to_uint64(array.ctypes.data_as(ctypes.c_void_p)), + dst_addr=tensor.storage.addr, + num_bytes=tensor.nbytes, + kind=cuda.HostToHost) + cuda.device_synchronize() + return tensor + + +def array(obj: Union[List, Tuple, np.ndarray, Tensor]) -> Tensor: + if isinstance(obj, np.ndarray): + return from_numpy(obj) + elif isinstance(obj, Tensor): + return obj + else: + return from_numpy(np.array(obj)) diff --git a/python/hidet/tos/transforms/__init__.py b/python/hidet/tos/transforms/__init__.py new file mode 100644 index 0000000..e2c5a26 --- /dev/null +++ b/python/hidet/tos/transforms/__init__.py @@ -0,0 +1,35 @@ +from hidet.tos.ir import FlowGraph + +from .base import GraphPass, PassContext, logger +from .instruments import GraphPassInstrument, SaveGraphInstrument, ProfileInstrument +from .fold_const import fold_const_pass +from .pattern_transform import pattern_transform_pass +from .automatic_mix_precision import automatic_mix_precision_pass +from .resolve_mma import resolve_mma_pass +from .resolve_variant import resolve_variant_pass +from .fuse_unary_elementwise import fuse_unary_elementwise_pass +from .fuse_epilogue import fuse_epilogue_pass +from .fuse_prologue import fuse_prologue_pass +from .eliminate_barrier import eliminate_barrier_pass + + +def optimize(graph: FlowGraph) -> FlowGraph: + passes = [ + fold_const_pass(), + pattern_transform_pass(), + automatic_mix_precision_pass(), + resolve_variant_pass(), + resolve_mma_pass(), + fuse_unary_elementwise_pass(), + fuse_epilogue_pass(), + fuse_prologue_pass(), + eliminate_barrier_pass() + ] + ctx = PassContext.current() + for inst in ctx.instruments: + inst.before_all_passes(graph) + for optimize_pass in passes: + graph = optimize_pass(graph) + for inst in reversed(ctx.instruments): + inst.after_all_passes(graph) + return graph.update_nodes() diff --git a/python/hidet/tos/transforms/automatic_mix_precision.py b/python/hidet/tos/transforms/automatic_mix_precision.py new file mode 100644 index 0000000..efe49d7 --- /dev/null +++ b/python/hidet/tos/transforms/automatic_mix_precision.py @@ -0,0 +1,72 @@ +from typing import List +from hidet.tos.ir.functors import GraphRewriter +from hidet.tos.ir.graph import FlowGraph, Operator, Tensor +from hidet.tos.ops import definitions as defs +from hidet.tos import ops +from hidet.ir.type import ScalarType +from .base import GraphPass +from hidet.utils import strict_zip, same_list + + +class DefaultTransformStrategy: + always = [ + # defs.Conv2dOp, defs.MatmulOp, defs.AddOp, defs.SubOp, defs.MultiplyOp, defs.DivideOp, defs.PadOp + ] + never = [ + # defs.ErfOp, defs.PowOp, + # defs.ReduceSumOp, defs.ReduceMeanOp # disable because we can accumulate with higher precision + ] + + +class AutoMixPrecisionRewriter(GraphRewriter): + def __init__(self, target_dtype: str): + super().__init__() + assert ScalarType(target_dtype).is_float() + self.target_dtype = target_dtype + self.policy = DefaultTransformStrategy() + + @staticmethod + def cast_float(x: Tensor, target_dtype: str) -> Tensor: + if ScalarType(x.dtype).is_float(): + return ops.cast(x, target_dtype) + return x + + def visit_Operator(self, op: Operator): + recv_inputs: List[Tensor] = [self(v) for v in op.inputs] + + # if type(op) in self.policy.always: + # decision = 'always' + if type(op) in self.policy.never: + decision = 'never' + else: + decision = 'always' + + casted_inputs = [] + for orig_input, recv_input in strict_zip(op.inputs, recv_inputs): + if decision == 'always': + casted_inputs.append(self.cast_float(recv_input, self.target_dtype)) + elif decision == 'never': + casted_inputs.append(self.cast_float(recv_input, orig_input.dtype)) + else: + casted_inputs.append(recv_input) + if same_list(casted_inputs, op.inputs): + return op + else: + updated_outputs = op.reforward(casted_inputs) + for original, updated in zip(op.outputs, updated_outputs): + self.memo[original] = updated + + +class AutoMixPrecisionPass(GraphPass): + def process_graph(self, graph: FlowGraph) -> FlowGraph: + target_dtype = self.current_context().configs['precision'] + if target_dtype is None: + return graph + else: + rewriter = AutoMixPrecisionRewriter(target_dtype) + graph = rewriter(graph) + return graph + + +def automatic_mix_precision_pass() -> GraphPass: + return AutoMixPrecisionPass() diff --git a/python/hidet/tos/transforms/base.py b/python/hidet/tos/transforms/base.py new file mode 100644 index 0000000..b681085 --- /dev/null +++ b/python/hidet/tos/transforms/base.py @@ -0,0 +1,205 @@ +from __future__ import annotations +from typing import List, Sequence, Optional, Dict, Any +import logging + +from hidet.tos.ir.graph import FlowGraph +from .instruments import GraphPassInstrument + +logger = logging.Logger(name='hidet.tos.transforms', level=logging.INFO) +logger.addHandler(logging.StreamHandler()) + + +class PassContext: + _stack: List['PassContext'] = [] + + def __init__(self): + self.instruments: List[GraphPassInstrument] = [] + self.configs: Dict[str, Any] = { + # target precision: + # [None, 'float16', 'bfloat16', 'float32'] + 'precision': None, + + # target reduce precision: + # [None, 'float16', 'float32'] + 'reduce_precision': None, + + # mma primitive: + # ['simt', 'wmma', 'mma'] + 'mma': 'simt', + + # parallel k + # ['default', 'disabled', 2, 4, ...] + 'parallel_k': 'default', + + # print lower details + 'verbose': False + } + + def __enter__(self) -> PassContext: + self._stack.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + popped = self._stack.pop() + assert popped == self + + @classmethod + def current(cls): + """ + Get the current pass context. + + Returns + ------- + ret: PassContext + The current pass context. + """ + if len(cls._stack) == 0: + cls._stack.append(PassContext()) + return cls._stack[-1] + + def set_precision(self, dtype: Optional[str] = None) -> PassContext: + """ + Set the target precision to use as the output of most operators. To retain the accuracy, + some operators will still use the original data type. + + Parameters + ---------- + dtype: Optional[str] + The target dtype to mix the precision of the model. Candidates: + - None + Do not mix the precision. + - 'float16' + Convert the model into float16 data type. + - 'bfloat16' + Convert the model into bfloat16 data type. + - 'float32' + Convert the model into float32 data type. + """ + self.configs['precision'] = dtype + return self + + def set_reduce_precision(self, dtype: Optional[str] = None) -> PassContext: + """ + Set the target precision used for accumulation results. Operators like reduce_mean, reduce_avg, + matrix multiplication and convolution will reduce along some dimensions. We might want to use a + data type with more precision to accumulate the results for more accuracy. + + Parameters + ---------- + dtype: Optional[str] + The target dtype to use for accumulation. + - None + Use the same as inputs of operators. + - 'float16' + Use 'float16' to accumulate. Only valid when set_precision('float16') has been used. + - 'float32' + Use 'float32' to accumulate. + """ + self.configs['reduce_precision'] = dtype + return self + + def set_verbose(self) -> PassContext: + """ + Allow each graph level passes to print detailed information related to its lowering and optimization. + """ + self.configs['verbose'] = True + return self + + def set_mma(self, mma: str) -> PassContext: + """ + Specify the matrix-multiply-accumulate (mma) computation primitives used in matrix multiplication and + convolution. + Parameters + ---------- + mma: str + The mma computation primitive to use. Candidates: + - 'simt' + Use cuda cores. + - 'wmma' + Use wmma instructions. + - 'mma' + Use mma instructions (not supported yet). + """ + self.configs['mma'] = mma + return self + + def set_parallel_k(self, disabled=False, default=False, search=False, nparts: Optional[int] = None): + """ + Set the strategy to parallel on reduction dimension for matrix multiplication and convolution. + + Only one of the three parameters should be specified. + + Parameters + ---------- + disabled: bool + Disable the parallelization on reduction dimension. + + default: bool + Allow hidet to figure our the parallel factor. + + search: bool + Whether to search the k. + + nparts: Optional[int] + Use a fixed factor. + """ + if sum([disabled, default, search, nparts is not None]) > 1: + raise ValueError('Only one of parameters should be set.') + if disabled: + self.configs['parallel_k'] = 'disabled' + if default: + self.configs['parallel_k'] = 'default' + if search: + self.configs['parallel_k'] = 'search' + if nparts is not None: + self.configs['parallel_k'] = nparts + + def save_graph_instrument(self, out_dir) -> PassContext: + """ + Save the computation graph after each pass to given output directory. + + Parameters + ---------- + out_dir: str + The directory to save graph. + """ + from .instruments.save_graph_instrument import SaveGraphInstrument + self.instruments.append(SaveGraphInstrument(out_dir)) + return self + + def profile_pass_instrument(self, log_file: Optional[str] = None, print_stdout: bool = False) -> PassContext: + """ + Profile the time of each pass. + + Parameters + ---------- + log_file: Optional[str] + When given, write the elapsed time for each pass to this file. + + print_stdout: bool + Whether to print the elapsed time for each pass to standard output. + """ + from .instruments.profile_instrument import ProfileInstrument + self.instruments.append(ProfileInstrument(log_file, print_stdout)) + return self + + +class GraphPass: + def __init__(self): + self.name = self.__class__.__name__ + + def __call__(self, graph: FlowGraph) -> FlowGraph: + ctx = PassContext.current() + for inst in ctx.instruments: + inst.before_pass(self.name, graph) + graph = self.process_graph(graph) + for inst in reversed(ctx.instruments): + inst.after_pass(self.name, graph) + return graph + + @staticmethod + def current_context() -> PassContext: + return PassContext.current() + + def process_graph(self, graph: FlowGraph) -> FlowGraph: + raise NotImplementedError() diff --git a/python/hidet/tos/transforms/common.py b/python/hidet/tos/transforms/common.py new file mode 100644 index 0000000..369cbaa --- /dev/null +++ b/python/hidet/tos/transforms/common.py @@ -0,0 +1,5 @@ +def concat_op_name(lhs: str, rhs: str) -> str: + # lhs = lhs[5:] if lhs.startswith('Fused') else lhs + # rhs = rhs[5:] if rhs.startswith('Fused') else rhs + # return 'Fused{}{}'.format(lhs, rhs) + return '{} {}'.format(lhs, rhs) diff --git a/python/hidet/tos/transforms/eliminate_barrier.py b/python/hidet/tos/transforms/eliminate_barrier.py new file mode 100644 index 0000000..bed3d1e --- /dev/null +++ b/python/hidet/tos/transforms/eliminate_barrier.py @@ -0,0 +1,26 @@ +from hidet.tos.ir import FlowGraph, Operator, Tensor, GraphRewriter +from hidet.tos.transforms import GraphPass + +from .utils import is_barrier + + +class EliminateBarrierRewriter(GraphRewriter): + def visit_Operator(self, op: Operator): + inputs = [self(x) for x in op.inputs] + + if is_barrier(op): + outputs = inputs + for original, updated in zip(op.outputs, outputs): + self.memo[original] = updated + else: + return GraphRewriter.visit_Operator(self, op) + + +class EliminateBarrierPass(GraphPass): + def process_graph(self, graph: FlowGraph) -> FlowGraph: + rewriter = EliminateBarrierRewriter() + return rewriter(graph) + + +def eliminate_barrier_pass(): + return EliminateBarrierPass() diff --git a/python/hidet/tos/transforms/fold_const.py b/python/hidet/tos/transforms/fold_const.py new file mode 100644 index 0000000..c90a325 --- /dev/null +++ b/python/hidet/tos/transforms/fold_const.py @@ -0,0 +1,30 @@ +from hidet.tos.ir import FlowGraph, Operator, Tensor, GraphRewriter +from hidet.tos.transforms import GraphPass +from hidet import utils + + +class FoldConstantRewriter(GraphRewriter): + def visit_Operator(self, op: Operator): + inputs = [self(input) for input in op.inputs] + if all(input.storage is not None for input in inputs): + outputs = Operator.imperative_run(op, inputs) + for original, updated in zip(op.outputs, outputs): + self.memo[original] = updated + return None + else: + if utils.same_list(inputs, op.inputs): + return + else: + updated_outputs = op.reforward(inputs) + for original, updated in zip(op.outputs, updated_outputs): + self.memo[original] = updated + + +class FoldConstantPass(GraphPass): + def process_graph(self, graph: FlowGraph) -> FlowGraph: + rewriter = FoldConstantRewriter() + return rewriter(graph) + + +def fold_const_pass(): + return FoldConstantPass() diff --git a/python/hidet/tos/transforms/fuse_epilogue.py b/python/hidet/tos/transforms/fuse_epilogue.py new file mode 100644 index 0000000..8753550 --- /dev/null +++ b/python/hidet/tos/transforms/fuse_epilogue.py @@ -0,0 +1,160 @@ +from .base import GraphPass, PassContext, logger +from hidet.ir.expr import Var, var, TensorElement +from hidet.tos.ir import FlowGraph, Operator +from hidet.ir.dialects.compute import TensorNode, GridCompute +from hidet.tos.ir.functors import clone, analyze_usage +from hidet.ir.task import Task, Epilogue, is_injective_task, is_unary_injective_task, is_elementwise_task +from hidet.ir.functors import rewrite, collect +from hidet.utils import py +from .utils import is_barrier + + +def try_fuse(graph: FlowGraph, usage) -> bool: + for u_op in graph.nodes: + if len(u_op.task.prologues) + len(u_op.task.epilogues) > 0: + continue + if is_barrier(u_op): + continue + if not is_injective_task(u_op.task) or len(u_op.outputs) != 1: + # u_op must be an elementwise op with a single output (may have multiple inputs) + continue + for u_input_idx, u_input in enumerate(u_op.inputs): + if u_input.trace is None: + # skip graph input and const tensor + continue + if len(usage[u_input]) > 1: + # intermediate tensor can not be used by other operators. + continue + v_op, v_output_index = u_input.trace + if v_op is None: + continue + assert isinstance(v_op, Operator) + if is_barrier(v_op): + continue + + if is_injective_task(v_op.task): + # we do not fuse an injective_op v with an injective_op u by taking u as epilogue. + continue + + # v_op -> u_op ==> v_op (with u_op as epilogue) + v_task = v_op.task + u_task = u_op.task + u_output = u_task.outputs[0] + v_output = v_task.outputs[v_output_index] + u_task_input = u_task.inputs[u_input_idx] + + if u_task_input not in u_task.inverse_map: + # u_op must be invertible regards to enumerated input that wants to be fused along. + continue + + parameters = v_task.parameters.copy() + + # fetch reverse map + imap = u_op.task.inverse_map[u_task_input] + + existed_epilogue = v_task.epilogues[v_output] if v_output in v_task.epilogues else None + # input indices for the input of u task, use existed epilogue's when possible + if existed_epilogue: + indices = existed_epilogue.indices + rmap = {a: b for a, b in zip(imap.axes, existed_epilogue.out_indices)} + else: + indices = [var('i') for _ in range(len(u_input.shape))] + rmap = {a: b for a, b in zip(imap.axes, indices)} + dest_indices = [rewrite(dest_index_expr, rmap) for dest_index_expr in imap.indices] + + # prepare the TensorElement in the u task's expression + grid_compute = u_output.grid_compute + tensor_elements = [te for te in collect(grid_compute.value, TensorElement) if te.base is u_task_input] + if len(tensor_elements) == 0: + raise ValueError('Encountered a task whose output has not accessed its input.') + if len(tensor_elements) > 1: # accessed input twice, we do not fuse in this case + continue + te = tensor_elements.pop() + + # prepare the value of u task's output + rmap = {a: b for a, b in zip(grid_compute.axes, dest_indices)} + if existed_epilogue: + orig_value = existed_epilogue.orig_value + rmap[te] = existed_epilogue.value + else: + orig_value = Var('orig_value', v_output.data_type.scalar_type) + rmap[te] = orig_value + value = rewrite(grid_compute.value, rmap) + + # prepare the parameters and epilogue + new_output_node = TensorNode( + name=u_output.name, + data_type=u_output.data_type, + grid_compute=GridCompute( + shape=grid_compute.shape, + axes=grid_compute.axes, + value=rewrite(grid_compute.value, {u_task_input: v_output}) + ) + ) + input_params = parameters[:len(v_op.inputs)] + output_params = parameters[len(v_op.inputs):] + v_task_extra_inputs = [input for input in u_task.inputs if input is not u_task_input] + v_extra_inputs = [input_tensor for input_tensor in u_op.inputs if input_tensor is not u_input] + input_params.extend(v_task_extra_inputs) + if existed_epilogue: + existed_output_in_param = existed_epilogue.out_tensor + extra_inputs = existed_epilogue.extra_inputs + v_task_extra_inputs + else: + existed_output_in_param = v_output + extra_inputs = v_task_extra_inputs + output_params = [p if p is not existed_output_in_param else new_output_node for p in output_params] + parameters = input_params + output_params + epilogue = Epilogue( + extra_inputs=extra_inputs, + indices=indices, + orig_value=orig_value, + value=value, + out_indices=dest_indices, + out_tensor=new_output_node + ) + + # prepare the fused op + epilogues = v_task.epilogues.copy() + epilogues.update({v_output: epilogue}) + outputs = v_op.outputs[:v_output_index] + [u_op.outputs[0]] + v_op.outputs[v_output_index + 1:] + + task = v_task.copy() + task.name = '{}_{}'.format(v_task.name, u_task.name) + task.epilogues = epilogues + task.parameters = parameters + + fused_op = Operator( + inputs=v_op.inputs + v_extra_inputs, + task=task, + outputs=outputs, + name=v_op.name, + attributes=v_op.attrs + ) + fused_op.outputs[v_output_index].trace = (fused_op, 0) + + # update graph.nodes + graph.nodes = [node if node is not u_op else fused_op for node in graph.nodes if node is not v_op] + + if PassContext.current().configs['verbose']: + logger.info('Fused epilogue {} {}'.format(py.color_text(v_op.name, idx=1), py.color_text(u_op.name, idx=2))) + logger.debug('{}'.format(task)) + return True + return False + + +class FuseEpiloguePass(GraphPass): + + def process_graph(self, graph: FlowGraph) -> FlowGraph: + graph = clone(graph) + usage = analyze_usage(graph) + graph.update_nodes() + + while True: + success = try_fuse(graph, usage) + if not success: + break + return graph + + +def fuse_epilogue_pass() -> GraphPass: + return FuseEpiloguePass() diff --git a/python/hidet/tos/transforms/fuse_prologue.py b/python/hidet/tos/transforms/fuse_prologue.py new file mode 100644 index 0000000..3cc77e4 --- /dev/null +++ b/python/hidet/tos/transforms/fuse_prologue.py @@ -0,0 +1,179 @@ +from typing import List +from .base import GraphPass, PassContext, logger +from hidet.ir.expr import Var, var, TensorElement +from hidet.tos.ir import FlowGraph, Operator, Tensor +from hidet.ir.dialects.compute import TensorNode, GridCompute +from hidet.tos.ir.functors import clone, analyze_usage +from hidet.ir.task import Task, Prologue, Epilogue, is_injective_task, is_unary_injective_task +from hidet.ir.functors import rewrite, collect +from hidet.utils import prod, strict_zip, py +from .common import concat_op_name +from .utils import is_barrier + + +def update_params(task: Task, op: Operator, op_input: Tensor, task_input: TensorNode, op_extra_inputs: List[Tensor], task_extra_inputs: List[TensorNode]): + """ + Update parameters of task and operator. + + Parameters + ---------- + task: Task + The task to update. + op: Operator + The operator to update. + op_input: Tensor + The original operator input to remove. + task_input: TensorNode + The original task input to remove. + task_extra_inputs: List[TensorNode] + The extra task inputs to add. + op_extra_inputs: List[Tensor] + The extra operator inputs to add. + """ + task_param_inputs = task.parameters[:len(op.inputs)] + task_param_outputs = task.parameters[len(op.inputs):] + + # remove original input + op.inputs = [v for v in op.inputs if v is not op_input] + task_param_inputs = [v for v in task_param_inputs if v is not task_input] + + # add extra inputs + op.inputs.extend(op_extra_inputs) + task_param_inputs.extend(task_extra_inputs) + + # update task parameters + task.parameters = task_param_inputs + task_param_outputs + + +def try_fuse(graph: FlowGraph, usage) -> bool: + # for u_op in graph.nodes: + for u_op in reversed(graph.nodes): + if is_barrier(u_op): + continue + # v_op -> u_op => (v_op as prologue) u_op + u_task = u_op.task + # if is_injective_task(u_task) and not is_sink_op(u_op, usage): + # if is_injective_task(u_task): + # continue + for i, u_input in enumerate(u_op.inputs): + if len(usage[u_input]) > 1: + # should not fuse op that has been used by multiple times + continue + + v_op = u_input.op + if v_op is None: + # u_input is a graph input + continue + if is_barrier(v_op): + continue + + if len(v_op.outputs) != 1: + # only fuse op with single output + continue + + if not is_injective_task(v_op.task): + # only fuse injective op + continue + + u_task_input: TensorNode = u_task.parameters[i] + + v_task = v_op.task + if len(v_task.prologues) + len(v_task.epilogues) > 0: + # todo: add support for these cases. + continue + v_task_output = v_task.outputs[0] + + task = None + if u_task_input in u_task.inputs: + # u_input is an input of original task + prologue = Prologue( + extra_inputs=v_task.inputs, + indices=v_task_output.grid_compute.axes, + value=v_task_output.grid_compute.value + ) + task = u_task.copy() + task.prologues[u_task_input] = prologue + else: + # u_input is used in a prologue or epilogue + for original_task_input, existed_prologue in u_task.prologues.items(): + if u_task_input in existed_prologue.extra_inputs: + # u_input is used in an existing prologue + tensor_elements: List[TensorElement] = collect(existed_prologue.value, TensorElement) + gc = v_task_output.grid_compute + rmap = {te: rewrite(gc.value, {a: b for a, b in strict_zip(gc.axes, te.indices)}) + for te in tensor_elements if te.base is u_task_input} + value = rewrite(existed_prologue.value, rmap) + filtered_extra_inputs = [extra_input for extra_input in existed_prologue.extra_inputs + if extra_input is not u_task_input] + prologue = Prologue( + extra_inputs=filtered_extra_inputs + v_task.inputs, + indices=existed_prologue.indices, + value=value + ) + task = u_task.copy() + task.prologues[original_task_input] = prologue + + for original_task_output, existed_epilogue in u_task.epilogues.items(): + if u_task_input in existed_epilogue.extra_inputs: + # u_input is used in an existing epilogue + tensor_elements: List[TensorElement] = collect(existed_epilogue.value, TensorElement) + gc = v_task_output.grid_compute + rmap = {te: rewrite(gc.value, {a: b for a, b in strict_zip(gc.axes, te.indices)}) + for te in tensor_elements if te.base is u_task_input} + value = rewrite(existed_epilogue.value, rmap) + filtered_extra_inputs = [inp for inp in existed_epilogue.extra_inputs if inp is not u_task_input] + epilogue = Epilogue( + extra_inputs= filtered_extra_inputs + v_task.inputs, + indices=existed_epilogue.indices, + orig_value=existed_epilogue.orig_value, + value=value, + out_indices=existed_epilogue.out_indices, + out_tensor=existed_epilogue.out_tensor + ) + task = u_task.copy() + task.epilogues[original_task_output] = epilogue + + if task is None: + raise ValueError('Input {} has not been used in task.'.format(u_task_input)) + + task.name = '{}_{}'.format(v_task.name, u_task.name) + update_params( + task=task, + op=u_op, + op_input=u_input, + task_input=u_task_input, + op_extra_inputs=v_op.inputs, + task_extra_inputs=v_task.inputs, + ) + u_op.task = task + if PassContext.current().configs['verbose']: + logger.info('Fused prologue {} {}'.format(py.color_text(v_op.name, idx=1), py.color_text(u_op.name, idx=2))) + logger.debug('u_task') + logger.debug(u_task) + logger.debug('v_task') + logger.debug(v_task) + logger.debug('fused_task') + logger.debug(task) + graph.nodes.remove(v_op) + + return True + + return False + + +class FuseProloguePass(GraphPass): + + def process_graph(self, graph: FlowGraph) -> FlowGraph: + graph = clone(graph) + usage = analyze_usage(graph) + graph.update_nodes() + + while True: + success = try_fuse(graph, usage) + if not success: + break + return graph + + +def fuse_prologue_pass() -> GraphPass: + return FuseProloguePass() diff --git a/python/hidet/tos/transforms/fuse_unary_elementwise.py b/python/hidet/tos/transforms/fuse_unary_elementwise.py new file mode 100644 index 0000000..dc5efc9 --- /dev/null +++ b/python/hidet/tos/transforms/fuse_unary_elementwise.py @@ -0,0 +1,84 @@ +from .base import GraphPass, PassContext, logger +from hidet.tos.ir import FlowGraph, Operator +from hidet.tos.ir.functors import clone, analyze_usage +from hidet.ir.task import Task, is_unary_injective_task +from hidet.ir.functors import rewrite +from hidet.utils import py + +from .common import concat_op_name +from .utils import is_barrier + + +class FuseUnaryElementwise(GraphPass): + """ + Fuse all consecutive unary elementwise operators together. + """ + def process_graph(self, graph: FlowGraph) -> FlowGraph: + graph = clone(graph) + usage = analyze_usage(graph) + graph.update_nodes() + + while True: + success = False + for u_op in graph.nodes: + if is_barrier(u_op): + continue + if not is_unary_injective_task(u_op.task): + continue + if len(usage[u_op.inputs[0]]) > 1: + # intermediate tensor can not be used by other operators. + continue + v_op = u_op.inputs[0].op + if v_op is None: + continue + if is_barrier(v_op): + continue + if not is_unary_injective_task(v_op.task): + continue + + # create fused op + # x --(v_op)--> y --(u_op)--> z + x = v_op.task.inputs[0] + y = v_op.task.outputs[0] + z = rewrite(u_op.task.outputs[0], {u_op.task.inputs[0]: y}) + if v_op.task.inverse_map and u_op.task.inverse_map: + inverse_map = {x: list(v_op.task.inverse_map.values())[0] + list(u_op.task.inverse_map.values())[0]} + elif v_op.task.inverse_map is not None: + # if v_op is invertible but u_op is not invertible, we do not fuse them + continue + else: + inverse_map = None + fused_op = Operator( + inputs=v_op.inputs, + task=Task( + name='{}_{}'.format(v_op.task.name, u_op.task.name), + inputs=[x], + outputs=[z], + inverse_map=inverse_map + ), + outputs=u_op.outputs, + name=concat_op_name(v_op.name, u_op.name), + attributes={**v_op.attrs, **u_op.attrs} + ) + fused_op.outputs[0].trace = (fused_op, 0) + + # update graph.nodes + graph.nodes = [node if node is not u_op else fused_op for node in graph.nodes if node is not v_op] + success = True + + # log + if PassContext.current().configs['verbose']: + logger.info('Fused elementwise {} {}'.format(py.color_text(v_op.name, idx=1), py.color_text(u_op.name, idx=2))) + logger.debug('front op') + logger.debug(v_op.task) + logger.debug('back op') + logger.debug(u_op.task) + logger.debug('fused task') + logger.debug(fused_op.task) + if not success: + break + return graph + + +def fuse_unary_elementwise_pass() -> GraphPass: + return FuseUnaryElementwise() diff --git a/python/hidet/tos/transforms/graph_patterns/__init__.py b/python/hidet/tos/transforms/graph_patterns/__init__.py new file mode 100644 index 0000000..e1a62c2 --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/__init__.py @@ -0,0 +1,10 @@ +from typing import List +from .base import TensorPattern, OperatorPattern, GraphPattern, MatchDict, Usage, graph_pattern_match +from .arithmatic_patterns import arithmatic_patterns +from .transform_patterns import transform_patterns +from .conv2d_patterns import conv2d_patterns +from .matmul_patterns import matmul_patterns + + +def all_graph_patterns() -> List[GraphPattern]: + return arithmatic_patterns() + transform_patterns() + conv2d_patterns() + matmul_patterns() diff --git a/python/hidet/tos/transforms/graph_patterns/arithmatic_patterns.py b/python/hidet/tos/transforms/graph_patterns/arithmatic_patterns.py new file mode 100644 index 0000000..8c741ea --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/arithmatic_patterns.py @@ -0,0 +1,78 @@ +from typing import List, Optional, Dict, Union + +from hidet.tos.ir.graph import Operator, Tensor +from .base import GraphPattern, TensorPattern, OperatorPattern, MatchDict + + +# class GraphConstructor: +# """ +# Construct the new subgraph according the matched subgraph and target graph. +# """ +# def __init__(self, matched): +# self.memo = {} +# self.matched = matched +# self.new_operators = [] +# +# def visit(self, obj: Union[TensorPattern, OperatorPattern]): +# if obj in self.memo: +# return self.memo[obj] +# if isinstance(obj, OperatorPattern): +# ret = self.visit_OperatorPattern(obj) +# elif isinstance(obj, TensorPattern): +# ret = self.visit_TensorPattern(obj) +# else: +# raise ValueError() +# self.memo[obj] = ret +# return ret +# +# def visit_TensorPattern(self, t: TensorPattern) -> Tensor: +# if t.trace is None: +# # input in pattern +# return self.matched[t] +# else: +# op, idx = t.trace +# return self.visit(op).get_output(idx) +# +# def visit_OperatorPattern(self, t: OperatorPattern) -> Operator: +# inputs = [self.visit(x) for x in t.inputs] +# op = t.op_cls(*inputs) +# self.new_operators.append(op) +# return op + + +class ArithmaticGraphPattern(GraphPattern): + def __init__(self, name, fsrc, fdst): + super().__init__(name) + x, y = TensorPattern.tensors(2, is_symbolic=True) # can not be const + a, b = TensorPattern.tensors(2, is_const=True) # can not be symbolic + self.x = x + self.y = y + self.a = a + self.b = b + self.src = fsrc(x, y, a, b) + self.fdst = fdst + + def source(self) -> List[TensorPattern]: + return [self.src] + + def target(self, matched: MatchDict) -> Optional[List[TensorPattern]]: + x, y, a, b = [matched[v] if v in matched else None for v in [self.x, self.y, self.a, self.b]] + return [self.fdst(x, y, a, b)] + # constructor = GraphConstructor(matched) + # return [constructor.visit(self.tgt)] + + +def arithmatic_patterns() -> List[GraphPattern]: + # # tensors can be used as pattern inputs + # x, y, z = TensorPattern.tensors(3, is_symbolic=True) # can not be const + # a, b, c = TensorPattern.tensors(3, is_const=True) # can not be symbolic + # + # (source, target) pattern pairs + pairs = [ + ['a + x => x + a', lambda x, y, a, b: a + x, lambda x, y, a, b: x + a], + ['x - a => x + (-a)', lambda x, y, a, b: x - a, lambda x, y, a, b: x + (-a)], + ['(x + a) + b => x + (a + b)', lambda x, y, a, b: (x + a) + b, lambda x, y, a, b: x + (a + b)], + ['(x + a) * b => x * b + a * b', lambda x, y, a, b: (x + a) * b, lambda x, y, a, b: x * b + a * b], + ['(x + a) + (y + b) => (x + y) + (a + b)', lambda x, y, a, b: (x + a) + (y + b), lambda x, y, a, b: (x + y) + (a + b)], + ] + return [ArithmaticGraphPattern(name, src, tgt) for name, src, tgt in pairs] diff --git a/python/hidet/tos/transforms/graph_patterns/base.py b/python/hidet/tos/transforms/graph_patterns/base.py new file mode 100644 index 0000000..7464661 --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/base.py @@ -0,0 +1,242 @@ +from __future__ import annotations +from typing import List, Optional, Dict, Any, Union, Tuple, Type, Set +from hidet.tos.ir.graph import FlowGraph, Operator, Tensor +from hidet.tos.transforms import GraphPass, PassContext +from hidet.tos import ops +from hidet import tos + + +class TensorPattern: + def __init__(self, is_const=False, is_symbolic=False, trace=None): + self.is_const: bool = is_const + self.is_symbolic: bool = is_symbolic + assert not (is_const and is_symbolic), 'Can not be const and symbolic at the same time' + self.trace: Optional[Tuple[OperatorPattern, int]] = trace + self.uses: List[Tuple[OperatorPattern, int]] = [] + + def __repr__(self): + if self.trace is None: + if self.is_const: + return 'c' + if self.is_symbolic: + return 's' + return 'v' + else: + op, idx = self.trace + op_str = str(op) + if len(op.outputs) == 1: + return op_str + else: + return '{}[{}]'.format(op_str, idx) + + def __add__(self, other): + return OperatorPattern(ops.definitions.arithmatic.AddOp, inputs=[self, other]).outputs[0] + + def __sub__(self, other): + return OperatorPattern(ops.definitions.arithmatic.SubOp, inputs=[self, other]).outputs[0] + + def __mul__(self, other): + return OperatorPattern(ops.definitions.arithmatic.MultiplyOp, inputs=[self, other]).outputs[0] + + def __neg__(self): + return OperatorPattern(ops.definitions.arithmatic.NegOp, inputs=[self]).outputs[0] + + def op(self) -> Optional[OperatorPattern]: + if self.trace is None: + return None + else: + return self.trace[0] + + def add_use(self, op: OperatorPattern, idx: int): + self.uses.append((op, idx)) + + @staticmethod + def tensor(is_const=False, is_symbolic=False): + return TensorPattern(is_const, is_symbolic) + + @staticmethod + def tensors(num, is_const=False, is_symbolic=False): + return [TensorPattern(is_const, is_symbolic) for _ in range(num)] + + +class OperatorPattern: + def __init__(self, op_cls, inputs, num_outputs=1): + self.op_cls = op_cls + self.inputs: List[TensorPattern] = inputs + self.outputs = [TensorPattern(is_symbolic=True, trace=(self, idx)) for idx in range(num_outputs)] + + for idx, input_tensor in enumerate(self.inputs): + input_tensor.add_use(self, idx) + + def __repr__(self): + input_items = [str(v) for v in self.inputs] + unary_ops = { + ops.definitions.arithmatic.NegOp: '-' + } + binary_ops = { + ops.definitions.arithmatic.AddOp: '+', + ops.definitions.arithmatic.SubOp: '-', + ops.definitions.arithmatic.MultiplyOp: '*' + } + if self.op_cls in unary_ops: + return '({}{})'.format(unary_ops[self.op_cls], input_items[0]) + elif self.op_cls in binary_ops: + return '({} {} {})'.format(input_items[0], binary_ops[self.op_cls], input_items[1]) + else: + return '{}({})'.format(self.op_cls.__name__[:-2], ', '.join(input_items)) + + +MatchDict = Dict[Union[TensorPattern, OperatorPattern], Union[Tensor, Operator]] + + +class GraphPattern: + def __init__(self, name): + self.name = name + + def source(self) -> List[TensorPattern]: + """ + The output tensors in the source template graph to match in the computation graph. + """ + raise NotImplementedError() + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + """ + The output tensors in the target sub-graph used to replace the matched pattern. + Return None means failed to generate the target sub-graph, and we should not do the transformation. + """ + raise NotImplementedError() + + +def op_pattern(op_cls: Type[Operator], input_patterns: List[TensorPattern], num_outputs=1) -> Union[TensorPattern, List[TensorPattern]]: + op = OperatorPattern(op_cls, input_patterns, num_outputs) + if num_outputs == 1: + return op.outputs[0] + else: + return op.outputs + + +Usage = Dict[Tensor, List[Tuple[Optional[Operator], int]]] + + +class NotMatchedException(Exception): + pass + + +class PatternMatcher: + """ + PatternMatcher matches a pattern to a subgraph in a larger graph. + + It starts from a tensor, or an operator, and tries to match the subgraph spanned from the start point. + + The spanning rules: + 1. A tensor spans to its producing operator and its consuming operators (i.e., uses). + 2. An operator spans to its input and output tensors. + + The matching rules: + 1. For tensor: + a) check the storage requirement (e.g., constant and symbolic) + b) check the output index in the producer's output array + 2. For operator: + a) check the operator type. + + Because the operator also spans to its outputs, as long as the pattern is connected, we only need to start + from a single tensor or operator. + """ + + def __init__(self, usage: Usage): + self.matched = {} + self.reverse_matched = {} + self.usage: Usage = usage + + @staticmethod + def check(cond: bool, msg=""): + if not cond: + raise NotMatchedException(msg) + + def match(self, pattern, target): + key = pattern if not isinstance(pattern, list) else id(pattern) + if key in self.matched: + self.check(target is self.matched[key], 'tried to match a pattern to two different objects') + # pattern has been matched to a different target + return + self.matched[key] = target + self.reverse_matched[target] = key + if isinstance(pattern, (list, tuple)): + self.match_Sequence(pattern, target) + elif isinstance(pattern, TensorPattern): + self.match_TensorPattern(pattern, target) + elif isinstance(pattern, OperatorPattern): + self.match_OperatorPattern(pattern, target) + else: + raise NotImplementedError() + + def match_Sequence(self, pattern, target): + self.check(isinstance(target, (list, tuple)), 'target should be tuple or list') + self.check(len(pattern) == len(target), 'sequence length does not match') + for a, b in zip(pattern, target): + self.match(a, b) + + def match_TensorPattern(self, pattern: TensorPattern, target): + self.check(isinstance(target, Tensor), "expect target with type 'Tensor'") + if pattern.is_const: + self.check(target.storage is not None, 'requires const tensor') + return + if pattern.is_symbolic: + self.check(target.storage is None, 'requires symbolic tensor') + + # spans to its inputs + if pattern.trace: + self.check(target.trace is not None) + self.check(pattern.trace[1] == target.trace[1]) + self.match(pattern.trace[0], target.trace[0]) + + # spans to its uses + desire_uses: List[Tuple[OperatorPattern, int]] = pattern.uses + actual_uses: List[Tuple[Optional[Operator], int]] = self.usage[target] + for desire_use in desire_uses: + desire_operator, desire_index = desire_use + if desire_operator in self.matched: + # this desire operator in pattern has been spanned + continue + spanned = False + for actual_use in actual_uses: + actual_operator, actual_index = actual_use + if actual_operator in self.reverse_matched: + # this actual operator has been matched + continue + if type(actual_operator) != desire_operator.op_cls: + continue + self.match(desire_operator, actual_operator) + spanned = True + break + self.check(spanned, "A usage of input tensor has not been spanned.") + + def match_OperatorPattern(self, pattern: OperatorPattern, target: Operator): + self.check(isinstance(target, pattern.op_cls), "expect target with type 'Operator'") + self.check(pattern.op_cls is target.__class__, 'operator cls does not match') + assert len(pattern.inputs) == len(target.inputs) and len(pattern.outputs) == len(target.outputs) + for a, b in zip(pattern.inputs, target.inputs): + self.match(a, b) + for a, b in zip(pattern.outputs, target.outputs): + self.match(a, b) + + +def graph_pattern_match(pattern: TensorPattern, target: Tensor, usage: Usage) -> Optional[MatchDict]: + # peek for early stop, only for performance + if pattern.trace is None: + if target.trace is not None: + return None + if (pattern.is_const and target.storage is None) or (pattern.is_symbolic and target.storage is not None): + return None + return {pattern: target} + if pattern.trace and target.trace and pattern.trace[0].op_cls is not target.trace[0].__class__: + return None + + # formal match + matcher = PatternMatcher(usage) + try: + matcher.match(pattern, target) + return matcher.matched + except NotMatchedException: + return None + diff --git a/python/hidet/tos/transforms/graph_patterns/conv2d_patterns.py b/python/hidet/tos/transforms/graph_patterns/conv2d_patterns.py new file mode 100644 index 0000000..a1b9b8e --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/conv2d_patterns.py @@ -0,0 +1,32 @@ +from typing import List, Optional, Dict, Union + +from hidet.tos import ops +from hidet.tos.ir.graph import Tensor +from hidet.tos.ops.definitions.conv2d import Conv2dOp +from .base import GraphPattern, TensorPattern, MatchDict, op_pattern + + +class Conv2dScalePattern(GraphPattern): + def __init__(self): + super().__init__('conv2d(x, w) * scale => conv2d(x, w * scale)') + self.x = TensorPattern.tensor() + self.w = TensorPattern.tensor(is_const=True) + self.scale = TensorPattern.tensor(is_const=True) + self.y = op_pattern(Conv2dOp, [self.x, self.w]) + self.z = self.y * self.scale + + def source(self) -> List[TensorPattern]: + return [self.z] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, w, y, scale = [matched[v] for v in [self.x, self.w, self.y, self.scale]] + if not (scale.shape[0] == scale.shape[2] == scale.shape[3] == 1): + return None + attrs = y.op.attrs + return [ops.conv2d(x, w * scale.squeeze([0]).unsqueeze([3]), stride=attrs['stride'], groups=attrs['groups'])] + + +def conv2d_patterns() -> List[GraphPattern]: + return [ + Conv2dScalePattern() + ] diff --git a/python/hidet/tos/transforms/graph_patterns/matmul_patterns.py b/python/hidet/tos/transforms/graph_patterns/matmul_patterns.py new file mode 100644 index 0000000..4f7f277 --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/matmul_patterns.py @@ -0,0 +1,120 @@ +from typing import List, Optional, Dict, Union + +from hidet.tos import ops +from hidet.tos.ir.graph import Operator, Tensor +from hidet.tos.ops.definitions.matmul.matmul import MatmulOp +from .base import GraphPattern, TensorPattern, OperatorPattern, MatchDict, op_pattern + + +class MatmulRightScalePattern(GraphPattern): + def __init__(self): + super().__init__('matmul(x, c1) * c2 => matmul(x, c1 * c2)') + self.x = TensorPattern.tensor() + self.c1 = TensorPattern.tensor(is_const=True) + self.c2 = TensorPattern.tensor(is_const=True) + self.y = op_pattern(MatmulOp, [self.x, self.c1]) * self.c2 + + def source(self) -> List[TensorPattern]: + return [self.y] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, c1, c2, y = [matched[v] for v in [self.x, self.c1, self.c2, self.y]] + if len(c2.shape) >= 2 and c2.shape[-2] != 1: # c2 should have shape [., 1, .] + return None + return [ops.matmul(x, c1 * c2)] + + +class MatmulLeftScalePattern(GraphPattern): + def __init__(self): + super().__init__('matmul(c1, x) * c2 ==> matmul(c1 * c2, x)') + self.c1 = TensorPattern.tensor(is_const=True) + self.x = TensorPattern.tensor(is_symbolic=True) + self.c2 = TensorPattern.tensor(is_const=True) + self.y = op_pattern(MatmulOp, [self.c1, self.x]) * self.c2 + + def source(self) -> List[TensorPattern]: + return [self.y] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + c1, x, c2, y = [matched[v] for v in [self.c1, self.x, self.c2, self.y]] + if len(c2.shape) > 0 and c2.shape[-1] != 1: # c2 should have shape [., ., 1] + return None + return [ops.matmul(c1 * c2, x)] + + +class TwoMatmulFusionPattern(GraphPattern): + def __init__(self): + super().__init__('matmul(x, c1)|matmul(x, c2) ==> matmul(x, concat(c1, c2)) followed by split') + self.x = TensorPattern.tensor(is_symbolic=True) + self.c1 = TensorPattern.tensor(is_const=True) + self.c2 = TensorPattern.tensor(is_const=True) + self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) + self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) + + def source(self) -> List[TensorPattern]: + return [self.y1, self.y2] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, c1, c2, y1, y2 = [matched[t] for t in [self.x, self.c1, self.c2, self.y1, self.y2]] + c = ops.concat([c1, c2], axis=2) + y = ops.matmul(x, c) + new_y1, new_y2 = ops.split(y, axis=2, parts=[y1.shape[2], y2.shape[2]]) + return [new_y1, new_y2] + + +class ThreeMatmulFusionPattern(GraphPattern): + def __init__(self): + super().__init__('matmul(x, c1)|matmul(x, c2)|matmul(x, c3) ==> matmul(x, concat(c1, c2, c3)) followed by split') + self.x = TensorPattern.tensor(is_symbolic=True) + self.c1 = TensorPattern.tensor(is_const=True) + self.c2 = TensorPattern.tensor(is_const=True) + self.c3 = TensorPattern.tensor(is_const=True) + self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) + self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) + self.y3 = op_pattern(MatmulOp, [self.x, self.c3]) + + def source(self) -> List[TensorPattern]: + return [self.y1, self.y2, self.y3] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, c1, c2, c3, y1, y2, y3 = [matched[t] for t in [self.x, self.c1, self.c2, self.c3, self.y1, self.y2, self.y3]] + c = ops.concat([c1, c2, c3], axis=2) + y = ops.matmul(x, c) + new_y1, new_y2, new_y3 = ops.split(y, axis=2, parts=[y1.shape[2], y2.shape[2], y3.shape[2]]) + return [new_y1, new_y2, new_y3] + + +class ThreeMatmulBiasFusionPattern(GraphPattern): + def __init__(self): + super().__init__('3 branches of matmul(x, branch c) + branch b ==> matmul(x, c) + b followed by split') + self.x = TensorPattern.tensor(is_symbolic=True) + self.c1 = TensorPattern.tensor(is_const=True) + self.c2 = TensorPattern.tensor(is_const=True) + self.c3 = TensorPattern.tensor(is_const=True) + self.b1 = TensorPattern.tensor(is_const=True) + self.b2 = TensorPattern.tensor(is_const=True) + self.b3 = TensorPattern.tensor(is_const=True) + self.y1 = op_pattern(MatmulOp, [self.x, self.c1]) + self.b1 + self.y2 = op_pattern(MatmulOp, [self.x, self.c2]) + self.b2 + self.y3 = op_pattern(MatmulOp, [self.x, self.c3]) + self.b3 + + def source(self) -> List[TensorPattern]: + return [self.y1, self.y2, self.y3] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, c1, c2, c3, b1, b2, b3, y1, y2, y3 = [matched[t] for t in [self.x, self.c1, self.c2, self.c3, self.b1, self.b2, self.b3, self.y1, self.y2, self.y3]] + c = ops.concat([c1, c2, c3], axis=2) + b = ops.concat([b1, b2, b3], axis=-1) + y = ops.matmul(x, c) + b + new_y1, new_y2, new_y3 = ops.split(y, axis=2, parts=[y1.shape[2], y2.shape[2], y3.shape[2]]) + return [new_y1, new_y2, new_y3] + + +def matmul_patterns() -> List[GraphPattern]: + return [ + MatmulRightScalePattern(), + MatmulLeftScalePattern(), + ThreeMatmulBiasFusionPattern(), + ThreeMatmulFusionPattern(), + TwoMatmulFusionPattern(), + ] diff --git a/python/hidet/tos/transforms/graph_patterns/transform_patterns.py b/python/hidet/tos/transforms/graph_patterns/transform_patterns.py new file mode 100644 index 0000000..045f8e7 --- /dev/null +++ b/python/hidet/tos/transforms/graph_patterns/transform_patterns.py @@ -0,0 +1,112 @@ +from typing import List, Optional + +from hidet.tos import ops +from hidet.tos.ir.graph import Tensor +from hidet.tos.ops.definitions.transform import ReshapeOp, SqueezeOp, StridedSliceOp +from .base import GraphPattern, TensorPattern, MatchDict, op_pattern +from hidet.utils import prod + + +def reverse_reshape_dim(orig_shape, new_shape, new_axis) -> Optional[int]: + pre_sum = prod(new_shape[:new_axis]) + cnt = 1 + for i, extent in enumerate(orig_shape): + if cnt == pre_sum: + if len(orig_shape) == i: + return None + elif orig_shape[i] == new_shape[new_axis]: + return i + else: + return None + elif cnt > pre_sum: + return None + else: + cnt *= extent + return None + + +class ReshapeScalePattern(GraphPattern): + def __init__(self): + super().__init__('reshape(x) * scale') + self.x = TensorPattern.tensor(is_symbolic=True) + self.scale = TensorPattern.tensor(is_const=True) + self.y = op_pattern(ReshapeOp, [self.x]) + self.z = self.y * self.scale + + def source(self) -> List[TensorPattern]: + return [self.z] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, scale, y, z = [matched[v] for v in [self.x, self.scale, self.y, self.z]] + if len(scale.shape) < len(y.shape): + diff_dims = len(y.shape) - len(scale.shape) + scale = scale.unsqueeze(dims=list(range(diff_dims))) + scale_dims = [i for i, dim in enumerate(scale.shape) if dim != 1] + if len(scale_dims) == 0: + return [ops.reshape(x * ops.flatten(scale), shape=y.shape)] + elif len(scale_dims) == 1: + dim = reverse_reshape_dim(x.shape, y.shape, scale_dims[0]) + if dim is None: + return None + scale = ops.flatten(scale).unsqueeze([i for i in range(len(x.shape)) if i != dim]) + return [ops.reshape(x * scale, shape=y.shape)] + else: + return None + + +class ReshapeBiasPattern(GraphPattern): + def __init__(self): + super().__init__('reshape(x) + bias') + self.x = TensorPattern.tensor(is_symbolic=True) + self.bias = TensorPattern.tensor(is_const=True) + self.y = op_pattern(ReshapeOp, [self.x]) + self.z = self.y + self.bias + + def source(self) -> List[TensorPattern]: + return [self.z] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, bias, y, z = [matched[v] for v in [self.x, self.bias, self.y, self.z]] + if len(bias.shape) < len(y.shape): + diff_dims = len(y.shape) - len(bias.shape) + bias = bias.unsqueeze(dims=list(range(diff_dims))) + scale_dims = [i for i, dim in enumerate(bias.shape) if dim != 1] + if len(scale_dims) == 0: + return [ops.reshape(x + ops.flatten(bias), shape=y.shape)] + elif len(scale_dims) == 1: + dim = reverse_reshape_dim(x.shape, y.shape, scale_dims[0]) + if dim is None: + return None + bias = ops.flatten(bias).unsqueeze([i for i in range(len(x.shape)) if i != dim]) + return [ops.reshape(x + bias, shape=y.shape)] + else: + return None + + +class SqueezeMultiplyPattern(GraphPattern): + def __init__(self): + super().__init__('squeeze(x) * c => squeeze(x * c)') + self.x = TensorPattern.tensor() + self.c = TensorPattern.tensor(is_const=True) + self.s = op_pattern(SqueezeOp, [self.x]) + self.y = self.s * self.c + + def source(self) -> List[TensorPattern]: + return [self.y] + + def target(self, matched: MatchDict) -> Optional[List[Tensor]]: + x, c, s, y = matched[self.x], matched[self.c], matched[self.s], matched[self.y] + dims = s.op.attrs['dims'] + if len(c.shape) < len(y.shape): + c = c.unsqueeze(list(range(len(y.shape) - len(c.shape)))) + c = c.unsqueeze(dims) # now, c has the same shape as x + return [ops.squeeze(x * c, dims=dims)] + + +def transform_patterns() -> List[GraphPattern]: + return [ + ReshapeScalePattern(), + ReshapeBiasPattern(), + SqueezeMultiplyPattern() + ] + diff --git a/python/hidet/tos/transforms/instruments/__init__.py b/python/hidet/tos/transforms/instruments/__init__.py new file mode 100644 index 0000000..884f737 --- /dev/null +++ b/python/hidet/tos/transforms/instruments/__init__.py @@ -0,0 +1,3 @@ +from .base import GraphPassInstrument +from .profile_instrument import ProfileInstrument +from .save_graph_instrument import SaveGraphInstrument diff --git a/python/hidet/tos/transforms/instruments/base.py b/python/hidet/tos/transforms/instruments/base.py new file mode 100644 index 0000000..41df3aa --- /dev/null +++ b/python/hidet/tos/transforms/instruments/base.py @@ -0,0 +1,19 @@ +from hidet.tos.ir.graph import FlowGraph + + +class GraphPassInstrument: + def before_all_passes(self, graph: FlowGraph): + pass + + def before_pass(self, pass_name: str, graph: FlowGraph): + pass + + def after_pass(self, pass_name: str, graph: FlowGraph): + pass + + def after_all_passes(self, graph: FlowGraph): + pass + + + + diff --git a/python/hidet/tos/transforms/instruments/profile_instrument.py b/python/hidet/tos/transforms/instruments/profile_instrument.py new file mode 100644 index 0000000..9ec8798 --- /dev/null +++ b/python/hidet/tos/transforms/instruments/profile_instrument.py @@ -0,0 +1,37 @@ +import os +import time +from typing import Optional, Dict + +from hidet import utils +from hidet.tos.ir.graph import FlowGraph + +from .base import GraphPassInstrument + + +class ProfileInstrument(GraphPassInstrument): + def __init__(self, log_file: Optional[str] = None, print_stdout: bool = False): + if log_file: + dirname = os.path.dirname(log_file) + os.makedirs(dirname, exist_ok=True) + self.log_file = log_file + self.print_stdout = print_stdout + self.start_time: Dict[str, float] = {} + + def before_all_passes(self, graph: FlowGraph): + if self.log_file: + # clear file contents + with open(self.log_file, 'w'): + pass + + def before_pass(self, pass_name: str, graph: FlowGraph): + self.start_time[pass_name] = time.time() + if self.print_stdout: + print('{:>50} started...'.format(pass_name)) + + def after_pass(self, pass_name: str, graph: FlowGraph): + elapsed_time = time.time() - self.start_time[pass_name] + if self.log_file: + with open(self.log_file, 'a') as f: + f.write('{:>50} {:.3f} seconds\n'.format(pass_name, elapsed_time)) + if self.print_stdout: + print('{:>50} {} seconds'.format(pass_name, utils.py.green(elapsed_time, '{:.3f}'))) diff --git a/python/hidet/tos/transforms/instruments/save_graph_instrument.py b/python/hidet/tos/transforms/instruments/save_graph_instrument.py new file mode 100644 index 0000000..4c25336 --- /dev/null +++ b/python/hidet/tos/transforms/instruments/save_graph_instrument.py @@ -0,0 +1,29 @@ +import os + +from hidet import utils +from hidet.tos.ir.graph import FlowGraph + +from .base import GraphPassInstrument + + +class SaveGraphInstrument(GraphPassInstrument): + def __init__(self, out_dir: str): + self.out_dir = out_dir + self.index = 0 + os.makedirs(out_dir, exist_ok=True) + + def before_all_passes(self, graph: FlowGraph): + # first clean all json starting with indices + for fname in os.listdir(self.out_dir): + fpath = os.path.join(self.out_dir, fname) + parts = fname.split('_') + if os.path.isfile(fpath) and len(parts) > 1 and parts[0].isdigit() and fname.endswith('.json'): + os.remove(fpath) + with open(os.path.join(self.out_dir, '0_Origin.json'), 'w') as f: + utils.netron.dump(graph, f) + self.index += 1 + + def after_pass(self, pass_name: str, graph: FlowGraph): + with open(os.path.join(self.out_dir, '{}_{}.json'.format(self.index, pass_name)), 'w') as f: + utils.netron.dump(graph, f) + self.index += 1 diff --git a/python/hidet/tos/transforms/pattern_transform.py b/python/hidet/tos/transforms/pattern_transform.py new file mode 100644 index 0000000..0db9402 --- /dev/null +++ b/python/hidet/tos/transforms/pattern_transform.py @@ -0,0 +1,122 @@ +from typing import List, Optional, Dict, Tuple, Set, Union + +import hidet.tos.ops.definitions +from hidet import tos +from hidet.tos.ir.graph import FlowGraph, Operator, Tensor +from hidet.tos.transforms import GraphPass, PassContext +from hidet.tos.ir.functors import analyze_usage, graph_collect +from .fold_const import fold_const_pass +from .graph_patterns import GraphPattern, TensorPattern, OperatorPattern, MatchDict, Usage, graph_pattern_match +from .graph_patterns import all_graph_patterns +from hidet.utils import strict_zip + + +class PatternTransformPass(GraphPass): + """ + A pattern transform can be conducted only if + 1. The pattern source matched the actual tensor and its spanned subregion. + 2. The intermediate tensor in the matched region should not be used. Only the output tensors can be used + by not matched operators in original graph. + For example, if pattern a -> b -> c matched x -> y -> z. We need to make sure y has not been + used by other operators in the original graph. + + Time complexity of this implementation: O(num_applies * num_operators * num_patterns * pattern_size). + """ + max_num_transforms = 1000 + + def process_graph(self, graph: FlowGraph) -> FlowGraph: + graph = tos.ir.functors.clone(graph) + graph_patterns = all_graph_patterns() + fold_const = fold_const_pass() + for t in range(self.max_num_transforms): + updated, graph = self.try_transform(graph, graph_patterns) + graph = fold_const.process_graph(graph) + if not updated: + return graph + print('Exceeded maximum number of transforms {}, stop early.'.format(self.max_num_transforms)) + return graph + + @staticmethod + def match_pattern(graph_pattern: GraphPattern, start_tensor: Tensor, usage: Usage) -> Optional[MatchDict]: + source_output_tensors = graph_pattern.source() + + matched = graph_pattern_match(source_output_tensors[0], target=start_tensor, usage=usage) + if matched is None: + return None + + for source_output_tensor in source_output_tensors: + if source_output_tensor not in matched: + raise NotImplementedError('The source pattern is not connected. Current we do not support disconnected patterns.') + + return matched + + @staticmethod + def check_usage_requirement(matched: MatchDict, usage: Usage, graph_pattern: GraphPattern) -> bool: + source_output_pattern_tensors: List[TensorPattern] = graph_pattern.source() + + # actual tensor -> pattern tensor + tensor_map: Dict[Tensor, TensorPattern] = {v: k for k, v in matched.items() if isinstance(v, Tensor)} + + # find out all inner tensors (all matched tensors that are not matched by output tensors, nor input tensors) + inner_tensors: List[Tensor] = [] + for actual_tensor in tensor_map: + pattern_tensor = tensor_map[actual_tensor] + # input tensor in pattern + if pattern_tensor.trace is None: + continue + # output tensor in pattern + if pattern_tensor in source_output_pattern_tensors: + continue + + # check whether all inner tensors are only used by matched operators + matched_operators: Set[Operator] = {v for v in matched.values() if isinstance(v, Operator)} + for inner_tensor in inner_tensors: + uses: List[Tuple[Optional[Operator], int]] = usage[inner_tensor] + if any(use[0] not in matched_operators for use in uses): + # used by not matched operator + return False + return True + + @staticmethod + def try_transform(graph: FlowGraph, graph_patterns: List[GraphPattern]) -> Tuple[bool, FlowGraph]: + patterns: List[GraphPattern] = graph_patterns + usage: Usage = analyze_usage(graph) + all_tensors: List[Tensor] = graph_collect(graph, Tensor) + + for graph_pattern in patterns: + # print(graph_pattern.name) + for start_tensor in all_tensors: + # condition 1 + matched = PatternTransformPass.match_pattern(graph_pattern, start_tensor, usage) + if matched is None: + continue + + # condition 2 + success = PatternTransformPass.check_usage_requirement(matched, usage, graph_pattern) + if not success: + continue + + # generate target subgraph + target_output_tensors: Optional[List[Tensor]] = graph_pattern.target(matched) + if target_output_tensors is None: + # matched graph pattern can not be applied to this subgraph + continue + + # apply the graph transform + if PassContext.current().configs['verbose']: + print('Applying transform: {}'.format(graph_pattern.name)) + source_output_pattern_tensors = graph_pattern.source() + source_output_tensors = [matched[t] for t in source_output_pattern_tensors] + for source_tensor, target_tensor in strict_zip(source_output_tensors, target_output_tensors): + for use in usage[source_tensor]: + op, idx = use + if op is None: + graph.outputs[idx] = target_tensor + else: + op.inputs[idx] = target_tensor + return True, graph + return False, graph + + +def pattern_transform_pass() -> GraphPass: + return PatternTransformPass() diff --git a/python/hidet/tos/transforms/resolve_mma.py b/python/hidet/tos/transforms/resolve_mma.py new file mode 100644 index 0000000..d3a0ed1 --- /dev/null +++ b/python/hidet/tos/transforms/resolve_mma.py @@ -0,0 +1,79 @@ +from typing import List, Optional +import warnings +from .base import GraphPass, PassContext +from hidet.tos.ir import FlowGraph, Operator, Tensor, GraphRewriter +from hidet.tos.ops.definitions import MatmulOp +from hidet.tos import ops +from hidet.ir.type import max_float_dtype +from hidet.tos.ops.definitions.matmul.matmul import batched_matmul + + +class ResolveMmaRewriter(GraphRewriter): + def visit_Operator(self, op: Operator): + if isinstance(op, MatmulOp): + a: Tensor = self(op.inputs[0]) + b: Tensor = self(op.inputs[1]) + mma_type: str = PassContext.current().configs['mma'] + reduce_dtype: Optional[str] = PassContext.current().configs['reduce_precision'] + op_mma, ta, tb, tc = [op.attrs[name] for name in ['mma', 'ta', 'tb', 'tc']] + if op_mma == 'default': + if mma_type == 'wmma': + mma_dtype = self.get_mma_dtype(a.dtype, b.dtype) + mma_acc_dtype = self.get_mma_acc_dtype(mma_dtype, reduce_dtype) + mma = 'wmma_{}_{}'.format(mma_dtype, mma_acc_dtype) + elif mma_type == 'mma': + raise NotImplementedError() + elif mma_type == 'simt': + mma = 'simt' + elif mma_type.startswith('wmma_'): + mma = mma_type + else: + raise ValueError('Can not recognize mma_type {}'.format(mma_type)) + else: + mma = op_mma + self.memo[op.outputs[0]] = batched_matmul(a, b, algo=op.attrs['algo'], mma=mma, ta=ta, tb=tb, tc=tc) + else: + return GraphRewriter.visit_Operator(self, op) + + @staticmethod + def get_mma_dtype(a_dtype: str, b_dtype: str): + dtype = max_float_dtype([a_dtype, b_dtype]) + if dtype not in ['float16', 'bfloat16', 'float32']: + raise ValueError('Can not recognize data type {} as input data type of matrix multiplication.'.format(dtype)) + return { + 'float16': 'f16', + 'bfloat16': 'bf16', + 'float32': 'tf32' + }[dtype] + + @staticmethod + def get_mma_acc_dtype(mma_dtype: str, acc_dtype: Optional[str]): + if mma_dtype == 'f16': + if acc_dtype is None: + return 'f32' + elif acc_dtype == 'float16': + return 'f16' + elif acc_dtype == 'float32': + return 'f32' + else: + raise ValueError() + elif mma_dtype == 'bf16': + if acc_dtype != 'float32' and acc_dtype is not None: + warnings.warn('bfloat16 only support float32 accumulation in wmma instruction, but got {}. float32 is used.'.format(acc_dtype)) + return 'f32' + elif mma_dtype == 'tf32': + if acc_dtype != 'float32' and acc_dtype is not None: + warnings.warn('tfloat32 only support float32 accumulation in wmma instruction, but got {}. float32 is used.'.format(acc_dtype)) + return 'f32' + else: + raise ValueError('Can not recognize mma_dtype {}'.format(mma_dtype)) + + +class ResolveMmaPass(GraphPass): + def process_graph(self, graph: FlowGraph) -> FlowGraph: + rewriter = ResolveMmaRewriter() + return rewriter(graph) + + +def resolve_mma_pass() -> GraphPass: + return ResolveMmaPass() diff --git a/python/hidet/tos/transforms/resolve_variant.py b/python/hidet/tos/transforms/resolve_variant.py new file mode 100644 index 0000000..f359eec --- /dev/null +++ b/python/hidet/tos/transforms/resolve_variant.py @@ -0,0 +1,56 @@ +from typing import Type, List +from .base import GraphPass +from .resolve_variant_rules import ResolveRule, Conv2dResolveRule, MatmulResolveRule +from hidet.tos.ir import FlowGraph, GraphRewriter, Tensor, Operator +from hidet.utils import strict_zip, same_list + + +class ResolveVariantRewriter(GraphRewriter): + def __init__(self, rule: ResolveRule): + super().__init__() + self.rule = rule + + def visit_Operator(self, op: Operator): + op_cls = self.rule.op_cls() + if not isinstance(op, op_cls): + return GraphRewriter.visit_Operator(self, op) + inputs = [self(x) for x in op.inputs] + if same_list(inputs, op.inputs): + resolve_op = op + else: + updated_outputs = op.reforward(inputs) + resolve_op = updated_outputs[0].op + outs = self.rule.resolve(resolve_op) + + if outs is None: + # keep the original operator + # we still need to update memo in case inputs changed + for original, updated in strict_zip(op.outputs, resolve_op.outputs): + assert original not in self.memo + self.memo[original] = updated + else: + # update output of resolved operator + for original, updated in strict_zip(op.outputs, outs): + assert original not in self.memo + self.memo[original] = updated + + +class ResolveVariantPass(GraphPass): + def process_graph(self, graph: FlowGraph) -> FlowGraph: + rule_seq: List[ResolveRule] = [ + Conv2dResolveRule(), + MatmulResolveRule() + ] + for rule in rule_seq: + resolver = ResolveVariantRewriter(rule) + while True: + updated_graph = resolver(graph) + if updated_graph is graph: + break + else: + graph = updated_graph + return graph + + +def resolve_variant_pass() -> GraphPass: + return ResolveVariantPass() diff --git a/python/hidet/tos/transforms/resolve_variant_rules/__init__.py b/python/hidet/tos/transforms/resolve_variant_rules/__init__.py new file mode 100644 index 0000000..7168c74 --- /dev/null +++ b/python/hidet/tos/transforms/resolve_variant_rules/__init__.py @@ -0,0 +1,3 @@ +from .base import ResolveRule +from .conv2d_rule import Conv2dResolveRule +from .matmul_rule import MatmulResolveRule diff --git a/python/hidet/tos/transforms/resolve_variant_rules/base.py b/python/hidet/tos/transforms/resolve_variant_rules/base.py new file mode 100644 index 0000000..278f5ef --- /dev/null +++ b/python/hidet/tos/transforms/resolve_variant_rules/base.py @@ -0,0 +1,23 @@ +from typing import Type, List, Optional +from hidet.tos.ir import FlowGraph, GraphRewriter, Tensor, Operator + + +class ResolveRule: + def op_cls(self) -> Type[Operator]: + raise NotImplementedError() + + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + """ + Parameters + ---------- + op: Operator + The operator to be resolved. + + Returns + ------- + ret: Optional[List[Tensor]] + None - indicates the operator has not been resolved, keep the original operator. + List[Tensor] - the output of resolved operators. + """ + raise NotImplementedError() + diff --git a/python/hidet/tos/transforms/resolve_variant_rules/conv2d_rule.py b/python/hidet/tos/transforms/resolve_variant_rules/conv2d_rule.py new file mode 100644 index 0000000..b5b6899 --- /dev/null +++ b/python/hidet/tos/transforms/resolve_variant_rules/conv2d_rule.py @@ -0,0 +1,29 @@ +from typing import List, Type, Optional +from hidet.tos.ir import Operator, Tensor +from hidet.tos import ops +from hidet.tos.ops.definitions import Conv2dOp +from .base import ResolveRule + + +class Conv2dResolveRule(ResolveRule): + def __init__(self, enable_winograd=False): + self.enable_winograd = enable_winograd + + def op_cls(self) -> Type[Operator]: + return Conv2dOp + + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + assert isinstance(op, Conv2dOp) + stride = ops.utils.normalize_stride(op.attrs['stride']) + groups = op.attrs['groups'] + if groups != 1: + return None + data, weight = op.inputs + kernel_size = weight.shape[2:] + if self.enable_winograd and tuple(stride) == (1, 1) and tuple(kernel_size) == (3, 3): + # winograd algorithm + out = ops.conv2d_winograd(data, weight) + else: + # implicit gemm algorithm + out = ops.conv2d_gemm(data, weight, stride) + return [out] diff --git a/python/hidet/tos/transforms/resolve_variant_rules/matmul_rule.py b/python/hidet/tos/transforms/resolve_variant_rules/matmul_rule.py new file mode 100644 index 0000000..920e61b --- /dev/null +++ b/python/hidet/tos/transforms/resolve_variant_rules/matmul_rule.py @@ -0,0 +1,40 @@ +from typing import List, Type, Optional +from hidet.tos.ir import Operator, Tensor +from hidet.tos.transforms.base import PassContext +from hidet.tos.ops.definitions.matmul.matmul import matmul +from hidet.tos.ops.definitions.matmul.parallel_k_matmul import parallel_k_batched_matmul, parallel_k_nparts, parallel_k_batched_matmul_search +from hidet.tos.ops.definitions import MatmulOp +from .base import ResolveRule + + +def op_use_parallel_k(op: MatmulOp) -> bool: + a, b = op.inputs + batch_size, m_size, k_size = a.shape + n_size = b.shape[2] + return parallel_k_nparts(batch_size, m_size, n_size, k_size) != 1 + + +class MatmulResolveRule(ResolveRule): + def op_cls(self) -> Type[Operator]: + return MatmulOp + + def resolve(self, op: Operator) -> Optional[List[Tensor]]: + assert isinstance(op, MatmulOp) + a, b = op.inputs + if op.attrs['algo'] == 'default': + parallel_k = PassContext.current().configs['parallel_k'] + if parallel_k == 'disabled': + return [matmul(a, b, 'direct', mma=op.attrs['mma'])] + elif parallel_k == 'default': + if op_use_parallel_k(op): + return [parallel_k_batched_matmul(a, b, mma=op.attrs['mma'])] + else: + return [matmul(a, b, 'direct', mma=op.attrs['mma'])] + elif parallel_k == 'search': + return [parallel_k_batched_matmul_search(a, b, mma=op.attrs['mma'])] + elif isinstance(parallel_k, int): + return [parallel_k_batched_matmul(a, b, mma=op.attrs['mma'], nparts=parallel_k)] + else: + raise ValueError('Can not recognize parallel_k config: {}'.format(parallel_k)) + + return None diff --git a/python/hidet/tos/transforms/utils.py b/python/hidet/tos/transforms/utils.py new file mode 100644 index 0000000..36d347e --- /dev/null +++ b/python/hidet/tos/transforms/utils.py @@ -0,0 +1,6 @@ +from hidet.tos.ir import Operator + + +def is_barrier(op: Operator): + from hidet.tos.ops.definitions.special import BarrierOp + return isinstance(op, BarrierOp) diff --git a/python/hidet/transforms/__init__.py b/python/hidet/transforms/__init__.py new file mode 100644 index 0000000..418c914 --- /dev/null +++ b/python/hidet/transforms/__init__.py @@ -0,0 +1,67 @@ +from hidet.ir.func import IRModule + +from .base import Pass, FunctionPass, FunctionBodyPass, SequencePass, RepeatFunctionPass, PassContext +from .instruments import PassInstrument, SaveIRInstrument, ProfileInstrument + +from .apply_prologue_epilogue import apply_prologue_epilogue_pass +from .flatten_tensor_slice import flatten_tensor_slice_pass +from .flatten_tensor_index import flatten_tensor_index_pass +from .generate_packed_func import generate_packed_func_pass +from .import_primitive_functions import import_primitive_functions_pass +from .simplify_stmt import simplify_stmt_pass +from .expand_let_expr import expand_let_expr_pass +from .resolve_generic_primitive_function import resolve_primitive_func_pass +from .add_explicit_cast import add_explicit_cast_pass +from .explicit_unroll_for_stmt import explicit_unroll_for_stmt_pass +from .inline_let_stmt import inline_let_stmt_pass +from .common_subexpression_elimination import common_subexpression_elimination_pass, chain_seq_stmt_using_let_stmt_pass +from .build_let_stmt import build_let_stmt_pass +from .rule_based_simplifier import rule_based_simplify_pass +from .simplify_stmt import simplify_stmt_pass +from .squeeze_let_stmt import squeeze_let_stmt_pass +from .uplift_let_stmt import uplift_let_stmt_pass +from .precompute_condition import precompute_condition_pass +from .normalize_const_tensor import normalize_const_tensor_pass + + +def lower(ir_module: IRModule) -> IRModule: + transforms = [ + # necessary passes + flatten_tensor_slice_pass(), + apply_prologue_epilogue_pass(), + generate_packed_func_pass(), + normalize_const_tensor_pass(), + flatten_tensor_index_pass(), + resolve_primitive_func_pass(), + import_primitive_functions_pass(), + resolve_primitive_func_pass(), + import_primitive_functions_pass(), + add_explicit_cast_pass(), + + # simplification + expand_let_expr_pass(), + inline_let_stmt_pass(inline_all=True), + rule_based_simplify_pass(), + simplify_stmt_pass(), + + # common sub-expression elimination + # build_let_stmt_pass(), + # uplift_let_stmt_pass(), + # common_subexpression_elimination_pass(), + # inline_let_stmt_pass(inline_factor=1), + + # optimization (precompute condition) + # precompute_condition_pass(), + + # necessary pass + ] + + ctx = PassContext.current() + for instrument in ctx.instruments: + instrument.before_all_passes(ir_module) + for transform in transforms: + ir_module = transform(ir_module) + for instrument in ctx.instruments: + instrument.after_all_passes(ir_module) + + return ir_module diff --git a/python/hidet/transforms/add_explicit_cast.py b/python/hidet/transforms/add_explicit_cast.py new file mode 100644 index 0000000..12f4bbc --- /dev/null +++ b/python/hidet/transforms/add_explicit_cast.py @@ -0,0 +1,149 @@ +from .base import FunctionBodyPass, Pass +from hidet.ir.functors import StmtExprRewriter, TypeInfer, TypeFunctor +from hidet.ir.dialects.lowlevel import TensorPointerType, PointerType, ReferenceType, VoidType +from hidet.ir.stmt import Stmt, AssignStmt, BufferStoreStmt +from hidet.ir.expr import Expr, Cast, Add, Sub, Multiply, Div, FloorDiv, BinaryOp, cast +from hidet.ir.type import ScalarType, TypeNode, TensorType +from hidet.utils import same_list + + +class TypeNotMatch(Exception): + def __init__(self, a, b, msg=""): + super().__init__() + self.a = a + self.b = b + self.msg = msg + + +class TypeChecker: + def visit(self, a: TypeNode, b: TypeNode): + if isinstance(a, ScalarType): + return self.visit_ScalarType(a, b) + elif isinstance(a, TensorType): + return self.visit_TensorType(a, b) + elif isinstance(a, PointerType): + return self.visit_PointerType(a, b) + elif isinstance(a, TensorPointerType): + return self.visit_TensorPointerType(a, b) + elif isinstance(a, ReferenceType): + return self.visit_ReferenceType(a, b) + elif isinstance(a, VoidType): + return self.visit_VoidType(a, b) + else: + raise ValueError('Can not recognize type {}'.format(a)) + + @staticmethod + def check(a, b, cond, msg=""): + if not cond: + raise TypeNotMatch(a, b, msg) + + def visit_ScalarType(self, a: ScalarType, b: TypeNode): + self.check(a, b, isinstance(b, ScalarType)) + assert isinstance(b, ScalarType) + self.check(a, b, a.name == b.name) + + def visit_TensorType(self, a: TensorType, b: TypeNode): + self.check(a, b, isinstance(b, TensorType)) + assert isinstance(b, TensorType) + self.visit(a.scalar_type, b.scalar_type) + # todo: check data layout and shape + + def visit_PointerType(self, a: PointerType, b: TypeNode): + self.check(a, b, isinstance(b, PointerType)) + assert isinstance(b, PointerType) + self.visit(a.base_type, b.base_type) + + def visit_TensorPointerType(self, a: TensorPointerType, b: TypeNode): + self.check(a, b, isinstance(b, TensorPointerType)) + assert isinstance(b, TensorPointerType) + self.visit(a.tensor_type, b.tensor_type) + + def visit_ReferenceType(self, a: ReferenceType, b: TypeNode): + self.check(a, b, isinstance(b, ReferenceType)) + assert isinstance(b, ReferenceType) + self.visit(a.base_type, b.base_type) + + def visit_VoidType(self, a: VoidType, b: TypeNode): + self.check(a, b, isinstance(b, VoidType)) + + +def same_type(a: TypeNode, b: TypeNode) -> bool: + try: + TypeChecker().visit(a, b) + return True + except TypeNotMatch: + return False + + +class AddExplicitCastRewriter(StmtExprRewriter): + def __init__(self): + super().__init__() + self.type_infer = TypeInfer() + + @staticmethod + def convert(source_type: TypeNode, target_type: TypeNode, source_value: Expr) -> Expr: + if isinstance(source_type, ScalarType) and isinstance(target_type, ScalarType): + # because there is no implicit conversion function between bfloat16 and float16 + # in the underlying cuda c library, we use 'float32' as a bridge type + has_float16 = 'float16' in [source_type.name, target_type.name] + has_bfloat16 = 'bfloat16' in [source_type.name, target_type.name] + if has_float16 and has_bfloat16: + return Cast(Cast(source_value, 'float32'), target_type) + if same_type(source_type, target_type): + return source_value + else: + return Cast(source_value, target_type) + + def visit_Binary(self, e: BinaryOp): + if isinstance(e, (Add, Sub, Multiply, Div)): + a, b = self(e.a), self(e.b) + a_dtype: ScalarType = self.type_infer(a) + b_dtype: ScalarType = self.type_infer(b) + op = e.__class__ + if a_dtype > b_dtype: + return op(a, cast(b, a_dtype)) + elif a_dtype < b_dtype: + return op(cast(a, b_dtype), b) + else: + return StmtExprRewriter.visit_Binary(self, e) + else: + return StmtExprRewriter.visit_Binary(self, e) + + def visit_Cast(self, e: Cast): + expr = self(e.expr) + source_type = self.type_infer(expr) + target_type = e.target_type + return self.convert(source_type, target_type, expr) + + def visit_AssignStmt(self, stmt: AssignStmt): + value = self(stmt.value) + var = self(stmt.var) + source_type = self.type_infer(value) + target_type = self.type_infer(var) + return AssignStmt(var, self.convert(source_type, target_type, value)) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + value = self(stmt.value) + buf = self(stmt.buf) + indices = self(stmt.indices) + source_type = self.type_infer(value) + buffer_type = self.type_infer(buf) + if isinstance(buffer_type, TensorType): + target_type = buffer_type.scalar_type + elif isinstance(buffer_type, TensorPointerType): + target_type = buffer_type.tensor_type.scalar_type + elif isinstance(buffer_type, PointerType): + target_type = buffer_type.base_type + else: + raise ValueError('Can not recognize the buffer type: {}'.format(buffer_type)) + return BufferStoreStmt(buf, indices, self.convert(source_type, target_type, source_value=value)) + + +class AddExplicitCastPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + rewriter = AddExplicitCastRewriter() + return rewriter(stmt) + + +def add_explicit_cast_pass() -> Pass: + return AddExplicitCastPass() diff --git a/python/hidet/transforms/apply_prologue_epilogue.py b/python/hidet/transforms/apply_prologue_epilogue.py new file mode 100644 index 0000000..b04dd39 --- /dev/null +++ b/python/hidet/transforms/apply_prologue_epilogue.py @@ -0,0 +1,111 @@ +from typing import Dict, List +from hidet.ir.expr import Var, TensorElement +from hidet.ir.stmt import BufferStoreStmt +from hidet.ir.func import Function, IRModule +from hidet.ir.dialects.compute import TensorNode +from hidet.ir.functors import collect, rewrite, inline_compute +from .base import Pass + + +class ApplyPrologueEpiloguePass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + if len(ir_module.functions) != 1: + raise ValueError('apply_prologue_epilogue_pass should run first before generate_packed_func pass.') + func = list(ir_module.functions.values())[0] + task = ir_module.task + + if task is None: + return ir_module + + if not (len(func.params) == len(task.inputs) + len(task.outputs)): + raise ValueError('The parameters of function should be the same as the sum of task inputs and outputs.') + num_inputs = len(task.inputs) + input2var: Dict[TensorNode, Var] = {a: b for a, b in zip(task.inputs, func.params[:num_inputs])} + output2var: Dict[TensorNode, Var] = {a: b for a, b in zip(task.outputs, func.params[num_inputs:])} + + param_vars = [Var(param.name, param.data_type) for param in task.parameters] + param2var = {a: b for a, b in zip(task.parameters, param_vars)} + + body = func.body + + # update func parameters + rmap = {} + for idx, t in enumerate(task.inputs + task.outputs): + if t in param2var: + rmap[func.params[idx]] = param2var[t] + body = rewrite(body, rmap) + + # apply prologues + for input_node, input_var in input2var.items(): + if input_node not in task.prologues: + continue + prologue = task.prologues[input_node] + prologue_value = inline_compute(prologue.value, reduce_limit=-1) + + # the following collect assumes that there is no nested tensor elements for the same tensor, such as A[A[1, 2], 3] + tensor_elements: List[TensorElement] = collect(body, TensorElement) + prologue_rewrite_map = {} + for te in tensor_elements: + if te.base is not input_var: + continue + rmap = {} + for extra_input in prologue.extra_inputs: + if extra_input not in param2var: + msg = 'Prologue used tensor {} that has not defined in task parameters. Task:\n{}'.format( + extra_input, task + ) + raise ValueError(msg) + rmap[extra_input] = param2var[extra_input] + for index_var, index_value in zip(prologue.indices, te.indices): + rmap[index_var] = index_value + prologue_expr = rewrite(prologue_value, rmap) + prologue_rewrite_map[te] = prologue_expr + body = rewrite(body, prologue_rewrite_map) + + # apply epilogues + for output_node, output_var in output2var.items(): + if output_node not in task.epilogues: + continue + epilogue = task.epilogues[output_node] + + # first check the usage of output var in TensorElement + tensor_elements: List[TensorElement] = collect(body, TensorElement) + if any(te.base is output_var for te in tensor_elements): + raise NotImplementedError('Currently do not support read from output tensor.') + + # todo: support nested cases + buffer_stores: List[BufferStoreStmt] = collect(body, BufferStoreStmt) + epilogue_rewrite_map = {} + epilogue_value = inline_compute(epilogue.value, reduce_limit=-1) + for bs in buffer_stores: + if bs.buf is not output_var: + continue + rmap = {epilogue.orig_value: bs.value} + for extra_input in epilogue.extra_inputs: + if extra_input not in param2var: + raise ValueError('Epilogue used tensor {} that has not defined in task parameters.'.format(extra_input)) + rmap[extra_input] = param2var[extra_input] + for index_var, index_value in zip(epilogue.indices, bs.indices): + rmap[index_var] = index_value + epilogue_expr = rewrite(epilogue_value, rmap) + if epilogue.out_indices and epilogue.out_tensor: + out_index_exprs = [rewrite(out_index_expr, rmap) for out_index_expr in epilogue.out_indices] + if epilogue.out_tensor not in param2var: + raise ValueError('Epilogue used a output tensor that has not defined in task parameters.'.format(epilogue.out_tensor)) + out_tensor = param2var[epilogue.out_tensor] + epilogue_rewrite_map[bs] = BufferStoreStmt(out_tensor, out_index_exprs, epilogue_expr) + else: + epilogue_rewrite_map[bs] = BufferStoreStmt(bs.buf, bs.indices, epilogue_expr) + body = rewrite(body, epilogue_rewrite_map) + + if body is func.body: + return ir_module + else: + func = Function(func.name, params=param_vars, body=body, ret_type=func.ret_type, kind=func.kind, + local_vars=func.local_vars, local_const_vars=func.local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + ir_module = IRModule(funcs={func.name: func}, task=ir_module.task, global_vars=ir_module.global_vars) + return ir_module + + +def apply_prologue_epilogue_pass() -> Pass: + return ApplyPrologueEpiloguePass() diff --git a/python/hidet/transforms/base.py b/python/hidet/transforms/base.py new file mode 100644 index 0000000..bae01c1 --- /dev/null +++ b/python/hidet/transforms/base.py @@ -0,0 +1,100 @@ +from typing import List, Optional +from hidet.ir.stmt import Stmt +from hidet.ir.func import IRModule, Function + +from .instruments import PassInstrument + + +class PassContext: + stack: List['PassContext'] = [] + + def __init__(self, instruments: Optional[List[PassInstrument]] = None, verbose: bool = False): + self.instruments = instruments + self.verbose = verbose + + @classmethod + def current(cls): + return cls.stack[-1] + + def __enter__(self): + self.stack.append(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + assert len(self.stack) > 0 and self.stack[-1] is self + self.stack.pop() + + +PassContext.stack.append(PassContext()) + + +class Pass: + def __init__(self, name=None): + self.name = name if name else self.__class__.__name__ + + def __call__(self, ir_module: IRModule) -> IRModule: + ctx = PassContext.current() + for instrument in ctx.instruments: + instrument.before_pass(self.name, ir_module) + ir_module = self.process_module(ir_module) + for instrument in ctx.instruments: + instrument.after_pass(self.name, ir_module) + return ir_module + + def process_module(self, ir_module: IRModule) -> IRModule: + new_funcs = {} + for name, func in ir_module.functions.items(): + new_funcs[name] = self.process_func(func) + if all(new_funcs[name] is ir_module.functions[name] for name in new_funcs): + return ir_module + else: + return IRModule(funcs=new_funcs, task=ir_module.task, global_vars=ir_module.global_vars) + + def process_func(self, func: Function) -> Function: + return func + + +class SequencePass(Pass): + def __init__(self, passes: List[Pass], name=None): + super().__init__(name) + self.passes = passes + + def process_module(self, ir_module: IRModule) -> IRModule: + for p in self.passes: + ir_module = p(ir_module) + return ir_module + + +class FunctionPass(Pass): + def process_func(self, func: Function) -> Function: + raise NotImplementedError() + + +class FunctionBodyPass(FunctionPass): + def process_func(self, func: Function) -> Function: + body = self.process_body(func.body) + if body is func.body: + return func + else: + return Function(func.name, func.params, body, func.ret_type, kind=func.kind, local_vars=func.local_vars, + local_const_vars=func.local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + + def process_body(self, stmt: Stmt) -> Stmt: + raise NotImplementedError() + + +class RepeatFunctionPass(FunctionPass): + def __init__(self, passes: List[FunctionPass], repeat_limit=10, name=None): + super().__init__(name) + assert all(isinstance(p, FunctionPass) for p in passes) + self.passes = passes + self.repeat_limit = repeat_limit + + def process_func(self, func: Function) -> Function: + for i in range(self.repeat_limit): + orig_func = func + for p in self.passes: + func = p.process_func(func) + if orig_func is func: + return func + print(f"Exceeded: {i} {self.name} on {func.name}") + return func diff --git a/python/hidet/transforms/build_let_stmt.py b/python/hidet/transforms/build_let_stmt.py new file mode 100644 index 0000000..a17133b --- /dev/null +++ b/python/hidet/transforms/build_let_stmt.py @@ -0,0 +1,159 @@ +import contextlib +from hidet.ir.expr import * +from hidet.ir.stmt import * +from hidet.ir.functors import StmtExprRewriter, StmtRewriter, same_list, TypeInfer +from hidet.ir.builders import StmtBuilder +from hidet.transforms.base import FunctionBodyPass + + +class StmtContext: + def __init__(self, rewriter: 'BuildLetStmtRewriter'): + self.rewriter = rewriter + + def __enter__(self): + self.rewriter.exit_stack_list.append(contextlib.ExitStack()) + self.rewriter.exit_stack = self.rewriter.exit_stack_list[-1] + self.rewriter.exit_stack.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + exit_stack = self.rewriter.exit_stack_list.pop() + exit_stack.__exit__(exc_type, exc_val, exc_tb) + + +class BuildLetStmtRewriter(StmtExprRewriter): + def __init__(self): + super().__init__(use_memo=False) + self.exit_stack_list = [] + self.exit_stack: Optional[contextlib.ExitStack] = None + self.sb: Optional[StmtBuilder] = None + self.type_infer = TypeInfer() + + def build(self, stmt): + self.sb = StmtBuilder() + with StmtContext(self): + self(stmt) + return self.sb.finish() + + def visit_Binary(self, e: BinaryOp): + etype = self.type_infer(e) + if isinstance(e, (Add, Sub, Multiply, Div, FloorDiv, Mod)) and (isinstance(etype, ScalarType) and etype.name == 'int32'): + return self.exit_stack.enter_context(self.sb.let('v', StmtExprRewriter.visit_Binary(self, e))) + else: + return StmtExprRewriter.visit_Binary(self, e) + + def visit_Let(self, e: Let): + self.exit_stack.enter_context(self.sb.let(e.var, self(e.value))) + return self(e.body) + + def visit_Var(self, e: Var): + return e + + def visit_Constant(self, e: Constant): + return e + + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + with StmtContext(self): + self.sb += StmtExprRewriter.visit_EvaluateStmt(self, stmt) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + with StmtContext(self): + self.sb += StmtExprRewriter.visit_BufferStoreStmt(self, stmt) + + def visit_AssignStmt(self, stmt: AssignStmt): + with StmtContext(self): + self.sb += StmtExprRewriter.visit_AssignStmt(self, stmt) + + def visit_LetStmt(self, stmt: LetStmt): + with StmtContext(self): + bind_vars = stmt.bind_vars + bind_values = [self(value) for value in stmt.bind_values] + with self.sb.lets(bind_vars=bind_vars, values=bind_values): + self(stmt.body) + + def visit_ForStmt(self, stmt: ForStmt): + with StmtContext(self): + loop_var = self.visit_expr(stmt.loop_var) + extent = self.visit_expr(stmt.extent) + with self.sb.for_loop(loop_var, extent, unroll=stmt.unroll): + self.visit(stmt.body) + + def visit_IfStmt(self, stmt: IfStmt): + with StmtContext(self): + cond = self.visit_expr(stmt.cond) + with self.sb.if_then(cond): + self.visit(stmt.then_body) + if stmt.else_body: + with self.sb.otherwise(): + self.visit(stmt.else_body) + + def visit_AssertStmt(self, stmt: AssertStmt): + with StmtContext(self): + self.sb += StmtExprRewriter.visit_AssertStmt(self, stmt) + + def visit_ReturnStmt(self, stmt: ReturnStmt): + with StmtContext(self): + self.sb += StmtExprRewriter.visit_ReturnStmt(self, stmt) + + def visit_AsmStmt(self, stmt: AsmStmt): + with self.exit_stack: + input_exprs = [self.visit_expr(e) for e in stmt.input_exprs] + output_exprs = [self.visit_expr(e) for e in stmt.output_exprs] + self.sb += AsmStmt(stmt.template_string, list(zip(stmt.output_labels, output_exprs)), + list(zip(stmt.input_labels, input_exprs)), stmt.is_volatile) + + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + with StmtContext(self): + exprs = [self.visit_expr(e) for e in stmt.exprs] + self.sb += BlackBoxStmt(stmt.template_string, *exprs) + + def visit_SeqStmt(self, stmt: SeqStmt): + for s in stmt.seq: + self.visit(s) + + +class SqueezeLetStmtRewriter(StmtRewriter): + def visit_LetStmt(self, stmt: LetStmt): + cur = StmtRewriter.visit_LetStmt(self, stmt) + + bind_vars = [] + bind_values = [] + while isinstance(cur, LetStmt): + bind_vars.extend(cur.bind_vars) + bind_values.extend(cur.bind_values) + cur = cur.body + if same_list(bind_vars, stmt.bind_vars) and same_list(bind_values, stmt.bind_values) and cur is stmt.body: + return stmt + else: + return LetStmt(bind_vars, bind_values, cur) + + def visit_SeqStmt(self, stmt: SeqStmt): + seq = [self(s) for s in stmt.seq] + if len(seq) == 0: + return stmt + body = seq[-1] + for s in reversed(seq[:-1]): + body = join_stmt(s, body) + if isinstance(body, SeqStmt) and same_list(body.seq, stmt.seq): + return stmt + else: + return body + + +def join_stmt(lhs: Stmt, rhs: Stmt): + if isinstance(lhs, LetStmt): + return LetStmt(lhs.bind_vars, lhs.bind_values, join_stmt(lhs.body, rhs)) + else: + lhs_seq = lhs.seq if isinstance(lhs, SeqStmt) else [lhs] + rhs_seq = rhs.seq if isinstance(rhs, SeqStmt) else [rhs] + return SeqStmt(list(lhs_seq) + list(rhs_seq)) + + +class BuildLetStmtPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + stmt_builder = BuildLetStmtRewriter() + squeezer = SqueezeLetStmtRewriter() + return squeezer(stmt_builder.build(stmt)) + + +def build_let_stmt_pass(): + return BuildLetStmtPass() diff --git a/python/hidet/transforms/common/__init__.py b/python/hidet/transforms/common/__init__.py new file mode 100644 index 0000000..8718205 --- /dev/null +++ b/python/hidet/transforms/common/__init__.py @@ -0,0 +1,3 @@ +from . import scope + +from .scope import Scope, ScopeStack, FuncStmtExprRewriterWithScope diff --git a/python/hidet/transforms/common/scope.py b/python/hidet/transforms/common/scope.py new file mode 100644 index 0000000..e171328 --- /dev/null +++ b/python/hidet/transforms/common/scope.py @@ -0,0 +1,134 @@ +from typing import List, Dict, Optional, ContextManager + +from hidet.ir.type import ScalarType, FuncType +from hidet.ir.expr import Expr, Var, BitwiseAnd, LeftShift, BitwiseOr +from hidet.ir.functors import collect +from hidet.ir.stmt import LetStmt, ForStmt +from hidet.ir.func import Function +from hidet.ir.functors import FuncStmtExprRewriter + + +class Scope: + """ + Every variable (i.e., parameter variable, local variable, loop variable, let variable) much be declared or defined + in a scope. Parameter, local and loop variable should be declared, because we should not move it place. Every + let variable should be defined (with their value). + """ + + def __init__(self, stack, scope_stmt): + self.stack: 'ScopeStack' = stack + self.scope_stmt = scope_stmt + self.level = None + self.parent: Optional['Scope'] = None + self.declare_vars: List[Var] = [] + self.defined_vars: List[Var] = [] + self.var2value: Dict[Var, Optional[Expr]] = {} + self.defined_predicates: List[List[Expr]] = [] + self.predicate_vars: List[Var] = [] + + def __enter__(self): + scopes = self.stack.scopes + self.parent = scopes[0] if len(scopes) > 0 else None + self.level = len(scopes) + scopes.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + scope = self.stack.scopes.pop() + assert scope is self + + def declare(self, var: Var): + # declare a variable at current scope + self.declare_vars.append(var) + self.var2value[var] = None + assert var not in self.stack.var2scope + self.stack.var2scope[var] = self + + def define(self, var: Var, value: Expr): + self.defined_vars.append(var) + self.var2value[var] = value + assert var not in self.stack.var2scope + self.stack.var2scope[var] = self + + def define_predicate(self, predicate: Expr) -> Expr: + if len(self.defined_predicates) == 0 or len(self.defined_predicates[-1]) == 32: + var = Var('p', type=ScalarType('uint32')) + self.defined_predicates.append([]) + self.predicate_vars.append(var) + self.stack.var2scope[var] = self + self.defined_predicates[-1].append(predicate) + mask = 1 << (len(self.defined_predicates[-1]) - 1) + return BitwiseAnd(self.predicate_vars[-1], mask) + + def wrap(self, body): + # wrap the body with defined variables at current scope + bind_vars = self.defined_vars + bind_values = [self.var2value[var] for var in bind_vars] + for p_var, p_exprs in zip(self.predicate_vars, self.defined_predicates): + bind_vars.append(p_var) + bind_values.append(BitwiseOr.join_list([LeftShift(p, idx) for idx, p in enumerate(p_exprs)])) + if len(bind_vars) > 0: + ret = LetStmt(bind_vars, bind_values, body) + else: + ret = body + for var in self.defined_vars + self.declare_vars: + del self.stack.var2scope[var] + return ret + + +class ScopeStack: + def __init__(self): + self.scopes = [] + self.var2scope: Dict[Var, Scope] = {} + + def find_scope_for_expr(self, expr) -> 'Scope': + used_vars = collect(expr, Var) + levels = [self.var2scope[used_var].level for used_var in used_vars if not isinstance(used_var.type, FuncType)] + max_level = max(levels) + return self.scopes[max_level] + + def new_scope(self, scope_stmt=None): + return Scope(self, scope_stmt) + + def current(self) -> Scope: + assert len(self.scopes) > 0 + return self.scopes[-1] + + +class FuncStmtExprRewriterWithScope(FuncStmtExprRewriter): + def __init__(self, use_memo=False): + super().__init__(use_memo=use_memo) + self.scope_stack = ScopeStack() + + def new_scope(self, stmt=None) -> ContextManager[Scope]: + return self.scope_stack.new_scope(stmt) + + def scope_to_define(self, expr: Expr) -> Scope: + return self.scope_stack.find_scope_for_expr(expr) + + def visit_Function(self, func: Function): + with self.new_scope(None) as scope: + for extern_var in func.extern_vars: + scope.declare(extern_var) + for param in func.params: + scope.declare(param) + for local_var in func.local_vars: + scope.declare(local_var) + for local_const_var, _ in func.local_const_vars: + scope.declare(local_const_var) + body = scope.wrap(self.visit(func.body)) + return Function(func.name, func.params, body, func.ret_type, kind=func.kind, local_vars=func.local_vars, + local_const_vars=func.local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + + def visit_ForStmt(self, stmt: ForStmt): + with self.new_scope(stmt) as scope: + self.visit(stmt.extent) + scope.declare(stmt.loop_var) + body = scope.wrap(self.visit(stmt.body)) + return ForStmt(stmt.loop_var, stmt.extent, stmt.unroll, body) + + def visit_LetStmt(self, stmt: LetStmt): + with self.new_scope(stmt) as scope: + for var, value in zip(stmt.bind_vars, stmt.bind_values): + scope.define(var, self.visit(value)) + return scope.wrap(self.visit(stmt.body)) diff --git a/python/hidet/transforms/common_subexpression_elimination.py b/python/hidet/transforms/common_subexpression_elimination.py new file mode 100644 index 0000000..d1157c6 --- /dev/null +++ b/python/hidet/transforms/common_subexpression_elimination.py @@ -0,0 +1,111 @@ +from typing import ContextManager +from contextlib import ExitStack +from hidet.transforms.base import FunctionBodyPass, SequencePass, RepeatFunctionPass +from hidet.ir.functors import StmtRewriter, StmtExprRewriter, same_list +from hidet.ir.expr import Expr, Var, Constant, convert, Call +from hidet.ir.stmt import Stmt, SeqStmt, EvaluateStmt, IfStmt, LetStmt +from hidet.ir.func import IRModule +from hidet.ir.functors import ExprHash, rewrite +from hidet.ir.builders import FunctionBuilder, StmtBuilder + + +def join_stmt(lhs: Stmt, rhs: Stmt): + if isinstance(lhs, LetStmt): + return LetStmt(lhs.bind_vars, lhs.bind_values, join_stmt(lhs.body, rhs)) + else: + lhs_seq = lhs.seq if isinstance(lhs, SeqStmt) else [lhs] + rhs_seq = rhs.seq if isinstance(rhs, SeqStmt) else [rhs] + return SeqStmt(list(lhs_seq) + list(rhs_seq)) + + +class ChainSeqStmtUsingLetStmtRewriter(StmtRewriter): + def visit_SeqStmt(self, stmt: SeqStmt): + seq = [self(s) for s in stmt.seq] + if len(seq) == 0: + return stmt + body = seq[-1] + for s in reversed(seq[:-1]): + body = join_stmt(s, body) + if isinstance(body, SeqStmt) and same_list(body.seq, stmt.seq): + return stmt + else: + return body + + +class ChainSeqStmtUsingLetStmtPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + ret = ChainSeqStmtUsingLetStmtRewriter()(stmt) + return ret + + +class Value2VarContext(ContextManager): + def __init__(self, rewriter, value_hash, var): + self.rewriter = rewriter + self.value_hash = value_hash + self.var = var + + def __enter__(self): + assert self.value_hash not in self.rewriter.value2var + self.rewriter.value2var[self.value_hash] = self.var + + def __exit__(self, exc_type, exc_val, exc_tb): + self.rewriter.value2var.pop(self.value_hash) + + +class CommonSubexpressionEliminationRewriter(StmtExprRewriter): + def __init__(self): + super().__init__(use_memo=False) + # value hash -> let var + self.expr_hash = ExprHash() + self.value2var = {} + self.replace_var = {} + + def visit(self, obj): + if isinstance(obj, Expr): + hash_value = self.expr_hash(obj) + if hash_value in self.value2var: + # TODO: add a structural equivalence check, now we assume (hash(A) == hash(B) => A == B) + return self.value2var[hash_value] + return StmtExprRewriter.visit(self, obj) + + def visit_LetStmt(self, stmt: LetStmt): + with ExitStack() as stack: + bind_vars = [] + bind_values = [] + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + updated_value = self(bind_value) + self.expr_hash.memo[bind_var] = self.expr_hash(updated_value) + if isinstance(updated_value, (Var, Constant)): + self.replace_var[bind_var] = updated_value + else: + value_hash = self.expr_hash(updated_value) + stack.enter_context(Value2VarContext(self, value_hash, bind_var)) + bind_vars.append(bind_var) + bind_values.append(updated_value) + body = self(stmt.body) + if same_list(bind_vars, stmt.bind_vars) and same_list(bind_values, stmt.bind_values) and body is stmt.body: + return stmt + else: + if len(bind_vars) == 0: + return body + else: + return LetStmt(bind_vars, bind_values, body) + + def visit_Var(self, e: Var): + if e in self.replace_var: + return self.replace_var[e] + else: + return e + + +class CommonSubexpressionEliminationPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + return CommonSubexpressionEliminationRewriter()(stmt) + + +def chain_seq_stmt_using_let_stmt_pass(): + return ChainSeqStmtUsingLetStmtPass() + + +def common_subexpression_elimination_pass(): + return CommonSubexpressionEliminationPass() diff --git a/python/hidet/transforms/expand_let_expr.py b/python/hidet/transforms/expand_let_expr.py new file mode 100644 index 0000000..8d0d34f --- /dev/null +++ b/python/hidet/transforms/expand_let_expr.py @@ -0,0 +1,87 @@ +from hidet.ir.expr import Let, var +from hidet.ir.stmt import * +from hidet.ir.functors import StmtExprRewriter +from hidet.transforms import Pass, FunctionBodyPass + + +def wrapper(stmt_visitor): + def wrapped_visitor(self, stmt): + self.stmt_stack.append([]) + self.memo.clear() # do not cache exprs between different statements, so the let expr will always generate let stmt. + updated_stmt = stmt_visitor(self, stmt) + let_stmts = self.stmt_stack.pop() + if len(let_stmts) == 0: + return updated_stmt + else: + bind_vars, bind_values = [], [] + for let in let_stmts: + bind_vars.extend(let.bind_vars) + bind_values.extend(let.bind_values) + return LetStmt(bind_vars, bind_values, updated_stmt) + + return wrapped_visitor + + +class LetExprExpander(StmtExprRewriter): + + def __init__(self): + super().__init__() + self.stmt_stack = [] + + def expand(self, stmt): + assert isinstance(stmt, Stmt) + return self.visit(stmt) + + def visit_Let(self, e: Let): + var = self(e.var) + value = self(e.value) + self.stmt_stack[-1].append(LetStmt(var, value)) + return self(e.body) + + @wrapper + def visit_EvaluateStmt(self, stmt: EvaluateStmt): + return StmtExprRewriter.visit_EvaluateStmt(self, stmt) + + @wrapper + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + return StmtExprRewriter.visit_BufferStoreStmt(self, stmt) + + @wrapper + def visit_AssignStmt(self, stmt: AssignStmt): + return StmtExprRewriter.visit_AssignStmt(self, stmt) + + @wrapper + def visit_LetStmt(self, stmt: LetStmt): + return StmtExprRewriter.visit_LetStmt(self, stmt) + + @wrapper + def visit_ForStmt(self, stmt: ForStmt): + return StmtExprRewriter.visit_ForStmt(self, stmt) + + @wrapper + def visit_IfStmt(self, stmt: IfStmt): + return StmtExprRewriter.visit_IfStmt(self, stmt) + + @wrapper + def visit_AssertStmt(self, stmt: AssertStmt): + return StmtExprRewriter.visit_AssertStmt(self, stmt) + + @wrapper + def visit_AsmStmt(self, stmt: AsmStmt): + return StmtExprRewriter.visit_AsmStmt(self, stmt) + + @wrapper + def visit_BlackBoxStmt(self, stmt: BlackBoxStmt): + return StmtExprRewriter.visit_BlackBoxStmt(self, stmt) + + +class ExpandLetExprPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + expander = LetExprExpander() + stmt = expander.expand(stmt) + return stmt + + +def expand_let_expr_pass() -> Pass: + return ExpandLetExprPass() + diff --git a/python/hidet/transforms/explicit_unroll_for_stmt.py b/python/hidet/transforms/explicit_unroll_for_stmt.py new file mode 100644 index 0000000..9ec4186 --- /dev/null +++ b/python/hidet/transforms/explicit_unroll_for_stmt.py @@ -0,0 +1,39 @@ +from hidet.ir.expr import Constant, convert +from hidet.ir.stmt import Stmt, ForStmt, SeqStmt +from hidet.ir.functors import StmtRewriter, rewrite, clone +from hidet.transforms.base import FunctionBodyPass +from hidet.transforms.rule_based_simplifier import ConstExprSimplifier + + +class ExplicitUnrollForStmtRewriter(StmtRewriter): + _unroll_threshold = 16 + + def __init__(self): + super().__init__() + self.const_expr_simplifier = ConstExprSimplifier() + + def visit_ForStmt(self, stmt: ForStmt): + extent = self.const_expr_simplifier(self.visit_expr(stmt.extent)) + body = self(stmt.body) + if isinstance(extent, Constant) and isinstance(extent.value, int) and extent.value <= self._unroll_threshold: + unrolled_body = [] + for i in range(extent.value): + unrolled_body.append(clone(rewrite(body, {stmt.loop_var: convert(i)}))) + return SeqStmt(seq=unrolled_body) + else: + if extent is stmt.extent and body is stmt.body: + return stmt + else: + return ForStmt(stmt.loop_var, extent, stmt.unroll, body) + + +class ExplicitUnrollForStmtPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + rewriter = ExplicitUnrollForStmtRewriter() + ret = rewriter(stmt) + # print(ret) + return ret + + +def explicit_unroll_for_stmt_pass(): + return ExplicitUnrollForStmtPass() diff --git a/python/hidet/transforms/flatten_tensor_index.py b/python/hidet/transforms/flatten_tensor_index.py new file mode 100644 index 0000000..c4f3b02 --- /dev/null +++ b/python/hidet/transforms/flatten_tensor_index.py @@ -0,0 +1,71 @@ +from typing import List, Union, Callable, Any +from hidet.ir.type import TensorType, tensor_type +from hidet.ir.expr import Var, TensorElement, TensorSlice, Constant +from hidet.ir.stmt import BufferStoreStmt +from hidet.ir.func import Function +from hidet.ir.functors import simplify_to_int, FuncStmtExprRewriter +from hidet.ir.dialects.lowlevel import PointerType, TensorPointerType +from hidet.transforms import Pass +from hidet.ir.layout import StridesLayout, DataLayout + + +class FlattenTensorAccessRewriter(FuncStmtExprRewriter): + # flatten all high-dimension tensor access + # A = int[3, 4] + # TensorElement: A[2, 1] ==> A[2 * 4 + 1] + # BufferStoreStmt: A[2, 1] = 3 ==> A[2 * 4 + 1] = 3 + def visit_Function(self, func: Function): + const_local_vars = [v for v, _ in func.local_const_vars] + for var in func.params + func.local_vars + const_local_vars: + if isinstance(var.type, TensorType): + size = simplify_to_int(var.type.layout.size) + self.memo[var] = Var(var.hint, tensor_type(var.type.scope, var.type.scalar_type, [size], DataLayout.row_major([size]))) + elif isinstance(var.type, TensorPointerType): + self.memo[var] = var + body = self(func.body) + params = [self(p) for p in func.params] + local_vars = [self(v) for v in func.local_vars] + local_const_vars = [(self(v), value) for v, value in func.local_const_vars] + return Function(func.name, params, body, func.ret_type, kind=func.kind, local_vars=local_vars, + local_const_vars=local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + + @staticmethod + def get_layout(e) -> Callable[..., Any]: + if isinstance(e, Var): + if isinstance(e.type, TensorType): + return e.type.layout + elif isinstance(e.type, TensorPointerType): + return e.type.tensor_type.layout + elif isinstance(e.type, PointerType): + return StridesLayout(shape=[0], strides=[1]) + elif isinstance(e, Constant) and isinstance(e.data_type, TensorType): + return e.data_type.layout + raise ValueError("Can not infer layout from '{}'".format(type(e))) + + def visit_TensorElement(self, e: TensorElement): + var = self(e.base) + indices = [self(i) for i in e.indices] + layout = self.get_layout(e.base) + global_index = layout(*indices) + return TensorElement(var, [global_index]) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + var = self(stmt.buf) + indices = [self(i) for i in stmt.indices] + value = self(stmt.value) + layout = self.get_layout(stmt.buf) + global_index = layout(indices) + return BufferStoreStmt(var, [global_index], value) + + def visit_TensorSlice(self, e: TensorSlice): + raise ValueError('there should not be any tensor slice after flattening tensor slice. got\n{}'.format(e)) + + +class FlattenTensorIndexPass(Pass): + def process_func(self, func: Function) -> Function: + flatten_index = FlattenTensorAccessRewriter() + return flatten_index(func) + + +def flatten_tensor_index_pass(): + return FlattenTensorIndexPass() diff --git a/python/hidet/transforms/flatten_tensor_slice.py b/python/hidet/transforms/flatten_tensor_slice.py new file mode 100644 index 0000000..53d9d55 --- /dev/null +++ b/python/hidet/transforms/flatten_tensor_slice.py @@ -0,0 +1,86 @@ +from hidet.ir.expr import TensorElement, TensorSlice +from hidet.ir.expr import TensorElement, TensorSlice +from hidet.ir.func import Function +from hidet.ir.functors import FuncStmtExprRewriter +from hidet.ir.stmt import BufferStoreStmt +from hidet.transforms import Pass + + +def concat_slices(lhs_indices, lhs_starts, lhs_ends, rhs_indices, rhs_starts=None, rhs_ends=None): + if rhs_starts is None: + rhs_starts = [None] * len(rhs_indices) + if rhs_ends is None: + rhs_ends = [None] * len(rhs_indices) + assert len(lhs_indices) == len(lhs_starts) == len(lhs_ends) + assert len(rhs_indices) == len(rhs_starts) == len(rhs_ends) + indices = [] + starts = [] + ends = [] + i = 0 + for index, start, end in zip(lhs_indices, lhs_starts, lhs_ends): + if index is not None: + indices.append(index) + starts.append(None) + ends.append(None) + else: + assert i < len(rhs_indices) + if rhs_indices[i] is not None: + indices.append(start + rhs_indices[i] if start else rhs_indices[i]) + starts.append(None) + elif rhs_starts[i] is not None: + indices.append(None) + starts.append(start + rhs_starts[i] if start else rhs_starts[i]) + else: + indices.append(None) + starts.append(None) + # we ignore the end because we do not allow tensor-wise op. + # end is only used for bound-checking, which is left in future. + ends.append(None) + i += 1 + assert i == len(rhs_indices) + return indices, starts, ends + + +class FlattenTensorSliceRewriter(FuncStmtExprRewriter): + # eliminate all TensorSlice + # (A[:, 3])[2] will be converted to A[2, 3] and the slice op A[:, 3] will be eliminated. + def visit_TensorSlice(self, e: TensorSlice): + base = self.visit(e.base) + if isinstance(base, TensorSlice): + e_indices = [self.visit(i) if i else None for i in e.indices] + e_starts = [self.visit(s) if s else None for s in e.starts] + e_ends = [self.visit(e) if e else None for e in e.ends] + indices, starts, ends = concat_slices(base.indices, base.starts, base.ends, e_indices, e_starts, e_ends) + return TensorSlice(base.base, indices, starts, ends) + else: + return FuncStmtExprRewriter.visit_TensorSlice(self, e) + + def visit_TensorElement(self, e: TensorElement): + base = self.visit(e.base) + if isinstance(base, TensorSlice): + e_indices = [self.visit(idx) for idx in e.indices] + indices, starts, ends = concat_slices(base.indices, base.starts, base.ends, e_indices) + assert not any(idx is None for idx in indices) + return TensorElement(base.base, indices) + else: + return FuncStmtExprRewriter.visit_TensorElement(self, e) + + def visit_BufferStoreStmt(self, stmt: BufferStoreStmt): + base = self.visit(stmt.buf) + stmt_indices = [self.visit(idx) for idx in stmt.indices] + if isinstance(base, TensorSlice): + indices, starts, ends = concat_slices(base.indices, base.starts, base.ends, stmt_indices) + assert not any(idx is None for idx in indices) + return BufferStoreStmt(base.base, indices, self.visit(stmt.value)) + else: + return FuncStmtExprRewriter.visit_BufferStoreStmt(self, stmt) + + +class FlattenTensorSlicePass(Pass): + def process_func(self, func: Function) -> Function: + flatten_slice = FlattenTensorSliceRewriter() + return flatten_slice(func) + + +def flatten_tensor_slice_pass() -> Pass: + return FlattenTensorSlicePass() diff --git a/python/hidet/transforms/generate_packed_func.py b/python/hidet/transforms/generate_packed_func.py new file mode 100644 index 0000000..b36e4ab --- /dev/null +++ b/python/hidet/transforms/generate_packed_func.py @@ -0,0 +1,79 @@ +from typing import Optional +from hidet.ffi import ArgType +from hidet.ir.type import ScalarType, TensorType +from hidet.ir.expr import Var, Call, Equal, Cast +from hidet.ir.stmt import AssertStmt, SeqStmt, EvaluateStmt +from hidet.ir.func import IRModule, Function +from hidet.ir.functors import astext, simplify_to_int +from hidet.ir.dialects.lowlevel import VoidType, PointerType, Dereference, TensorPointerType +# from hidet.ir.task import Grid, Host +from hidet.ir.builders import FunctionBuilder, StmtBuilder +from hidet.transforms import Pass +from hidet.ir.primitives import set_kernel_max_dynamic_smem_bytes + + +class GeneratePackedFuncPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + new_ir_module = IRModule(task=ir_module.task) + for func in ir_module.functions.values(): + new_ir_module.add(func.name, func) + if func.kind not in ['cuda_kernel', 'host_kernel']: + # only generate packed func for entry function + continue + if func.get_attr('packed_func', None) is not None: + # this function itself is a packed function + continue + if any(f.get_attr('packed_func', None) is func for f in ir_module.functions.values()): + # the packed function for current function has existed, skip + continue + packed_func = self.generate_packed_func(func, ir_module.lookup_var(func.name)) + new_ir_module.add(packed_func.name, packed_func) + return new_ir_module + + def generate_packed_func(self, func: Function, func_global_var: Var) -> Function: + assert isinstance(func.ret_type, VoidType) + assert isinstance(func.name, str) and (func.name.endswith('_grid') or func.name.endswith('_host')) + packed_name = func.name[:-5] + with FunctionBuilder(name=packed_name, kind='packed_func', attrs={'packed_func': func_global_var}) as fb: + # params + p_num_args = Var('num_args', ScalarType('int32')) + p_arg_types = Var('arg_types', PointerType(ScalarType('int32'))) + p_args = Var('args', PointerType(PointerType(VoidType()))) + fb.extend_params([p_num_args, p_arg_types, p_args]) + + # body + sb = StmtBuilder() + sb += AssertStmt(Equal(p_num_args, len(func.params)), "expect {} args".format(len(func.params))) + func_args = [] + for idx, param in enumerate(func.params): + assert isinstance(param, Var) + if isinstance(param.type, ScalarType): + if param.type.name == 'int32': + code = ArgType.INT32 + elif param.type.name == 'float32': + code = ArgType.FLOAT32 + else: + raise NotImplementedError() + func_args.append(Dereference(Cast(p_args[idx], PointerType(param.type)))) + elif isinstance(param.type, (TensorPointerType, TensorType)): + code = ArgType.POINTER + if isinstance(param.type, TensorType): + dtype = param.type.scalar_type + else: + dtype = param.type.tensor_type.scalar_type + func_args.append(Cast(p_args[idx], PointerType(dtype))) + elif isinstance(param.type, PointerType): + code = ArgType.POINTER + func_args.append(Cast(p_args[idx], param.type)) + else: + raise NotImplementedError() + sb += AssertStmt(Equal(p_arg_types[idx], code), "The {} th arg should be {}".format(idx, astext(param.type))) + if func.kind == 'cuda_kernel' and func.get_attr('cuda_dynamic_smem_bytes', 0) > 48 * 1024: + sb += set_kernel_max_dynamic_smem_bytes(func_global_var, func.attrs['cuda_dynamic_smem_bytes']) + sb += Call(func_global_var, func_args) + fb.set_body(sb.finish()) + return fb.get() + + +def generate_packed_func_pass(): + return GeneratePackedFuncPass() diff --git a/python/hidet/transforms/import_primitive_functions.py b/python/hidet/transforms/import_primitive_functions.py new file mode 100644 index 0000000..99f3008 --- /dev/null +++ b/python/hidet/transforms/import_primitive_functions.py @@ -0,0 +1,38 @@ +from typing import List +from hidet.ir.expr import Call +from hidet.ir.func import IRModule, Function +from hidet.ir.functors import collect +from hidet.ir.primitives import is_primitive_function, lookup_primitive_function +from hidet.transforms import Pass + + +class ImportPrimitiveFunctionPass(Pass): + def process_module(self, ir_module: IRModule) -> IRModule: + used_primitive_funcs = set() + for func in ir_module.functions.values(): + calls: List[Call] = collect(func.body, Call) + for call in calls: + callee_name: str = call.func_var.hint + if is_primitive_function(callee_name): + used_primitive_funcs.add(callee_name) + + primitive_funcs: List[Function] = [] + for func_name in used_primitive_funcs: + entry = lookup_primitive_function(func_name) + if entry.function is not None: + primitive_funcs.append(entry.function) + + if len(primitive_funcs) == 0: + return ir_module + else: + new_ir_module = IRModule(task=ir_module.task) + for func_name, func in ir_module.functions.items(): + new_ir_module.add(func_name, func) + for func in primitive_funcs: + if func.name not in new_ir_module.functions: + new_ir_module.add(func.name, func) + return new_ir_module + + +def import_primitive_functions_pass() -> Pass: + return ImportPrimitiveFunctionPass() diff --git a/python/hidet/transforms/inline_let_stmt.py b/python/hidet/transforms/inline_let_stmt.py new file mode 100644 index 0000000..a8d885b --- /dev/null +++ b/python/hidet/transforms/inline_let_stmt.py @@ -0,0 +1,103 @@ +from typing import Mapping +from collections import defaultdict + +from hidet.ir.expr import Var, Expr, Constant, Mod, Add, Sub +from hidet.ir.functors import StmtExprRewriter, StmtExprVisitor, rewrite, same_list +from hidet.ir.stmt import Stmt, LetStmt +from hidet.transforms import Pass, FunctionBodyPass, RepeatFunctionPass + + +class LetVarRefAnalyzer(StmtExprVisitor): + def __init__(self): + super().__init__(use_memo=False) + self.usage_count = None + self.var2value = None + + def analyze(self, expr): + self.usage_count = defaultdict(int) + self.var2value = {} + self.visit(expr) + + def visit(self, obj): + if isinstance(obj, Var): + self.usage_count[obj] += 1 + return StmtExprVisitor.visit(self, obj) + + def visit_LetStmt(self, stmt: LetStmt): + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + self.var2value[bind_var] = bind_value + self.visit(bind_value) + self.visit(stmt.body) + + +class NaiveLetStmtInlineRewriter(StmtExprRewriter): + def __init__(self, inline_factor=1, inline_all=False): + super().__init__() + self.inline_factor = inline_factor + self.inline_all = inline_all + self.usage_count = None + self.var2value = None + + def eliminate(self, stmt): + self.memo.clear() + # count the usage number and let var to its value + analyzer = LetVarRefAnalyzer() + analyzer.analyze(stmt) + self.usage_count, self.var2value = analyzer.usage_count, analyzer.var2value + # inline + return self.visit(stmt) + + def should_inline(self, var, expr) -> bool: + if isinstance(expr, (Var, Constant)): + # let v1 = v2 + # let v1 = constant + return True + elif self.usage_count[var] <= self.inline_factor or self.inline_all: + # let v1 = expr and v1 is only used with in self.inline_factor times + return True + elif isinstance(expr, (Add, Sub)) and (isinstance(expr.a, Constant) or isinstance(expr.b, Constant)): + # let v1 = expr + constant + return True + return False + + def visit_LetStmt(self, stmt: LetStmt): + bind_vars = [] + bind_values = [] + for bind_var, bind_value in zip(stmt.bind_vars, stmt.bind_values): + updated_value = self(bind_value) + if self.should_inline(bind_var, updated_value): + self.memo[bind_var] = updated_value + else: + bind_vars.append(bind_var) + bind_values.append(updated_value) + body = self(stmt.body) + if same_list(bind_vars, stmt.bind_vars) and same_list(bind_values, stmt.bind_values) and body is stmt.body: + return stmt + else: + if len(bind_vars) > 0: + return LetStmt(bind_vars, bind_values, body) + else: + return body + + +class InlineNaiveLetStmtPass(FunctionBodyPass): + def __init__(self, inline_factor=1, inline_all=False): + super().__init__() + self.inline_factor = inline_factor + self.inline_all = inline_all + + def process_body(self, stmt: Stmt) -> Stmt: + eliminator = NaiveLetStmtInlineRewriter(self.inline_factor, self.inline_all) + return eliminator.eliminate(stmt) + + +def inline_let_stmt_pass(inline_factor=1, inline_all=False) -> Pass: + if inline_all: + return InlineNaiveLetStmtPass(inline_factor, inline_all) + else: + return RepeatFunctionPass( + name='InlineLetStmtPass', + passes=[ + InlineNaiveLetStmtPass(inline_factor, inline_all) + ], + repeat_limit=10) diff --git a/python/hidet/transforms/instruments/__init__.py b/python/hidet/transforms/instruments/__init__.py new file mode 100644 index 0000000..bfc8f1c --- /dev/null +++ b/python/hidet/transforms/instruments/__init__.py @@ -0,0 +1,3 @@ +from .base import PassInstrument +from .profile_instrument import ProfileInstrument +from .save_ir_instrument import SaveIRInstrument diff --git a/python/hidet/transforms/instruments/base.py b/python/hidet/transforms/instruments/base.py new file mode 100644 index 0000000..0473940 --- /dev/null +++ b/python/hidet/transforms/instruments/base.py @@ -0,0 +1,17 @@ +from hidet.ir.func import IRModule + + +class PassInstrument: + def before_all_passes(self, ir_module: IRModule): + pass + + def before_pass(self, pass_name: str, ir_module: IRModule): + pass + + def after_pass(self, pass_name: str, ir_module: IRModule): + pass + + def after_all_passes(self, ir_module: IRModule): + pass + + diff --git a/python/hidet/transforms/instruments/profile_instrument.py b/python/hidet/transforms/instruments/profile_instrument.py new file mode 100644 index 0000000..398934e --- /dev/null +++ b/python/hidet/transforms/instruments/profile_instrument.py @@ -0,0 +1,37 @@ +import os +import time +from typing import Optional, Dict + +from hidet import utils +from hidet.ir.func import IRModule + +from .base import PassInstrument + + +class ProfileInstrument(PassInstrument): + def __init__(self, log_file: Optional[str] = None, print_stdout: bool = False): + if log_file: + dirname = os.path.dirname(log_file) + os.makedirs(dirname, exist_ok=True) + self.log_file = log_file + self.print_stdout = print_stdout + self.start_time: Dict[str, float] = {} + + def before_all_passes(self, ir_module: IRModule): + if self.log_file: + # clear file contents + with open(self.log_file, 'w'): + pass + + def before_pass(self, pass_name: str, ir_module: IRModule): + self.start_time[pass_name] = time.time() + if self.print_stdout: + print('{:>50} started...'.format(pass_name)) + + def after_pass(self, pass_name: str, ir_module: IRModule): + elapsed_time = time.time() - self.start_time[pass_name] + if self.log_file: + with open(self.log_file, 'a') as f: + f.write('{:>50} {:.3f} seconds\n'.format(pass_name, elapsed_time)) + if self.print_stdout: + print('{:>50} {} seconds'.format(pass_name, utils.py.green(elapsed_time, '{:.3f}'))) diff --git a/python/hidet/transforms/instruments/save_ir_instrument.py b/python/hidet/transforms/instruments/save_ir_instrument.py new file mode 100644 index 0000000..1f039a5 --- /dev/null +++ b/python/hidet/transforms/instruments/save_ir_instrument.py @@ -0,0 +1,27 @@ +import os + +from hidet.ir.func import IRModule +from .base import PassInstrument + + +class SaveIRInstrument(PassInstrument): + def __init__(self, out_dir: str): + self.out_dir = out_dir + self.index = 0 + os.makedirs(out_dir, exist_ok=True) + + def before_all_passes(self, ir_module: IRModule): + # first clean all json starting with indices + for fname in os.listdir(self.out_dir): + fpath = os.path.join(self.out_dir, fname) + parts = fname.split('_') + if os.path.isfile(fpath) and len(parts) > 1 and parts[0].isdigit() and fname.endswith('.txt'): + os.remove(fpath) + with open(os.path.join(self.out_dir, '0_Origin.txt'), 'w') as f: + f.write(str(ir_module)) + self.index += 1 + + def after_pass(self, pass_name: str, ir_module: IRModule): + with open(os.path.join(self.out_dir, '{}_{}.txt'.format(self.index, pass_name)), 'w') as f: + f.write(str(ir_module)) + self.index += 1 diff --git a/python/hidet/transforms/normalize_const_tensor.py b/python/hidet/transforms/normalize_const_tensor.py new file mode 100644 index 0000000..5eba357 --- /dev/null +++ b/python/hidet/transforms/normalize_const_tensor.py @@ -0,0 +1,25 @@ +from typing import List +from hidet.ir import Function, Constant, TensorType, Var, var, tensor_var +from hidet.transforms.base import FunctionPass, Pass +from hidet.ir.functors import collect, rewrite + + +class NormalizeConstTensorPass(FunctionPass): + def process_func(self, func: Function) -> Function: + consts: List[Constant] = collect(func.body, Constant) + tensor_consts = [const for const in consts if isinstance(const.data_type, TensorType)] + body = func.body + local_const_vars = func.local_const_vars + for tensor_const in tensor_consts: + pair = (Var('const', tensor_const.data_type), tensor_const) + body = rewrite(body, {pair[1]: pair[0]}) + local_const_vars.append(pair) + if body is func.body: + return func + else: + return Function(func.name, func.params, body, func.ret_type, kind=func.kind, local_vars=func.local_vars, local_const_vars=local_const_vars, + extern_vars=func.extern_vars, attrs=func.attrs) + + +def normalize_const_tensor_pass() -> Pass: + return NormalizeConstTensorPass() diff --git a/python/hidet/transforms/precompute_condition.py b/python/hidet/transforms/precompute_condition.py new file mode 100644 index 0000000..37bdf61 --- /dev/null +++ b/python/hidet/transforms/precompute_condition.py @@ -0,0 +1,56 @@ +from hidet.ir.expr import IfThenElse +from hidet.ir.stmt import Stmt, IfStmt, ForStmt +from hidet.ir.func import Function + +from .base import FunctionPass +from .common import FuncStmtExprRewriterWithScope + + +class PrecomputeConditionRewriter(FuncStmtExprRewriterWithScope): + def __init__(self): + super().__init__(use_memo=False) + + @staticmethod + def scope_in_loop(scope) -> bool: + while scope is not None: + if isinstance(scope.scope_stmt, ForStmt): + return True + scope = scope.parent + return False + + def should_precompute(self, cond) -> bool: + """ + we only precompute when the following two conditions holds: + 1. current scope is in a loop. + 2. the used variables are defined not in any loop + """ + return self.scope_in_loop(self.scope_stack.current()) and not self.scope_in_loop(self.scope_to_define(cond)) + + def visit_IfStmt(self, stmt: IfStmt): + if self.should_precompute(stmt.cond): + # we can precompute the predicate + scope = self.scope_to_define(stmt.cond) + cond = scope.define_predicate(stmt.cond) + then_body = self.visit(stmt.then_body) + else_body = self.visit(stmt.else_body) if stmt.else_body else None + return IfStmt(cond, then_body, else_body) + else: + return FuncStmtExprRewriterWithScope.visit_IfStmt(self, stmt) + + def visit_IfThenElse(self, e: IfThenElse): + if self.should_precompute(e.cond): + scope = self.scope_to_define(e.cond) + cond = scope.define_predicate(e.cond) + return IfThenElse(cond, e.then_expr, e.else_expr) + else: + return FuncStmtExprRewriterWithScope.visit_IfThenElse(self, e) + + +class PrecomputeConditionPass(FunctionPass): + def process_func(self, func: Function) -> Function: + rewriter = PrecomputeConditionRewriter() + return rewriter(func) + + +def precompute_condition_pass(): + return PrecomputeConditionPass() diff --git a/python/hidet/transforms/resolve_generic_primitive_function.py b/python/hidet/transforms/resolve_generic_primitive_function.py new file mode 100644 index 0000000..a86321d --- /dev/null +++ b/python/hidet/transforms/resolve_generic_primitive_function.py @@ -0,0 +1,73 @@ +from typing import List + +import hidet.ir.primitives.base.funcs +from hidet.ir.type import ScalarType +from hidet.ir.stmt import Stmt +from hidet.ir.expr import Call, Expr, Add, Sub, Multiply, Div, BinaryOp, cast +from hidet.ir.func import IRModule, Function +from hidet.ir.functors import collect, StmtExprRewriter, infer_type, TypeInfer +from hidet.ir.primitives import is_primitive_function, lookup_primitive_function +from hidet.transforms import Pass, FunctionBodyPass +from hidet.utils.py import green + + +def resolve_dtype(arg_dtypes: List[ScalarType]) -> ScalarType: + return hidet.ir.primitives.base.funcs.type_infer_func(arg_dtypes) + + +def cast_args(args: List[Expr], arg_dtypes: List[ScalarType], target_dtype: ScalarType) -> List[Expr]: + casted_args = [] + for arg, arg_dtype in zip(args, arg_dtypes): + if arg_dtype.name != target_dtype.name: + casted_args.append(cast(arg, target_dtype)) + else: + casted_args.append(arg) + return casted_args + + +class ResolveGenericPrimitiveFuncRewriter(StmtExprRewriter): + def __init__(self): + super().__init__() + self.type_infer = TypeInfer() + + def visit_Call(self, e: Call): + if is_primitive_function(e.func_var.hint): + entry = lookup_primitive_function(e.func_var.hint) + if entry.generic: + args = [self(arg) for arg in e.args] + arg_types = [infer_type(arg) for arg in args] + resolved_dtype = resolve_dtype(arg_types) + if resolved_dtype.name not in entry.dispatch_dtype_rules: + msg = 'Can not dispatch generic primitive function {} to dtype {}'.format(green(entry.name), green(resolved_dtype)) + raise NotImplementedError(msg) + dispatched_func_key = entry.dispatch_dtype_rules[resolved_dtype.name] + dispatched_func_entry = lookup_primitive_function(key=dispatched_func_key) + casted_args = cast_args(args, arg_types, resolved_dtype) + return Call(dispatched_func_entry.var, casted_args) + + return StmtExprRewriter.visit_Call(self, e) + + def visit_Binary(self, e: BinaryOp): + lhs = self.visit(e.a) + rhs = self.visit(e.b) + lhs_dtype = self.type_infer(lhs) + rhs_dtype = self.type_infer(rhs) + if lhs_dtype.name != rhs_dtype.name: + dtype = resolve_dtype([lhs_dtype, rhs_dtype]) + lhs, rhs = cast_args([lhs, rhs], [lhs_dtype, rhs_dtype], dtype) + if lhs is e.a and rhs is e.b: + return e + else: + return e.__class__(lhs, rhs) + else: + return StmtExprRewriter.visit_Binary(self, e) + + +class ResolveGenericPrimitiveFuncPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + rewriter = ResolveGenericPrimitiveFuncRewriter() + return rewriter.visit(stmt) + + +def resolve_primitive_func_pass(): + return ResolveGenericPrimitiveFuncPass() diff --git a/python/hidet/transforms/rule_based_simplifier.py b/python/hidet/transforms/rule_based_simplifier.py new file mode 100644 index 0000000..f7458af --- /dev/null +++ b/python/hidet/transforms/rule_based_simplifier.py @@ -0,0 +1,249 @@ +import operator +from typing import Dict, Optional, List +from itertools import product + +from hidet.ir.dialects.pattern import AnyExpr, match +from hidet.ir.expr import Add, convert, Sub, Multiply, FloorDiv, Mod, LessThan, LessEqual, Equal, BinaryOp, And, IfThenElse, Or, Div, Constant +from hidet.ir.expr import Constant, Expr, Var, Cast, cast +from hidet.ir.functors import FuncStmtExprRewriter +from hidet.ir.functors import StmtExprRewriter, ExprVisitor +from hidet.ir.functors import rewrite, ExprHash +from hidet.transforms.base import FunctionPass +from hidet.utils import prod, repeat_until_converge +from hidet.ir.stmt import LetStmt, ForStmt +from hidet.ir.func import Function +from hidet.ir.analyzers import BoundAnalyzer, BoundInfo + + +def any_expr(allow_const): + if allow_const: + return AnyExpr() + else: + return AnyExpr(exclude_cls=Constant) + + +def any_constant(): + return Constant(value=None) + + +def c_div(a, b): + if isinstance(a, int) and isinstance(b, int): + return a // b + else: + return a / b + + +class ConstExprSimplifier(StmtExprRewriter): + op_dict = { + Add: operator.add, + Sub: operator.sub, + Multiply: operator.mul, + Div: c_div, + Mod: operator.mod, + LessThan: operator.lt, + LessEqual: operator.le, + Equal: operator.eq, + } + + def visit_Binary(self, e: BinaryOp): + e = StmtExprRewriter.visit_Binary(self, e) + if e.a.is_const() and e.b.is_const() and e.__class__ in self.op_dict: + assert isinstance(e.a, Constant) and isinstance(e.b, Constant) + op = self.op_dict[e.__class__] + c = op(e.a.const().value, e.b.const().value) + if isinstance(c, bool): + return Constant(c, 'bool') + else: + return Constant(c, max(e.a.data_type, e.b.data_type)) + return e + + def visit_And(self, e: And): + e = StmtExprRewriter.visit_Binary(self, e) + a_val = e.a.const().value if e.a.is_const() else None + b_val = e.b.const().value if e.b.is_const() else None + if a_val and b_val: + return convert(True) + elif a_val is False or b_val is False: + return convert(False) + elif a_val: + return e.b + elif b_val: + return e.a + else: + return e + + +class RuleBasedSimplifier(FuncStmtExprRewriter): + _enumerate_limit = 256 + + def __init__(self): + super().__init__() + self.analyzer = BoundAnalyzer() + self.bound: Dict[Expr, BoundInfo] = self.analyzer.bound + self.const_expr_simplifier = ConstExprSimplifier() + e1, e2 = any_expr(allow_const=False), any_expr(allow_const=False) + c1, c2 = any_constant(), any_constant() + ec1, ec2 = any_expr(allow_const=True), any_expr(allow_const=True) + zero = convert(0) + one = convert(1) + self.args = {e1, e2, c1, c2, ec1, ec2} + self.patterns = [ + (e1 + zero, e1), + (e1 - zero, e1), + (e1 * one, e1), + (e1 * zero, zero), + (e1 // one, e1), + # add + ((c1 + e1) + e2, (e1 + e2) + c1), + ((e1 + c1) + c2, e1 + (c1 + c2)), + ((c1 - e1) + e2, (e2 - e1) + c1), + ((e1 - c1) + e2, (e1 + e2) - c1), + # sub + ((c1 + e1) - e2, (e1 - e2) + c1), + (e1 - (c1 + e2), (e1 - e2) - c1), + ((c1 - e1) - e2, c1 - (e1 + e2)), + ((e1 - c1) - e2, (e1 - e2) - c1), + (e1 - (c1 - e2), (e1 + e2) - c1), + (e1 - (e2 - c1), (e1 - e2) + c1), + ((e1 - c1) - c2, e1 - (c1 + c2)), + # mul + ((e1 + c1) * c2, c1 * c2 + e1 * c2), + ((c1 - e1) * c2, c1 * c2 - e1 * c2), + ((e1 - c1) * c2, e1 * c2 - c1 * c2), + ((e1 * c1) * c2, e1 * (c1 * c2)), + # div + (((e1 * c1) + (e2 % c1)) // c1, e1), + ((e1 // c1) // c2, e1 // (c1 * c2)), + ((e1 * c1) // c1, e1), + ((e1 * c1 + e2) // c1, e1 + e2 // c1), + # mod + ((e1 * c1 + e2) % c1, e2 % c1), + ((e1 % c1) % c1, e1 % c1), + # comparison + (e1 + c1 < c2, e1 < c2 - c1), + (e1 - c1 < c2, e1 < c1 + c2), + (c1 <= e1 - c2, c1 + c2 <= e1), + (c1 <= e1 + c2, c1 - c2 <= e1), + # and/or + (And(ec1, True), ec1), + (And(ec1, False), convert(False)), + (Or(ec1, True), convert(True)), + (Or(ec1, False), ec1), + # if then else + (IfThenElse(True, ec1, ec2), ec1), + (IfThenElse(False, ec1, ec2), ec2), + ] + self.bound_patterns = [ + # ((pattern_args, pattern_func, target_args, target_func) + ((ec1, ec2, c1), (ec1, ec2, c1), lambda ec1, ec2, c1: (ec1 + ec2) // c1, lambda ec1, ec2, c1: ec1 // c1 + ec2 // c1), + ((ec1, ec2, c1), (ec1, ec2, c1), lambda ec1, ec2, c1: (ec1 + ec2) % c1, lambda ec1, ec2, c1: ec1 % c1 + ec2 % c1), + ((ec1, c1), (ec1,), lambda ec1, c1: ec1 % c1, lambda ec1: ec1), + ((ec1, c1, c2), (ec1, c2), lambda ec1, c1, c2: (ec1 % c1) % c2, lambda ec1, c2: ec1 % c2) + ] + + def apply_rule(self, e): + for idx, (pattern, target) in enumerate(self.patterns): + if pattern.__class__ is not e.__class__: + continue + mapping, msg = match(pattern, e) + if mapping: + # print('apply rule ', pattern, target, 'on', e) + mapping = {a: b for a, b in mapping.items() if a in self.args} + ret = rewrite(target, rewrite_map=mapping) + return ret + return e + + def apply_bound_aware_rule(self, e): + for idx, (pattern_args, target_args, pattern_func, target_func) in enumerate(self.bound_patterns): + pattern = pattern_func(*pattern_args) + if pattern.__class__ is not e.__class__: + continue + mapping, msg = match(pattern, e) + if mapping: + mapping = {a: b for a, b in mapping.items() if a in self.args} + self.analyzer(e) + arg_candidates = {arg: self.bound[mapping[arg]].candidate_set() for arg in mapping.keys()} + if any(can_set is None for can_set in arg_candidates.values()): + continue + if prod([len(can_set) for can_set in arg_candidates.values()]) > self._enumerate_limit: + continue + sorted_can_sets = [] + for pattern_arg in pattern_args: + sorted_can_sets.append(arg_candidates[pattern_arg]) + target_arg_index = [] + for target_arg in target_args: + for i in range(len(pattern_args)): + if pattern_args[i] is target_arg: + target_arg_index.append(i) + break + for args in product(*sorted_can_sets): + t_args = [args[i] for i in target_arg_index] + if pattern_func(*args) != target_func(*t_args): + break + else: + target = target_func(*target_args) + ret = rewrite(target, rewrite_map=mapping) + return ret + return e + + def visit(self, obj): + if obj in self.memo: + return self.memo[obj] + self.analyzer(obj) + if obj in self.bound and self.bound[obj].value is not None and not isinstance(obj, Constant): + return convert(self.bound[obj].value) + cur = FuncStmtExprRewriter.visit(self, obj) + if isinstance(cur, Expr): + while True: + orig_obj = cur + cur = self.apply_rule(cur) + cur = self.const_expr_simplifier(cur) + cur = self.apply_bound_aware_rule(cur) + cur = self.const_expr_simplifier(cur) + if orig_obj is cur: + break + self.memo[obj] = cur + return cur + + def visit_Mod(self, e: Mod): + ua, ub = self.bound[e.a], self.bound[e.b] + if ua.is_zero() or ua < ub: + return self(e.a) + return FuncStmtExprRewriter.visit_Mod(self, e) + + def visit_LessThan(self, e: LessThan): + ua, ub = self.bound[e.a], self.bound[e.b] + if ua < ub: + return convert(True) + if ub <= ua: + return convert(False) + return FuncStmtExprRewriter.visit_LessThan(self, e) + + def visit_LessEqual(self, e: LessEqual): + ua, ub = self.bound[e.a], self.bound[e.b] + if ua <= ub: + return convert(True) + if ub < ua: + return convert(False) + return FuncStmtExprRewriter.visit_LessEqual(self, e) + + def visit_Equal(self, e: Equal): + ua, ub = self.bound[e.a], self.bound[e.b] + if ua <= ub <= ua: + return convert(True) + if ua < ub or ub < ua: + return convert(False) + return FuncStmtExprRewriter.visit_Equal(self, e) + + def visit_Function(self, func: Function): + return FuncStmtExprRewriter.visit_Function(self, func) + + +class RuleBasedSimplifyPass(FunctionPass): + def process_func(self, func: Function) -> Function: + simplifier = RuleBasedSimplifier() + return repeat_until_converge(simplifier, func) + + +def rule_based_simplify_pass(): + return RuleBasedSimplifyPass() diff --git a/python/hidet/transforms/simplify_stmt.py b/python/hidet/transforms/simplify_stmt.py new file mode 100644 index 0000000..6fc5647 --- /dev/null +++ b/python/hidet/transforms/simplify_stmt.py @@ -0,0 +1,37 @@ +from hidet.ir import Stmt +from hidet.ir.expr import is_one, is_zero, is_true, is_false, convert +from hidet.ir.stmt import IfStmt, ForStmt, SeqStmt +from hidet.ir.functors import StmtExprRewriter +from hidet.transforms.base import FunctionBodyPass + + +class StatementSimplifier(StmtExprRewriter): + def visit_IfStmt(self, stmt: IfStmt): + if is_true(stmt.cond): + then_body = self(stmt.then_body) + return then_body + elif is_false(stmt.cond): + if stmt.else_body: + return self(stmt.else_body) + else: + return SeqStmt([]) + else: + return StmtExprRewriter.visit_IfStmt(self, stmt) + + def visit_ForStmt(self, stmt: ForStmt): + if is_zero(stmt.extent): + return SeqStmt([]) + elif is_one(stmt.extent): + self.memo[stmt.loop_var] = convert(0) + return self(stmt.body) + else: + return StmtExprRewriter.visit_ForStmt(self, stmt) + + +class SimplifyStmtPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + return StatementSimplifier()(stmt) + + +def simplify_stmt_pass(): + return SimplifyStmtPass() diff --git a/python/hidet/transforms/squeeze_let_stmt.py b/python/hidet/transforms/squeeze_let_stmt.py new file mode 100644 index 0000000..1757487 --- /dev/null +++ b/python/hidet/transforms/squeeze_let_stmt.py @@ -0,0 +1,27 @@ +from hidet.ir.stmt import Stmt, LetStmt +from hidet.ir.functors import StmtRewriter, same_list +from .base import FunctionBodyPass + + +class SqueezeLetStmtRewriter(StmtRewriter): + def visit_LetStmt(self, stmt: LetStmt): + bind_vars = [] + bind_values = [] + cur = stmt + while isinstance(cur, LetStmt): + bind_vars.extend(cur.bind_vars) + bind_vars.extend(cur.bind_values) + cur = cur.body + if same_list(bind_vars, stmt.bind_vars) and same_list(bind_values, stmt.bind_values) and cur is stmt.body: + return stmt + else: + return LetStmt(bind_vars, bind_values, stmt) + + +class SqueezeLetStmtPass(FunctionBodyPass): + def process_body(self, stmt: Stmt) -> Stmt: + return SqueezeLetStmtRewriter()(stmt) + + +def squeeze_let_stmt_pass(): + return SqueezeLetStmtPass() diff --git a/python/hidet/transforms/uplift_let_stmt.py b/python/hidet/transforms/uplift_let_stmt.py new file mode 100644 index 0000000..b89e041 --- /dev/null +++ b/python/hidet/transforms/uplift_let_stmt.py @@ -0,0 +1,24 @@ +from hidet.ir.func import Function +from hidet.ir.stmt import LetStmt +from .base import FunctionPass +from .common import FuncStmtExprRewriterWithScope + + +class UpliftLetStmtRewriter(FuncStmtExprRewriterWithScope): + def visit_LetStmt(self, stmt: LetStmt): + with self.new_scope(stmt) as scope: + for var, value in zip(stmt.bind_vars, stmt.bind_values): + value = self.visit(value) + scope_to_define = self.scope_to_define(value) + scope_to_define.define(var, value) + return scope.wrap(self.visit(stmt.body)) + + +class UpliftLetStmtPass(FunctionPass): + def process_func(self, func: Function) -> Function: + rewriter = UpliftLetStmtRewriter() + return rewriter.visit(func) + + +def uplift_let_stmt_pass(): + return UpliftLetStmtPass() diff --git a/python/hidet/transforms/vectorize_load_store.py b/python/hidet/transforms/vectorize_load_store.py new file mode 100644 index 0000000..2972990 --- /dev/null +++ b/python/hidet/transforms/vectorize_load_store.py @@ -0,0 +1,140 @@ +from typing import List, Type, Optional +import contextlib +from hidet.transforms import Pass +from hidet.ir.type import TensorType +from hidet.ir.expr import Expr, Var, TensorElement +from hidet.ir.stmt import Stmt, SeqStmt, BufferStoreStmt, AssignStmt, EvaluateStmt +from hidet.ir.func import Function +from hidet.ir.functors import StmtRewriter, equal, same_list +from hidet.ir.dialects.lowlevel import Address +from hidet.ir.primitives import lds128, sts128 + + +class Vectorizer: + @property + def num_stmts(self) -> int: + raise NotImplementedError() + + @property + def stmt_cls(self) -> Type[Stmt]: + raise NotImplementedError() + + def vectorize(self, seq: List[Stmt]) -> Optional[Stmt]: + raise NotImplementedError() + + @staticmethod + def is_greater_one(lhs: Expr, rhs: Expr) -> bool: + return equal(lhs + 1, rhs) + + @staticmethod + def is_contiguous(seq: List[Expr]) -> bool: + for i in range(len(seq) - 1): + if not Vectorizer.is_greater_one(seq[i], seq[i + 1]): + return False + return True + + +class CudaLds128Vectorizer(Vectorizer): + @property + def num_stmts(self) -> int: + return 4 + + @property + def stmt_cls(self) -> Type[Stmt]: + return BufferStoreStmt + + def vectorize(self, seq: List[BufferStoreStmt]) -> Optional[Stmt]: + with contextlib.suppress(AssertionError): + assert len(seq) == 4 + assert all(isinstance(s, BufferStoreStmt) for s in seq) + assert all(isinstance(s.value, TensorElement) for s in seq) + dst_vars: List[Var] = [s.buf for s in seq] + src_vars: List[Var] = [s.value.base for s in seq] + assert all(isinstance(v.type, TensorType) and v.type.scope.name == 'shared' for v in src_vars) + assert all(isinstance(v.type, TensorType) and v.type.scope.name == 'register' for v in dst_vars) + assert all(len(s.value.indices) == 1 for s in seq) + shared_indices = [s.value.indices[0] for s in seq] + assert self.is_contiguous(shared_indices) + regs = [TensorElement(s.buf, s.indices) for s in seq] + smem_addr = Address(seq[0].value) + return EvaluateStmt(lds128(regs[0], regs[1], regs[2], regs[3], smem_addr)) + return None + + +class CudaSts128Vectorizer(Vectorizer): + @property + def num_stmts(self) -> int: + return 4 + + @property + def stmt_cls(self) -> Type[Stmt]: + return BufferStoreStmt + + def vectorize(self, seq: List[BufferStoreStmt]) -> Optional[Stmt]: + with contextlib.suppress(AssertionError): + assert len(seq) == 4 + assert all(isinstance(s, BufferStoreStmt) for s in seq) + assert all(isinstance(s.value, TensorElement) for s in seq) + dst_vars: List[Var] = [s.buf for s in seq] + src_vars: List[Var] = [s.value.base for s in seq] + assert all(isinstance(v.type, TensorType) and v.type.scope.name == 'register' for v in src_vars) + assert all(isinstance(v.type, TensorType) and v.type.scope.name == 'shared' for v in dst_vars) + assert all(len(s.indices) == 1 for s in seq) + shared_indices = [s.indices[0] for s in seq] + assert self.is_contiguous(shared_indices) + regs = [s.value for s in seq] + smem_addr = Address(TensorElement(seq[0].buf, seq[0].indices)) + return EvaluateStmt(sts128(regs[0], regs[1], regs[2], regs[3], smem_addr)) + return None + + +class StmtVectorizer(StmtRewriter): + def __init__(self): + super().__init__() + self.vectorizers: List[Vectorizer] = [ + CudaLds128Vectorizer(), + CudaSts128Vectorizer() + ] + + def visit_SeqStmt(self, stmt: SeqStmt): + seq: List[Stmt] = [self(s) for s in stmt.seq] + new_seq = [] + n = len(seq) + i = 0 + while i < n: + success = False + for vectorizer in self.vectorizers: + cls = vectorizer.stmt_cls + m = vectorizer.num_stmts + if i + m - 1 >= n: + continue + if not all(isinstance(s, cls) for s in seq[i: i+m]): + continue + new_stmt = vectorizer.vectorize(seq[i: i+m]) + if new_stmt: + new_seq.append(new_stmt) + i += m + success = True + break + if not success: + new_seq.append(seq[i]) + i += 1 + if same_list(new_seq, stmt.seq): + return stmt + else: + return SeqStmt(new_seq) + + +class VectorizeLoadStorePass(Pass): + def process_func(self, func: Function) -> Function: + vectorizer = StmtVectorizer() + body = vectorizer(func.body) + if body is func.body: + return func + else: + return Function(func.name, func.params, body, func.ret_type, local_vars=func.local_vars, + local_const_vars=func.local_const_vars, extern_vars=func.extern_vars, attrs=func.attrs) + + +def vectorize_load_store_pass(): + return VectorizeLoadStorePass() diff --git a/python/hidet/utils/__init__.py b/python/hidet/utils/__init__.py new file mode 100644 index 0000000..f72c101 --- /dev/null +++ b/python/hidet/utils/__init__.py @@ -0,0 +1,13 @@ +from . import doc +from . import cuda +from . import namer +from . import py +from . import netron +from . import nvtx_utils +from . import transformers_utils + +from .py import prod, Timer, repeat_until_converge, COLORS, get_next_file_index, factor, HidetProfiler, TableBuilder, line_profile, same_list, strict_zip, initialize, gcd, lcm, error_tolerance +from .nvtx_utils import nvtx_annotate +from .git_utils import hidet_cache_dir, hidet_cache_file, hidet_set_cache_root +from .net_utils import download +from .profile_utils import tracer diff --git a/python/hidet/utils/cuda.py b/python/hidet/utils/cuda.py new file mode 100644 index 0000000..8fd3881 --- /dev/null +++ b/python/hidet/utils/cuda.py @@ -0,0 +1,267 @@ +import os +import subprocess +from functools import lru_cache +from subprocess import PIPE +from typing import List, Optional, Union + + +def max_smem_bytes_per_sm(cc=None): + legacy = True + if legacy: + return 48 * 1024 + else: + if cc is None: + cc = query_compute_capability() + data = { + (6, 0): 64, + (6, 1): 96, + (6, 2): 64, + (7, 0): 96, + (7, 2): 96, + (7, 5): 64, + (8, 0): 164, + (8, 6): 100, + (8, 7): 164 + } + return data[cc] * 1024 + + +def max_smem_bytes_per_block(cc=None): + legacy = True + if legacy: + return 48 * 1024 + else: + if cc is None: + cc = query_compute_capability() + data = { + (6, 0): 48, + (6, 1): 48, + (6, 2): 48, + (7, 0): 96, + (7, 2): 96, + (7, 5): 64, + (8, 0): 163, + (8, 6): 99, + (8, 7): 163 + } + return data[cc] * 1024 + + +def max_num_regs_per_thread(): + return 255 + + +def max_num_regs_per_block(cc=None): + if cc is None: + cc = query_compute_capability() + data = { + (6, 0): 64, + (6, 1): 64, + (6, 2): 32, + (7, 0): 64, + (7, 2): 64, + (7, 5): 64, + (8, 0): 64, + (8, 6): 64 + } + return data[cc] * 1024 + + +def max_num_regs_per_sm(cc=None): + return 64 * 1024 + + +@lru_cache(maxsize=128) +def query_compute_capability(): + major, minor = query_gpu('compute_cap').split('.') + return int(major), int(minor) + + +def device_synchronize(): + from hidet.ffi.cuda_api import CudaAPI + CudaAPI.device_synchronize() + + +def preferred_gpu_clock(): + use_max = True + if use_max: + return query_gpu_max_clock() + base_clocks = { + 'NVIDIA GeForce RTX 3070 Laptop GPU': 1560, + 'Tesla V100-SXM2-16GB': 1530, + 'Tesla T4': 1250, + } + name = query_gpu('gpu_name') + if name in base_clocks: + return base_clocks[name] + else: + print('running on a new device: {}'.format(name)) + print('please set the base clock at {}:preferred_gpu_clock()', __file__) + return int(query_gpu_max_clock() * 0.8) + + +def lock_gpu_clock(clock: Optional[int] = None): + if clock is None: + clock = preferred_gpu_clock() + command = f'sudo -S nvidia-smi --lock-gpu-clocks={clock}' + print(f"Running '{command}'...") + subprocess.run(command.split(), check=True) + + +def reset_gpu_clock(): + command = 'sudo -S nvidia-smi --reset-gpu-clocks' + print(f"Running '{command}'") + subprocess.run(command.split(), check=True) + + +def query_gpu_current_clock() -> int: + return int(query_gpu('clocks.current.graphics')) + + +def query_gpu_temperature() -> int: + return int(query_gpu('temperature.gpu')) + + +def query_device_name(short=False) -> str: + full_name = query_gpu('name') + if short: + short_name_dict = { + 'NVIDIA GeForce RTX 3070 Laptop GPU': 'RTX3070L', + 'NVIDIA GeForce RTX 3090': 'RTX3090', + 'Tesla V100-SXM2-16GB': 'V100', + 'Tesla T4': 'T4', + } + ret = short_name_dict[full_name] if full_name in short_name_dict else full_name + else: + ret = full_name + return ret + + +def query_arch() -> str: + arch2name = { + (2, 0): 'Fermi', + (3, 0): 'Kepler', + (3, 5): 'Kepler', + (3, 7): 'Kepler', + (5, 0): 'Maxwell', + (5, 2): 'Maxwell', + (5, 3): 'Maxwell', + (6, 0): 'Pascal', + (6, 1): 'Pascal', + (6, 2): 'Pascal', + (7, 0): 'Volta', + (7, 2): 'Volta', + (7, 5): 'Turing', + (8, 0): 'Ampere', + (8, 6): 'Ampere' + } + return arch2name[query_compute_capability()] + + +def query_clocks_throttle_reason() -> str: + # see 'nvidia-smi --help' and 'nvml.h' for more information + bitmask = int(query_gpu('clocks_throttle_reasons.active'), base=16) + bit2reason = { + 1: 'gpu_idle', + 2: 'app_clock_setting', + 4: 'sw_power_cap', + 8: 'hw_slowdown', + 16: 'sync_boost', + 32: 'sw_thermal_slowdown', + 64: 'hw_thermal_slowdown', + 128: 'hw_power_brake_slowdown', + 256: 'display_clock_setting', + } + if bitmask == 0: + return 'no' + else: + reasons = [] + for bit, reason in bit2reason.items(): + if (bitmask & bit) != 0: + reasons.append(reason) + if len(reasons) == 0: + raise NotImplementedError() + return "/".join(reasons) + + +def query_gpu(names: Union[List[str], str]): + if not isinstance(names, (list, tuple)): + names = [names] + result = subprocess.run(f'nvidia-smi -i 0 --query-gpu={",".join(names)} --format=csv,noheader,nounits'.split(), + stdin=PIPE, stdout=PIPE, check=True) + results = [s.strip() for s in result.stdout.decode('utf-8').split(',')] + if len(results) == 1: + return results[0] + else: + return results + + +def query_gpu_max_clock() -> int: + return int(query_gpu('clocks.max.sm')) + + +def lock_memory_clock(clock: int): + command = f'sudo -S nvidia-smi --lock-memory-clocks={clock}' + print(f"Running '{command}'...") + subprocess.run(command.split(), check=True) + + +def reset_memory_clock(): + command = 'sudo -S nvidia-smi --reset-memory-clocks' + print(f"Running '{command}'") + subprocess.run(command.split(), check=True) + + +def query_memory_current_clock() -> int: + return int(query_gpu('clocks.current.memory')) + + +def query_memory_max_clock() -> int: + return int(query_gpu('clocks.max.memory')) + + +def query_persistent_mode() -> bool: + result = subprocess.run('nvidia-smi -pm 1'.split(), stdin=PIPE, stdout=PIPE) + return result.returncode == 0 + + +def turn_on_persistent_mode(): + result = subprocess.run('nvidia-smi -pm 1'.split(), stdin=PIPE, stdout=PIPE) + if result.returncode != 0: + # the persistent mode is disabled, use sudo -S to turn it on, passwd is required from shell + command = 'sudo -S nvidia-smi -pm 1' + print(f"Running '{command}' to turn on persistent mode...") + subprocess.run(command.split(), check=True) + + +class BenchmarkContext: + def __init__(self, lock_clock=True): + self.lock_clock = lock_clock + + def __enter__(self): + if self.lock_clock: + # sm clock; to make result more stable (trying to avoid gpu throttle) + turn_on_persistent_mode() + lock_gpu_clock() + lock_memory_clock(query_memory_max_clock()) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.lock_clock: + reset_memory_clock() + reset_gpu_clock() + + @staticmethod + def get_bench_ratio(clock_ratio=None): + ratio = os.environ.get('HIDET_BENCH_RATIO') + if ratio: + ratio = float(ratio) + elif clock_ratio is not None: + ratio = clock_ratio + else: + default_clock_ratio = 1.0 + ratio = default_clock_ratio + return min(max(float(ratio), 0.1), 1.0) + + +if __name__ == '__main__': + print(query_compute_capability()) diff --git a/python/hidet/utils/doc.py b/python/hidet/utils/doc.py new file mode 100644 index 0000000..f9467d8 --- /dev/null +++ b/python/hidet/utils/doc.py @@ -0,0 +1,79 @@ +from typing import List + + +def doc_join(seq: List, sep): + doc = Doc() + for i in range(len(seq)): + if i != 0: + doc += sep + doc += seq[i] + return doc + + +class NewLineToken: + def __init__(self, indent=0): + self.indent = indent + + def __str__(self): + return '\n' + ' ' * self.indent + + +class Doc: + default_indent = 2 + + def __init__(self): + self.docs = [] + + def append(self, doc): + if isinstance(doc, list): + for item in doc: + self.append(item) + elif isinstance(doc, Doc): + self.docs.extend(doc.docs) + elif isinstance(doc, str): + self.docs.append(doc) + else: + raise NotImplementedError() + + def indent(self, inc=None): + if inc is None: + inc = self.default_indent + doc = Doc() + for token in self.docs: + if isinstance(token, NewLineToken): + doc.docs.append(NewLineToken(indent=token.indent + inc)) + else: + doc.docs.append(token) + return doc + + def __add__(self, other): + doc = Doc() + doc.docs = [token for token in self.docs] + doc += other + return doc + + def __radd__(self, other): + doc = Doc() + doc.docs = [] + doc.append(other) + doc.append(self) + return doc + + def __iadd__(self, other): + self.append(other) + return self + + def __str__(self): + return "".join(str(s) for s in self.docs) + + +class NewLine(Doc): + def __init__(self, indent=0): + super().__init__() + self.docs.append(NewLineToken(indent)) + + +class Text(Doc): + def __init__(self, s): + super().__init__() + self.docs.append(s) diff --git a/python/hidet/utils/git_utils.py b/python/hidet/utils/git_utils.py new file mode 100644 index 0000000..1f5c9e2 --- /dev/null +++ b/python/hidet/utils/git_utils.py @@ -0,0 +1,106 @@ +from typing import List +import os +import git +import functools +import datetime +import logging + + +logger = logging.Logger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + + +def get_repo_sha(short=False): + """ + Get the current commit (i.e., HEAD) sha hash. + + Parameters + ---------- + short: bool, default False + Whether get a short version of hash. + + Returns + ------- + ret: str + The commit sha hash. + """ + repo = git.Repo(search_parent_directories=True) + sha = repo.head.object.hexsha + if short: + return sha[:7] + else: + return sha + + +def get_repo_commit_date(strftime='%Y-%m-%d') -> str: + """ + Get the commit date time of current commit (i.e., HEAD). + + Parameters + ---------- + strftime: str, default '%Y-%m-%d' + The format of the date time. The default format will return a date time like '2023-03-24'. + Others: %H: hour, %M: minutes + + Returns + ------- + ret: str + The commit date time in given format. + """ + repo = git.Repo(search_parent_directories=True) + commit = repo.head + committed_date = commit.commit.committed_date + dt = datetime.datetime.fromtimestamp(committed_date) + return str(dt.strftime(strftime)) + + +@functools.lru_cache(maxsize=1) +def repo_root() -> str: + """ + Get the root directory of current git repository. + + Returns + ------- + ret: str + The root directory. + """ + repo = git.Repo(search_parent_directories=True) + return repo.working_dir + + +_hidet_cache_root_dir = os.path.join(repo_root(), '.hidet_cache') +os.makedirs(_hidet_cache_root_dir, exist_ok=True) + + +def hidet_set_cache_root(root_dir: str): + global _hidet_cache_root_dir + root_dir = os.path.abspath(os.path.expanduser(root_dir)) + if not os.path.exists(root_dir): + os.makedirs(root_dir) + if not os.path.isdir(root_dir): + raise ValueError('Expect {} to be a directory.'.format(root_dir)) + _hidet_cache_root_dir = root_dir + logger.info('Hidet cache root dir: {}'.format(root_dir)) + + +def hidet_cache_dir(category='./') -> str: + root = _hidet_cache_root_dir + if category == './': + ret = root + else: + ret = os.path.join(root, category) + os.makedirs(ret, exist_ok=True) + return ret + + +def hidet_cache_file(*items: str) -> str: + root_dir = hidet_cache_dir('./') + ret_path = os.path.join(root_dir, *items) + os.makedirs(os.path.dirname(ret_path), exist_ok=True) + return ret_path + + +if __name__ == '__main__': + print(repo_root()) + diff --git a/python/hidet/utils/info.py b/python/hidet/utils/info.py new file mode 100644 index 0000000..ddf845b --- /dev/null +++ b/python/hidet/utils/info.py @@ -0,0 +1,6 @@ +def float_type_min_value(): + return -3.4e38 + + +def float_type_max_value(): + return 3.4e38 diff --git a/python/hidet/utils/namer.py b/python/hidet/utils/namer.py new file mode 100644 index 0000000..a87b1ef --- /dev/null +++ b/python/hidet/utils/namer.py @@ -0,0 +1,54 @@ +from collections import defaultdict + + +class Namer: + def __init__(self): + self.name_id_clock = defaultdict(int) + self.obj_name = {} + self.clear() + + def __call__(self, x): + return self.get_name(x) + + def clear(self): + self.name_id_clock.clear() + self.obj_name.clear() + # add keywords in target language + keywords = [ + 'const' + ] + for kw in keywords: + self.name_id_clock[kw] = 0 + + def get_name(self, e, hint=None): + from hidet.ir.expr import Var + from hidet.ir.dialects.compute import ScalarNode, TensorNode + from hidet.tos.tensor import Tensor + if e in self.obj_name: + return self.obj_name[e] + if hint: + orig_name = hint + elif isinstance(e, Var) and e.hint is not None: + orig_name = e.hint + elif isinstance(e, (ScalarNode, TensorNode)): + orig_name = e.name + else: + alias = { + ScalarNode: 'scalar', + TensorNode: 'tensor', + Var: 'v', + Tensor: 'x' + } + orig_name = alias[type(e)] if type(e) in alias else type(e).__name__ + + if orig_name in self.name_id_clock: + name = orig_name + while name in self.name_id_clock: + self.name_id_clock[orig_name] += 1 + name = orig_name + '_' + str(self.name_id_clock[orig_name]) + else: + self.name_id_clock[orig_name] = 0 + name = orig_name + + self.obj_name[e] = name + return name diff --git a/python/hidet/utils/net_utils.py b/python/hidet/utils/net_utils.py new file mode 100644 index 0000000..9728f5e --- /dev/null +++ b/python/hidet/utils/net_utils.py @@ -0,0 +1,56 @@ +from typing import Optional +import os +import sys +import shutil +import tempfile +import urllib.parse +import urllib.request +from tqdm import tqdm + +import hidet + + +def download(url: str, file_name: Optional[str] = None, progress: bool = True) -> str: + if file_name is None: + parts = urllib.parse.urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.join(hidet.utils.hidet_cache_dir(), file_name) + if not os.path.exists(cached_file): + sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + +def download_url_to_file(url, dst, progress=True): + # modified based on PyTorch + file_size = None + req = urllib.request.Request(url, headers={"User-Agent": ""}) + u = urllib.request.urlopen(req) + meta = u.info() + if hasattr(meta, 'getheaders'): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + os.makedirs(dst_dir, exist_ok=True) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + with tqdm(total=file_size, disable=not progress, unit='B', unit_scale=True, unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + pbar.update(len(buffer)) + + f.close() + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) diff --git a/python/hidet/utils/netron.py b/python/hidet/utils/netron.py new file mode 100644 index 0000000..d6edef6 --- /dev/null +++ b/python/hidet/utils/netron.py @@ -0,0 +1,212 @@ +from typing import List, Union +import json +from collections import defaultdict + + +class Model: + def __init__(self, graph, description="", author="", company="", license="", domain="", source=""): + self.graphs: List[Graph] = [graph] + self.description: str = description + self.author: str = author + self.company: str = company + self.license: str = license + self.domain: str = domain + self.source: str = source + self.format: str = 'netron' + + def export(self): + return { + 'graphs': [graph.export() for graph in self.graphs], + 'description': self.description, + 'author': self.author, + 'company': self.company, + 'license': self.license, + 'domain': self.domain, + 'source': self.source, + 'format': self.format + } + + +class Graph: + def __init__(self, inputs, outputs, nodes, name=""): + self.inputs: List[Parameter] = inputs + self.outputs: List[Parameter] = outputs + self.nodes: List[Node] = nodes + self.name: str = name + + def export(self): + return { + 'name': self.name, + 'inputs': [param.export() for param in self.inputs], + 'outputs': [param.export() for param in self.outputs], + 'nodes': [node.export() for node in self.nodes] + } + + +class Parameter: + def __init__(self, name, argument, visible=True): + self.name: str = name + self.arguments: List[Argument] = [argument] + self.visible: bool = visible + + def export(self): + return { + 'name': self.name, + 'arguments': [arg.export() for arg in self.arguments], + 'visible': self.visible + } + + +class Argument: + def __init__(self, name, data_type, shape: Union[str, List[int]], has_initializer=False, scalar_value=None): + self.name: str = name + self.data_type: str = data_type + self.shape: Union[str, List[int]] = shape + self.has_initializer: bool = has_initializer + self.scalar_value = scalar_value + + def export(self): + ret = { + 'name': self.name, + 'type': { + "string": '{}{}'.format(self.data_type, self.shape), + "shape": {'dimensions': self.shape}, + "dataType": self.data_type + } + } + if self.has_initializer: + ret['initializer'] = {'kind': 'Initializer'} + if len(self.shape) == 0 and self.scalar_value is not None: + ret['initializer']['value'] = str(self.scalar_value) + else: + ret['initializer']['value'] = '<>' + return ret + + +class Node: + # category influence the color in netron + categories = { + 'layer': ['Conv2d', 'Matmul'], + 'constant': [], + 'activation': ['Relu'], + 'pool': ['MaxPool2d', 'AvgPool2d'], + 'normalization': [], + 'dropout': [], + 'transform': ['Squeeze', 'Unsqueeze', 'Add', 'Sub', 'Multiply', 'Rsqrt'], + 'custom': [], + } + + def __init__(self, name, type_name, inputs, outputs, attributes, category=None, description=''): + self.name: str = name + self.type_name: str = type_name + self.inputs: List[Parameter] = inputs + self.outputs: List[Parameter] = outputs + self.attributes: List[Attribute] = attributes + self.description: Union[List[str], str] = description.split('\n') + self.category = category + if self.category is None: + if self.type_name.startswith('Fused') or ' ' in self.type_name: + # fused op, use the color of 'dropout' + self.category = 'dropout' + elif self.type_name.startswith('Conv2dGemm') or self.type_name.startswith('Conv2dWinograd'): + self.category = 'custom' + else: + for cat, ops in self.categories.items(): + if type_name in ops: + self.category = cat + break + + def export(self): + return { + 'name': self.name, + 'type': { + 'name': self.type_name, + 'category': self.category + }, + 'inputs': [param.export() for param in self.inputs], + 'outputs': [param.export() for param in self.outputs], + 'attributes': [attr.export() for attr in self.attributes], + 'description': self.description + } + + +class Attribute: + def __init__(self, name, type_name: str, value: str, visible=True, description=""): + self.name: str = name + self.type_name: str = type_name + self.value: str = value + self.visible: bool = visible + self.description: str = description + + def export(self): + return { + 'name': self.name, + 'type': self.type_name, + 'value': self.value, + 'visible': self.visible, + 'description': self.description + } + + +def type_string_of(value): + if isinstance(value, (list, tuple)): + if len(value) > 0: + return 'Sequence[{}]'.format(type(value[0]).__name__) + else: + return 'Sequence[]' + else: + return str(type(value).__name__) + + +def dump(flow_graph, fp): + from hidet import FlowGraph + assert isinstance(flow_graph, FlowGraph) + flow_graph.update_nodes() + tensor2argument = {} + node2idx = defaultdict(int) + + inputs = [] + outputs = [] + nodes = [] + for idx, tensor in enumerate(flow_graph.inputs): + name = 'input:{}'.format(idx) + argument = Argument(name, data_type=tensor.dtype, shape=tensor.shape, has_initializer=False) + tensor2argument[tensor] = argument + inputs.append(Parameter(name, argument)) + + constant_cnt = 0 + for node in flow_graph.nodes: + node_type = node.name + node2idx[node_type] += 1 + node_name = '{}{}'.format(node_type, node2idx[node_type]) + for idx, tensor in enumerate(node.inputs): + if tensor.storage is None: # not a constant + continue + if tensor in tensor2argument: # constant shared by multiple nodes + continue + name = 'const:{}'.format(constant_cnt) + constant_cnt += 1 + scalar_value = str(tensor.cpu().numpy()) if len(tensor.shape) == 0 and tensor.storage else None + tensor2argument[tensor] = Argument(name, data_type=tensor.dtype, shape=tensor.shape, has_initializer=True, scalar_value=scalar_value) + for idx, tensor in enumerate(node.outputs): + name = '{}:{}'.format(node_name, idx) + tensor2argument[tensor] = Argument(name, data_type=tensor.dtype, shape=tensor.shape, has_initializer=False) + nodes.append(Node( + name=node_name, + type_name=node_type, + inputs=[Parameter(str(idx), tensor2argument[tensor]) for idx, tensor in enumerate(node.inputs)], + outputs=[Parameter(str(idx), tensor2argument[tensor]) for idx, tensor in enumerate(node.outputs)], + attributes=[ + Attribute(name, type_string_of(value), str(value)) for name, value in node.attrs.items() + ], + description="{}".format(str(node.task)) + )) + for idx, tensor in enumerate(flow_graph.outputs): + outputs.append(Parameter('output:{}'.format(idx), tensor2argument[tensor])) + graph = Graph(inputs, outputs, nodes, name="") + model = Model(graph, source='Hidet', description='Converted from FlowGraph') + + json.dump(model.export(), fp, indent=2) + + + diff --git a/python/hidet/utils/nvtx_utils.py b/python/hidet/utils/nvtx_utils.py new file mode 100644 index 0000000..a208074 --- /dev/null +++ b/python/hidet/utils/nvtx_utils.py @@ -0,0 +1,17 @@ +import nvtx + +nvtx_annotate = nvtx.annotate + + +class CudaProfileContext: + def __enter__(self): + from hidet.ffi import cuda + cuda.start_profiler() + + def __exit__(self, exc_type, exc_val, exc_tb): + from hidet.ffi import cuda + cuda.stop_profiler() + + +def enable_cuda_profile(): + return CudaProfileContext() diff --git a/python/hidet/utils/ort_utils.py b/python/hidet/utils/ort_utils.py new file mode 100644 index 0000000..a05beee --- /dev/null +++ b/python/hidet/utils/ort_utils.py @@ -0,0 +1,62 @@ +from time import time +from typing import Dict, List + +import onnxruntime as ort + +import hidet +from hidet import Tensor +from hidet.ffi import cuda + + +def create_ort_session(onnx_model_path, provider='CUDAExecutionProvider') -> ort.InferenceSession: + session = ort.InferenceSession(onnx_model_path, providers=[provider]) + session.disable_fallback() + return session + + +def _prepare_io_binding(session: ort.InferenceSession, inputs: Dict[str, Tensor]) -> ort.IOBinding: + input_values: Dict[str, ort.OrtValue] = { + name: ort.OrtValue.ortvalue_from_numpy(tensor.numpy(), device_type='cuda') for name, tensor in inputs.items() + } + output_names = [output.name for output in session.get_outputs()] + io_binding = session.io_binding() + for name, value in input_values.items(): + io_binding.bind_ortvalue_input(name, value) + for name in output_names: + io_binding.bind_output(name, device_type='cuda') + return io_binding + + +def ort_inference(session: ort.InferenceSession, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + io_binding = _prepare_io_binding(session, inputs) + session.run_with_iobinding(iobinding=io_binding) + outputs = {output_node.name: hidet.array(value.numpy()).cuda() + for output_node, value in zip(session.get_outputs(), io_binding.get_outputs())} + return outputs + + +def ort_benchmark(session: ort.InferenceSession, dummy_inputs: Dict[str, Tensor], warmup=10, number=10, repeat=10) -> List[float]: + io_binding = _prepare_io_binding(session, dummy_inputs) + for i in range(warmup): + session.run_with_iobinding(iobinding=io_binding) + results = [] + for i in range(repeat): + cuda.device_synchronize() + start_time = time() + for j in range(number): + session.run_with_iobinding(iobinding=io_binding) + cuda.device_synchronize() + end_time = time() + results.append((end_time - start_time) * 1000 / number) + return results + + +if __name__ == '__main__': + model_path = hidet.utils.hidet_cache_file('onnx', 'resnet50-v1-7.onnx') + session = create_ort_session(model_path) + inputs = { + 'data': hidet.randn([1, 3, 224, 224]) + } + outputs = ort_inference(session, inputs) + print(outputs) + print(ort_benchmark(session, inputs)) diff --git a/python/hidet/utils/profile_utils.py b/python/hidet/utils/profile_utils.py new file mode 100644 index 0000000..fa1cb0d --- /dev/null +++ b/python/hidet/utils/profile_utils.py @@ -0,0 +1,110 @@ +from typing import Dict, Any, Optional, List, ContextManager +from contextlib import nullcontext +from time import time_ns +import json + + +# See also: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU +# Color names: https://github.com/catapult-project/catapult/blob/main/tracing/tracing/base/color_scheme.html#L29-L72 +class TraceEvent: + def __init__(self, name, category, event_type, time_stamp, pid, tid, args: Dict[str, Any]): + self.name = name + self.category = category + self.event_type = event_type + self.time_stamp = time_stamp + self.pid = pid + self.tid = tid + self.args = args if args else {} + + def export(self) -> Dict: + event = { + 'name': self.name, + 'cat': self.category, + 'ph': self.event_type, + 'ts': self.time_stamp / 1000.0, + 'pid': self.pid, + 'tid': self.tid, + 'args': {k: str(v) for k, v in self.args.items()} + } + return event + + +class CpuTraceEvent(TraceEvent): + def __init__(self, name, category, event_type, tid, args): + super().__init__(name, category, event_type, time_ns(), pid=0, tid=tid, args=args) + + +class CudaTraceEvent(TraceEvent): + anchor_cuda_event = None + anchor_cuda_event_host_time = None + + def __init__(self, name, category, event_type, tid, args): + super().__init__(name, category, event_type, None, pid=0, tid=tid, args=args) + from hidet.runtime import cuda_event_pool + if CudaTraceEvent.anchor_cuda_event is None: + from hidet.ffi.cuda_api import cuda + CudaTraceEvent.anchor_cuda_event = cuda_event_pool.new_event() + cuda.device_synchronize() + CudaTraceEvent.anchor_cuda_event.record_on() + CudaTraceEvent.anchor_cuda_event_host_time = time_ns() + self.cuda_event = cuda_event_pool.new_event() + self.cuda_event.record_on() + + def export(self) -> Dict: + self.time_stamp = self.cuda_event.elapsed_time_since(self.anchor_cuda_event) * 1000000.0 + self.anchor_cuda_event_host_time + return TraceEvent.export(self) + + +class TraceContext: + def __init__(self, tracer, name, category, args, trace_cuda=False): + self.tracer: Tracer = tracer + self.name = name + self.category = category + self.args = args + self.trace_cuda = trace_cuda + + def __enter__(self): + self.tracer.events.append(CpuTraceEvent(self.name, self.category, 'B', 0, self.args)) + if self.trace_cuda: + self.tracer.events.append(CudaTraceEvent(self.name, self.category, 'B', 1, self.args)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.trace_cuda: + self.tracer.events.append(CudaTraceEvent(self.name, self.category, 'E', 1, self.args)) + self.tracer.events.append(CpuTraceEvent(self.name, self.category, 'E', 0, self.args)) + + +class Tracer: + def __init__(self): + self.events: List[TraceEvent] = [] + self.tracing: bool = False + + def export(self) -> Dict: + from hidet.ffi.cuda_api import cuda + cuda.device_synchronize() # sync cuda events in trace + ret = { + 'traceEvents': [event.export() for event in self.events], + 'displayTimeUnit': 'ns' + } + self.clear() + return ret + + def dump(self, f): + json.dump(self.export(), f) + + def clear(self): + from hidet.ffi.cuda_api import cuda + cuda.device_synchronize() # sync cuda events in trace + self.events.clear() + + def turn_on(self, turn_on=True): + self.tracing = turn_on + + def profile(self, name: str, category: str = 'python', args: Optional[Dict[str, Any]] = None, trace_cuda=False) -> ContextManager: + if self.tracing: + return TraceContext(self, name, category, args, trace_cuda) + else: + return nullcontext() + + +tracer = Tracer() diff --git a/python/hidet/utils/py.py b/python/hidet/utils/py.py new file mode 100644 index 0000000..60d2b27 --- /dev/null +++ b/python/hidet/utils/py.py @@ -0,0 +1,390 @@ +from typing import TypeVar, Iterable, Tuple, List, Union +import numpy as np +import cProfile +import contextlib +import io +import itertools +import os +import pstats +import time +from tabulate import tabulate +from typing import Callable, MutableMapping, Sequence + + +def prod(seq: Sequence): + if len(seq) == 0: + return 1 + else: + c = seq[0] + for i in range(1, len(seq)): + c = c * seq[i] + return c + + +TypeA = TypeVar('TypeA') +TypeB = TypeVar('TypeB') + + +def strict_zip(a: Sequence[TypeA], b: Sequence[TypeB]) -> Iterable[Tuple[TypeA, TypeB]]: + if len(a) != len(b): + raise ValueError('Expect two sequence have the same length in zip, got length {} and {}.'.format(len(a), len(b))) + return zip(a, b) + + +class COLORS: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +def green(v, fmt='{}'): + return COLORS.OKGREEN + fmt.format(v) + COLORS.ENDC + + +def cyan(v, fmt='{}'): + return COLORS.OKCYAN + fmt.format(v) + COLORS.ENDC + + +def blue(v, fmt='{}'): + return COLORS.OKBLUE + fmt.format(v) + COLORS.ENDC + + +def red(v, fmt='{}'): + return COLORS.WARNING + fmt.format(v) + COLORS.ENDC + + +def color(v, fmt='{}', fg='default', bg='default'): + fg_code = { + "black": 30, + "red": 31, + "green": 32, + "yellow": 33, + "blue": 34, + "magenta": 35, + "cyan": 36, + "white": 37, + "default": 39, + } + bg_code = { + "black": 40, + "red": 41, + "green": 42, + "yellow": 43, + "blue": 44, + "magenta": 45, + "cyan": 46, + "white": 47, + "default": 49 + } + return '\033[{};{}m{}\033[0m'.format(fg_code[fg], bg_code[bg], fmt.format(v)) + + +def color_table(): + fg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] + bg_names = ["default", "black", "red", "green", "yellow", "blue", "magenta", "cyan", "white"] + print('{:>10} {:>10} {:<10}'.format('fg', 'bg', 'text')) + for bg in bg_names: + for fg in fg_names: + print('{:>10} {:>10} {}'.format(fg, bg, color('sample text', fg=fg, bg=bg))) + + +def color_rgb(v, fg, fmt='{}'): + return '\033[38;2;{};{};{}m{}\033[0m'.format(fg[0], fg[1], fg[2], v) + + +def color_text(v, fmt='{}', idx: int = 0): + if idx == 0: + return fmt.format(v) + colors = { + 1: (153, 96, 52), + 2: (135, 166, 73) + } + return color_rgb(v, colors[idx], fmt=fmt) + + +def nocolor(s: str) -> str: + for name, value in COLORS.__dict__.items(): + if isinstance(value, str) and value[0] == '\033': + s = s.replace(value, '') + return s + + +class Timer: + def __init__(self, msg=None, file=None, verbose=True, stdout=True): + self.start_time = None + self.end_time = None + self.msg = msg + self.stdout = stdout + self.verbose = verbose + self.file = file + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + if self.msg is not None and self.verbose: + msg = '{} {}'.format(self.msg, green(self.time2str(self.end_time - self.start_time))) + if self.stdout: + print(msg) + if self.file: + if isinstance(self.file, str): + with open(self.file, 'w') as f: + f.write(nocolor(msg)) + else: + self.file.write(msg + '\n') + + def elapsed_seconds(self) -> float: + return self.end_time - self.start_time + + def time2str(self, seconds: float) -> str: + if seconds < 1: + return '{:.1f} {}'.format(seconds * 1000, 'ms') + elif seconds < 60: + return '{:.1f} {}'.format(seconds, 'seconds') + elif seconds < 60 * 60: + return '{:.1f} {}'.format(seconds / 60, 'minutes') + else: + return '{:.1f} {}'.format(seconds / 60 / 60, 'hours') + + +# class DictCustomKey(MutableMapping, dict): +# def __init__(self, hash_func: Callable[[object], int]): +# super().__init__() +# self.hash_func = hash_func +# +# def __delitem__(self, v): +# return dict.__delitem__(self, self.hash_func(v)) +# +# def __len__(self) -> int: +# return dict.__len__(self) +# +# def __iter__(self): +# return dict.__iter__(self) +# +# def __getitem__(self, item): +# return dict.__getitem__(self, self.hash_func(item)) +# +# def __setitem__(self, key, value): +# return dict.__setitem__(self, self.hash_func(key), value) +# + +def repeat_until_converge(func, obj, limit=None): + i = 0 + while True: + i += 1 + orig_obj = obj + obj = func(obj) + if obj is orig_obj: + return obj + if limit is not None and i >= limit: + return obj + + +def get_next_file_index(dirname: str) -> int: + indices = set() + for fname in os.listdir(dirname): + parts = fname.split('_') + with contextlib.suppress(ValueError): + indices.add(int(parts[0])) + for idx in itertools.count(0): + if idx not in indices: + return idx + + +def factor(n): + """ + example: + factor(12) => [1, 2, 3, 4, 6, 12] + """ + i = 1 + ret = [] + while i * i <= n: + if n % i == 0: + ret.append(i) + if i * i != n: + ret.append(n // i) + i += 1 + return list(sorted(ret)) + + +def same_list(lhs, rhs, use_equal=False): + assert isinstance(lhs, (tuple, list)) and isinstance(rhs, (tuple, list)) + if len(lhs) != len(rhs): + return False + for l, r in zip(lhs, rhs): + if use_equal: + if l != r: + return False + else: + if l is not r: + return False + return True + + +class HidetProfiler: + def __init__(self, display_on_exit=True): + self.pr = cProfile.Profile() + self.display_on_exit = display_on_exit + + def __enter__(self): + self.pr.enable() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.pr.disable() + if self.display_on_exit: + print(self.result()) + + def result(self): + s = io.StringIO() + ps = pstats.Stats(self.pr, stream=s).sort_stats('cumulative') + ps.print_stats() + return str(s.getvalue()) + + +class TableRowContext: + def __init__(self, tb): + self.tb = tb + self.row = [] + + def __iadd__(self, other): + if isinstance(other, (tuple, list)): + self.row.extend(other) + else: + self.row.append(other) + + def append(self, other): + self.row.append(other) + + def extend(self, other): + self.row.extend(other) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.tb.rows.append(self.row) + + +class TableBuilder: + def __init__(self, headers=tuple(), tablefmt='simple', floatfmt='.3f'): + self.headers = list(headers) + self.rows = [] + self.tablefmt = tablefmt + self.floatfmt = floatfmt + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __iadd__(self, row): + self.rows.append(row) + return self + + def __str__(self): + return str(tabulate(self.rows, self.headers, tablefmt=self.tablefmt, floatfmt=self.floatfmt)) + + def new_row(self) -> TableRowContext: + return TableRowContext(tb=self) + + def extend_header(self, column_names): + self.headers.extend(column_names) + + +def line_profile(): + from line_profiler_pycharm import profile + return profile + + +def initialize(*args, **kwargs): + """ + Decorate an initialization function. After decorating with this function, the initialization function will be called + after the definition. + + Parameters + ---------- + args: + The positional arguments of initializing. + kwargs: + The keyword arguments of initializing. + + Returns + ------- + ret: + A decorator that will call given function with args and kwargs, + and return None (to prevent this function to be called again). + """ + def decorator(f): + f(*args, **kwargs) + return None + return decorator + + +def gcd(a: int, b: int) -> int: + """ + Get the greatest common divisor of non-negative integers a and b. + + Parameters + ---------- + a: int + The lhs operand. + b: int + The rhs operand. + + Returns + ------- + ret: int + The greatest common divisor. + """ + assert a >= 0 and b >= 0 + return a if b == 0 else gcd(b, a % b) + + +def lcm(a: int, b: int) -> int: + """ + Get the least common multiple of non-negative integers a and b. + Parameters + ---------- + a: int + The lhs operand. + b: int + The rhs operand. + + Returns + ------- + ret: int + The least common multiple. + """ + return a // gcd(a, b) * b + + +def error_tolerance(a: Union[np.ndarray, 'Tensor'], b: Union[np.ndarray, 'Tensor']) -> float: + from hidet.tos import Tensor + if isinstance(a, Tensor): + a = a.numpy() + if isinstance(b, Tensor): + b = b.numpy() + lf = 0.0 + rg = 9.0 + for step in range(20): + mid = (lf + rg) / 2.0 + if np.allclose(a, b, rtol=mid, atol=mid): + rg = mid + else: + lf = mid + return (lf + rg) / 2.0 + + +if __name__ == '__main__': + # color_table() + print(color_text('sample', idx=1)) + print(color_text('sample', idx=2)) diff --git a/python/hidet/utils/tensorrt_utils.py b/python/hidet/utils/tensorrt_utils.py new file mode 100644 index 0000000..1219b78 --- /dev/null +++ b/python/hidet/utils/tensorrt_utils.py @@ -0,0 +1,267 @@ +import datetime +from typing import List, Optional, Dict, Tuple +from collections import OrderedDict +from hashlib import sha256 +import json +import os +import time +import numpy as np +import tensorrt as trt +import hidet +from hidet.ffi import cuda +from hidet import Tensor, randn, empty +from hidet.utils import hidet_cache_dir, nvtx_annotate + + +class Profiler(trt.IProfiler): + def __init__(self): + super().__init__() + self.layer2latency: Dict[str, float] = OrderedDict() + + def report_layer_time(self, layer_name, ms): + self.layer2latency[layer_name] = ms + + def export_trace(self): + from hidet.utils.profile_utils import TraceEvent + events = [] + current_time = 0 + for layer, latency in self.layer2latency.items(): + events.append(TraceEvent(layer, 'op', 'B', current_time * 1000000, 0, 0, {'name': layer})) + current_time += latency + events.append(TraceEvent(layer, 'op', 'E', current_time * 1000000, 0, 0, {'name': layer})) + return { + 'traceEvents': [event.export() for event in events], + 'displayTimeUnit': 'ns' + } + + +class Logger(trt.ILogger): + def __init__(self, log_file: Optional[str] = None, print_out_level: str = 'INFO'): + super().__init__() + self.log_file = log_file + self.print_out_level = print_out_level + self.opened_file = None + self.level_id = { + 'INTERNAL_ERROR': 0, + 'ERROR': 1, + 'WARNING': 2, + 'INFO': 3, + 'VERBOSE': 4 + } + if self.log_file: + self.opened_file = open(self.log_file, 'w') + + def log(self, severity: trt.ILogger.Severity, msg: str): + severity2name = { + trt.ILogger.INTERNAL_ERROR: 'INTERNAL_ERROR', + trt.ILogger.ERROR: 'ERROR', + trt.ILogger.WARNING: 'WARNING', + trt.ILogger.INFO: 'INFO', + trt.ILogger.VERBOSE: 'VERBOSE' + } + severity_name = severity2name[severity] + msg = '{} {} {}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S'), severity_name, msg) + self.opened_file.write(msg) + # self.opened_file.flush() + if self.level_id[self.print_out_level] >= self.level_id[severity_name] >= self.level_id['WARNING']: + print(msg) + if severity_name in ['INTERNAL_ERROR', 'ERROR']: + raise RuntimeError('TensorRT ' + msg) + + def __del__(self): + if self.opened_file: + self.opened_file.close() + + +def milo_bytes(MiB): + return MiB << 20 + + +def create_engine_from_onnx( + onnx_model_path: str, + workspace_bytes: int = 512 << 20, + input_shapes: Optional[Dict[str, List[int]]] = None, + use_tf32: bool = False, + use_fp16: bool = False +) -> trt.ICudaEngine: + cache_dir = hidet_cache_dir('trt_engine') + os.makedirs(cache_dir, exist_ok=True) + model_name = os.path.basename(onnx_model_path).split('.')[0] + shape_hash = tuple((name, tuple(shape)) for name, shape in sorted(input_shapes.items(), key=lambda item: item[0])) + shape_hash_suffix = sha256(str(shape_hash).encode()).hexdigest()[:6] + engine_name = '{}{}{}_ws{}_{}.engine'.format(model_name, '_tf32' if use_tf32 else '', '_fp16' if use_fp16 else '', workspace_bytes // (1 << 20), shape_hash_suffix) + engine_path = os.path.join(cache_dir, engine_name) + + # logger = trt.Logger(min_severity=trt.Logger.ERROR) # use WARNINGS when needed + + if os.path.exists(engine_path): + # load the engine directly + logger = Logger(engine_path + '.log', print_out_level='ERROR') + runtime = trt.Runtime(logger) + with open(engine_path, 'rb') as f: + serialized_engine = f.read() + engine = runtime.deserialize_cuda_engine(serialized_engine) + else: + build_logger = Logger(engine_path + '.build.log', print_out_level='ERROR') + builder = trt.Builder(build_logger) + # parse onnx model + network: trt.INetworkDefinition = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + onnx_parser = trt.OnnxParser(network, build_logger) + success = onnx_parser.parse_from_file(onnx_model_path) + for idx in range(onnx_parser.num_errors): + print(onnx_parser.get_error(idx)) + if not success: + raise Exception('Failed parse onnx model in tensorrt onnx parser.') + + # set configs of the network builder + config: trt.IBuilderConfig = builder.create_builder_config() + # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes) + config.max_workspace_size = workspace_bytes + # allow us to inspect the engine, see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#engine-inspector + config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED + # whether allow tf32/, see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#tf32-inference-c + if use_tf32: + config.set_flag(trt.BuilderFlag.TF32) + else: + config.clear_flag(trt.BuilderFlag.TF32) + if use_fp16: + config.set_flag(trt.BuilderFlag.FP16) + else: + config.clear_flag(trt.BuilderFlag.FP16) + # force to use the precision in network definition, see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#layer-level-control + config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS) + + # optimization profiles required by dynamic inputs + profile: trt.IOptimizationProfile = builder.create_optimization_profile() + # assert len(inputs_shape) == network.num_inputs, 'Expect {} number of input shapes'.format(network.num_inputs) + for i in range(network.num_inputs): + tensor: trt.ITensor = network.get_input(i) + if any(v == -1 for v in tensor.shape): + if input_shapes is None or tensor.name not in input_shapes: + raise Exception("Found dynamic input: {}{}, " + "please specify input_shapes as the target shape.".format(tensor.name, list(tensor.shape))) + opt_shape = input_shapes[tensor.name] + profile.set_shape(tensor.name, min=opt_shape, opt=opt_shape, max=opt_shape) + config.add_optimization_profile(profile) + + # build engine + supported = builder.is_network_supported(network, config) + if not supported: + raise Exception('Network is not supported by TensorRT.') + engine: trt.ICudaEngine = builder.build_engine(network, config) + + if engine is None: + raise Exception('Can not build network with given config.') + + # save engine + serialized_engine = builder.build_serialized_network(network, config) + with open(engine_path, 'wb') as f: + f.write(serialized_engine) + return engine + + +dtype_map = { + trt.DataType.INT32: 'int32', + trt.DataType.FLOAT: 'float32', +} + + +def _prepare_buffer(engine: trt.ICudaEngine, inputs: Dict[str, Tensor]) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], List[int]]: + inputs = inputs.copy() + outputs = {} + buffers = [] + for i in range(engine.num_bindings): + name = engine.get_binding_name(i) + if engine.binding_is_input(i): + dtype: trt.DataType = engine.get_binding_dtype(i) + if name not in inputs: + raise ValueError("TensorRT engine requires input '{}', but only received inputs: {}.".format(name, list(inputs.keys()))) + if dtype != inputs[name].dtype: + inputs[name] = hidet.tos.ops.cast(inputs[name], dtype_map[dtype]) + buffers.append(inputs[name].storage.addr) + else: + shape = engine.get_binding_shape(i) + dtype: trt.DataType = engine.get_binding_dtype(i) + output = hidet.empty(shape, dtype_map[dtype], device='cuda') + outputs[name] = output + buffers.append(output.storage.addr) + return inputs, outputs, buffers + + +def engine_inference(engine: trt.ICudaEngine, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: + # prepare inputs and outputs + inputs, outputs, buffers = _prepare_buffer(engine, inputs) + + # inference + context: trt.IExecutionContext = engine.create_execution_context() + context.execute_async_v2(buffers, 0) + cuda.device_synchronize() + return outputs + + +def engine_benchmark(engine: trt.ICudaEngine, dummy_inputs: Dict[str, Tensor], warmup: int = 3, number: int = 5, repeat: int = 5) -> List[float]: + inputs, outputs, buffers = _prepare_buffer(engine, dummy_inputs) + context: trt.IExecutionContext = engine.create_execution_context() + results = [] + with nvtx_annotate('warmup'): + for i in range(warmup): + context.execute_async_v2(buffers, 0) + cuda.device_synchronize() + for i in range(repeat): + with nvtx_annotate(f'repeat {i}'): + cuda.device_synchronize() + start_time = time.time() + for j in range(number): + context.execute_async_v2(buffers, 0) + cuda.device_synchronize() + end_time = time.time() + results.append((end_time - start_time) * 1000 / number) + return results + + +def engine_inspect(engine: trt.ICudaEngine) -> Dict: + inspector: trt.EngineInspector = engine.create_engine_inspector() + layer_information = {} + for i in range(engine.num_layers): + layer_information['layer_{}'.format(i)] = json.loads(str(inspector.get_layer_information(i, trt.LayerInformationFormat.JSON))) + # engine_information = json.loads(str(inspector.get_engine_information(trt.LayerInformationFormat.JSON))) + return { + 'layers': layer_information, + # 'engine': engine_information + } + + +def engine_profiler(engine: trt.ICudaEngine, dummy_inputs: Dict[str, Tensor]) -> Dict: + # prepare inputs and outputs + inputs, outputs, buffers = _prepare_buffer(engine, dummy_inputs) + context: trt.IExecutionContext = engine.create_execution_context() + profiler = Profiler() + context.profiler = profiler + context.execute_v2(buffers) + cuda.device_synchronize() + return profiler.export_trace() + + +if __name__ == '__main__': + # onnx_model_path = os.path.join(hidet_cache_dir('onnx'), 'resnet50-v1-7.onnx') + onnx_model_path = os.path.join(hidet_cache_dir('onnx'), 'bert-base-uncased.onnx') + batch_size = 1 + seq_length = 512 + vocab_size = 30522 + input_ids = np.random.randint(0, vocab_size, [batch_size, seq_length], dtype=np.int64) + attention_mask = np.ones(shape=[batch_size, seq_length], dtype=np.int64) + token_type_ids = np.zeros(shape=[batch_size, seq_length], dtype=np.int64) + + # onnx + inputs = { + 'input_ids': hidet.array(input_ids).cuda(), + 'attention_mask': hidet.array(attention_mask).cuda(), + 'token_type_ids': hidet.array(token_type_ids).cuda() + } + engine = create_engine_from_onnx(onnx_model_path, input_shapes={ + key: tensor.shape for key, tensor in inputs.items() + }) + outputs = engine_inference(engine, inputs) + results = engine_benchmark(engine, inputs) + print(results) + diff --git a/python/hidet/utils/torch_utils.py b/python/hidet/utils/torch_utils.py new file mode 100644 index 0000000..eaff681 --- /dev/null +++ b/python/hidet/utils/torch_utils.py @@ -0,0 +1,47 @@ +import sys +import os +import subprocess +from hidet.utils import hidet_cache_file + + +def export_torchvision_model_as_onnx(model_name: str, output_path: str, skip_existed: bool = True): + if skip_existed and os.path.exists(output_path): + return + os.makedirs(os.path.dirname(output_path), exist_ok=True) + import torchvision + import torch + if model_name == 'resnet50': + model = torchvision.models.resnet50(pretrained=True).cuda() + input_shape = [1, 3, 224, 224] + elif model_name == 'inception_v3': + model = torchvision.models.inception_v3(pretrained=True, transform_input=False, aux_logits=True).cuda() + input_shape = [1, 3, 299, 299] + elif model_name == 'mobilenet_v2': + model = torchvision.models.mobilenet_v2(pretrained=True).cuda() + input_shape = [1, 3, 224, 224] + else: + raise NotImplementedError(model_name) + + model.eval() + dummy_input = torch.randn(*input_shape, device='cuda') + input_names = ['data'] + output_names = ['output'] + torch.onnx.export( + model=model, + args=dummy_input, + f=output_path, + training=torch.onnx.TrainingMode.PRESERVE, + input_names=input_names, + output_names=output_names, + do_constant_folding=False, + dynamic_axes={ + 'data': {0: 'batch_size'}, + 'output': {0: 'batch_size'} + } + ) + + +if __name__ == '__main__': + names = ['resnet50', 'inception_v3', 'mobilenet_v2'] + for name in names: + export_torchvision_model_as_onnx(name, hidet_cache_file('onnx', f'{name}.onnx')) diff --git a/python/hidet/utils/transformers_utils.py b/python/hidet/utils/transformers_utils.py new file mode 100644 index 0000000..4060ead --- /dev/null +++ b/python/hidet/utils/transformers_utils.py @@ -0,0 +1,37 @@ +import sys +import os +import subprocess + + +def export_transformer_model_as_onnx(model_name: str, output_path: str, feature='default', skip_exists=True): + """ + Export a model from transformers package. + + Parameters + ---------- + model_name: str + The model name. + output_path: str + The output path. + feature: str + The feature of the exported model. + skip_exists: bool + Skip export if target exists. Default True. + + Examples + -------- + Call export_transformer_model_as_onnx() will download (when needed) the requested model and export it to an onnx model. + The function will return '{output_dir}/bert-base-uncased.onnx', which can be load by onnx package. + """ + if skip_exists and os.path.exists(output_path): + return + temp_dir = '/tmp/hidet' + command = '{} -m transformers.onnx --model {} --feature {} {}'.format(sys.executable, model_name, feature, temp_dir) + print("Running '{}'".format(command)) + subprocess.run(command.split(), check=True) + os.rename(os.path.join(temp_dir, 'model.onnx'), output_path) + print('Model saved at: {}'.format(output_path)) + + +if __name__ == '__main__': + export_transformer_model_as_onnx(model_name='bert-base-uncased', output_path='./bert.onnx') diff --git a/python/hidet/utils/tvm_utils.py b/python/hidet/utils/tvm_utils.py new file mode 100644 index 0000000..77639e1 --- /dev/null +++ b/python/hidet/utils/tvm_utils.py @@ -0,0 +1,234 @@ +import os +import time +from hashlib import sha256 +from typing import Optional, Dict, List + +import numpy as np +import onnx + +import hidet +import tvm +import tvm.relay.backend.executor_factory +from hidet import Tensor +from hidet.ffi import cuda +from hidet.utils import Timer, hidet_cache_dir, hidet_cache_file +from tvm import relay +from tvm.contrib import graph_executor +from tvm.contrib.graph_executor import GraphModule + + +def dump_code(graph_factory: tvm.relay.backend.executor_factory.ExecutorFactoryModule, out_dir): + runtime_module: tvm.runtime.Module = graph_factory.get_lib() + runtime_cuda_module = runtime_module.imported_modules[0] + os.makedirs(out_dir, exist_ok=True) + with open(os.path.join(out_dir, 'tvm_host.cpp'), 'w') as f: + f.write(runtime_module.get_source()) + with open(os.path.join(out_dir, 'tvm_cuda.cu'), 'w') as f: + f.write(runtime_cuda_module.get_source()) + + +def dump_relay_cuda_code(ir_module, params=None, out_dir: str = './outs', opt_level=3): + import tvm.relay + import tvm.target + with tvm.transform.PassContext(opt_level=opt_level): + graph_module = tvm.relay.build(ir_module, target='cuda', target_host=tvm.target.Target('c'), params=params) + # graph_module = tvm.relay.build(ir_module, target='cuda') + dump_code(graph_module, out_dir) + + +def autotvm_tune(ir_module: tvm.ir.IRModule, params: Dict[str, tvm.nd.NDArray], target: tvm.target.Target, out_dir: str, tuner_name='ga', num_trial=1000) -> None: + from tvm import autotvm + import tvm.contrib.graph_executor + lib_path = os.path.join(out_dir, 'lib.so') + + log_file = os.path.join(out_dir, 'records.json') + if not os.path.exists(log_file): + tasks: List[autotvm.task.Task] = autotvm.task.extract_from_program(ir_module, params, target=target) + with open(os.path.join(out_dir, 'tasks.txt'), 'w') as f: + for task_idx, task in enumerate(tasks): + f.write('task {}\n{}\n\n'.format(task_idx, task)) + + temp_log_file = log_file + '.tmp' + open(temp_log_file, 'a').close() # in case no tunable operators + with Timer(msg='AutoTVM tuning of {} tasks'.format(len(tasks)), file=os.path.join(out_dir, 'tuning_time.txt')): + for task_idx, task in enumerate(tasks): + if tuner_name == 'xgb': + tuner = autotvm.tuner.XGBTuner(task) + elif tuner_name == 'ga': + tuner = autotvm.tuner.GATuner(task) + else: + raise ValueError(tuner_name) + num_trial = min(num_trial, len(task.config_space)) + tuner.tune( + n_trial=num_trial, + measure_option=autotvm.measure_option( + builder=autotvm.LocalBuilder(timeout=10), + runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150), + ), + callbacks=[ + autotvm.callback.progress_bar(num_trial, f'[Task {task_idx:>2}/{len(tasks):<2}]'), + autotvm.callback.log_to_file(temp_log_file) + ] + ) + autotvm.record.pick_best(temp_log_file, log_file) + # os.remove(temp_log_file) + + with autotvm.apply_history_best(log_file): + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(ir_module, target=target, params=params) + lib.export_library(lib_path) + dump_code(lib, out_dir) + + +def ansor_tune(ir_module: tvm.ir.IRModule, params: Dict[str, tvm.nd.NDArray], target: tvm.target.Target, out_dir: str, num_trial_per_task=800): + from tvm import auto_scheduler + log_file = os.path.join(out_dir, 'records.json') + lib_path = os.path.join(out_dir, 'lib.so') + + pair = auto_scheduler.extract_tasks(ir_module, params, target) + tasks: List[auto_scheduler.SearchTask] = pair[0] + task_weights: List[int] = pair[1] + + if not os.path.exists(log_file): + with open(os.path.join(out_dir, 'tasks.txt'), 'w') as f: + for task_idx, task in enumerate(tasks): + f.write('task {} (key {})\n{}\n\n'.format(task_idx, task.workload_key, task.compute_dag)) + temp_log_file = log_file + '.temp' + tune_option = auto_scheduler.TuningOptions( + num_measure_trials=num_trial_per_task * len(tasks), + measure_callbacks=[ + auto_scheduler.RecordToFile(temp_log_file) + ] + ) + tuner = auto_scheduler.TaskScheduler(tasks, task_weights, callbacks=[ + auto_scheduler.task_scheduler.PrintTableInfo(), + auto_scheduler.task_scheduler.LogEstimatedLatency(os.path.join(out_dir, 'estimated_latency.csv')) + ]) + with Timer(msg='Ansor tuning of {} tasks'.format(len(tasks)), file=os.path.join(out_dir, 'tuning_time.txt')): + tuner.tune(tune_option) + os.rename(temp_log_file, log_file) + with auto_scheduler.ApplyHistoryBest(log_file): + with tvm.transform.PassContext(opt_level=3, config={'relay.backend.use_auto_scheduler': True}): + lib = relay.build(ir_module, target, params=params) + lib.export_library(lib_path) + dump_code(lib, out_dir) + + +def build_ir_module(ir_module: tvm.ir.IRModule, params: Dict[str, tvm.nd.NDArray], target: tvm.target.Target, out_dir: str): + lib_path = os.path.join(out_dir, 'lib.so') + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(ir_module, target, params=params) + lib.export_library(lib_path) + dump_code(lib, out_dir) + + +def tvm_graph_module_from_onnx(onnx_model_path: str, input_shapes: Optional[Dict[str, List[int]]], tune_autotvm=False, tune_ansor=False, tune_trial_per_task=800) -> GraphModule: + # determine output dir + if tune_autotvm and tune_ansor: + raise ValueError('Can not tune network with ansor and autotvm at the same time.') + if tune_autotvm: + tuner_name = 'autotvm' + elif tune_ansor: + tuner_name = 'ansor' + else: + tuner_name = 'notune' + model_name = os.path.basename(onnx_model_path).rsplit('.', 1)[0] + cache_dir = hidet_cache_dir(category='tvm_cache') + hash_key = (onnx_model_path + str(input_shapes) + str(tune_trial_per_task)) + out_dir = os.path.join(cache_dir, f'{model_name}_{tuner_name}_{sha256(hash_key.encode()).hexdigest()[:6]}') + os.makedirs(out_dir, exist_ok=True) + + lib_path = os.path.join(out_dir, 'lib.so') + if not os.path.exists(lib_path): + onnx_model = onnx.load_model(onnx_model_path) + ir_module, params = relay.frontend.from_onnx(onnx_model, input_shapes, dtype='float32') + target = tvm.target.cuda(arch='sm_{}{}'.format(*hidet.utils.cuda.query_compute_capability())) + with open(os.path.join(out_dir, 'relay_model.txt'), 'w') as f: + f.write(str(ir_module)) + with open(os.path.join(out_dir, 'model_info.txt'), 'w') as f: + lines = [ + 'model: {}'.format(onnx_model_path), + 'inputs: {}'.format(str(input_shapes)), + 'ansor: {}'.format(tune_ansor), + 'autotvm: {}'.format(tune_autotvm), + 'trial per task: {}'.format(tune_trial_per_task) + ] + f.write('\n'.join(lines)) + if tune_autotvm: + autotvm_tune(ir_module, params, target, out_dir=out_dir, num_trial=tune_trial_per_task) + elif tune_ansor: + ansor_tune(ir_module, params, target, out_dir=out_dir, num_trial_per_task=tune_trial_per_task) + else: + build_ir_module(ir_module, params, target, out_dir=out_dir) + assert os.path.exists(lib_path), 'Failed to generate lib for model {}.'.format(onnx_model_path) + lib = tvm.runtime.load_module(lib_path) + device = tvm.cuda() + gmod = graph_executor.GraphModule(lib['default'](device)) + return gmod + + +def tvm_inference(gmod: GraphModule, inputs: Dict[str, Tensor]) -> List[Tensor]: + # currently, TVM does not support get output by name, thus return a list of outputs + for name, tensor in inputs.items(): + gmod.set_input(name, value=tvm.nd.array(tensor.cpu().numpy())) + gmod.run() + outputs = [] + for i in range(gmod.get_num_outputs()): + output: tvm.nd.NDArray = gmod.get_output(i) + outputs.append(hidet.array(output.numpy()).cuda()) + return outputs + + +def tvm_benchmark(gmod: GraphModule, dummy_inputs: Dict[str, Tensor], warmup=10, number=10, repeat=10) -> List[float]: + for name, tensor in dummy_inputs.items(): + gmod.set_input(name, value=tvm.nd.array(tensor.cpu().numpy())) + for i in range(warmup): + gmod.run() + results = [] + for i in range(repeat): + cuda.device_synchronize() + start_time = time.time() + for j in range(number): + gmod.run() + cuda.device_synchronize() + end_time = time.time() + results.append((end_time - start_time) * 1000 / number) + return results + + +def tvm_commit() -> str: + info = tvm.support.libinfo() + return info['GIT_COMMIT_HASH'] + + +if __name__ == '__main__': + print(tvm_commit()) + # # onnx_model_path = os.path.join(hidet_cache_dir('onnx'), 'resnet50-v1-7.onnx') + # # dummy_inputs = { + # # 'data': hidet.randn([1, 3, 224, 224], device='cuda') + # # } + # onnx_model_path = hidet_cache_file('onnx', 'bert-base-uncased.onnx') + # batch_size = 1 + # seq_length = 512 + # vocab_size = 30522 + # input_ids = np.random.randint(0, vocab_size, [batch_size, seq_length], dtype=np.int64) + # attention_mask = np.ones(shape=[batch_size, seq_length], dtype=np.int64) + # token_type_ids = np.zeros(shape=[batch_size, seq_length], dtype=np.int64) + # dummy_inputs = { + # 'input_ids': hidet.array(input_ids).cuda(), + # 'attention_mask': hidet.array(attention_mask).cuda(), + # 'token_type_ids': hidet.array(token_type_ids).cuda() + # } + # input_shapes = {key: tensor.shape for key, tensor in dummy_inputs.items()} + # # from hidet.tos.frontend import from_onnx + # # hidet_model = from_onnx(onnx_model_path) + # # hidet_output = hidet_model(*dummy_inputs.values()) + # # + # gmod = tvm_graph_module_from_onnx(onnx_model_path, input_shapes) + # tvm_output = tvm_inference(gmod, dummy_inputs)[0] + # + # # np.testing.assert_allclose(actual=hidet_output.cpu().numpy(), desired=tvm_output.cpu().numpy(), rtol=1e-5, atol=1e-5) + # + # # print(tvm_benchmark(gmod, dummy_inputs)) + + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e3ac913 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,49 @@ +################################################################################ +# Necessary packages +################################################################################ +gitpython +numpy +sympy + +# used for query available memory +psutil + +# used for print table +tabulate + +# python tests +pytest + +# show progress bar +tqdm + +# used to annotate the scope of events in host process, which can be visualized +# in Nsight System. +nvtx + +# for onnx frontend +onnx==1.10.2 + +################################################################################ +# Optional packages +################################################################################ +--extra-index-url https://download.pytorch.org/whl/cu115 +torch==1.11 +torchvision==0.12 + +# for language model converting +transformers==4.19.2 +transformers[onnx] + +# for onnx runtime baseline +onnxruntime-gpu==1.11.1 + +# for tensor rt baseline +--extra-index-url https://pypi.ngc.nvidia.com +nvidia-tensorrt==8.2.5.1 + +# for tvm tuning +decorator +xgboost==1.5.0 +tornado +cloudpickle diff --git a/src/hidet/cuda_api.cpp b/src/hidet/cuda_api.cpp new file mode 100644 index 0000000..bff76af --- /dev/null +++ b/src/hidet/cuda_api.cpp @@ -0,0 +1,251 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct CurandContext { + curandGenerator_t generator{}; + CurandContext() { + unsigned long long seed = time(nullptr) ^ clock(); + CURAND_CALL(curandCreateGenerator(&generator, CURAND_RNG_PSEUDO_DEFAULT)); + CURAND_CALL(curandSetPseudoRandomGeneratorSeed(generator, seed)); + } + + static CurandContext* global() { + static CurandContext ctx; + return &ctx; + } +}; + +DLL void hidet_cuda_mem_info(uint64_t *free, uint64_t *total) { + API_BEGIN(); + CUDA_CALL(cudaMemGetInfo(free, total)); + API_END(); +} + +DLL uint64_t hidet_cuda_malloc_async(uint64_t bytes) { + API_BEGIN(); + void *ptr; + cudaError_t status = cudaMallocAsync(&ptr, bytes, nullptr); + if(status == cudaErrorMemoryAllocation) { + // out of memory + return 0; + } + CUDA_CALL(status); + return reinterpret_cast(ptr); + API_END(0); +} + +DLL uint64_t hidet_cuda_malloc_host(uint64_t bytes) { + API_BEGIN(); + void* ptr; + CUDA_CALL(cudaMallocHost(&ptr, bytes)); + return reinterpret_cast(ptr); + API_END(0); +} + +DLL void hidet_cuda_free_async(uint64_t addr) { + API_BEGIN(); + CUDA_CALL(cudaFreeAsync(reinterpret_cast(addr), nullptr)); + API_END(); +} + +DLL void hidet_cuda_free_host(uint64_t addr) { + API_BEGIN(); + CUDA_CALL(cudaFreeHost(reinterpret_cast(addr))); +// auto status = cudaFreeHost(reinterpret_cast(addr)); +// if(status != cudaSuccess) { +// fprintf(stderr, "Can not free host memory %p\n", reinterpret_cast(addr)); +// } + API_END(); +} + +DLL void hidet_cuda_memset_async(uint64_t addr, uint64_t bytes, uint8_t value) { + API_BEGIN(); + CUDA_CALL(cudaMemsetAsync(reinterpret_cast(addr), value, bytes, nullptr)); + API_END(); +} + +DLL void hidet_cuda_memcpy_async(uint64_t src, uint64_t dst, uint64_t bytes, uint32_t kind) { + API_BEGIN(); + /*! + * kind: + * cudaMemcpyHostToHost = 0, + * cudaMemcpyHostToDevice = 1, + * cudaMemcpyDeviceToHost = 2, + * cudaMemcpyDeviceToDevice = 3, + */ + CUDA_CALL(cudaMemcpyAsync(reinterpret_cast(dst), reinterpret_cast(src), bytes, cudaMemcpyKind(kind), nullptr)); + API_END(); +} + +DLL void hidet_cuda_device_synchronize() { + API_BEGIN(); + CUDA_CALL(cudaDeviceSynchronize()); + API_END(); +} + +DLL void hidet_curand_generate_uniform(uint64_t addr, uint64_t size) { + API_BEGIN(); + CURAND_CALL(curandGenerateUniform(CurandContext::global()->generator, reinterpret_cast(addr), size)); + API_END(); +} + +DLL void hidet_curand_generate_normal(uint64_t addr, uint64_t size, float mean, float stddev) { + API_BEGIN(); + // This function only support to generate even number of random numbers. We work around this limitation by up round to a multiple of 2. + // this usually will not trigger error because the memory allocation on cuda is 256 bytes aligned. + if(size & 1) { + size += 1; + } + CURAND_CALL(curandGenerateNormal(CurandContext::global()->generator, reinterpret_cast(addr), size, mean, stddev)); + API_END(); +} + +DLL void hidet_cuda_mem_pool_trim_to(uint64_t min_bytes_to_keep) { + API_BEGIN(); + cudaMemPool_t pool; + CUDA_CALL(cudaDeviceGetDefaultMemPool(&pool, 0)); + CUDA_CALL(cudaMemPoolTrimTo(pool, min_bytes_to_keep)); + API_END(); +} + +DLL uint64_t hidet_cuda_stream_create() { + API_BEGIN(); + cudaStream_t stream; + CUDA_CALL(cudaStreamCreate(&stream)); + return reinterpret_cast(stream); + API_END(0); +} + + +DLL void hidet_cuda_stream_destroy(uint64_t stream) { + API_BEGIN(); + CUDA_CALL(cudaStreamDestroy(reinterpret_cast(stream))); + API_END(); +} + +DLL void hidet_cuda_stream_synchronize(uint64_t stream) { + API_BEGIN(); + CUDA_CALL(cudaStreamSynchronize(reinterpret_cast(stream))); + API_END(); +} + +DLL uint64_t hidet_cuda_event_create() { + API_BEGIN(); + cudaEvent_t event; + CUDA_CALL(cudaEventCreate(&event)); + return reinterpret_cast(event); + API_END(0); +} + +DLL void hidet_cuda_event_destroy(uint64_t handle) { + API_BEGIN(); + auto event = reinterpret_cast(handle); + CUDA_CALL(cudaEventDestroy(event)); + API_END(); +} + +DLL float hidet_cuda_event_elapsed_time(uint64_t start, uint64_t end) { + API_BEGIN(); + float latency; + CUDA_CALL(cudaEventElapsedTime(&latency, reinterpret_cast(start), reinterpret_cast(end))); + return latency; + API_END(0.0); +} + +DLL void hidet_cuda_event_record(uint64_t event_handle, uint64_t stream_handle) { + API_BEGIN(); + CUDA_CALL(cudaEventRecord(reinterpret_cast(event_handle), reinterpret_cast(stream_handle))); + API_END(); +} + + +DLL uint64_t hidet_cuda_graph_create() { + API_BEGIN(); + cudaGraph_t graph; + CUDA_CALL(cudaGraphCreate(&graph, 0)); + return reinterpret_cast(graph); + API_END(0); +} + +DLL void hidet_cuda_graph_destroy(uint64_t handle) { + API_BEGIN(); + CUDA_CALL(cudaGraphDestroy(reinterpret_cast(handle))); + API_END(); +} + +DLL void hidet_cuda_stream_begin_capture(uint64_t stream_handle) { + API_BEGIN(); + CUDA_CALL(cudaStreamBeginCapture(reinterpret_cast(stream_handle), cudaStreamCaptureModeThreadLocal)); + API_END(); +} + +DLL uint64_t hidet_cuda_stream_end_capture(uint64_t stream_handle) { + API_BEGIN(); + cudaGraph_t graph; + CUDA_CALL(cudaStreamEndCapture(reinterpret_cast(stream_handle), &graph)); + return reinterpret_cast(graph); + API_END(0); +} + +DLL uint64_t hidet_cuda_graph_instantiate(uint64_t graph_handle) { + API_BEGIN(); + auto graph = reinterpret_cast(graph_handle); + cudaGraphExec_t graph_exec; + CUDA_CALL(cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0)); + return reinterpret_cast(graph_exec); + API_END(0); +} + +DLL void hidet_cuda_graph_exec_launch(uint64_t graph_exec_handle, uint64_t stream_handle) { + API_BEGIN(); + CUDA_CALL(cudaGraphLaunch(reinterpret_cast(graph_exec_handle), + reinterpret_cast(stream_handle))); + API_END(); +} + +DLL void hidet_cuda_graph_exec_destroy(uint64_t graph_exec_handle) { + API_BEGIN(); + CUDA_CALL(cudaGraphExecDestroy(reinterpret_cast(graph_exec_handle))); + API_END(); +} + +DLL void hidet_cuda_profiler_start() { + API_BEGIN(); + CUDA_CALL(cudaProfilerStart()); + API_END(); +} + +DLL void hidet_cuda_profiler_stop() { + API_BEGIN(); + CUDA_CALL(cudaProfilerStop()); + API_END(); +} + +DLL uint64_t hidet_cuda_get_device_property(uint64_t device_id, const char *property_name) { + API_BEGIN(); + static bool queried = false; + static cudaDeviceProp prop{}; + if(!queried) { + CUDA_CALL(cudaGetDeviceProperties(&prop, device_id)); + } + + std::string name(property_name); + if(name == "multiProcessorCount") { + return prop.multiProcessorCount; + } else if(name == "major") { + return prop.major; + } else if(name == "minor") { + return prop.minor; + } else { + std::cout << "Can not recognize property name: " << name << std::endl; + return 0; + } + API_END(0); +} diff --git a/src/hidet/cuda_kernels.cu b/src/hidet/cuda_kernels.cu new file mode 100644 index 0000000..08a70ae --- /dev/null +++ b/src/hidet/cuda_kernels.cu @@ -0,0 +1,28 @@ +#include "hidet/cuda_utils.h" + +template +static __global__ void fill_value_kernel(dtype* dst, uint64_t num_elements, dtype fill_value) { + auto stride = gridDim.x * blockDim.x; + for(auto idx = blockIdx.x * blockDim.x + threadIdx.x; idx < num_elements; idx += stride) { + dst[idx] = fill_value; + } +} + +template +static void fill_value_generic(uint64_t addr, uint64_t num_elements, dtype fill_value) { + dim3 block(512); + dim3 grid((num_elements + block.x - 1) / block.x); + fill_value_kernel<<>>(reinterpret_cast(addr), num_elements, fill_value); +} + +DLL void hidet_cuda_fill_value_float32(uint64_t addr, uint64_t num_elements, float fill_value) { + fill_value_generic(addr, num_elements, fill_value); +} + +DLL void hidet_cuda_fill_value_int32(uint32_t addr, uint64_t num_elements, int32_t fill_value) { + fill_value_generic(addr, num_elements, fill_value); +} + +DLL void hidet_cuda_fill_value_int64(uint64_t addr, uint64_t num_elements, int64_t fill_value) { + fill_value_generic(addr, num_elements, fill_value); +} diff --git a/src/hidet/logging.cpp b/src/hidet/logging.cpp new file mode 100644 index 0000000..de85df2 --- /dev/null +++ b/src/hidet/logging.cpp @@ -0,0 +1,26 @@ +#include + + +ErrorState* ErrorState::global() { + static thread_local ErrorState instance; + return &instance; +} + +DLL void hidet_set_last_error(const char *msg) { + ErrorState* state = ErrorState::global(); + if(state->has_error) { + fprintf(stderr, "Warning: hidet error state has been override: %s\n", state->error_msg.c_str()); + } + state->has_error = true; + state->error_msg = msg; +} + +DLL const char * hidet_get_last_error() { + ErrorState* state = ErrorState::global(); + if(state->has_error) { + state->has_error = false; + return state->error_msg.c_str(); + } else { + return nullptr; + } +} diff --git a/src/hidet/packedfunc.cpp b/src/hidet/packedfunc.cpp new file mode 100644 index 0000000..11eecad --- /dev/null +++ b/src/hidet/packedfunc.cpp @@ -0,0 +1,49 @@ +#include +#include + +#include +#include +#include +#include + +extern "C" { + +DLL void CallPackedFunc(PackedFunc func, void** args) { + auto f = PackedFunc_t(func.func_pointer); + f(func.num_args, func.arg_types, args); +} + +DLL void ProfilePackedFunc(PackedFunc func, void** args, int warmup, int number, int repeat, float* results) { + cudaEvent_t start, end; + CUDA_CALL(cudaEventCreate(&start)); + CUDA_CALL(cudaEventCreate(&end)); + + for(int i = 0; i < warmup; i++) { + CallPackedFunc(func, args); + } + + for(int i = 0; i < repeat; i++) { + CUDA_CALL(cudaDeviceSynchronize()); + CUDA_CALL(cudaEventRecord(start)); + for(int j = 0; j < number; j++) { + CallPackedFunc(func, args); + } + CUDA_CALL(cudaEventRecord(end)); + CUDA_CALL(cudaDeviceSynchronize()); + CUDA_CALL(cudaEventElapsedTime(results + i, start, end)); // results[i] in milliseconds. + } + + CUDA_CALL(cudaEventDestroy(start)); + CUDA_CALL(cudaEventDestroy(end)); +} + +DLL void packed_func_sample(int num_args, int *arg_types, void** args) { + assert(num_args == 1); + int type_code = arg_types[0]; + assert(type_code == INT32); + int* arg = static_cast(args[0]); + printf("hello, world!\n%d\n", *arg); +} + +} + diff --git a/src/hidet/runtime/cuda_context.cpp b/src/hidet/runtime/cuda_context.cpp new file mode 100644 index 0000000..494a291 --- /dev/null +++ b/src/hidet/runtime/cuda_context.cpp @@ -0,0 +1,14 @@ +#include + +CudaContext *CudaContext::global() { + static thread_local CudaContext instance; + return &instance; +} + +DLL void set_cuda_stream(cudaStream_t stream) { + CudaContext::global()->stream = stream; +} + +DLL cudaStream_t get_cuda_stream() { + return CudaContext::global()->stream; +}