Skip to content

Commit

Permalink
overload eval for evaluating multiple tensors at once (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley authored Jan 27, 2024
1 parent 63a512e commit bded472
Show file tree
Hide file tree
Showing 16 changed files with 342 additions and 129 deletions.
2 changes: 1 addition & 1 deletion backend/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.9
0.0.10
85 changes: 55 additions & 30 deletions backend/src/tensorflow/compiler/xla/literal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include "literal.h"
#include "shape.h"
#include "shape_util.h"

extern "C" {
Literal* Literal_new(Shape& shape) {
Expand All @@ -31,61 +32,85 @@ extern "C" {
}

template <typename NativeT>
NativeT Literal_Get(Literal& lit, int* indices) {
NativeT Literal_Get(Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index) {
xla::Literal& lit_ = reinterpret_cast<xla::Literal&>(lit);
int64_t rank = lit_.shape().rank();
int64_t multi_index[rank];
std::copy(indices, indices + rank, multi_index);
return lit_.Get<NativeT>(absl::Span<const int64_t>(multi_index, rank));
int64_t multi_index_[multi_index_len];
std::copy(multi_index, multi_index + multi_index_len, multi_index_);
auto multi_index_span = absl::Span<const int64_t>(multi_index_, multi_index_len);
auto& shape_index_ = reinterpret_cast<xla::ShapeIndex&>(shape_index);
return lit_.Get<NativeT>(multi_index_span, shape_index_);
};

template <typename NativeT>
void Literal_Set(Literal& lit, int* indices, NativeT value) {
void Literal_Set(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, NativeT value
) {
xla::Literal& lit_ = reinterpret_cast<xla::Literal&>(lit);
int64_t rank = lit_.shape().rank();
int64_t multi_index[rank];
std::copy(indices, indices + rank, multi_index);
lit_.Set<NativeT>(absl::Span<const int64_t>(multi_index, rank), value);
int64_t multi_index_[multi_index_len];
std::copy(multi_index, multi_index + multi_index_len, multi_index_);
auto multi_index_span = absl::Span<const int64_t>(multi_index_, multi_index_len);
auto& shape_index_ = reinterpret_cast<xla::ShapeIndex&>(shape_index);
lit_.Set<NativeT>(multi_index_span, shape_index_, value);
};

extern "C" {
int Literal_Get_bool(Literal& lit, int* indices) {
return (int) Literal_Get<bool>(lit, indices);
int Literal_Get_bool(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
) {
return (int) Literal_Get<bool>(lit, multi_index, multi_index_len, shape_index);
}

int Literal_Get_int32_t(Literal& lit, int* indices) {
return Literal_Get<int32_t>(lit, indices);
int Literal_Get_int32_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
) {
return Literal_Get<int32_t>(lit, multi_index, multi_index_len, shape_index);
}

int Literal_Get_uint32_t(Literal& lit, int* indices) {
return (int) Literal_Get<uint32_t>(lit, indices);
int Literal_Get_uint32_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
) {
return (int) Literal_Get<uint32_t>(lit, multi_index, multi_index_len, shape_index);
}

int Literal_Get_uint64_t(Literal& lit, int* indices) {
return (int) Literal_Get<uint64_t>(lit, indices);
int Literal_Get_uint64_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
) {
return (int) Literal_Get<uint64_t>(lit, multi_index, multi_index_len, shape_index);
}

double Literal_Get_double(Literal& lit, int* indices) {
return Literal_Get<double>(lit, indices);
double Literal_Get_double(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
) {
return Literal_Get<double>(lit, multi_index, multi_index_len, shape_index);
}

void Literal_Set_bool(Literal& lit, int* indices, int value) {
Literal_Set<bool>(lit, indices, (bool) value);
void Literal_Set_bool(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
) {
Literal_Set<bool>(lit, multi_index, multi_index_len, shape_index, (bool) value);
}

void Literal_Set_int32_t(Literal& lit, int* indices, int value) {
Literal_Set<int32_t>(lit, indices, value);
void Literal_Set_int32_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
) {
Literal_Set<int32_t>(lit, multi_index, multi_index_len, shape_index, value);
}

void Literal_Set_uint32_t(Literal& lit, int* indices, int value) {
Literal_Set<uint32_t>(lit, indices, (uint32_t) value);
void Literal_Set_uint32_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
) {
Literal_Set<uint32_t>(lit, multi_index, multi_index_len, shape_index, (uint32_t) value);
}

void Literal_Set_uint64_t(Literal& lit, int* indices, int value) {
Literal_Set<uint64_t>(lit, indices, (uint64_t) value);
void Literal_Set_uint64_t(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
) {
Literal_Set<uint64_t>(lit, multi_index, multi_index_len, shape_index, (uint64_t) value);
}

void Literal_Set_double(Literal& lit, int* indices, double value) {
Literal_Set<double>(lit, indices, value);
void Literal_Set_double(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, double value
) {
Literal_Set<double>(lit, multi_index, multi_index_len, shape_index, value);
}
}
27 changes: 20 additions & 7 deletions backend/src/tensorflow/compiler/xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "shape.h"
#include "shape_util.h"

extern "C" {
struct Literal;
Expand All @@ -22,11 +23,23 @@ extern "C" {

void Literal_delete(Literal* lit);

int Literal_Get_bool(Literal& lit, int* indices);
int Literal_Get_int(Literal& lit, int* indices);
double Literal_Get_double(Literal& lit, int* indices);

void Literal_Set_bool(Literal& lit, int* indices, int value);
void Literal_Set_int(Literal& lit, int* indices, int value);
void Literal_Set_double(Literal& lit, int* indices, double value);
int Literal_Get_bool(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
);
int Literal_Get_int(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
);
double Literal_Get_double(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index
);

void Literal_Set_bool(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
);
void Literal_Set_int(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, int value
);
void Literal_Set_double(
Literal& lit, int* multi_index, int multi_index_len, ShapeIndex& shape_index, double value
);
}
16 changes: 16 additions & 0 deletions backend/src/tensorflow/compiler/xla/shape_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,24 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"

#include "shape.h"
#include "shape_util.h"

extern "C" {
ShapeIndex* ShapeIndex_new() {
return reinterpret_cast<ShapeIndex*>(new xla::ShapeIndex());
}
void ShapeIndex_delete(ShapeIndex* s) {
delete reinterpret_cast<xla::ShapeIndex*>(s);
}

void ShapeIndex_push_back(ShapeIndex& shape_index, int value) {
reinterpret_cast<xla::ShapeIndex&>(shape_index).push_back(value);
}

void ShapeIndex_push_front(ShapeIndex& shape_index, int value) {
reinterpret_cast<xla::ShapeIndex&>(shape_index).push_front(value);
}

Shape* MakeShape(int primitive_type, int* shape, int rank) {
int64_t shape64[rank];
std::copy(shape, shape + rank, shape64);
Expand Down
7 changes: 7 additions & 0 deletions backend/src/tensorflow/compiler/xla/shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,12 @@ limitations under the License.
#include "shape.h"

extern "C" {
struct ShapeIndex;

ShapeIndex* ShapeIndex_new();
void ShapeIndex_delete(ShapeIndex* s);
void ShapeIndex_push_back(ShapeIndex& shape_index, int value);
void ShapeIndex_push_front(ShapeIndex& shape_index, int value);

Shape* MakeShape(int primitive_type, int* shape, int rank);
}
7 changes: 3 additions & 4 deletions src/Compiler/Eval.idr
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ interpret xlaBuilder (MkFn params root env) = do
set posInGraph param

interpretE : Expr -> Builder XlaOp
interpretE (FromLiteral {dtype} lit) = constantLiteral xlaBuilder !(write {dtype} lit)
interpretE (FromLiteral {dtype} lit) = constantLiteral xlaBuilder !(write {dtype} [] lit)
interpretE (Arg x) = get x
interpretE (Tuple xs) = tuple xlaBuilder !(traverse get xs)
interpretE (GetTupleElement idx x) = getTupleElement !(get x) idx
Expand Down Expand Up @@ -234,12 +234,11 @@ toString f = do
pure $ opToString xlaBuilder root

export covering
execute : PrimitiveRW dtype a => Fn 0 -> {shape : _} -> ErrIO $ Literal shape a
execute : Fn 0 -> ErrIO Literal
execute f = do
xlaBuilder <- mkXlaBuilder "root"
computation <- compile xlaBuilder f
gpuStatus <- validateGPUMachineManager
platform <- if ok gpuStatus then gpuMachineManager else getPlatform "Host"
client <- getOrCreateLocalClient platform
lit <- executeAndTransfer client computation
pure (read {dtype} lit)
executeAndTransfer client computation
34 changes: 25 additions & 9 deletions src/Compiler/LiteralRW.idr
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ limitations under the License.
module Compiler.LiteralRW

import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import public Compiler.Xla.TensorFlow.Compiler.Xla.Literal
import Compiler.Xla.TensorFlow.Compiler.Xla.ShapeUtil
import Literal
import Util

Expand All @@ -39,21 +40,36 @@ indexed = go shape []
go (0 :: _) _ = []
go (S d :: ds) idxs = concat $ map (\i => go ds (snoc idxs i)) (range (S d))

export
public export
interface Primitive dtype => LiteralRW dtype ty where
set : Literal -> List Nat -> ty -> IO ()
get : Literal -> List Nat -> ty
set : Literal -> List Nat -> ShapeIndex -> ty -> IO ()
get : Literal -> List Nat -> ShapeIndex -> ty

export
write : (HasIO io, LiteralRW dtype a) => {shape : _} -> Literal shape a -> io Literal
write xs = liftIO $ do
write : HasIO io =>
LiteralRW dtype a =>
{shape : _} ->
List Nat ->
Literal shape a ->
io Literal
write idxs xs = liftIO $ do
literal <- allocLiteral {dtype} shape
sequence_ [| (\idxs => set {dtype} literal idxs) indexed xs |]
shapeIndex <- allocShapeIndex
traverse_ (pushBack shapeIndex) idxs
sequence_ [| (\idxs => set {dtype} literal idxs shapeIndex) indexed xs |]
pure literal

export
read : LiteralRW dtype a => Literal -> {shape : _} -> Literal shape a
read lit = map (get {dtype} lit) indexed
read : HasIO io =>
LiteralRW dtype a =>
{shape : _} ->
List Nat ->
Literal ->
io $ Literal shape a
read idxs lit = do
shapeIndex <- allocShapeIndex
traverse_ (pushBack shapeIndex) idxs
pure $ map (\mIdx => get {dtype} lit mIdx shapeIndex) (indexed {shape})

export
LiteralRW PRED Bool where
Expand Down
20 changes: 10 additions & 10 deletions src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Literal.idr
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,40 @@ prim__delete : AnyPtr -> PrimIO ()

export
%foreign (libxla "Literal_Set_bool")
prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> PrimIO ()
prim__literalSetBool : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "Literal_Get_bool")
literalGetBool : GCAnyPtr -> GCPtr Int -> Int
literalGetBool : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int

export
%foreign (libxla "Literal_Set_double")
prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Double -> PrimIO ()
prim__literalSetDouble : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Double -> PrimIO ()

export
%foreign (libxla "Literal_Get_double")
literalGetDouble : GCAnyPtr -> GCPtr Int -> Double
literalGetDouble : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Double

export
%foreign (libxla "Literal_Set_int32_t")
prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> PrimIO ()
prim__literalSetInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "Literal_Get_int32_t")
literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int
literalGetInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Int

export
%foreign (libxla "Literal_Set_uint32_t")
prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Bits32 -> PrimIO ()
prim__literalSetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits32 -> PrimIO ()

export
%foreign (libxla "Literal_Get_uint32_t")
literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Bits32
literalGetUInt32t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits32

export
%foreign (libxla "Literal_Set_uint64_t")
prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Bits64 -> PrimIO ()
prim__literalSetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits64 -> PrimIO ()

export
%foreign (libxla "Literal_Get_uint64_t")
literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Bits64
literalGetUInt64t : GCAnyPtr -> GCPtr Int -> Int -> GCAnyPtr -> Bits64
16 changes: 16 additions & 0 deletions src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/ShapeUtil.idr
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,22 @@ import System.FFI

import Compiler.Xla.Prim.Util

export
%foreign (libxla "ShapeIndex_new")
prim__shapeIndexNew : PrimIO AnyPtr

export
%foreign (libxla "ShapeIndex_delete")
prim__shapeIndexDelete : AnyPtr -> PrimIO ()

export
%foreign (libxla "ShapeIndex_push_back")
prim__shapeIndexPushBack : GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "ShapeIndex_push_front")
prim__shapeIndexPushFront : GCAnyPtr -> Int -> PrimIO ()

export
%foreign (libxla "MakeShape")
prim__mkShape : Int -> GCPtr Int -> Int -> PrimIO AnyPtr
Loading

0 comments on commit bded472

Please sign in to comment.