Skip to content

Commit

Permalink
add arbitrary rank, vectorized tensor dot dotGeneral (#379)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley authored Jan 2, 2024
1 parent 3eee9f3 commit c1bde4a
Show file tree
Hide file tree
Showing 17 changed files with 473 additions and 34 deletions.
2 changes: 1 addition & 1 deletion backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.7
0.0.8
9 changes: 9 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"

#include "../literal.h"
#include "../xla_data.pb.h"
#include "xla_builder.h"
#include "xla_computation.h"

Expand Down Expand Up @@ -257,6 +258,14 @@ extern "C" {
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* DotGeneral(XlaOp& lhs, XlaOp& rhs, DotDimensionNumbers& dimension_numbers) {
auto& lhs_ = reinterpret_cast<xla::XlaOp&>(lhs);
auto& rhs_ = reinterpret_cast<xla::XlaOp&>(rhs);
auto& dimension_numbers_ = reinterpret_cast<xla::DotDimensionNumbers&>(dimension_numbers);
xla::XlaOp res = xla::DotGeneral(lhs_, rhs_, dimension_numbers_);
return reinterpret_cast<XlaOp*>(new xla::XlaOp(res));
}

XlaOp* TriangularSolve(
XlaOp& a, XlaOp& b, int left_side, int lower, int unit_diagonal, int transpose_a
) {
Expand Down
2 changes: 2 additions & 0 deletions backend/src/tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"

#include "../literal.h"
#include "../xla_data.pb.h"
#include "xla_computation.h"

extern "C" {
Expand Down Expand Up @@ -105,6 +106,7 @@ extern "C" {
XlaOp* Le(XlaOp& lhs, XlaOp& rhs);

XlaOp* Dot(XlaOp& lhs, XlaOp& rhs);
XlaOp* DotGeneral(XlaOp& lhs, XlaOp& rhs, DotDimensionNumbers& dimension_numbers);
XlaOp* TriangularSolve(
XlaOp& a, XlaOp& b, int left_side, int lower, int unit_diagonal, int transpose_a
);
Expand Down
56 changes: 56 additions & 0 deletions backend/src/tensorflow/compiler/xla/xla_data.pb.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "tensorflow/compiler/xla/xla_data.pb.h"

#include "xla_data.pb.h"

extern "C" {
DotDimensionNumbers* DotDimensionNumbers_new() {
return reinterpret_cast<DotDimensionNumbers*>(new xla::DotDimensionNumbers());
}

void DotDimensionNumbers_delete(DotDimensionNumbers* dimension_numbers) {
delete reinterpret_cast<xla::DotDimensionNumbers*>(dimension_numbers);
}

void DotDimensionNumbers_add_lhs_contracting_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
) {
auto& dimension_numbers_ = reinterpret_cast<xla::DotDimensionNumbers&>(dimension_numbers);
dimension_numbers_.add_lhs_contracting_dimensions(dim);
}

void DotDimensionNumbers_add_rhs_contracting_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
) {
auto& dimension_numbers_ = reinterpret_cast<xla::DotDimensionNumbers&>(dimension_numbers);
dimension_numbers_.add_rhs_contracting_dimensions(dim);
}

void DotDimensionNumbers_add_lhs_batch_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
) {
auto& dimension_numbers_ = reinterpret_cast<xla::DotDimensionNumbers&>(dimension_numbers);
dimension_numbers_.add_lhs_batch_dimensions(dim);
}

void DotDimensionNumbers_add_rhs_batch_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
) {
auto& dimension_numbers_ = reinterpret_cast<xla::DotDimensionNumbers&>(dimension_numbers);
dimension_numbers_.add_rhs_batch_dimensions(dim);
}
}
40 changes: 40 additions & 0 deletions backend/src/tensorflow/compiler/xla/xla_data.pb.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "tensorflow/compiler/xla/xla_data.pb.h"

extern "C" {
struct DotDimensionNumbers;

DotDimensionNumbers* DotDimensionNumbers_new();

void DotDimensionNumbers_delete(DotDimensionNumbers* dimension_numbers);

void DotDimensionNumbers_add_lhs_contracting_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
);

void DotDimensionNumbers_add_rhs_contracting_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
);

void DotDimensionNumbers_add_lhs_batch_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
);

