diff --git a/README.md b/README.md index 24a23ff5..428af3fc 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ A list of Examples contains how to use kernl with Pytorch. | **XNLI classification**: classification with / without optimizations (`Roberta` + `XNLI` classification task) | [link](https://github.com/ELS-RD/kernl/blob/main/tutorial/bert%20e2e.ipynb) | | **Text generation**: with/without optimizations (`T5`) | [link](https://github.com/ELS-RD/kernl/blob/main/tutorial/t5%20e2e.ipynb) | | **Transcription generation**: with/without optimizations (`Whisper`) | [link](https://github.com/ELS-RD/kernl/blob/main/experimental/whisper/speedup.ipynb) | +| **Llama version 2 optimization by kernel fusion | [link](https://github.com/ELS-RD/kernl/blob/main/experimental/llama-v2) | ## Installation diff --git a/experimental/llama-v2/.gitignore b/experimental/llama-v2/.gitignore new file mode 100755 index 00000000..d701b487 --- /dev/null +++ b/experimental/llama-v2/.gitignore @@ -0,0 +1,165 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ +llama-2-7b/ +llama-2-13b/ +triton/ +*.chk +*.model \ No newline at end of file diff --git a/experimental/llama-v2/CODE_OF_CONDUCT.md b/experimental/llama-v2/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..cf9dc244 --- /dev/null +++ b/experimental/llama-v2/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/experimental/llama-v2/CONTRIBUTING.md b/experimental/llama-v2/CONTRIBUTING.md new file mode 100644 index 00000000..5eb507d6 --- /dev/null +++ b/experimental/llama-v2/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to Llama +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to Llama, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. \ No newline at end of file diff --git a/experimental/llama-v2/LICENSE b/experimental/llama-v2/LICENSE new file mode 100644 index 00000000..51089e27 --- /dev/null +++ b/experimental/llama-v2/LICENSE @@ -0,0 +1,126 @@ +LLAMA 2 COMMUNITY LICENSE AGREEMENT +Llama 2 Version Release Date: July 18, 2023 + +"Agreement" means the terms and conditions for use, reproduction, distribution and +modification of the Llama Materials set forth herein. + +"Documentation" means the specifications, manuals and documentation +accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Licensee" or "you" means you, or your employer or any other person or entity (if +you are entering into this Agreement on such person or entity's behalf), of the age +required under applicable laws, rules or regulations to provide legal consent and that +has legal authority to bind your employer or such other person or entity if you are +entering in this Agreement on their behalf. + +"Llama 2" means the foundational large language models and software and +algorithms, including machine-learning model code, trained model weights, +inference-enabling code, training-enabling code, fine-tuning enabling code and other +elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and- +libraries/llama-downloads/. + +"Llama Materials" means, collectively, Meta's proprietary Llama 2 and +Documentation (and any portion thereof) made available under this Agreement. + +"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you +are an entity, your principal place of business is in the EEA or Switzerland) and Meta +Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +By clicking "I Accept" below or by using or distributing any portion or element of the +Llama Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. + + a. Grant of Rights. You are granted a non-exclusive, worldwide, non- +transferable and royalty-free limited license under Meta's intellectual property or +other rights owned by Meta embodied in the Llama Materials to use, reproduce, +distribute, copy, create derivative works of, and make modifications to the Llama +Materials. + + b. Redistribution and Use. + + i. If you distribute or make the Llama Materials, or any derivative works +thereof, available to a third party, you shall provide a copy of this Agreement to such +third party. + ii. If you receive Llama Materials, or any derivative works thereof, from +a Licensee as part of an integrated end user product, then Section 2 of this +Agreement will not apply to you. + + iii. You must retain in all copies of the Llama Materials that you +distribute the following attribution notice within a "Notice" text file distributed as a +part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License, +Copyright (c) Meta Platforms, Inc. All Rights Reserved." + + iv. Your use of the Llama Materials must comply with applicable laws +and regulations (including trade compliance laws and regulations) and adhere to the +Acceptable Use Policy for the Llama Materials (available at +https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into +this Agreement. + + v. You will not use the Llama Materials or any output or results of the +Llama Materials to improve any other large language model (excluding Llama 2 or +derivative works thereof). + +2. Additional Commercial Terms. If, on the Llama 2 version release date, the +monthly active users of the products or services made available by or for Licensee, +or Licensee's affiliates, is greater than 700 million monthly active users in the +preceding calendar month, you must request a license from Meta, which Meta may +grant to you in its sole discretion, and you are not authorized to exercise any of the +rights under this Agreement unless or until Meta otherwise expressly grants you +such rights. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE +LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE +PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, +EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY +WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR +FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE +FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING +THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR +USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE +LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, +NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS +AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, +CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN +IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF +ANY OF THE FOREGOING. + +5. Intellectual Property. + + a. No trademark licenses are granted under this Agreement, and in +connection with the Llama Materials, neither Meta nor Licensee may use any name +or mark owned by or associated with the other or any of its affiliates, except as +required for reasonable and customary use in describing and redistributing the +Llama Materials. + + b. Subject to Meta's ownership of Llama Materials and derivatives made by or +for Meta, with respect to any derivative works and modifications of the Llama +Materials that are made by you, as between you and Meta, you are and will be the +owner of such derivative works and modifications. + + c. If you institute litigation or other proceedings against Meta or any entity +(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama +Materials or Llama 2 outputs or results, or any portion of any of the foregoing, +constitutes infringement of intellectual property or other rights owned or licensable +by you, then any licenses granted to you under this Agreement shall terminate as of +the date such litigation or claim is filed or instituted. You will indemnify and hold +harmless Meta from and against any claim by any third party arising out of or related +to your use or distribution of the Llama Materials. + +6. Term and Termination. The term of this Agreement will commence upon your +acceptance of this Agreement or access to the Llama Materials and will continue in +full force and effect until terminated in accordance with the terms and conditions +herein. Meta may terminate this Agreement if you are in breach of any term or +condition of this Agreement. Upon termination of this Agreement, you shall delete +and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the +termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and +construed under the laws of the State of California without regard to choice of law +principles, and the UN Convention on Contracts for the International Sale of Goods +does not apply to this Agreement. The courts of California shall have exclusive +jurisdiction of any dispute arising out of this Agreement. + diff --git a/experimental/llama-v2/MODEL_CARD.md b/experimental/llama-v2/MODEL_CARD.md new file mode 100644 index 00000000..0a2718c1 --- /dev/null +++ b/experimental/llama-v2/MODEL_CARD.md @@ -0,0 +1,98 @@ +# **Model Details** + +Meta developed and released the Llama 2 family of large language models (LLMs), a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama-2-Chat, are optimized for dialogue use cases. Llama-2-Chat models outperform open-source chat models on most benchmarks we tested, and in our human evaluations for helpfulness and safety, are on par with some popular closed-source models like ChatGPT and PaLM. + +**Model Developers** Meta + +**Variations** Llama 2 comes in a range of parameter sizes — 7B, 13B, and 70B — as well as pretrained and fine-tuned variations. + +**Input** Models input text only. + +**Output** Models generate text only. + +**Model Architecture** Llama 2 is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align to human preferences for helpfulness and safety. + +||Training Data|Params|Content Length|GQA|Tokens|LR| +|---|---|---|---|---|---|---| +Llama 2|*A new mix of publicly available online data*|7B|4k|✗|2.0T|3.0 x 10-4 +Llama 2|*A new mix of publicly available online data*|13B|4k|✗|2.0T|3.0 x 10-4 +Llama 2|*A new mix of publicly available online data*|70B|4k|✔|2.0T|1.5 x 10-4 + +**Llama 2 family of models.** Token counts refer to pretraining data only. All models are trained with a global batch-size of 4M tokens. The 70B version uses Grouped-Query Attention (GQA) for improved inference scalability. + +**Model Dates** Llama 2 was trained between January 2023 and July 2023. + +**Status** This is a static model trained on an offline dataset. Future versions of the tuned models will be released as we improve model safety with community feedback. + +**License** A custom commercial license is available at: [https://ai.meta.com/resources/models-and-libraries/llama-downloads/](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) + +**Research Paper** More information can be found in the paper "Llama-2: Open Foundation and Fine-tuned Chat Models", available at https://ai.meta.com/research/publications/llama-2-open-foundation-and-fine-tuned-chat-models/. + +**Where to send questions or comments about the model** Instructions on how to provide feedback or comments on the model can be found in the model [README](README.md). + +# **Intended Use** +**Intended Use Cases** Llama 2 is intended for commercial and research use in English. Tuned models are intended for assistant-like chat, whereas pretrained models can be adapted for a variety of natural language generation tasks. + +**Out-of-scope Uses** Use in any manner that violates applicable laws or regulations (including trade compliance laws). Use in languages other than English. Use in any other way that is prohibited by the Acceptable Use Policy and Licensing Agreement for Llama 2. + +# **Hardware and Software** +**Training Factors** We used custom training libraries, Meta's Research Super Cluster, and production clusters for pretraining. Fine-tuning, annotation, and evaluation were also performed on third-party cloud compute. + +**Carbon Footprint** Pretraining utilized a cumulative 3.3M GPU hours of computation on hardware of type A100-80GB (TDP of 350-400W). Estimated total emissions were 539 tCO2eq, 100% of which were offset by Meta’s sustainability program. + +||Time (GPU hours)|Power Consumption (W)|Carbon Emitted(tCO2eq)| +|---|---|---|---| +|Llama 2 7B|184320|400|31.22| +|Llama 2 13B|368640|400|62.44| +|Llama 2 70B|1720320|400|291.42| +|Total|3311616||539.00| + +**CO2 emissions during pretraining.** Time: total GPU time required for training each model. Power Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. 100% of the emissions are directly offset by Meta's sustainability program, and because we are openly releasing these models, the pretraining costs do not need to be incurred by others. + +# **Training Data** +**Overview** Llama 2 was pretrained on 2 trillion tokens of data from publicly available sources. The fine-tuning data includes publicly available instruction datasets, as well as over one million new human-annotated examples. Neither the pretraining nor the fine-tuning datasets include Meta user data. + +**Data Freshness** The pretraining data has a cutoff of September 2022, but some tuning data is more recent, up to July 2023. + +# **Evaluation Results** + +In this section, we report the results for the Llama 1 and Llama 2 models on standard academic benchmarks. +For all the evaluations, we use our internal evaluations library. + +|Model|Size|Code|Commonsense Reasoning|World Knowledge|Reading Comprehension|Math|MMLU|BBH|AGI Eval| +|---|---|---|---|---|---|---|---|---|---| +|Llama 1|7B|14.1|60.8|46.2|58.5|6.95|35.1|30.3|23.9| +|Llama 1|13B|18.9|66.1|52.6|62.3|10.9|46.9|37.0|33.9| +|Llama 1|33B|26.0|70.0|58.4|67.6|21.4|57.8|39.8|41.7| +|Llama 1|65B|30.7|70.7|60.5|68.6|30.8|63.4|43.5|47.6| +|Llama 2|7B|16.8|63.9|48.9|61.3|14.6|45.3|32.6|29.3| +|Llama 2|13B|24.5|66.9|55.4|65.8|28.7|54.8|39.4|39.1| +|Llama 2|70B|**37.5**|**71.9**|**63.6**|**69.4**|**35.2**|**68.9**|**51.2**|**54.2**| + +**Overall performance on grouped academic benchmarks.** *Code:* We report the average pass@1 scores of our models on HumanEval and MBPP. *Commonsense Reasoning:* We report the average of PIQA, SIQA, HellaSwag, WinoGrande, ARC easy and challenge, OpenBookQA, and CommonsenseQA. We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. *World Knowledge:* We evaluate the 5-shot performance on NaturalQuestions and TriviaQA and report the average. *Reading Comprehension:* For reading comprehension, we report the 0-shot average on SQuAD, QuAC, and BoolQ. *MATH:* We report the average of the GSM8K (8 shot) and MATH (4 shot) benchmarks at top 1. + +|||TruthfulQA|Toxigen| +|---|---|---|---| +|Llama 1|7B|27.42|23.00| +|Llama 1|13B|41.74|23.08| +|Llama 1|33B|44.19|22.57| +|Llama 1|65B|48.71|21.77| +|Llama 2|7B|33.29|**21.25**| +|Llama 2|13B|41.86|26.10| +|Llama 2|70B|**50.18**|24.60| + +**Evaluation of pretrained LLMs on automatic safety benchmarks.** For TruthfulQA, we present the percentage of generations that are both truthful and informative (the higher the better). For ToxiGen, we present the percentage of toxic generations (the smaller the better). + + +|||TruthfulQA|Toxigen| +|---|---|---|---| +|Llama-2-Chat|7B|57.04|**0.00**| +|Llama-2-Chat|13B|62.18|**0.00**| +|Llama-2-Chat|70B|**64.14**|0.01| + +**Evaluation of fine-tuned LLMs on different safety datasets.** Same metric definitions as above. + +# **Ethical Considerations and Limitations** +Llama 2 is a new technology that carries risks with use. Testing conducted to date has been in English, and has not covered, nor could it cover all scenarios. For these reasons, as with all LLMs, Llama 2’s potential outputs cannot be predicted in advance, and the model may in some instances produce inaccurate, biased or other objectionable responses to user prompts. Therefore, before deploying any applications of Llama 2, developers should perform safety testing and tuning tailored to their specific applications of the model. + +Please see the Responsible Use Guide available at [https://ai.meta.com/llama/responsible-use-guide/](https://ai.meta.com/llama/responsible-use-guide/) diff --git a/experimental/llama-v2/README.md b/experimental/llama-v2/README.md new file mode 100755 index 00000000..4ab90f10 --- /dev/null +++ b/experimental/llama-v2/README.md @@ -0,0 +1,16 @@ +# LLama 2 optimization + +The purpose of this experiment is to improve Llama v2 performance by fusing kernels together. +Note that code should be run on top of Triton commit [69a806c](https://github.com/openai/triton/tree/69a806c745aa604fec6bd317628d3dc293aa1e46). +Main triton branch has some CPU overhead probably because of the add of AMD GPUs support (and some new mechanism to load the right backend). + +We measured on consumer grade GPU 3090 RTX a speed up from 30 to 54 tokens/sec for 7B model. +The purpose is not to get extreme perf, there are many easy things to do to get even better performances. +Also, think about replacing the triton jit part by a lighter launcher if you want to push perf higher. + + +We tried to keep Llama model code as close as possible to the original one. +In particular we removed all the multi GPU support code and replaced ut by classical local execution function (like Linear module). +It makes things more simple to run, less overhead for PyTorch benchmark, and at the end it is easier to understand. + +More details about what is done and how it works are in our article [here](). diff --git a/experimental/llama-v2/Responsible-Use-Guide.pdf b/experimental/llama-v2/Responsible-Use-Guide.pdf new file mode 100644 index 00000000..e65e5d1c Binary files /dev/null and b/experimental/llama-v2/Responsible-Use-Guide.pdf differ diff --git a/experimental/llama-v2/USE_POLICY.md b/experimental/llama-v2/USE_POLICY.md new file mode 100644 index 00000000..abbcc199 --- /dev/null +++ b/experimental/llama-v2/USE_POLICY.md @@ -0,0 +1,50 @@ +# Llama 2 Acceptable Use Policy + +Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). + +## Prohibited Uses +We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to: + +1. Violate the law or others’ rights, including to: + 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: + 1. Violence or terrorism + 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material + 3. Human trafficking, exploitation, and sexual violence + 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. + 5. Sexual solicitation + 6. Any other criminal activity + 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals + 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services + 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices + 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws + 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials + 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system + + + +2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following: + 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State + 2. Guns and illegal weapons (including weapon development) + 3. Illegal drugs and regulated/controlled substances + 4. Operation of critical infrastructure, transportation technologies, or heavy machinery + 5. Self-harm or harm to others, including suicide, cutting, and eating disorders + 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual + + + +3. Intentionally deceive or mislead others, including use of Llama 2 related to the following: + 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation + 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content + 3. Generating, promoting, or further distributing spam + 4. Impersonating another individual without consent, authorization, or legal right + 5. Representing that the use of Llama 2 or outputs are human-generated + 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement +4. Fail to appropriately disclose to end users any known dangers of your AI system + +Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means: + +* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) +* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) +* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) +* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) + diff --git a/experimental/llama-v2/__init__.py b/experimental/llama-v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/llama-v2/download.sh b/experimental/llama-v2/download.sh new file mode 100644 index 00000000..c62f4e20 --- /dev/null +++ b/experimental/llama-v2/download.sh @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +read -p "Enter the URL from email: " PRESIGNED_URL +echo "" +read -p "Enter the list of models to download without spaces (7B,13B,70B,7B-chat,13B-chat,70B-chat), or press Enter for all: " MODEL_SIZE +TARGET_FOLDER="." # where all files should end up +mkdir -p ${TARGET_FOLDER} + +if [[ $MODEL_SIZE == "" ]]; then + MODEL_SIZE="7B,13B,70B,7B-chat,13B-chat,70B-chat" +fi + +echo "Downloading LICENSE and Acceptable Usage Policy" +wget ${PRESIGNED_URL/'*'/"LICENSE"} -O ${TARGET_FOLDER}"/LICENSE" +wget ${PRESIGNED_URL/'*'/"USE_POLICY.md"} -O ${TARGET_FOLDER}"/USE_POLICY.md" + +echo "Downloading tokenizer" +wget ${PRESIGNED_URL/'*'/"tokenizer.model"} -O ${TARGET_FOLDER}"/tokenizer.model" +wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokenizer_checklist.chk" +(cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk) + +for m in ${MODEL_SIZE//,/ } +do + if [[ $m == "7B" ]]; then + SHARD=0 + MODEL_PATH="llama-2-7b" + elif [[ $m == "7B-chat" ]]; then + SHARD=0 + MODEL_PATH="llama-2-7b-chat" + elif [[ $m == "13B" ]]; then + SHARD=1 + MODEL_PATH="llama-2-13b" + elif [[ $m == "13B-chat" ]]; then + SHARD=1 + MODEL_PATH="llama-2-13b-chat" + elif [[ $m == "70B" ]]; then + SHARD=7 + MODEL_PATH="llama-2-70b" + elif [[ $m == "70B-chat" ]]; then + SHARD=7 + MODEL_PATH="llama-2-70b-chat" + fi + + echo "Downloading ${MODEL_PATH}" + mkdir -p ${TARGET_FOLDER}"/${MODEL_PATH}" + + for s in $(seq -f "0%g" 0 ${SHARD}) + do + wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/consolidated.${s}.pth" + done + + wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/params.json"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/params.json" + wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/checklist.chk"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/checklist.chk" + echo "Checking checksums" + (cd ${TARGET_FOLDER}"/${MODEL_PATH}" && md5sum -c checklist.chk) +done + diff --git a/experimental/llama-v2/example_text_completion.py b/experimental/llama-v2/example_text_completion.py new file mode 100755 index 00000000..fa94df9c --- /dev/null +++ b/experimental/llama-v2/example_text_completion.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import argparse + +from utils.config import Config +from llama import Llama + + +def main( + ckpt_dir: str, + tokenizer_path: str, + temperature: float, + top_p: float, + max_seq_len: int, + max_gen_len: int, + max_batch_size: int, +): + generator = Llama.build( + ckpt_dir=ckpt_dir, + tokenizer_path=tokenizer_path, + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + ) + + # prompts = [ + # For these prompts, the expected answer is the natural continuation of the prompt + # "I believe the meaning of life is", + # "Simply put, the theory of relativity states that ", + # """A brief message congratulating the team on the launch: + # + # Hi everyone, + # + # I just """, + # # Few shot prompt (providing a few examples before asking model to complete more); + # """Translate English to French: + # + # sea otter => loutre de mer + # peppermint => menthe poivrée + # plush girafe => girafe peluche + # cheese =>""", + # ] + prompts = ["I believe the meaning of life is"] * args.max_batch_size + for _ in range(2): # warmup + results, batched_token_timings = generator.text_completion( + prompts, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + ) + + for prompt, result in zip(prompts, results): + print(prompt) + print(f"> {result['generation']}") + print("\n==================================\n") + print(f"batch size: {len(prompts)}") + print(f"longer generated sequence in the batch: {len(batched_token_timings)}") + print(f"total inference time (fwd): {sum(batched_token_timings):.2f}") + print(f"average token timings: {sum(batched_token_timings) / len(batched_token_timings):.2f}") + print(f"min/max token timings: {min(batched_token_timings):.2f}/{max(batched_token_timings):.2f}") + print(f"token / sec: {max_batch_size * len(batched_token_timings) / sum(batched_token_timings):.2f}") + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + + argparser.add_argument( + "--ckpt_dir", + type=str, + default="./llama-2-7b", + help="Path to the checkpoint directory", + ) + argparser.add_argument( + "--tokenizer_path", + type=str, + default="./tokenizer.model", + help="Path to the tokenizer model", + ) + argparser.add_argument( + "--temperature", + type=float, + default=0.6, + help="Temperature for sampling", + ) + argparser.add_argument( + "--top_p", + type=float, + default=0.9, + help="Top p for sampling", + ) + argparser.add_argument( + "--max_seq_len", + type=int, + default=512, + help="Maximum sequence length", + ) + argparser.add_argument( + "--max_gen_len", + type=int, + default=512, + help="Maximum generation length", + ) + argparser.add_argument( + "--max_batch_size", + type=int, + default=1, + help="Maximum batch size", + ) + argparser.add_argument( + "--enable_nvtx", + action="store_true", + help="Enable NVTX profiling", + ) + # enable use_triton in config + argparser.add_argument( + "--use_triton", + action="store_true", + help="Use Triton kernels instead of PyTorch implementation", + ) + + args = argparser.parse_args() + + config = Config() + config.set_nvtx(args.enable_nvtx) + config.set_use_triton(args.use_triton) + + main( + ckpt_dir=args.ckpt_dir, + tokenizer_path=args.tokenizer_path, + temperature=args.temperature, + top_p=args.top_p, + max_seq_len=args.max_seq_len, + max_gen_len=args.max_gen_len, + max_batch_size=args.max_batch_size, + ) diff --git a/experimental/llama-v2/kernel/__init__.py b/experimental/llama-v2/kernel/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/experimental/llama-v2/kernel/fused_kernel_attention.py b/experimental/llama-v2/kernel/fused_kernel_attention.py new file mode 100644 index 00000000..43fa9642 --- /dev/null +++ b/experimental/llama-v2/kernel/fused_kernel_attention.py @@ -0,0 +1,180 @@ +import math +from typing import Union + +import torch + +import triton +import triton.language as tl + +from fused_kernel_fp8 import f16_to_f8 +from pytorch_reference import attention_reference + + +# main changes +# 1. add start_position +# 2. manage is_causal +# 3. fix stride +# 4. add load masks +# 5. delete saves related to backward pass + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + N_HEAD, H, N_CTX, + start_position, # <- ADDED + IS_CAUSAL: tl.constexpr, # <- ADDED + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, +): + start_m = tl.program_id(0) + + head_idx = tl.program_id(1) + batch_id = head_idx // N_HEAD + off_hz = head_idx % N_HEAD + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk # <- stride fixed + off_k = batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk # <- stride fixed + off_v = batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn # <- stride fixed + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0) + # loop over k, v and update accumulator + block_n_end = N_CTX # <- ADDED (including the IF) + if IS_CAUSAL: + # in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE + # autotune will prune shapes not matching this rule + block_n_end = (start_m + 1) * BLOCK_N + start_position + for start_n in range(0, block_n_end, BLOCK_N): + block_n_offs = start_n + offs_n # <- ADDED + # -- compute qk ---- + k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.) + if USE_FP8: + k = k.to(tl.float8e5, bitcast=True) + k = k.to(tl.float16) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + qk = tl.where(offs_n[None, :] < N_CTX, qk, float("-inf")) # <- ADDED + qk *= sm_scale + if IS_CAUSAL: # <- ADDED + qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float("-inf")) + + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(Q.dtype.element_ty) + v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0) + if USE_FP8: + v = v.to(tl.float8e5, bitcast=True) + v = v.to(tl.float16) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_DMODEL) + off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, offs_m[:, None] < H) + + +def triton_fa(q, k, v, sm_scale, is_causal, start_position): + assert q.dtype == torch.float16 + assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8] + + BLOCK = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + num_warps = 4 if Lk <= 64 else 8 + batch, head_size, m_size, dhead = q.size() + grid = (triton.cdiv(m_size, BLOCK), head_size * batch) + n_size = k.size(2) + _fwd_kernel[grid]( + q, k, v, sm_scale, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + head_size, m_size, n_size, + start_position=start_position, + IS_CAUSAL=is_causal, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + USE_FP8=k.dtype == torch.int8, # USE_FP8 + num_warps=num_warps, + num_stages=2, + ) + + return o + + +xq = torch.randn([1, 16, 32, 128], dtype=torch.float16, device="cuda") +keys = torch.randn([1, 16, 32, 128], dtype=torch.float16, device="cuda") +values = torch.randn([1, 16, 32, 128], dtype=torch.float16, device="cuda") + +xq = xq.transpose(1, 2) +keys = keys.transpose(1, 2) +values = values.transpose(1, 2) + +scale = 1 / math.sqrt(128) +output_t = triton_fa(xq, keys, values, scale, True, 0) +output_p = attention_reference( + q = xq, + k= keys, + v=values, + output=torch.empty_like(output_t), + sm_scale=scale, + is_causal=True, +) + +assert torch.allclose(output_p, output_t, atol=1e-2) + + +xq_fp8 = f16_to_f8(xq, dtypes=tl.float8e5) +keys_fp8 = f16_to_f8(keys, dtypes=tl.float8e5) +values_fp8 = f16_to_f8(values, dtypes=tl.float8e5) + +output_t_fp8 = triton_fa(xq, keys_fp8, values_fp8, scale, True, 0) +assert torch.allclose(output_t_fp8, output_p, atol=5e-1) + +print("attention fp16", triton.testing.do_bench(lambda: triton_fa(xq, keys, values, scale, True, 0))) +print("attention fp8", triton.testing.do_bench(lambda: triton_fa(xq, keys_fp8, values_fp8, scale, True, 0))) + + diff --git a/experimental/llama-v2/kernel/fused_kernel_ff.py b/experimental/llama-v2/kernel/fused_kernel_ff.py new file mode 100644 index 00000000..e45a00e5 --- /dev/null +++ b/experimental/llama-v2/kernel/fused_kernel_ff.py @@ -0,0 +1,154 @@ +import torch + +import triton +import triton.language as tl + +from kernel.pytorch_reference import rms_norm_pytorch +from kernel.fused_kernel_fp8 import f16_to_f8, f8_to_f16 + + +@triton.jit +def ff_llama( + a_ptr, w1_ptr, w3_ptr, out_ptr, rms_w_ptr, + M, N, K, + stride_am, stride_ak, + stride_w1k, stride_w1n, + stride_w3k, stride_w3n, + stride_outm, stride_outn, + stride_rms_w, + USE_FP8: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + """ + w1 and w3 are weights (linear layers) + F.silu(w1(x)) * w3(x) + """ + pid = tl.program_id(axis=0) + pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N) + pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N) + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + w1_ptrs = w1_ptr + (offs_k[:, None] * stride_w1k + offs_bn[None, :] * stride_w1n) + w3_ptrs = w3_ptr + (offs_k[:, None] * stride_w3k + offs_bn[None, :] * stride_w3n) + acc1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + acc2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w + a_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs) + a_sum += tl.math.pow(a.to(tl.float32), 2) + rms_w = tl.load(rms_w_ptrs) + if USE_FP8: + rms_w = rms_w.to(tl.float8e5, bitcast=True) + rms_w = rms_w.to(tl.float16) + a = a * rms_w + b = tl.load(w1_ptrs) + if USE_FP8: + b = b.to(tl.float8e5, bitcast=True) + b = b.to(tl.float32) + b = b.to(tl.float16) + acc1 += tl.dot(a, b) + c = tl.load(w3_ptrs) + if USE_FP8: + c = c.to(tl.float8e5, bitcast=True) + c = c.to(tl.float32) + c = c.to(tl.float16) + acc2 += tl.dot(a, c) + + a_ptrs += BLOCK_SIZE_K * stride_ak + w1_ptrs += BLOCK_SIZE_K * stride_w1k + w3_ptrs += BLOCK_SIZE_K * stride_w3k + + rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w + + a_mean = tl.sum(a_sum, axis=1) / K + EPS + a_norm = tl.math.rsqrt(a_mean) + acc1 = acc1 * a_norm[:, None] + acc2 = acc2 * a_norm[:, None] + accumulator = (acc1 * tl.sigmoid(acc1)) * acc2 + + offs_outm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_outn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_ptrs = out_ptr + (stride_outm * offs_outm[:, None] + stride_outn * offs_outn[None, :]) + out_mask = (offs_outm[:, None] < M) & (offs_outn[None, :] < N) + tl.store(out_ptrs, accumulator, mask=out_mask) + + +def kernel_ff(x: torch.Tensor, w1: torch.Tensor, w3: torch.Tensor, rms_w: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float16 + assert w1.dtype == w3.dtype == rms_w.dtype + assert w1.dtype in [torch.int8, torch.float16] + assert w1.shape == w3.shape + + w1_t = w1.t() + w3_t = w3.t() + + batch, seq_len, dim = x.shape + M, K = batch * seq_len, dim + + N = w1_t.shape[1] + assert K == w1_t.shape[0] + assert w1_t.shape == w3_t.shape + x_reshape = x.reshape(M, K) + out = torch.empty((M, N), dtype=x.dtype, device=x.device) + grid = lambda META: (triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["N"], META["BLOCK_SIZE_N"]),) + ff_llama[grid]( + x_reshape, w1_t, w3_t, out, rms_w, + M, N, K, + *x_reshape.stride(), + *w1_t.stride(), + *w3_t.stride(), + *out.stride(), + *rms_w.stride(), + USE_FP8=w1_t.dtype != torch.float16, + EPS=1e-6, + BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=64, + num_stages=2, num_warps=4 + ) + out = out.view(batch, seq_len, -1) + return out + + +x = torch.randn([1, 16, 4096], dtype=torch.float16, device="cuda") +# weights tends to be very small values +rms_w = torch.randn([4096], dtype=torch.float16, device="cuda") * 0.2 +w1_w = torch.randn([11008, 4096], dtype=torch.float16, device="cuda") * 0.2 +w3_w = torch.randn([11008, 4096], dtype=torch.float16, device="cuda") * 0.2 + + +x_norm_p = rms_norm_pytorch(x, rms_w, eps=1e-6) +w1_p = x_norm_p @ w1_w.t() +w1_silu_p = torch.nn.functional.silu(w1_p) +w3_p = x_norm_p @ w3_w.t() + + +def ff_pytorch(x: torch.Tensor, w1: torch.Tensor, w3: torch.Tensor, rms_w: torch.Tensor) -> torch.Tensor: + x_norm = rms_norm_pytorch(x, rms_w, eps=1e-6) + a = torch.nn.functional.silu(torch.matmul(x_norm, w1.t())) + b = torch.matmul(x_norm, w3.t()) + return a * b + + +output_triton = kernel_ff(x=x, w1=w1_w, w3=w3_w, rms_w=rms_w) +output_pytorch = ff_pytorch(x=x, w1=w1_w, w3=w3_w, rms_w=rms_w) + +assert torch.allclose(output_triton, w1_silu_p * w3_p, atol=1e-1), f"max diff: {torch.max(torch.abs(output_triton - w1_silu_p * w3_p))}" +assert torch.allclose(output_triton, output_pytorch, atol=1e-1), f"max diff: {torch.max(torch.abs(output_triton - output_pytorch))}" + +print("rms matmul silu mul triton", triton.testing.do_bench(lambda: kernel_ff(x=x, w1=w1_w, w3=w3_w, rms_w=rms_w))) +print("rms matmul silu mul pytorch", triton.testing.do_bench(lambda: ff_pytorch(x=x, w1=w1_w, w3=w3_w, rms_w=rms_w))) + +w1_w_fp8 = f16_to_f8(w1_w, dtypes=tl.float8e5) +w3_w_fp8 = f16_to_f8(w3_w, dtypes=tl.float8e5) +rms_w_fp8 = f16_to_f8(rms_w, dtypes=tl.float8e5) + +out_fp8 = kernel_ff(x=x, w1=w1_w_fp8, w3=w3_w_fp8, rms_w=rms_w_fp8) +# on very large tensors, it is expected that the error is large, we just check it is not crazy large +assert torch.allclose(out_fp8, w1_silu_p * w3_p, atol=10) + +print("rms matmul silu mul triton fp8", triton.testing.do_bench(lambda: kernel_ff(x=x, w1=w1_w_fp8, w3=w3_w_fp8, rms_w=rms_w_fp8))) \ No newline at end of file diff --git a/experimental/llama-v2/kernel/fused_kernel_fp8.py b/experimental/llama-v2/kernel/fused_kernel_fp8.py new file mode 100644 index 00000000..c6245751 --- /dev/null +++ b/experimental/llama-v2/kernel/fused_kernel_fp8.py @@ -0,0 +1,91 @@ +from typing import List, Optional + +import triton +import triton.language as tl +import torch + +torch.manual_seed(123) + + +def find_last_one_index(lst: List[int]) -> Optional[int]: + index = len(lst) - 1 + while index >= 0: + if lst[index] == 1: + return index + else: + index -= 1 + return None + + +def f8_to_f16(x, dtypes=tl.float8e5) -> torch.Tensor: + assert x.dtype == torch.int8, f"torch.int8 expected but got {x.dtype}" + assert "cuda" in str(x.device), f"CUDA tensors only but got {x.device}" + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty_like(x, dtype=torch.float16) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + numel = ret.untyped_storage().size() // ret.element_size() # manage cases where tensor is not contiguous, like ::2 + kernel[grid](ret, triton.reinterpret(x, dtypes), numel, BLOCK_SIZE=1024) + return ret + + +def f16_to_f8(x: torch.Tensor, dtypes=tl.float8e5) -> torch.Tensor: + assert x.dtype in [torch.float16, torch.float32] + assert "cuda" in str(x.device), f"CUDA tensors only but got {x.device}" + + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty_like(x, dtype=torch.int8) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + numel = x.untyped_storage().size() // x.element_size() # manage cases where tensor is not contiguous, like ::2 + kernel[grid](triton.reinterpret(ret, dtypes), x, numel, BLOCK_SIZE=1024) + return ret + + +class FakeLinear(torch.nn.Module): + def __init__(self, weight: torch.Tensor): + super().__init__() + self.weight = weight + + def forward(self, *args, **kwargs): + raise Exception("should not be used") + + +def get_model_fp8(parent_module: torch.nn.Module) -> None: + from llama.model import Attention, RMSNorm + + for name, module in parent_module.named_children(): + if isinstance(module, Attention): + module.cache_v = f16_to_f8(module.cache_v) + module.cache_k = f16_to_f8(module.cache_k) + + if isinstance(module, RMSNorm) or isinstance(module, torch.nn.Linear): + if name not in ["norm", "wo", "w2", "output", "lm_head"]: + assert module.weight.abs().max() < 431 + weight_fp8 = f16_to_f8(module.weight.data) + m = FakeLinear(weight=weight_fp8) + setattr(parent_module, name, m) + + get_model_fp8(module) # Recursion for nested modules + + +for _ in range(100): + a = torch.randn((16, 128), dtype=torch.float16, device="cuda") + b = f16_to_f8(a, dtypes=tl.float8e4) + c = f8_to_f16(b, dtypes=tl.float8e4) + 1e-4 + + assert (a/c).abs().mean().item()-1 < 1e-1, f"{(a/c).abs().mean()}" diff --git a/experimental/llama-v2/kernel/fused_kernel_proj_qkv.py b/experimental/llama-v2/kernel/fused_kernel_proj_qkv.py new file mode 100644 index 00000000..e798cc11 --- /dev/null +++ b/experimental/llama-v2/kernel/fused_kernel_proj_qkv.py @@ -0,0 +1,435 @@ +import torch + +import triton +import triton.language as tl + +from kernel.pytorch_reference import rms_norm_pytorch, precompute_freqs_cis_pytorch, apply_rotary_emb_pytorch + +torch.manual_seed(1234) + + +@triton.jit +def rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr, + stride_x_batch, stride_x_m, stride_x_k, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_k, + N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr): + pid_batch = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m + block_N = tl.arange(0, BLOCK_N_SIZE) + var = tl.zeros((BLOCK_N_SIZE,), tl.float32) + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0) + var += tl.math.pow(x.to(tl.float32), 2) + + var = tl.sum(var, axis=0) / N_SIZE + rstd = tl.math.rsqrt(var + eps) + + # multiply by weight and add bias + for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE): + offs_n = block_n_start_idx + block_N + x_ptr_mask = offs_n < N_SIZE + rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask) + + x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32) + x_hat = x * rstd + out = x_hat * rms_w + out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k + tl.store(output_ptr + out_off, out, mask=x_ptr_mask) + + +def rmsnorm_triton_wrapper(x, rms_w, eps=1e-6): + batch, M, K = x.shape + assert rms_w.shape[-1] == K + out = torch.empty_like(x) + rmsnorm_triton[(batch, M,)](x, rms_w, out, + *x.stride(), + *rms_w.stride(), + *out.stride(), + N_SIZE=K, eps=eps, BLOCK_N_SIZE=1024, + ) + return out + + +@triton.jit +def get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr): + DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads + freqs = offs_cn % DIM + freqs = freqs.to(tl.float32) / DIM + freqs = tl.math.pow(theta, freqs) + freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :] + return tl.cos(freqs), tl.sin(freqs) + + +@triton.jit +def rbe_triton(x_ptr, out_ptr, + M, K, + stride_x_batch, stride_x_m, stride_x_n, + stride_out_batch, stride_out_m, stride_out_n, + start_token_position, + THETA: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + pid_batch = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K) + pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers + x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :]) + x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K) + real = tl.load(x_ptrs, mask=x_real_mask, other=0.0) + x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K) + imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0) + tl.debug_barrier() + start_block = start_token_position + pid_m * BLOCK_SIZE_M + cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M) + + out_real = real * cos - imag * sin + out_imag = real * sin + imag * cos + tl.debug_barrier() + out_ptrs = out_ptr + ( + pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :]) + out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K) + tl.store(out_ptrs, out_real, mask=out_real_mask) + out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K) + tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask) + + +def rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor: + batch, M, K = x.shape + out = torch.empty_like(x) + grid = lambda META: ( + batch, triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["K"], META["BLOCK_SIZE_K"]),) + + rbe_triton[grid](x, out, + M, K, + *x.stride(), + *out.stride(), + start_token_position=pos, THETA=10000., BLOCK_SIZE_M=2, BLOCK_SIZE_K=1024) + return out + + +@triton.jit +def rms_matmul_rbe( + x_ptr, w_ptr, rms_w_ptr, out_ptr, + M, N, K, + stride_x_batch, stride_x_m, stride_x_k, + stride_w_k, stride_w_n, + stride_rms_w, + stride_out_batch, stride_out_m, stride_out_n, + start_token_position, + USE_FP8: tl.constexpr, + RBE_EPILOGUE: tl.constexpr, + THETA: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + """ + Prologue: RMS + Epilogue: nothing or Rotary embeddings + c = ROBE((rms(a) * rms_w) @ b) + """ + pid_batch = tl.program_id(axis=0) + pid = tl.program_id(axis=1) + pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N) + pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N) + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (pid_batch * stride_x_batch + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k) + w_ptrs = w_ptr + (offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w + x_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + x = tl.load(x_ptrs) + x_sum += tl.math.pow(x.to(tl.float32), 2) + rms_w = tl.load(rms_w_ptrs) # TODO add an assert that rms_w is a multiple of BLOCK SIZE K + if USE_FP8: + rms_w = rms_w.to(tl.float8e5, bitcast=True) + rms_w = rms_w.to(tl.float16) + x = x * rms_w + w = tl.load(w_ptrs) # TODO add an assert that w is a multiple of BLOCK SIZE K + if USE_FP8: + w = w.to(tl.float8e5, bitcast=True) + w = w.to(tl.float32) + w = w.to(tl.float16) + accumulator += tl.dot(x, w) + x_ptrs += BLOCK_SIZE_K * stride_x_k + w_ptrs += BLOCK_SIZE_K * stride_w_k + rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w + x_mean = tl.sum(x_sum, axis=1) / K + EPS + x_norm = tl.math.rsqrt(x_mean) + accumulator = accumulator * x_norm[:, None] + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_ptrs = out_ptr + ( + pid_batch * stride_out_batch + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n) + out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + + if RBE_EPILOGUE: + tl.store(out_ptrs, accumulator, mask=out_mask) + tl.debug_barrier() + rbe_triton(out_ptr, out_ptr, M, N, stride_out_batch, stride_out_m, stride_out_n, stride_out_batch, stride_out_m, + stride_out_n, start_token_position, THETA, + BLOCK_SIZE_M, BLOCK_SIZE_N) + else: + tl.store(out_ptrs, accumulator, mask=out_mask) + + +def rms_matmul_rbe_wrapper(x: torch.Tensor, weight: torch.Tensor, rms_w: torch.Tensor, use_rbe: bool, start_pos: int, + n_heads: int, head_dim: int): + assert weight.dtype == rms_w.dtype + assert weight.dtype in [torch.float16, torch.int8] + batch, M, K = x.shape + weight_t = weight.t() + K_W, N = weight_t.shape + assert K == K_W + out = torch.empty((batch, M, N), dtype=weight_t.dtype, device=weight_t.device) # TODO replace by empty + out_ptr = triton.reinterpret(out, tl.float8e5 if out.dtype == torch.int8 else tl.float16) + + grid = lambda META: ( + batch, triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["N"], META["BLOCK_SIZE_N"])) + + rms_matmul_rbe[grid]( + x_ptr=x, + w_ptr=weight_t, rms_w_ptr=rms_w, out_ptr=out_ptr, + M=M, N=N, K=K, + stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2), + stride_w_k=weight_t.stride(0), stride_w_n=weight_t.stride(1), + stride_rms_w=rms_w.stride(0), + stride_out_batch=out.stride(0), stride_out_m=out.stride(1), stride_out_n=out.stride(2), + start_token_position=start_pos, + USE_FP8=weight_t.dtype == torch.int8, + RBE_EPILOGUE=use_rbe, + THETA=10000., + EPS=1e-6, + BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + num_stages=4, num_warps=4 + ) + out = out.view(batch, M, n_heads, head_dim) + return out + + +@triton.jit +def rms_matmul_rbe_qkv(x_ptr, + q_weight_ptr, k_weight_ptr, v_weight_ptr, + rms_w_ptr, + q_ptr, k_ptr, v_ptr, + M, N, K, + stride_x_batch, stride_x_m, stride_x_k, + stride_q_w_k, stride_q_w_n, + stride_k_w_k, stride_k_w_n, + stride_v_w_k, stride_v_w_n, + stride_rms_w, + stride_q_batch, stride_q_m, stride_q_n, + stride_k_batch, stride_k_m, stride_k_n, + stride_v_batch, stride_v_m, stride_v_n, + start_token_position, + USE_FP8: tl.constexpr, + THETA: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + # q + rms_matmul_rbe( + x_ptr=x_ptr, + w_ptr=q_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=q_ptr, + M=M, N=N, K=K, + stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k, + stride_w_k=stride_q_w_k, stride_w_n=stride_q_w_n, + stride_rms_w=stride_rms_w, + stride_out_batch=stride_q_batch, stride_out_m=stride_q_m, stride_out_n=stride_q_n, + start_token_position=start_token_position, + USE_FP8=USE_FP8, + RBE_EPILOGUE=True, + THETA=THETA, + EPS=EPS, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + # k + rms_matmul_rbe( + x_ptr=x_ptr, + w_ptr=k_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=k_ptr, + M=M, N=N, K=K, + stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k, + stride_w_k=stride_k_w_k, stride_w_n=stride_k_w_n, + stride_rms_w=stride_rms_w, + stride_out_batch=stride_k_batch, stride_out_m=stride_k_m, stride_out_n=stride_k_n, + start_token_position=start_token_position, + USE_FP8=USE_FP8, + RBE_EPILOGUE=True, + THETA=THETA, + EPS=EPS, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + # v + rms_matmul_rbe( + x_ptr=x_ptr, + w_ptr=v_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=v_ptr, + M=M, N=N, K=K, + stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k, + stride_w_k=stride_v_w_k, stride_w_n=stride_v_w_n, + stride_rms_w=stride_rms_w, + stride_out_batch=stride_v_batch, stride_out_m=stride_v_m, stride_out_n=stride_v_n, + start_token_position=start_token_position, + USE_FP8=USE_FP8, + RBE_EPILOGUE=False, + THETA=THETA, + EPS=EPS, + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + +def rms_matmul_rbe_qkv_wrapper(x: torch.Tensor, + start_pos: int, + q_weight: torch.Tensor, k_weight: torch.Tensor, v_weight: torch.Tensor, + rms_w: torch.Tensor, + n_heads: int, head_dim: int, + k: torch.Tensor, + v: torch.Tensor, + eps: float = 1e-6, theta=10000.): + assert q_weight.shape == k_weight.shape == v_weight.shape + assert q_weight.dtype == k_weight.dtype == v_weight.dtype == rms_w.dtype + assert q_weight.dtype in [torch.float16, torch.int8] + batch, M, K = x.shape + + assert K == rms_w.shape[0] + + q_weight_t = q_weight.t() + k_weight_t = k_weight.t() + v_weight_t = v_weight.t() + K_W, N = q_weight_t.shape + assert K == K_W + q = torch.empty((batch, M, N), dtype=torch.float16, device=q_weight_t.device) + + k = k.view((batch, M, N)) + v = v.view((batch, M, N)) + assert k.dtype == k_weight.dtype + assert v.dtype == v_weight.dtype + + q_ptr = triton.reinterpret(q, tl.float16) + k_ptr = triton.reinterpret(k, tl.float8e5 if k.dtype == torch.int8 else tl.float16) + v_ptr = triton.reinterpret(v, tl.float8e5 if v.dtype == torch.int8 else tl.float16) + + grid = lambda META: ( + batch, triton.cdiv(META["M"], META["BLOCK_SIZE_M"]) * triton.cdiv(META["N"], META["BLOCK_SIZE_N"])) + + rms_matmul_rbe_qkv[grid]( + x_ptr=x, + q_weight_ptr=q_weight_t, k_weight_ptr=k_weight_t, v_weight_ptr=v_weight_t, + rms_w_ptr=rms_w, + q_ptr=q_ptr, k_ptr=k_ptr, v_ptr=v_ptr, + M=M, N=N, K=K, + stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2), + stride_q_w_k=q_weight_t.stride(0), stride_q_w_n=q_weight_t.stride(1), + stride_k_w_k=k_weight_t.stride(0), stride_k_w_n=k_weight_t.stride(1), + stride_v_w_k=v_weight_t.stride(0), stride_v_w_n=v_weight_t.stride(1), + stride_rms_w=rms_w.stride(0), + stride_q_batch=q.stride(0), stride_q_m=q.stride(1), stride_q_n=q.stride(2), + stride_k_batch=k.stride(0), stride_k_m=k.stride(1), stride_k_n=k.stride(2), + stride_v_batch=v.stride(0), stride_v_m=v.stride(1), stride_v_n=v.stride(2), + start_token_position=start_pos, + USE_FP8=q_weight.dtype == torch.int8, + THETA=theta, + EPS=eps, + BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64, + num_stages=4, num_warps=4 + ) + q = q.view(batch, M, n_heads, head_dim) + k = k.view(batch, M, n_heads, head_dim) + v = v.view(batch, M, n_heads, head_dim) + return q, k, v + + +batch, seq_len, heads, dim = [1, 16, 32, 128] + +embeddings_load = torch.randn([batch, seq_len, heads * dim], dtype=torch.float16, device="cuda") +rms_weights = torch.randn([heads * dim], dtype=torch.float16, device="cuda") * 0.2 +q_weights_load = torch.randn([heads * dim, heads * dim], dtype=torch.float16, device="cuda") * 0.2 + + +out_rms_triton = rmsnorm_triton_wrapper(x=embeddings_load, rms_w=rms_weights) +out_rms_pytorch = rms_norm_pytorch(x=embeddings_load, rms_w=rms_weights) +assert torch.allclose(out_rms_triton, out_rms_pytorch, atol=1e-1) +print("rmsnorm triton", triton.testing.do_bench(lambda: rmsnorm_triton_wrapper(x=embeddings_load, rms_w=rms_weights))) +print("rmsnorm pytorch", triton.testing.do_bench(lambda: rms_norm_pytorch(x=embeddings_load, rms_w=rms_weights))) + + +xq = out_rms_pytorch @ q_weights_load.t() +xq = xq.view(batch, seq_len, heads, dim) + + +xq_output_triton = out_rms_triton @ q_weights_load.t() +out_rbe_triton = rbe_triton_wrapper(xq_output_triton, pos=0).view(batch, seq_len, heads, dim) + +freq_cos, freq_sin = precompute_freqs_cis_pytorch(dim=128, end=seq_len) +out_rbe_pytorch = apply_rotary_emb_pytorch(x=xq, freq_cos=freq_cos, freq_sin=freq_sin).view(batch, seq_len, heads, dim) +assert torch.allclose(out_rbe_pytorch, out_rbe_triton, atol=1e-1), f"max diff: {torch.max(torch.abs(out_rbe_pytorch - out_rbe_triton))}" +print("rbe triton", triton.testing.do_bench(lambda: rbe_triton_wrapper(xq_output_triton, pos=0))) +print("rbe pytorch", triton.testing.do_bench(lambda: apply_rotary_emb_pytorch(x=xq, freq_cos=freq_cos, freq_sin=freq_sin))) + + +out_rms_matmul_rbe_triton = rms_matmul_rbe_wrapper(x=embeddings_load, start_pos=0, weight=q_weights_load, rms_w=rms_weights, + use_rbe=True, n_heads=32, head_dim=128).view(batch, seq_len, heads, dim) +assert torch.allclose(out_rms_matmul_rbe_triton, out_rbe_pytorch, atol=1e-1) + + +def get_rms_matmul_rbe_pytorch(): + a = rms_norm_pytorch(x=embeddings_load, rms_w=rms_weights) + b = a @ q_weights_load.t() + b = b.view(batch, seq_len, heads, dim) + apply_rotary_emb_pytorch(x=b, freq_cos=freq_cos, freq_sin=freq_sin) + + +print("rms_matmul_rbe triton", triton.testing.do_bench(lambda: rms_matmul_rbe_wrapper(x=embeddings_load, start_pos=0, weight=q_weights_load, rms_w=rms_weights, + use_rbe=True, n_heads=32, head_dim=128))) +print("rms_matmul_rbe pytorch", triton.testing.do_bench(get_rms_matmul_rbe_pytorch)) + +k = torch.empty((embeddings_load.shape[0], embeddings_load.shape[1], q_weights_load.shape[-1]), + dtype=q_weights_load.dtype, device=q_weights_load.device) +v = torch.empty_like(k) +out_rms_matmul_rbe_qkv, _, _ = rms_matmul_rbe_qkv_wrapper(x=embeddings_load, start_pos=0, + q_weight=q_weights_load, k_weight=q_weights_load, + v_weight=q_weights_load, rms_w=rms_weights, + k=k, v=v, + n_heads=32, + head_dim=128) + +assert torch.allclose(out_rms_matmul_rbe_qkv, out_rbe_pytorch, atol=1e-1) + +position = 5 +embeddings_load_1_token = embeddings_load[:, position:position + 1, :] +xq_cplx_rotated_loaded_1_token = out_rbe_pytorch[:, position:position + 1, ...] +k_1_token = k[:, position:position + 1, :] +v_1_token = v[:, position:position + 1, :] + +_, out_rms_matmul_rbe_qkv_1_token, _ = rms_matmul_rbe_qkv_wrapper(x=embeddings_load_1_token, start_pos=position, + q_weight=q_weights_load, k_weight=q_weights_load, + v_weight=q_weights_load, rms_w=rms_weights, + k=k_1_token, v=v_1_token, + n_heads=32, + head_dim=128) +assert torch.allclose(out_rms_matmul_rbe_qkv_1_token, xq_cplx_rotated_loaded_1_token, atol=1e-1) + + +def get_qkv_rms_matmul_rbe_pytorch(): + for i in range(3): + a = rms_norm_pytorch(x=embeddings_load, rms_w=rms_weights) + b = a @ q_weights_load.t() + b = b.view(batch, seq_len, heads, dim) + if i == 2: + apply_rotary_emb_pytorch(x=b, freq_cos=freq_cos, freq_sin=freq_sin) + +print("qkv rms_matmul_rbe triton", triton.testing.do_bench(lambda: rms_matmul_rbe_qkv_wrapper(x=embeddings_load, start_pos=0, + q_weight=q_weights_load, k_weight=q_weights_load, + v_weight=q_weights_load, rms_w=rms_weights, + k=k, v=v, + n_heads=32, + head_dim=128))) +print("qkv rms_matmul_rbe pytorch", triton.testing.do_bench(get_qkv_rms_matmul_rbe_pytorch)) \ No newline at end of file diff --git a/experimental/llama-v2/kernel/pytorch_reference.py b/experimental/llama-v2/kernel/pytorch_reference.py new file mode 100644 index 00000000..d98a21de --- /dev/null +++ b/experimental/llama-v2/kernel/pytorch_reference.py @@ -0,0 +1,104 @@ +import torch + +batch, seq_len, heads, dim = 1, 16, 32, 128 + + +def rms_norm_pytorch(x: torch.Tensor, rms_w: torch.Tensor, eps=1e-6) -> torch.Tensor: + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + return x * rms_w + + +def reshape_for_broadcast_pytorch(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rbe_pytorch(input_tensor: torch.Tensor, weights: torch.Tensor, freqs_cis: torch.Tensor): + embenddings_rms_reshaped = input_tensor.view(-1, heads * dim) + proj = torch.nn.functional.linear(embenddings_rms_reshaped, weights) + proj_reshaped = proj.view(batch, -1, heads, dim) + proj_reshaped = torch.view_as_complex(proj_reshaped.float().view(*proj_reshaped.shape[:-1], -1, 2)) + freqs_cis_reshaped = reshape_for_broadcast_pytorch(freqs_cis, proj_reshaped) + out = torch.view_as_real(proj_reshaped * freqs_cis_reshaped).flatten(3) + return out.type_as(input_tensor) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f"{freqs_cis.shape} != {(x.shape[1], x.shape[-1])}" + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_pytorch(dim: int, end: int, theta: float = 10000.0): + assert dim % 2 == 0 + + # Generate a sequence of numbers from 0 to dim in steps of 2 + sequence = torch.arange(0, dim, 2, dtype=torch.float32, device="cuda") + + # Keep only the first half of the sequence (in case dim is odd?) + # sequence = sequence[: (dim // 2)] + + # Calculate frequency values based on the sequence and theta + freqs = 1.0 / (theta ** (sequence / dim)) + + # Create a tensor of numbers from 0 to end, it represents the position ids + t = torch.arange(end, device=freqs.device) + + # Generate a table of frequency values + freqs = t[:, None] * freqs[None, :] # torch.outer(t, freqs).float() + + # Calculate cosine and sine values for the frequencies + # These can be considered as the real and imaginary parts of complex numbers + freqs_cos = torch.cos(freqs) + freqs_sin = torch.sin(freqs) + + # Return the cosine and sine values as two separate tensors + return freqs_cos, freqs_sin + + +def apply_rotary_emb_pytorch(x: torch.Tensor, freq_cos: torch.Tensor, freq_sin: torch.Tensor) -> torch.Tensor: + # Split x and x into real and imaginary parts + x_real = x[..., 0::2] + x_imag = x[..., 1::2] + + # Reshape freq_cos and freq_sin for broadcasting + freq_cos = reshape_for_broadcast(freq_cos, x_real).to(torch.float32) + freq_sin = reshape_for_broadcast(freq_sin, x_imag).to(torch.float32) + + # Perform the equivalent of complex multiplication + x_out_real = x_real * freq_cos - x_imag * freq_sin + x_out_imag = x_real * freq_sin + x_imag * freq_cos + + # Combine real and imaginary parts back into the original tensor + x_out = torch.stack((x_out_real, x_out_imag), dim=-1).flatten(-2) + + return x_out.type_as(x) + + +def pytorch_all(input_tensor: torch.Tensor, weights: torch.Tensor, rms_weights: torch.Tensor, freqs_cis: torch.Tensor): + embenddings_rms = rms_norm_pytorch(input_tensor, rms_weights) + return rbe_pytorch(embenddings_rms, weights, freqs_cis) + + +def attention_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + output: torch.Tensor, + sm_scale: float, + is_causal: bool, +) -> torch.Tensor: + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if is_causal: + m_size = q.size(2) + n_size = k.size(2) + M = torch.tril(torch.ones((m_size, n_size), device="cuda")) + p = torch.where(M == 0, float("-inf"), p) + p = torch.nn.functional.softmax(p, dim=-1) + ref_out = torch.matmul(p.to(v.dtype), v, out=output) + return ref_out \ No newline at end of file diff --git a/experimental/llama-v2/launch.sh b/experimental/llama-v2/launch.sh new file mode 100755 index 00000000..ede66eef --- /dev/null +++ b/experimental/llama-v2/launch.sh @@ -0,0 +1,3 @@ +source /home/geantvert/.local/share/virtualenvs/vanilla-llama/bin/activate +# /home/geantvert/.local/share/virtualenvs/vanilla-llama/bin/python /home/geantvert/workspace/vanilla-llama/benchmark.py +/home/geantvert/.local/share/virtualenvs/vanilla-llama/bin/python /home/geantvert/workspace/vanilla-llama/example_text_completion.py --max_seq_len 32 --max_gen_len 32 --enable_nvtx \ No newline at end of file diff --git a/experimental/llama-v2/llama/__init__.py b/experimental/llama-v2/llama/__init__.py new file mode 100755 index 00000000..354342dd --- /dev/null +++ b/experimental/llama-v2/llama/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from .generation import Llama +from .model import ModelArgs, Transformer +from .tokenizer import Tokenizer diff --git a/experimental/llama-v2/llama/generation.py b/experimental/llama-v2/llama/generation.py new file mode 100755 index 00000000..53aae952 --- /dev/null +++ b/experimental/llama-v2/llama/generation.py @@ -0,0 +1,309 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +# import os +# import sys +import time +from pathlib import Path +from typing import List, Literal, Optional, Tuple, TypedDict + +import torch +import torch.nn.functional as F +# from fairscale.nn.model_parallel.initialize import ( +# get_model_parallel_rank, +# initialize_model_parallel, +# model_parallel_is_initialized, +# ) + +from llama.model import ModelArgs, Transformer +from llama.tokenizer import Tokenizer + +Role = Literal["system", "user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +class CompletionPrediction(TypedDict, total=False): + generation: str + tokens: List[str] # not required + logprobs: List[float] # not required + + +class ChatPrediction(TypedDict, total=False): + generation: Message + tokens: List[str] # not required + logprobs: List[float] # not required + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" +DEFAULT_SYSTEM_PROMPT = """\ +You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""" + + +class Llama: + @staticmethod + def build( + ckpt_dir: str, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + # model_parallel_size: Optional[int] = None, + ) -> "Llama": + # if not torch.distributed.is_initialized(): + # torch.distributed.init_process_group("nccl") + # if not model_parallel_is_initialized(): + # if model_parallel_size is None: + # model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) + # initialize_model_parallel(model_parallel_size) + + # local_rank = int(os.environ.get("LOCAL_RANK", 0)) + # torch.cuda.set_device(local_rank) + + # seed must be the same in all processes + torch.manual_seed(1) + + # if local_rank > 0: + # sys.stdout = open(os.devnull, "w") + + start_time = time.time() + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + # assert model_parallel_size == len( + # checkpoints + # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" + # ckpt_path = checkpoints[get_model_parallel_rank()] + ckpt_path = checkpoints[0] + checkpoint = torch.load(ckpt_path, map_location="cpu") + with open(Path(ckpt_dir) / "params.json", "r") as f: + params = json.loads(f.read()) + + model_args: ModelArgs = ModelArgs( + max_seq_len=max_seq_len, + max_batch_size=max_batch_size, + **params, + ) + tokenizer = Tokenizer(model_path=tokenizer_path) + model_args.vocab_size = tokenizer.n_words + torch.set_default_tensor_type(torch.cuda.HalfTensor) + model = Transformer(model_args) + model.load_state_dict(checkpoint, strict=False) + print(f"Loaded in {time.time() - start_time:.2f} seconds") + + return Llama(model, tokenizer) + + def __init__(self, model: Transformer, tokenizer: Tokenizer): + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + max_gen_len: int, + temperature: float = 0.6, + top_p: float = 0.9, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[List[int]], List[float], Optional[List[List[float]]]]: + params = self.model.params + bsz = len(prompt_tokens) + assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) + + min_prompt_len = min(len(t) for t in prompt_tokens) + max_prompt_len = max(len(t) for t in prompt_tokens) + assert max_prompt_len <= params.max_seq_len + total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) + + pad_id = self.tokenizer.pad_id + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + for k, t in enumerate(prompt_tokens): + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + if logprobs: + token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + + prev_pos = 0 + eos_reached = torch.tensor([False] * bsz, device="cuda") + input_text_mask = tokens != pad_id + batched_token_timings = list() # each timing measure is for a batch inference + for cur_pos in range(min_prompt_len, total_len): + torch.cuda.synchronize() + start = time.time() + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + torch.cuda.synchronize() + batched_token_timings.append(time.time() - start) + if logprobs: + token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( + input=logits.transpose(1, 2), + target=tokens[:, prev_pos + 1 : cur_pos + 1], + reduction="none", + ignore_index=pad_id, + ) + if temperature > 0: + probs = torch.softmax(logits[:, -1] / temperature, dim=-1) + next_token = sample_top_p(probs, top_p) + else: + next_token = torch.argmax(logits[:, -1], dim=-1) + + next_token = next_token.reshape(-1) + # only replace token if prompt has already been generated + next_token = torch.where( + input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + tokens[:, cur_pos] = next_token + eos_reached |= (~input_text_mask[:, cur_pos]) & ( + next_token == self.tokenizer.eos_id + ) + prev_pos = cur_pos + if all(eos_reached): + break + + if logprobs: + token_logprobs = token_logprobs.tolist() + out_tokens, out_logprobs = [], [] + for i, toks in enumerate(tokens.tolist()): + # cut to max gen len + start = 0 if echo else len(prompt_tokens[i]) + toks = toks[start : len(prompt_tokens[i]) + max_gen_len] + if logprobs: + probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len] + # cut to eos tok if any + if self.tokenizer.eos_id in toks: + eos_idx = toks.index(self.tokenizer.eos_id) + toks = toks[:eos_idx] + probs = probs[:eos_idx] if logprobs else None + out_tokens.append(toks) + out_logprobs.append(probs) + return (out_tokens, batched_token_timings, out_logprobs if logprobs else None) + + def text_completion( + self, + prompts: List[str], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + echo: bool = False, + ) -> Tuple[List[CompletionPrediction], List[float]]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] + generation_tokens, batched_token_timings, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + echo=echo, + ) + if logprobs: + return [ + { + "generation": self.tokenizer.decode(t), + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ], batched_token_timings + return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens], batched_token_timings + + def chat_completion( + self, + dialogs: List[Dialog], + temperature: float = 0.6, + top_p: float = 0.9, + max_gen_len: Optional[int] = None, + logprobs: bool = False, + ) -> List[ChatPrediction]: + if max_gen_len is None: + max_gen_len = self.model.params.max_seq_len - 1 + prompt_tokens = [] + for dialog in dialogs: + if dialog[0]["role"] != "system": + dialog = [ + { + "role": "system", + "content": DEFAULT_SYSTEM_PROMPT, + } + ] + dialog + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system', 'user' and 'assistant' roles, " + "starting with 'system', then 'user' and alternating (u/a/u/a/u...)" + ) + dialog_tokens: List[int] = sum( + [ + self.tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + bos=True, + eos=True, + ) + for prompt, answer in zip( + dialog[::2], + dialog[1::2], + ) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += self.tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + bos=True, + eos=False, + ) + prompt_tokens.append(dialog_tokens) + + generation_tokens, generation_logprobs = self.generate( + prompt_tokens=prompt_tokens, + max_gen_len=max_gen_len, + temperature=temperature, + top_p=top_p, + logprobs=logprobs, + ) + if logprobs: + return [ + { + "generation": { + "role": "assistant", + "content": self.tokenizer.decode(t), + }, + "tokens": [self.tokenizer.decode(x) for x in t], + "logprobs": logprobs_i, + } + for t, logprobs_i in zip(generation_tokens, generation_logprobs) + ] + return [ + {"generation": {"role": "assistant", "content": self.tokenizer.decode(t)}} + for t in generation_tokens + ] + + +def sample_top_p(probs, p): + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = torch.multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token diff --git a/experimental/llama-v2/llama/model.py b/experimental/llama-v2/llama/model.py new file mode 100755 index 00000000..bf67c3f8 --- /dev/null +++ b/experimental/llama-v2/llama/model.py @@ -0,0 +1,352 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# code commented below is to replace parallel exec by simpler local exec +# it is not strictly required but remove some small PyTorch overhead +# plus it makes code simpler to launch and think about + +import math +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +# import fairscale.nn.model_parallel.initialize as fs_init +import torch +import torch.nn.functional as F +# from fairscale.nn.model_parallel.layers import ( +# ColumnParallelLinear, +# ParallelEmbedding, +# RowParallelLinear, +# ) +from torch import nn +from utils.config import Config +from utils.nvtx_fake import NoOpContextManager + +config = Config() + + +try: + import nvtx +except ImportError: + config.set_nvtx(False) + + +@dataclass +class ModelArgs: + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 # defined later by tokenizer + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + model_parallel_size = 1 # fs_init.get_model_parallel_world_size() + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + + self.wq = torch.nn.Linear( # ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + # gather_output=False, + # init_method=lambda x: x, + ) + self.wk = torch.nn.Linear( # ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + # gather_output=False, + # init_method=lambda x: x, + ) + self.wv = torch.nn.Linear( # ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + # gather_output=False, + # init_method=lambda x: x, + ) + self.wo = torch.nn.Linear( # RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + # input_is_parallel=True, + # init_method=lambda x: x, + ) + + self.cache_k = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + self.cache_v = torch.zeros( + ( + args.max_batch_size, + args.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + ).cuda() + + def forward( + self, + x: torch.Tensor, + x_norm: Optional[torch.Tensor], + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + rms_weights: torch.Tensor + ): + bsz, seqlen, _ = x.shape + + if config.get_use_triton(): + from kernel.fused_kernel_proj_qkv import rms_matmul_rbe_qkv_wrapper + with nvtx.annotate(message=f"RMS projection RBE (element W + matmul)", color="red"): + cache_k = self.cache_k[:bsz, start_pos: start_pos + seqlen] + cache_v = self.cache_v[:bsz, start_pos: start_pos + seqlen] + xq, xk, xv = rms_matmul_rbe_qkv_wrapper(x=x, start_pos=start_pos, q_weight=self.wq.weight, + k_weight=self.wk.weight, v_weight=self.wv.weight, + rms_w=rms_weights, n_heads=self.n_local_heads, + head_dim=self.head_dim, k=cache_k, v=cache_v) + else: + with nvtx.annotate(message=f"QKV proj", color="red"): + xq, xk, xv = self.wq(x_norm), self.wk(x_norm), self.wv(x_norm) + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + with nvtx.annotate(message=f"RBE (element W)", color="blue"): + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + self.cache_k = self.cache_k.to(xq) + self.cache_v = self.cache_v.to(xq) + + self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk + self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv + + keys = self.cache_k[:bsz, : start_pos + seqlen] + values = self.cache_v[:bsz, : start_pos + seqlen] + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + with nvtx.annotate(message=f"attention score computation", color="red"): + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) + if mask is not None: + scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) + scores = F.softmax(scores.float(), dim=-1).type_as(xq) + with nvtx.annotate(message=f"attention score application", color="red"): + output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + with nvtx.annotate(message=f"output projection", color="red"): + output = self.wo(output) + return output + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = torch.nn.Linear( # ColumnParallelLinear( + dim, hidden_dim, bias=False, # gather_output=False, init_method=lambda x: x + ) + self.w2 = torch.nn.Linear( # RowParallelLinear( + hidden_dim, dim, bias=False, # input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = torch.nn.Linear( # ColumnParallelLinear( + dim, hidden_dim, bias=False, # gather_output=False, init_method=lambda x: x + ) + + def forward(self, x, x_norm, rms_weights): + if config.get_use_triton(): + from kernel.fused_kernel_ff import kernel_ff + with nvtx.annotate(message=f"FF1 (matmul)", color="red"): + silu_times_w3 = kernel_ff(x, self.w1.weight, self.w3.weight, rms_weights) + w2_out = self.w2(silu_times_w3) + return w2_out + else: + with nvtx.annotate(message=f"FFN (W1)", color="red"): + w1_out = self.w1(x_norm) + silu_out = F.silu(w1_out) + with nvtx.annotate(message=f"FFN (W3)", color="red"): + w3_out = self.w3(x_norm) + silu_times_w3 = silu_out * w3_out + with nvtx.annotate(message=f"FFN (W2)", color="red"): + w2_out = self.w2(silu_times_w3) + return w2_out + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): + if config.get_use_triton(): + x_rms_norm = None + else: + with nvtx.annotate(message=f"RMS proj (element W)", color="blue"): + x_rms_norm = self.attention_norm(x) + + h = x + self.attention.forward( + x, x_rms_norm, start_pos, freqs_cis, mask, self.attention_norm.weight, + ) + if config.get_use_triton(): + h_rms_norm = None + else: + with nvtx.annotate(message=f"RMS FF (element W)", color="blue"): + h_rms_norm = self.ffn_norm(h) + out = h + self.feed_forward.forward(h, h_rms_norm, self.ffn_norm.weight) + return out + + +class Transformer(nn.Module): + def __init__(self, params: ModelArgs): + super().__init__() + self.params = params + self.vocab_size = params.vocab_size + self.n_layers = params.n_layers + + self.tok_embeddings = torch.nn.Embedding( # ParallelEmbedding( + params.vocab_size, params.dim, # init_method=lambda x: x + ) + + self.layers = torch.nn.ModuleList() + for layer_id in range(params.n_layers): + self.layers.append(TransformerBlock(layer_id, params)) + + self.norm = RMSNorm(params.dim, eps=params.norm_eps) + self.output = torch.nn.Linear( # ColumnParallelLinear( + params.dim, params.vocab_size, bias=False, # init_method=lambda x: x + ) + + self.freqs_cis = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + # to reduce overhead we disable nvtx context manager when not required + if not config.get_nvtx(): + nvtx.annotate = NoOpContextManager + + @torch.inference_mode() + def forward(self, tokens: torch.Tensor, start_pos: int): + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen] + + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device + ) + mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) + + for layer in self.layers: + with nvtx.annotate(message=f"layer {layer.layer_id} (element W + matmul)", color="white"): + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h) + with nvtx.annotate(message=f"output proj", color="red"): + output = self.output(h) + output = output.float() + return output diff --git a/experimental/llama-v2/llama/tokenizer.py b/experimental/llama-v2/llama/tokenizer.py new file mode 100755 index 00000000..e3af0111 --- /dev/null +++ b/experimental/llama-v2/llama/tokenizer.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from logging import getLogger +from typing import List + +from sentencepiece import SentencePieceProcessor + + +logger = getLogger() + + +class Tokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + logger.info(f"Reloaded SentencePiece model from {model_path}") + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.pad_id() + logger.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + return self.sp_model.decode(t) diff --git a/experimental/llama-v2/requirements.txt b/experimental/llama-v2/requirements.txt new file mode 100755 index 00000000..93a3a24d --- /dev/null +++ b/experimental/llama-v2/requirements.txt @@ -0,0 +1,5 @@ +torch +# fairscale +# fire +sentencepiece +nvtx diff --git a/experimental/llama-v2/setup.py b/experimental/llama-v2/setup.py new file mode 100755 index 00000000..57f86dcb --- /dev/null +++ b/experimental/llama-v2/setup.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from setuptools import find_packages, setup + + +def get_requirements(path: str): + return [l.strip() for l in open(path)] + + +setup( + name="llama", + version="0.0.1", + packages=find_packages(), + install_requires=get_requirements("requirements.txt"), +) diff --git a/experimental/llama-v2/utils/__init__.py b/experimental/llama-v2/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/llama-v2/utils/config.py b/experimental/llama-v2/utils/config.py new file mode 100644 index 00000000..dc21ab5b --- /dev/null +++ b/experimental/llama-v2/utils/config.py @@ -0,0 +1,28 @@ +class Config: + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(Config, cls).__new__(cls) + cls._instance.enable_nvtx = False + cls._instance.use_triton = False + + # perplexity + cls._instance.prefix = 10 + cls._instance.seq_len = 128 + cls._instance.num_samples = 20 + cls._instance.batch_size = 20 + + return cls._instance + + def set_nvtx(self, value: bool): + self.enable_nvtx = value + + def get_nvtx(self) -> bool: + return self.enable_nvtx + + def set_use_triton(self, value: bool): + self.use_triton = value + + def get_use_triton(self) -> bool: + return self.use_triton diff --git a/experimental/llama-v2/utils/nvtx_fake.py b/experimental/llama-v2/utils/nvtx_fake.py new file mode 100644 index 00000000..f5066517 --- /dev/null +++ b/experimental/llama-v2/utils/nvtx_fake.py @@ -0,0 +1,9 @@ +class NoOpContextManager: + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_val, exc_tb): + return False diff --git a/experimental/streamk-old/benchmark.py b/experimental/streamk-old/benchmark.py new file mode 100644 index 00000000..cd00923d --- /dev/null +++ b/experimental/streamk-old/benchmark.py @@ -0,0 +1,143 @@ +from typing import Optional + +import torch +import triton +import random + +import json + +from triton.runtime import driver + +from experimental.streamk.kernel import matmul + + +torch.manual_seed(123) +random.seed(123) + +device = torch.cuda.current_device() +total_sm = driver.utils.get_device_properties(device)["multiprocessor_count"] +print(f"total SMs: {total_sm}") + +# TODO restore disable two-tile so can support very little # SMs +# TODO dead lock with: +# m, n, k = 2304, 5632, 3328 +# matmul.set_debug(True) +# A = torch.randn(m, k, device="cuda", dtype=torch.float16) +# B = torch.randn(k, n, device="cuda", dtype=torch.float16) +# C = matmul.apply(A, B, 158, 128, 128, 32, False, 4, 4) +# exit(0) + +m, n, k = 1536, 1792, 6016 # some problem size to test +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) + +matmul.set_debug(True) +C = matmul.apply(A, B, 158, 128, 128, 32, True, 4, 4) +matmul.set_debug(False) +expected = A @ B + +assert torch.allclose(C, expected, atol=1), f"max: {(C - expected).abs().max().item()}\n{C}\n{expected}" + +# for debugging, uncomment the following line +# exit(0) + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print("PyTorch", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, 128, 128, 32, True, 4, 4)) +print(f"hybrid stream-k (grid={total_sm})", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, 128, 128, 32, True, 4, 4)) +print(f"hybrid stream-k (grid={total_sm * 2})", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, 128, 128, 32, True, 4, 4)) +print("tile matmul (grid=0)", triton_ms) + +# --------------------------------------------------------------------------- +# Log-sampled benchmark +# --------------------------------------------------------------------------- + +# tried to reproduce the tests described in the paper +num_samples = 1000 # 32768 +step = 256 +values = ((torch.logspace(torch.tensor(step).log2(), torch.tensor(8192).log2(), num_samples, base=2) / step).round() * step).unique().tolist() +shapes = [(int(m), int(n), int(k)) for m in values for n in values for k in values] +shapes = random.sample(shapes, num_samples) +assert len(shapes) == num_samples +output: Optional[torch.Tensor] = None + + +def wrapper_matmul(*args, **kwargs): + global output + output = matmul.apply(*args, **kwargs) + return output + + +results = [] +for idx, (m, n, k) in enumerate(shapes): + # print progress bar + if idx % 10 == 0 and idx > 0: + speedups = [r["speedup"] for r in results] + print(f"{idx}/{num_samples} - average speedup: {sum(speedups) / len(speedups):.3f}") + + A = torch.randn(m, k, device="cuda", dtype=torch.float16) + B = torch.randn(k, n, device="cuda", dtype=torch.float16) + + expected = A @ B + pytorch_ms = triton.testing.do_bench(lambda: A @ B) + measures = list() + for two_tiles in [True]: # TODO reenable False when dead lock is fixed + nb_sm = [total_sm, total_sm * 2] + total_tile = (m // 128) * (n // 128) + if total_tile < total_sm * 2: + nb_sm.append(total_tile) + nb_sm += random.sample(range(2, total_sm * 2, 2), 10) + for sm in nb_sm: + triton_ms = triton.testing.do_bench(lambda: wrapper_matmul(A, B, sm, 128, 128, 32, two_tiles, 4, 4)) + max_disc = (output - expected).abs().max().item() + # large tolerance to accomodate for large K (rounding due to half precision), we just want to catch bugs. + assert max_disc <= 5., f"pb size: {m}x{n}x{k} - max discrepancy: {max_disc} - sm: {sm}, 2 tiles: {two_tiles}\n{output}\n{expected}" + info = { + "2 tiles": two_tiles, + "sm": sm, + "disc": max_disc, + "triton_ms": triton_ms, + } + measures.append(info) + best_triton_ms = min([m["triton_ms"] for m in measures]) + d = { + "m": m, + "n": n, + "k": k, + "triton": measures, + "pytorch_ms": pytorch_ms, + "speedup": pytorch_ms / best_triton_ms, + } + results.append(d) + measures = list() + +results.sort(key=lambda x: x["speedup"], reverse=False) + +# --------------------------------------------------------------------------- +# Benchmark export +# --------------------------------------------------------------------------- + +with open("./experimental/streamk/results.json", "w") as f: + json.dump(results, f, indent=4) + +# python -m experimental.streamk.benchmark + +# 32760/32768 - average speedup: 0.962 (A100) +# 990/1000 - average speedup: 1.060 (3090 RTX no while loop) +# 990/1000 - average speedup: 1.053 (3090 RTX with while loop) +# 990/1000 - average speedup: 1.063 (3090 RTX with while loop and 2 tiles disabled / enabled) + +# for profiling: +# (echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"') | sudo tee -a /etc/modprobe.d/RestrictedProfiling.conf >/dev/null +# sudo update-initramfs -u -k all +# cat /proc/driver/nvidia/params | grep RmProfilingAdminOnly +# sudo apt-get install zlib1g-dev +# for reproductible experiments +# sudo nvidia-smi -pm 1 -i 0 +# sudo nvidia-smi -i 0 -pl 350 # 400 for A100 +# sudo nvidia-smi -i 0 -lgc 1005 \ No newline at end of file diff --git a/experimental/streamk-old/kernel.py b/experimental/streamk-old/kernel.py new file mode 100644 index 00000000..f56e5972 --- /dev/null +++ b/experimental/streamk-old/kernel.py @@ -0,0 +1,260 @@ +import torch + +import triton +from triton import language as tl + + +@triton.jit() +def swizzle_tile(tile_id, + M, N, K, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr + ): + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + + +@triton.jit() +def linear_tile(tile_id, + M, N, K, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr + ): + pid_m = tile_id // tl.cdiv(N, BLOCK_N) + pid_n = tile_id % tl.cdiv(N, BLOCK_N) + return pid_m, pid_n + + +@triton.jit() +def mac_loop(A, B, C, + M, N, K, + locks, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + iters_per_tile, + start_iter, end_iter, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr): + + # where are we in the grid + tile_id = start_iter // iters_per_tile + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile) + B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for current_iter in range(start_iter, end_iter): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM + C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) # compute inside the if/else to avoid spilling! + tl.store(C_, acc) + if start_iter % iters_per_tile != 0: # only if tile has been partially processed + tl.atomic_xchg(locks + tile_id, 1) + else: + while tl.atomic_cas(locks + tile_id, 1, 1) != 1: + pass + C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) # compute inside the if/else to avoid spilling! + tl.atomic_add(C_, acc) + + +@triton.jit() +def first_wave( + A, B, C, + M, N, K, + locks, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(0) + start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk) + last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk) + + while start_iter < last_iter: + end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter) + mac_loop(A, B, C, + M, N, K, + locks, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + iters_per_tile, + start_iter, end_iter, + BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, + GROUP_M, + ) + + start_iter = end_iter + + +@triton.jit() +def full_tiles( + A, B, C, + M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + total_tiles_streamk, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr, + GROUP_M: tl.constexpr, +): + # first wave has done more tiles than there are SMs, we adjust pid + tile_id = tl.program_id(0) + total_tiles_streamk + if GROUP_M > 0: + pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + else: + pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M) + + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + acc = acc.to(tl.float16) # restore C.dtype.element_ty + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + tl.store(C, acc) + + +class matmul(torch.autograd.Function): + + _debug = False + + @staticmethod + def set_debug(debug: bool): + matmul._debug = debug + + @staticmethod + def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int): + device = a.device + + assert a.is_contiguous() and b.is_contiguous(), "non-contiguous inputs are not supported" + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # accumulator types + ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # compute grid (work to do per SM on the first wave) + total_blocks_M = triton.cdiv(M, BLK_M) + total_blocks_N = triton.cdiv(N, BLK_N) + iters_per_tile = triton.cdiv(K, BLK_K) + GROUP_M = 8 # 0 to disable swizzling + total_tiles = total_blocks_M * total_blocks_N + + if total_programs_streamk > 0: # Stream-K + # last wave may occupy less than total_programs_streamk SMs + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + # iterations related to full waves + total_full_tiles_streamk = total_iters_streamk // total_programs_streamk + # iterations related to last (partial) wave + total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk + + else: # all tiles are computed using classical blocking + total_blocking_tiles = total_tiles + total_tiles_streamk = 0 + total_full_tiles_streamk = 0 + total_partial_tiles_streamk = 0 + total_iters_streamk = 0 + + if matmul._debug: + print(f"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}") + print(f"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}") + print(f"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}") + print(f"{total_programs_streamk=}") + print(f"{total_blocking_tiles=}") + print(f"{iters_per_tile=}") + print(f"{total_iters_streamk=}") + + # allocates output + c = torch.empty((M, N), device=device, dtype=a.dtype) + # allocates locks to sync work accross SMs + locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32) + k1 = first_wave[(total_programs_streamk,)]( + a, + b, + c, + M, + N, + K, + locks, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_full_tiles_streamk=total_full_tiles_streamk, + total_partial_tiles_streamk=total_partial_tiles_streamk, + iters_per_tile=iters_per_tile, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k1.n_regs} registers used, {k1.n_spills} spills") + k2 = full_tiles[(total_blocking_tiles,)]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + total_tiles_streamk=total_tiles_streamk, + BLOCK_M=BLK_M, + BLOCK_N=BLK_N, + BLOCK_K=BLK_K, + ACC_TYPE=ACC_TYPE, + GROUP_M=GROUP_M, + num_stages=num_stages, + num_warps=num_warps, + ) + if matmul._debug: + print(f"{k2.n_regs} registers used, {k2.n_spills} spills") + return c + + @staticmethod + def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4): + return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages) diff --git a/experimental/streamk-old/pred.py b/experimental/streamk-old/pred.py new file mode 100644 index 00000000..d25d0a00 --- /dev/null +++ b/experimental/streamk-old/pred.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +from experimental.streamk.utils import TritonMeasure, Measure, get_features + +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +model = nn.Linear(4, 1, bias=True).to(device) +model.load_state_dict(torch.load("./experimental/streamk/model.pt")) +model.eval() + +triton_measures = list() +for i in [82, 2*82]: + t = TritonMeasure( + two_tiles=True, + sm=i, + ) + triton_measures.append(t) + +m, n, k = 768, 4864, 8192 +blk_m, blk_n, blk_k = 128, 128, 32 + +features = get_features(triton_measures, total_tiles=(m // blk_m) * (n // blk_n), iters_per_tile=k // blk_k) +print(features) + +X_train_min = 0.0 +X_train_max = 791040.0 +X = torch.tensor(features, dtype=torch.float, device=device) +X_normalized = (X - X_train_min) / (X_train_max - X_train_min) +print(X_normalized.shape) + +with torch.inference_mode(): + y = model(X_normalized).squeeze() + print(y) diff --git a/experimental/streamk-old/train.py b/experimental/streamk-old/train.py new file mode 100644 index 00000000..f8b120a7 --- /dev/null +++ b/experimental/streamk-old/train.py @@ -0,0 +1,140 @@ +import json +import torch +import torch.nn as nn +import torch.optim as optim + +from experimental.streamk.utils import TritonMeasure, Measure, get_timings, get_features + +torch.manual_seed(123) + + +def from_dict_to_dataclass(data): + return Measure( + m=data['m'], + n=data['n'], + k=data['k'], + triton=[TritonMeasure( + two_tiles=triton_data['2 tiles'], + sm=triton_data['sm'], + disc=triton_data['disc'], + triton_ms=triton_data['triton_ms'] + ) for triton_data in data['triton']], + pytorch_ms=data['pytorch_ms'], + speedup=data['speedup'] + ) + + +blk_m, blk_n, blk_k = 128, 128, 32 +with open("./experimental/streamk/results.json") as f: + measure_json = json.load(f) + +to_skip = set() +data = list() +triton_timings = list() + + +for xp_measure in measure_json: + m = from_dict_to_dataclass(xp_measure) + total_tiles: int = m.number_of_tiles(blk_m, blk_n) + iters_per_tile: int = m.iter_per_tile(blk_k) + features = get_features(m.triton, total_tiles, iters_per_tile) + to_pred = get_timings(m.triton) + for f, t in zip(features, to_pred): + if tuple(f) in to_skip: + continue + to_skip.add(tuple(f)) + data.append((f, m.m, m.n, m.k, t)) + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +X = torch.tensor([d[0] for d in data], dtype=torch.float, device=device) +y = torch.tensor([d[4] for d in data], dtype=torch.float, device=device) + +print(X[:5], y[:5]) +print(X.shape, y.shape) + +assert not torch.isnan(X).any() and not torch.isnan(y).any(), "Input data contains NaN values." +assert not torch.isinf(X).any() and not torch.isinf(y).any(), "Input data contains infinity values." + + +def custom_split(data, split_ratio=0.8): + sorted_data = sorted(data, key=lambda x: (x[1], x[2], x[3])) + num_samples = len(sorted_data) + split_idx = int(num_samples * split_ratio) + train_data = sorted_data[:split_idx] + eval_data = sorted_data[split_idx:] + return train_data, eval_data + + +train_data, eval_data = custom_split(data) + +X_train = torch.tensor([d[0] for d in train_data], dtype=torch.float, device=device) +y_train = torch.tensor([d[4] for d in train_data], dtype=torch.float, device=device) +X_eval = torch.tensor([d[0] for d in eval_data], dtype=torch.float, device=device) +y_eval = torch.tensor([d[4] for d in eval_data], dtype=torch.float, device=device) + + +# Normalize or standardize train and eval datasets using train dataset statistics +X_train_min, X_train_max = X_train.min(dim=0).values, X_train.max(dim=0).values +X_train_normalized = (X_train - X_train_min) / (X_train_max - X_train_min) +X_eval_normalized = (X_eval - X_train_min) / (X_train_max - X_train_min) +print(f"X_train_min: {X_train_min}, X_train_max: {X_train_max}") +# Model, loss, and optimizer +input_dim = X.shape[1] +# model = LinearRegression(input_dim).to(device) +model = nn.Linear(input_dim, 1, bias=True).to(device) +criterion = nn.MSELoss() +optimizer = optim.Adam(model.parameters(), lr=0.01) + +# Training with early stopping +num_epochs = 200 +batch_size = 50 +patience = 200 +num_train_samples = X_train_normalized.shape[0] + +best_eval_loss = float('inf') +epochs_since_last_improvement = 0 + +for epoch in range(num_epochs): + # Shuffle dataset + indices = torch.randperm(num_train_samples) + X_train_shuffled = X_train_normalized[indices] + y_train_shuffled = y_train[indices] + + for i in range(0, num_train_samples, batch_size): + X_batch = X_train_shuffled[i:i + batch_size] + y_batch = y_train_shuffled[i:i + batch_size] + + # Forward pass + outputs = model(X_batch) + loss = criterion(outputs.squeeze(), y_batch) + + # Backward and optimize + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # Evaluate on the evaluation set + with torch.no_grad(): + eval_outputs = model(X_eval_normalized) + eval_loss = criterion(eval_outputs.squeeze(), y_eval) + + if eval_loss.item() < best_eval_loss: + best_eval_loss = eval_loss.item() + epochs_since_last_improvement = 0 + else: + epochs_since_last_improvement += 1 + + if epochs_since_last_improvement >= patience: + print(f"Early stopping at epoch {epoch + 1}") + break + + if (epoch + 1) % 10 == 0: + print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {loss.item():.4f}, Eval Loss: {eval_loss.item():.4f}') + +with torch.inference_mode(): + y_pred = model(X_eval_normalized).squeeze() + diff = y_pred - y_eval + print(f"Mean {diff.mean().item():.2f} ms, std {diff.std().item():.2f} ms, max {diff.max().item():.2f} ms, min {diff.min().item():.2f} ms") + +torch.save(model.state_dict(), "./experimental/streamk/model.pt") diff --git a/experimental/streamk-old/utils.py b/experimental/streamk-old/utils.py new file mode 100644 index 00000000..c698bc52 --- /dev/null +++ b/experimental/streamk-old/utils.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from typing import Optional, List + + +@dataclass +class TritonMeasure: + two_tiles: bool + sm: int + disc: Optional[float] = None + triton_ms: Optional[float] = None + + +@dataclass +class Measure: + m: int + n: int + k: int + triton: List[TritonMeasure] + pytorch_ms: Optional[float] = None + speedup: Optional[float] = None + + def number_of_tiles(self, blk_m: int, blk_n: int) -> int: + return (self.m // blk_m) * (self.n // blk_n) + + def iter_per_tile(self, blk_k: int) -> int: + return self.k // blk_k + + def get_minimum_triton_measure(self) -> TritonMeasure: + return min(self.triton, key=lambda x: x.triton_ms) + + +def get_timings(measures: List[TritonMeasure]) -> List[float]: + xp_timings = list() + for triton in measures: + xp_timings.append(triton.triton_ms) + return xp_timings + + +def get_features(measures: List[TritonMeasure], total_tiles: int, iters_per_tile: int) -> List[List[float]]: + xp_features = list() + for triton in measures: + total_programs_streamk = triton.sm + total_tiles_streamk = total_tiles % total_programs_streamk + # for two-tile Stream-K + data-parallel from original paper + if triton.two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk: + total_tiles_streamk += total_programs_streamk + # remaining tiles are computed using classical blocking + total_blocking_tiles = total_tiles - total_tiles_streamk + total_iters_streamk = total_tiles_streamk * iters_per_tile + + # values used for prediction + nb_sync_stream_k = triton.sm # there is 2 syncs per SM in stream-k + nb_store = total_blocking_tiles # there is 1 store per tile in blocking loop + nb_iter_stream_k = total_iters_streamk # includes loading + nb_iter_blocking = total_blocking_tiles * iters_per_tile # includes loading + + xp_features.append([nb_sync_stream_k, nb_iter_stream_k, nb_iter_blocking, nb_store]) + + return xp_features diff --git a/experimental/streamk-old/xp.py b/experimental/streamk-old/xp.py new file mode 100644 index 00000000..118314dd --- /dev/null +++ b/experimental/streamk-old/xp.py @@ -0,0 +1,37 @@ +import torch +import triton +import random + +from triton.runtime import driver + +from experimental.streamk.kernel import matmul + + +torch.manual_seed(123) +random.seed(123) + +device = torch.cuda.current_device() +total_sm = driver.utils.get_device_properties(device)["multiprocessor_count"] +print(f"total SMs: {total_sm}") + + +m, n, k = 1024, 256, 768 +A = torch.randn(m, k, device="cuda", dtype=torch.float16) +B = torch.randn(k, n, device="cuda", dtype=torch.float16) + + +triton_ms = triton.testing.do_bench(lambda: torch.matmul(A, B)) +print("PyTorch", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm, 128, 128, 32, True, 4, 4)) +print(f"hybrid stream-k (grid={total_sm})", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, total_sm * 2, 128, 128, 32, True, 4, 4)) +print(f"hybrid stream-k (grid={total_sm * 2})", triton_ms) + +triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, 0, 128, 128, 32, True, 4, 4)) +print("tile matmul (grid=0)", triton_ms) + +for i in range(1, 82): + triton_ms = triton.testing.do_bench(lambda: matmul.apply(A, B, i, 128, 128, 32, True, 4, 4)) + print(f"hybrid stream-k (grid={i})", triton_ms)