Skip to content

Commit

Permalink
Merge pull request #3 from skolai/wheels
Browse files Browse the repository at this point in the history
Build manylinux wheels
  • Loading branch information
daskol authored Jul 26, 2023
2 parents 8ce2935 + eaa3e74 commit 940706b
Show file tree
Hide file tree
Showing 24 changed files with 677 additions and 132 deletions.
27 changes: 26 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,35 @@ if (USE_CUDA)
enable_language(CUDA)
endif ()

set(TORCH_CUDA_ARCH_LIST "Auto" CACHE STRING "Target CUDA archetecture" FORCE)
set(TORCH_CUDA_ARCH_LIST "Common" CACHE STRING "Target CUDA archetecture")

# Set common library dependencies.
find_package(Torch REQUIRED)

message(STATUS "FewBit: Torch version detected: ${Torch_VERSION}")

# There is an issue somewhere in either CMake scripts of Torch or CMake itself.
# The issue is that some CUDA flags are not copied. So, we copy flags manually
# for Torch 1.10.
if (Torch_VERSION VERSION_LESS "1.10")
message(FATAL_ERROR "Torch version lesser than 1.10 is not supported.")
elseif (Torch_VERSION VERSION_LESS "1.11")
foreach(FLAG ${CUDA_NVCC_FLAGS})
string(FIND "${FLAG}" " " flag_space_position)
if(NOT flag_space_position EQUAL -1)
message(FATAL_ERROR "Found spaces in CUDA_NVCC_FLAGS entry '${FLAG}'")
endif()
string(APPEND CMAKE_CUDA_FLAGS " ${FLAG}")
endforeach()
endif()

# Set common C++ standard requirements.
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)

set(CMAKE_CUDA_STANDARD 14)
set(CMAKE_CUDA_STANDARD_REQUIRED TRUE)

include_directories("${CMAKE_CURRENT_SOURCE_DIR}")