void DotDimensionNumbers_add_rhs_batch_dimensions(
DotDimensionNumbers& dimension_numbers, int dim
);
}
1 change: 1 addition & 0 deletions spidr.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ modules =
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Literal,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.Shape,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.ShapeUtil,
Compiler.Xla.Prim.TensorFlow.Compiler.Xla.XlaData,
Compiler.Xla.Prim.TensorFlow.Core.CommonRuntime.GPU.GPUInit,
Compiler.Xla.Prim.TensorFlow.Core.Platform.Status,
Compiler.Xla.Prim.Util,
Expand Down
7 changes: 7 additions & 0 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ interpret xlaBuilder (MkFn params root env) = do
compFalse <- lift $ compile subBuilderF fFalse
conditional !(get pred) !(get true) compTrue !(get false) compFalse
interpretE (Dot l r) = dot !(get l) !(get r)
interpretE (DotGeneral lb rb lc rc l r) = do
dimensionNumbers <- allocDotDimensionNumbers
traverse_ (addLhsBatchDimensions dimensionNumbers) lb
traverse_ (addRhsBatchDimensions dimensionNumbers) rb
traverse_ (addLhsContractingDimensions dimensionNumbers) lc
traverse_ (addRhsContractingDimensions dimensionNumbers) rc
dotGeneral !(get l) !(get r) dimensionNumbers
interpretE (Cholesky x) = cholesky !(get x) True
interpretE (TriangularSolve a b lower) =
triangularSolve !(get a) !(get b) True lower False NoTranspose
Expand Down
1 change: 1 addition & 0 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ data Expr : Type where
Select : Nat -> Nat -> Nat -> Expr
Cond : Nat -> Fn 1 -> Nat -> Fn 1 -> Nat -> Expr
Dot : Nat -> Nat -> Expr
DotGeneral : (lBatch, lContract, rBatch, rContract : List Nat) -> Nat -> Nat -> Expr
Cholesky : Nat -> Expr
TriangularSolve : Nat -> Nat -> Bool -> Expr
UniformFloatingPoint : Nat -> Nat -> Nat -> Nat -> Shape -> Expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ export
%foreign (libxla "Dot")
prim__dot : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr

export
%foreign (libxla "DotGeneral")
prim__dotGeneral : GCAnyPtr -> GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr

export
%foreign (libxla "TriangularSolve")
prim__triangularSolve : GCAnyPtr -> GCAnyPtr -> Int -> Int -> Int -> Int -> PrimIO AnyPtr
Expand Down
44 changes: 44 additions & 0 deletions src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/XlaData.idr
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{--
Copyright 2024 Joel Berkeley
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--}
module Compiler.Xla.Prim.TensorFlow.Compiler.Xla.XlaData

import System.FFI

import Compiler.Xla.Prim.Util

export
%foreign (libxla "DotDimensionNumbers_new")
prim__dotDimensionNumbersNew : PrimIO AnyPtr

export
%foreign (libxla "DotDimensionNumbers_delete")
prim__dotDimensionNumbersDelete : AnyPtr -> PrimIO ()

export
%foreign (libxla "DotDimensionNumbers_add_lhs_contracting_dimensions")
prim__addLhsContractingDimensions : GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "DotDimensionNumbers_add_rhs_contracting_dimensions")
prim__addRhsContractingDimensions : GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "DotDimensionNumbers_add_lhs_batch_dimensions")
prim__addLhsBatchDimensions : GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "DotDimensionNumbers_add_rhs_batch_dimensions")
prim__addRhsBatchDimensions : GCAnyPtr -> Int -> PrimIO ()
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaComputation
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.Util
import Types
import Util
Expand Down Expand Up @@ -225,6 +226,13 @@ export
dot : HasIO io => XlaOp -> XlaOp -> io XlaOp
dot = binaryOp prim__dot

export
dotGeneral : HasIO io => XlaOp -> XlaOp -> DotDimensionNumbers -> io XlaOp
dotGeneral (MkXlaOp l) (MkXlaOp r) (MkDotDimensionNumbers dimensionNumbers) = do
opPtr <- primIO $ prim__dotGeneral l r dimensionNumbers
opPtr <- onCollectAny opPtr XlaOp.delete
pure (MkXlaOp opPtr)

public export
data Transpose = NoTranspose | Transpose_ | Adjoint

Expand Down
38 changes: 38 additions & 0 deletions src/Compiler/Xla/TensorFlow/Compiler/Xla/XlaData.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.
--}
module Compiler.Xla.TensorFlow.Compiler.Xla.XlaData

import Compiler.Xla.Prim.TensorFlow.Compiler.Xla.XlaData

export
interface Primitive dtype where
xlaIdentifier : Int
Expand Down Expand Up @@ -60,3 +62,39 @@ export data F64 : Type where
export
Primitive F64 where
xlaIdentifier = 12

namespace Xla
public export
data DotDimensionNumbers : Type where
MkDotDimensionNumbers : GCAnyPtr -> DotDimensionNumbers

export
delete : HasIO io => AnyPtr -> io ()
delete = primIO . prim__dotDimensionNumbersDelete

export
allocDotDimensionNumbers : HasIO io => io DotDimensionNumbers
allocDotDimensionNumbers = do
ptr <- primIO prim__dotDimensionNumbersNew
ptr <- onCollectAny ptr delete
pure (MkDotDimensionNumbers ptr)

export
addLhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addLhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addLhsContractingDimensions dimension_numbers (cast n)

export
addRhsContractingDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addRhsContractingDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addRhsContractingDimensions dimension_numbers (cast n)

export
addLhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addLhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addLhsBatchDimensions dimension_numbers (cast n)

export
addRhsBatchDimensions : HasIO io => DotDimensionNumbers -> Nat -> io ()
addRhsBatchDimensions (MkDotDimensionNumbers dimension_numbers) n =
primIO $ prim__addRhsBatchDimensions dimension_numbers (cast n)
60 changes: 58 additions & 2 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -518,9 +518,9 @@ transpose :
(ordering : List Nat) ->
Tensor shape dtype ->
{auto 0 lengths : length ordering = length shape} ->
{auto 0 unique : Sorted Neq ordering} ->
{auto 0 axesUnique : unique ordering = True} ->
{auto 0 inBounds : All (flip InBounds shape) ordering} ->
Graph $ Tensor (map (dflip List.index shape) ordering) dtype
Graph $ Tensor (multiIndex ordering shape) dtype
transpose ordering $ MkTensor x = addTensor $ Transpose ordering x

||| The identity tensor, with inferred shape and element type. For example,
Expand Down Expand Up @@ -976,6 +976,62 @@ namespace Matrix
MkTensor x' <- x'
addTensor $ Dot x x'

||| The output shape of a `dotGeneral` operation.
public export
contract : (lBatch, rBatch, lContract, rContract : List Nat) ->
(ls, rs : Shape) ->
{auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
{auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
{auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
{auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
Shape
contract lBatch rBatch lContract rContract ls rs =
let lResultDims = deleteAt {inBounds = lInBoundsBatch ++ lInBoundsContract}
(lBatch ++ lContract) ls
rResultDims = deleteAt {inBounds = rInBoundsBatch ++ rInBoundsContract}
(rBatch ++ rContract) rs
in multiIndex lBatch ls ++ lResultDims ++ rResultDims

||| Matrix multiplication.
|||
||| This is a much more general version of `(@@)`, in which you can specify any number of batch
||| and contracting axes. Matrix multiplication is done over each contracting axis.
||| The operation is vectorized over batch axes. For each contracting axis on the left-hand
||| operand, there is one contracting axis on the right-hand operand. These can be different axes
||| in each operand. The same is true for each batch axis.
|||
||| For example, we can vectorize over a typical rank-two matrix multiplication as follows: given
||| two inputs tensors
||| ```
||| let x : Tensor [3, 4, 5, 6] F64
||| y : Tensor [3, 4, 6, 7] F64
||| ```
||| we do
||| ```
||| let z : Graph $ Tensor [3, 4, 5, 7] F64 = dotGeneral [0, 1] [0, 1] [3] [2] x y
||| ```
||| Here, we vectorized over the first two axes `[0, 1]`, and do standard matrix multiplication
||| over the remaining axes by specifying the axes 3 and 2 respectively as contracting axes. Notice
||| how the batch axes appear once each at the start of the output shape, and the contracting axis
||| disappears. Remaining axes appear in order from left to right.
|||
||| Note this API is somewhat of a quickfix to bring general matrix multiplication to the tensor
||| API. It is not thoroughly tested. Expect it to change in the future.
export
dotGeneral : (lBatch, rBatch, lContract, rContract : List Nat) ->
{auto 0 lUnique : unique (lBatch ++ lContract) = True} ->
{auto 0 rUnique : unique (rBatch ++ rContract) = True} ->
{auto 0 lInBoundsBatch : All (flip InBounds ls) lBatch} ->
{auto 0 rInBoundsBatch : All (flip InBounds rs) rBatch} ->
{auto 0 lInBoundsContract : All (flip InBounds ls) lContract} ->
{auto 0 rInBoundsContract : All (flip InBounds rs) rContract} ->
{auto 0 batchDimsEq : multiIndex lBatch ls = multiIndex rBatch rs} ->
{auto 0 contractDimsEq : multiIndex lContract ls = multiIndex rContract rs} ->
Tensor ls dtype ->
Tensor rs dtype ->
Graph $ Tensor (contract lBatch rBatch lContract rContract ls rs) dtype
dotGeneral lb rb lc rc (MkTensor x) (MkTensor y) = addTensor $ DotGeneral lb rb lc rc x y

||| Element-wise addition. For example, `tensor [1, 2] + tensor [3, 4]` is
||| `tensor [4, 6]`.
export
Expand Down
Loading

0 comments on commit c1bde4a

Please sign in to comment.