# Force to display colorised error messages.
Expand Down
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ The latest release can be installed with the following command.
pip install -U https://github.com/SkoltechAI/fewbit.git
```

Another one way to get FewBit is an installation from pre-built wheels from
custom PyPI. Assume that CUDA version is 11.7 and desired PyTorch version is
2.0.1 then the command below downloads and installes PyTorch of specified
version and the latest availiable FewBit.

```shell
pip install fewbit torch==2.0.1 \
--extra-index-url https://download.pytorch.org/whl/cu117 \
--extra-index-url https://mirror.daskol.xyz/pypi/cu117/pt2.0.1
```

Note that URLs of the custom PyPIs are built from CUDA version and PyTorch
version and can be manually adjusted (see [this page][7] for list of pre-built
wheels).

### List of Activation Functions

The library supports the following activation functions.
Expand Down Expand Up @@ -176,3 +191,4 @@ Please cite the following papers if the library is used in an academic paper (ex
[4]: doc/fig/activations.svg
[5]: https://arxiv.org/abs/2201.13195
[6]: https://arxiv.org/abs/2202.00441
[7]: https://mirror.daskol.xyz/pypi/
56 changes: 56 additions & 0 deletions ci/build-wheel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
#
# Script for building binary distribution (wheel) in isolated environment
# (docker container).
#
# ci/build-wheel.sh 11.5 3.10 1.10
#

CUDA_VERSION=$1
CUDA_VERSION_SHORT=$(cut -f 1,2 -d . <<< $CUDA_VERSION)
CUDA_ORDINAL=$(tr -d '.' <<< $CUDA_VERSION_SHORT)

PYTHON_VERSION=$2
PYTHON="python${PYTHON_VERSION}"
PYTHON_ABI="cp$(tr -d '.' <<< $PYTHON_VERSION)"
PYTHON_PREFIX="/opt/python/${PYTHON_ABI}-${PYTHON_ABI}"

TORCH_VERSION=$3

echo "Build wheel in context of"
echo " CUDA $CUDA_VERSION"
echo " Python $PYTHON_VERSION"
echo " Torch $TORCH_VERSION"

PATH="${PYTHON_PREFIX}/bin:$PATH"
PYPI_EXTRA_INDEX="https://download.pytorch.org/whl/cu${CUDA_ORDINAL}/"

# Ad hoc solution to fix git repo.
git config --global --add safe.directory /workspace

$PYTHON -m pip install -U --extra-index-url "$PYPI_EXTRA_INDEX" \
"auditwheel" \
"setuptools" \
"setuptools_scm>=3.4" \
"wheel" \
"numpy" \
"torch==${TORCH_VERSION}+cu${CUDA_ORDINAL}"

$PYTHON setup.py \
build_ext -i --cuda \
bdist_wheel \
-d dist/cu$CUDA_ORDINAL/fewbit \
-p manylinux2014_x86_64

# TODO Add check with auditwheel on manylinux2014 compliance.
# auditwheel show dist/cu$CUDA_ORDINAL/fewbit/fewbit-....whl

echo "Make sanity check and show package versions"
$PYTHON - <<-END
import sys
import torch as T
import fewbit
print('Python version is', sys.version)
print('PyTorch version is', T.version.__version__)
print('FewBit version is', fewbit.__version__)
END
39 changes: 39 additions & 0 deletions ci/build-wheels.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

# Use CACHE_DIR to cache downloaded wheels from PyPI.
if [ -z "${CACHE_DIR+x}" ]; then
DOCKER_ARGS=
else
DOCKER_ARGS="-v $CACHE_DIR/pip:/root/.cache/pip"
fi

# Use custom manylinux2014 image with CUDA support.
DOCKER_IMAGE=doge.skoltech.ru/manylinux2014_x86_64

# Build wheels across version matrix.
CUDA_VERSIONS=(10.2 11.1 11.3 11.5)
PYTHON_VERSIONS=(3.8 3.9 3.10)
TORCH_VERSIONS=(1.10.2 1.11.0)
for CUDA_VERSION in ${CUDA_VERSIONS[@]}; do
for PYTHON_VERSION in ${PYTHON_VERSIONS[@]}; do
for TORCH_VERSION in ${TORCH_VERSIONS[@]}; do
versions="$CUDA_VERSION $PYTHON_VERSION $TORCH_VERSION"

docker run --rm -ti \
$DOCKER_ARGS \
-v $(pwd):/workspace \
-w /workspace \
$DOCKER_IMAGE:$CUDA_VERSION \
/workspace/ci/build-wheel.sh $versions

if [ $? -ne 0 ]; then
echo
echo "ERROR Failed to build wheels for versions $versions."
exit 1
fi
done
done
done

echo
echo "OK All wheels are built!"
29 changes: 29 additions & 0 deletions ci/release/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
ARG CUDA_VERSION=11.7.1

FROM nvidia/cuda:$CUDA_VERSION-devel-ubuntu20.04

LABEL maintainer "Daniel Bershatsky <d.bershatsky2@skoltech.ru>"

ENV TZ=Europe/Moscow

RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && \
echo $TZ > /etc/timezone

RUN DEBIAN_FRONTEND=noninteractive && \
apt update && \
apt install -y --no-install-recommends software-properties-common && \
add-apt-repository ppa:deadsnakes/ppa && \
apt update && \
apt install -y --no-install-recommends \
python3.8 \
python3.8-dev \
python3.8-venv \
python3.9 \
python3.9-dev \
python3.9-venv \
python3.10 \
python3.10-dev \
python3.10-venv \
python3.11 \
python3.11-dev \
python3.11-venv
25 changes: 25 additions & 0 deletions ci/release/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# FewBit: Release

Assume we want to build Python wheels for `fewbit` for CUDA 11.7. In this case
we should find the latest Docker image with such major and minor version (e.g.
nvidia/cuda:11.7.2-devel-ubuntu20.04). Then we should run the command below
with CUDA vesion in build argument. This and all other commands are executed
from repo root.

```shell
docker build \
-f ci/release/Dockerfile \
-t github.com/skoltech-ai/fewbit/sandbox:cu117 \
--build-arg CUDA_VERSION=11.7.2 .
```

As soon as builder image is ready we can build a wheel for specific versions of
Python and Torch.

```shell
docker run --rm -ti \
-v $PWD:/usr/src/fewbit \
-w /usr/src/fewbit \
github.com/skoltech-ai/fewbit/sandbox:cu118 \
ci/release/build-wheel.sh 11.8 3.10 2.0.0 0.1.0
```
46 changes: 46 additions & 0 deletions ci/release/build-wheel.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#!/bin/bash
# This scripts is supposed to be run in docker container. It creates and
# activates a virtual env, install core dependencies and build wheel.
#
# build-wheel.sh <cuda-version> <py-version> <torch-version> <fewbit-version>

CUDA_VERSION=${1:-11.7}
CUDA_VERSION=${CUDA_VERSION::4}
PYTHON_VERSION=${2:-3.10}
[[ -z "$3" ]] && { echo "Torch version is not specified" ; exit 1; }
TORCH_VERSION=$3
FEWBIT_VERSION=${4:-0.0.0}


python$PYTHON_VERSION -m venv .env/py$PYTHON_VERSION/cu$CUDA_VERSION
. .env/py${PYTHON_VERSION}/cu${CUDA_VERSION}/bin/activate

set -xe

pip3 install -U \
--extra-index-url "https://download.pytorch.org/whl/cu${CUDA_VERSION/./}" \
"auditwheel" \
"setuptools" \
"setuptools_scm>=3.4" \
"wheel" \
"numpy" \
"torch==${TORCH_VERSION}+cu${CUDA_VERSION/./}"

SETUPTOOLS_SCM_PRETEND_VERSION=$FEWBIT_VERSION python3 setup.py \
build_ext -i --cuda \
bdist_wheel \
-d dist/cu${CUDA_VERSION/./}/pt${TORCH_VERSION}/fewbit \
-p manylinux2014_x86_64

# TODO Add check with auditwheel on manylinux2014 compliance.
# auditwheel show dist/cu$${CUDA_VERSION/./}/fewbit/fewbit-....whl

echo "Make sanity check and show package versions"
python3 - <<-END
import sys
import torch as T
import fewbit
print('Python version is', sys.version)
print('PyTorch version is', T.version.__version__)
print('FewBit version is', fewbit.__version__)
END
6 changes: 4 additions & 2 deletions fewbit/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
add_library("fewbit" SHARED)
set_property(TARGET "fewbit" PROPERTY POSITION_INDEPENDENT_CODE ON)
target_compile_features("fewbit" PRIVATE cxx_std_20)
set_target_properties("fewbit" PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
target_link_libraries("fewbit"
PRIVATE fewbit-cpu
PUBLIC torch)
Expand Down
27 changes: 20 additions & 7 deletions fewbit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
# encoding: utf-8
# filename: __init__.py
"""Package fewbit provides a plenty optimized primitives for the bottleneck
layers of modern neural networks.
"""

import torch as T

from pathlib import Path
from os import getenv
from warnings import warn

try:
T.ops.load_library(Path(__file__).with_name('libfewbit.so'))
except Exception as e:
warn(f'Failed to load ops library: {e}.', RuntimeWarning)
finally:
del Path, T, warn
# This is feature toggle which enables or disable usage of native
# implementation or primitive operations. We assume that environment variable
# FEWBIT_NATIVE manages loading of native extension. If native extension is not
# loaded then fallback implementation is used.
if getenv('FEWBIT_NATIVE') not in ('0', 'no', 'false'):
try:
T.ops.load_library(Path(__file__).with_name('libfewbit.so'))
except Exception as e:
warn(f'Failed to load ops library: {e}.', RuntimeWarning)
finally:
del Path, T, warn

from . import functional # noqa: F401
from .approx import StepWiseFunction, approximate # noqa: F401
from .modules import * # noqa: F401,F403
from .util import map_module # noqa: F401

try:
from .version import version as __version__
except ImportError:
__version__ = None
4 changes: 2 additions & 2 deletions fewbit/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, borders, levels):
assert borders.size == levels.size + 1

self.borders = borders
self.card = self.borders.size - 1
self.card = levels.size
self.levels = levels
self.steps = levels.copy()
self.steps[1:] = levels[1:] - levels[:-1]
Expand Down Expand Up @@ -155,7 +155,7 @@ def initializer(rng, size):


def estimate_error(fn, fn_approx, dx):
es = np.empty(fn_approx.card + 1)
es = np.empty(fn_approx.card)
for i in range(fn_approx.card):
a, b = fn_approx.borders[i:i + 2]
nopoints = min(1024 ** 2, int((b - a) / dx))
Expand Down
14 changes: 14 additions & 0 deletions fewbit/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Module compat provides a few routines to manage compatibility between minor
versions of Python 3.
"""

from sys import version_info

if version_info < (3, 9):
def removeprefix(self: str, prefix: str, /) -> str:
if self.startswith(prefix):
return self[len(prefix):]
else:
return self[:]
else:
removeprefix = str.removeprefix
9 changes: 6 additions & 3 deletions fewbit/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
add_executable("fewbit-cpu-codec-test" EXCLUDE_FROM_ALL)
target_compile_features("fewbit-cpu-codec-test" PRIVATE cxx_std_20)
set_target_properties("fewbit-cpu-codec-test" PROPERTIES
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON)
target_sources("fewbit-cpu-codec-test" PRIVATE codec.cc codec.h codec_test.cc)

add_library("fewbit-cpu" OBJECT)
set_target_properties("fewbit-cpu" PROPERTIES
POSITION_INDEPENDENT_CODE ON)
target_compile_features("fewbit-cpu" PRIVATE cxx_std_20)
CXX_STANDARD 20
CXX_STANDARD_REQUIRED ON
POSITION_INDEPENDENT_CODE ON)
target_link_libraries("fewbit-cpu" PUBLIC torch)
target_sources("fewbit-cpu"
PRIVATE
Expand Down
8 changes: 5 additions & 3 deletions fewbit/cpu/gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ namespace fewbit {

std::tuple<torch::Tensor, torch::Tensor> Quantize(torch::Tensor const &inputs,
torch::Tensor const &bounds) {
#if __GNUC__ > 9
auto outputs = torch::gelu(inputs);
#else
// Add support for PyTorch 1.10.0a for NGC docker images. Namely, for docker
// image nvcr.io/nvidia/pytorch:21.10-py3.
#if TORCH_ALPHA
auto outputs = torch::gelu(inputs, true);
#else
auto outputs = torch::gelu(inputs);
#endif
auto codes = torch::searchsorted(bounds, inputs, true);

Expand Down
Loading

0 comments on commit 940706b

Please sign in to comment.