From e3fc6c200d64a4a6dc54b38d91969b25235bd3d1 Mon Sep 17 00:00:00 2001 From: Joel Berkeley <16429957+joelberkeley@users.noreply.github.com> Date: Sat, 3 Jun 2023 23:27:02 +0100 Subject: [PATCH] wip --- src/Compiler/Eval.idr | 10 +- src/Compiler/Expr.idr | 16 ++- src/Compiler/Transform.idr | 37 ++++-- src/Tensor.idr | 256 ++++++++++++++++++++----------------- 4 files changed, 186 insertions(+), 133 deletions(-) diff --git a/src/Compiler/Eval.idr b/src/Compiler/Eval.idr index dc7105a92..f28601fb3 100644 --- a/src/Compiler/Eval.idr +++ b/src/Compiler/Eval.idr @@ -67,7 +67,7 @@ lookup n = do lift $ left (IndexErr "Tried to look up XlaOp at index \{show n} but none was found. Indices: \{show $ keys !get}") Just op => pure op -interpret : XlaBuilder -> Nat -> Env -> Computation XlaOp +interpret : XlaBuilder -> Nat -> Program -> Computation XlaOp buildSub : XlaBuilder -> String -> Fn arity -> Computation XlaComputation buildSub builder name (MkFn params i env) = do @@ -78,8 +78,8 @@ buildSub builder name (MkFn params i env) = do where - interpretParameter : XlaBuilder -> (Nat, Nat, ShapeAndType) -> Computation () - interpretParameter builder (position, i, MkShapeAndType shape dtype) = do + interpretParameter : XlaBuilder -> (Nat, Nat, FullShape) -> Computation () + interpretParameter builder (position, i, shape ### dtype) = do xlaShape <- mkShape {dtype} shape param <- parameter builder position xlaShape name put $ insert i param !get @@ -203,14 +203,14 @@ interpret builder root env = do interpretExpr (n, expr) = put (insert n !(enqueue builder expr) !get) export -toString : Nat -> Env -> EitherT Err IO String +toString : Nat -> Program -> EitherT Err IO String toString root env = do builder <- mkXlaBuilder "toString" xlaOp <- evalStateT empty (interpret builder root env) pure $ opToString builder xlaOp export -run : PrimitiveRW dtype a => Nat -> Env -> {shape : _} -> EitherT Err IO (Literal shape a) +run : PrimitiveRW dtype a => Nat -> Program -> {shape : _} -> EitherT Err IO (Literal shape a) run root env = do builder <- mkXlaBuilder "root" root <- evalStateT empty (interpret builder root env) diff --git a/src/Compiler/Expr.idr b/src/Compiler/Expr.idr index c7bfb9c20..1a11f77cd 100644 --- a/src/Compiler/Expr.idr +++ b/src/Compiler/Expr.idr @@ -25,9 +25,11 @@ import Primitive import Types import Util +infix 9 ### + public export -data ShapeAndType : Type where - MkShapeAndType : Shape -> (0 dtype : Type) -> Primitive dtype => ShapeAndType +data FullShape : Type where + (###) : Shape -> (0 dtype : Type) -> Primitive dtype => FullShape export new : Ref Nat @@ -39,12 +41,16 @@ new = do data Expr : Type where public export 0 -Env : Type -Env = SortedMap Nat Expr +ProgramShape : Type +ProgramShape = SortedMap Nat Shape + +public export 0 +Program : Type +Program = SortedMap Nat Expr public export data Fn : Nat -> Type where - MkFn : {arity : _} -> Vect arity (Nat, ShapeAndType) -> Nat -> Env -> Fn arity + MkFn : {arity : _} -> Vect arity (Nat, FullShape) -> Nat -> Program -> Fn arity public export data BinaryOp = diff --git a/src/Compiler/Transform.idr b/src/Compiler/Transform.idr index d8e0d3606..1a111fa76 100644 --- a/src/Compiler/Transform.idr +++ b/src/Compiler/Transform.idr @@ -42,13 +42,34 @@ record Acc where ||| Keys are indices of nodes in the original metadata : SortedMap Nat Value ||| the resulting graph - graph : Env - + graph : Program + +||| Traverse the `program` in sorted order. For each `Expr` in the graph, inspect the nodes it is +||| built from. Each node it is built from either +||| * does not exist in `program`. This means that it comes from the global scope, is therefore +||| constant with respect to the `vmap` argument, and we simply broadcast the value using the +||| shape extracted from `programShape`. +||| * exists in `program`, in which case ... +||| If a node is built from only constant nodes, it is also constant. +||| +||| @res A pointer to the return value of the original function. +||| @n The size of the vmap-ed dimension. +||| @param A pointer to the parameter in the `vmap`-ed function. +||| @arg A pointer to the argument to `vmap`. +||| @to The return shape of the function to vmap. +||| @localProgram The program to vmap. We vecotrize the whole of this, so this should not include +||| the whole global program, just the local program containing all values dependent on the value +||| we vmap over. +||| @globalProgramShape The shape of the whole global program. export partial -vmap : (res, n, param, arg : Nat) -> (to : Shape) -> (graph : Env) -> Ref (Env, Nat) -vmap res n param arg to original = do +vmap : (res, n, param, arg : Nat) -> + (to : Shape) -> + (localProgram : Program) -> + (globalProgramShape : ProgramShape) -> + Ref (Program, Nat) +vmap res n param arg to localProgram globalProgramShape = do foo <- runEitherT $ do - acc <- recurse (toList original) impl (MkAcc empty empty) + acc <- recurse (toList localProgram) impl (MkAcc empty empty) case lookup res acc.metadata `or` idris_crash "\{show res} \{show (keys acc.metadata)}" of Var i => pure (acc.graph, i) Const => lift new <&> \j => (insert j (Broadcast to (n :: to) res) acc.graph, j) @@ -140,14 +161,14 @@ vmap res n param arg to original = do ||| @n The size of the extra dimensions we're mapping over. ||| @arg The index of the argument to replace export covering -vmap : (res, n, arg : Nat) -> (unvmapped : Env) -> Expr -> Ref (Env, Nat) +vmap : (res, n, arg : Nat) -> (unvmapped : Program) -> Expr -> Ref (Program, Nat) vmap res n arg unvmapped expr = runStateT empty (impl expr) where - impl : Expr -> StateT Env Ref Nat + impl : Expr -> StateT Program Ref Nat - recurse : Shape -> Nat -> StateT Env Ref Nat + recurse : Shape -> Nat -> StateT Program Ref Nat recurse shape j = case lookup j unvmapped of Just expr => impl expr diff --git a/src/Tensor.idr b/src/Tensor.idr index c6d43eeaf..40b9529b7 100644 --- a/src/Tensor.idr +++ b/src/Tensor.idr @@ -47,17 +47,16 @@ import public Util ||| @dtype The element type. export data Tensor : (shape : Shape) -> (dtype : Type) -> Type where - MkTensor : {shape : _} -> Nat -> Env -> Tensor shape dtype + MkTensor : {shape : _} -> Nat -> ProgramShape -> Program -> Tensor shape dtype -end : Env -> Expr -> {shape : _} -> Ref $ Tensor shape dtype -end env expr = do - i <- new - pure $ MkTensor i (insert i expr env) +extend : ProgramShape -> Program -> Expr -> {shape : _} -> Ref $ Tensor shape dtype +extend progShape prog expr = new <&> \node => + MkTensor node (insert node shape progShape) (insert node expr prog) ||| Construct a `Tensor` from `Literal` data. export tensor : PrimitiveRW dtype a => {shape : _} -> Literal shape a -> Ref $ Tensor shape dtype -tensor lit = empty `end` FromLiteral {dtype} {shape} lit +tensor lit = extend empty empty (FromLiteral {dtype} {shape} lit) namespace F64 export @@ -83,8 +82,8 @@ namespace S32 ||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`. export partial eval : PrimitiveRW dtype ty => Ref (Tensor shape dtype) -> IO (Literal shape ty) -eval x = let MkTensor n nodes = evalState 0 x in - runEitherT (run {dtype} n (traceVal nodes)) <&> \case +eval x = let MkTensor n _ program = evalState 0 x in + runEitherT (run {dtype} n (traceVal program)) <&> \case Right lit => lit Left err => idris_crash (show err) @@ -92,27 +91,27 @@ eval x = let MkTensor n nodes = evalState 0 x in ||| Useful for debugging. export partial Show (Ref $ Tensor shape dtype) where - show x = let MkTensor n nodes = evalState 0 x in - case unsafePerformIO $ runEitherT $ toString n nodes of + show x = let MkTensor n _ program = evalState 0 x in + case unsafePerformIO $ runEitherT $ toString n program of Right str => str ||| Bounds for numeric tensors. Will be infinite for floating point types. export [NonFinite] Primitive.Ord dtype => Bounded (Ref $ Tensor [] dtype) where - min = empty `end` MinValue {dtype} - max = empty `end` MaxValue {dtype} + min = extend empty empty $ MinValue {dtype} + max = extend empty empty $ MaxValue {dtype} ||| Finite bounds for numeric tensors. export [Finite] Primitive.Ord dtype => Bounded (Ref $ Tensor [] dtype) where - min = empty `end` MinFiniteValue {dtype} - max = empty `end` MaxFiniteValue {dtype} + min = extend empty empty $ MinFiniteValue {dtype} + max = extend empty empty $ MaxFiniteValue {dtype} ||| Cast the element type. For example, `castDtype (tensor {dtype=S32} [1, -2])` is ||| `tensor {dtype=F64} [1.0, -2.0]`. export castDtype : Primitive.Integral a => Tensor shape a -> Ref $ Tensor shape F64 -castDtype $ MkTensor i env = env `end` ConvertElementType {dtype=F64} i +castDtype $ MkTensor i progShape prog = extend progShape prog $ ConvertElementType {dtype=F64} i ----------------------------- structural operations ---------------------------- @@ -125,7 +124,7 @@ reshape : {auto 0 sizesEqual : product from = product to} -> Tensor from dtype -> Ref $ Tensor to dtype -reshape $ MkTensor {shape} i env = env `end` Reshape shape to i +reshape $ MkTensor {shape} i progShape prog = extend progShape prog $ Reshape shape to i ||| Add a dimension of length one at the specified `axis`. The new dimension will be at the ||| specified `axis` in the new `Tensor` (as opposed to the original `Tensor`). For example, @@ -137,7 +136,8 @@ expand : {auto 0 inBounds : axis `LTE` length shape} -> Tensor shape dtype -> Ref $ Tensor (insertAt axis 1 shape) dtype -expand axis $ MkTensor {shape = _} i env = env `end` Reshape shape (insertAt axis 1 shape) i +expand axis $ MkTensor {shape = _} i progShape prog = + extend progShape prog $ Reshape shape (insertAt axis 1 shape) i namespace Squeezable ||| A `Squeezable from to` constitutes proof that the shape `from` can be squeezed to the @@ -184,7 +184,7 @@ squeeze : {auto 0 shapesSqueezable : Squeezable from to} -> Tensor from dtype -> Ref $ Tensor to dtype -squeeze $ MkTensor {shape} i env = env `end` Reshape shape to i +squeeze $ MkTensor {shape} i progShape prog = extend progShape prog $ Reshape shape to i ||| A `SliceOrIndex d` is a valid slice or index into a dimension of size `d`. See `slice` for ||| details. @@ -340,13 +340,16 @@ slice : (at : MultiSlice shape) -> Tensor shape dtype -> Ref $ Tensor (slice at) dtype -slice at $ MkTensor i env = do +slice at $ MkTensor i progShape prog = do + -- handle program shapes j <- new - let env = insert j (Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) i) env - (dynStartsIdxs, env) <- dynStarts [] env at + let sliced = Slice (mapd start (const 0) at) (mapd stop id at) (replicate (length shape) 1) i + prog = insert j sliced prog + progShape = insert j ?sliceFullShape progShape + (dynStartsIdxs, env) <- dynStarts [] prog at k <- new let env = insert k (DynamicSlice dynStartsIdxs (mapd size id at) j) env - env `end` Reshape (mapd size id at) (MultiSlice.slice at) k + extend progShape prog $ Reshape (mapd size id at) (MultiSlice.slice at) k where mapd : @@ -377,24 +380,24 @@ slice at $ MkTensor i env = do zero : Expr zero = FromLiteral {shape=[]} {dtype=U64} 0 - dynStarts : List Nat -> Env -> {shape : _} -> MultiSlice shape -> Ref (List Nat, Env) + dynStarts : List Nat -> Program -> {shape : _} -> MultiSlice shape -> Ref (List Nat, Program) dynStarts idxs env {shape} [] = f (length shape) (idxs, env) where - f : Nat -> (List Nat, Env) -> Ref (List Nat, Env) + f : Nat -> (List Nat, Program) -> Ref (List Nat, Program) f 0 (idxs, env) = pure (idxs, env) f (S k) (idxs, env) = do i <- new f k (i :: idxs, insert i zero env) - dynStarts idxs env (DynamicSlice (MkTensor i env') _ :: ds) = do - (idxs, env) <- dynStarts idxs env ds - pure (i :: idxs, mergeLeft env env') - dynStarts idxs env (DynamicIndex (MkTensor i env') :: ds) = do - (idxs, env) <- dynStarts idxs env ds - pure (i :: idxs, mergeLeft env env') - dynStarts idxs env (_ :: ds) = do - (idxs, env) <- dynStarts idxs env ds + dynStarts idxs prog (DynamicSlice (MkTensor i progShape' prog') _ :: ds) = do + (idxs, prog) <- dynStarts idxs prog ds + pure (i :: idxs, mergeLeft prog prog') + dynStarts idxs prog (DynamicIndex (MkTensor i progShape' prog') :: ds) = do + (idxs, prog) <- dynStarts idxs prog ds + pure (i :: idxs, mergeLeft prog prog') + dynStarts idxs prog (_ :: ds) = do + (idxs, prog) <- dynStarts idxs prog ds i <- new - pure (i :: idxs, insert i zero env) + pure (i :: idxs, insert i zero prog) ||| Concatenate two `Tensor`s along the specfied `axis`. For example, ||| `concat 0 !(tensor [[1, 2], [3, 4]]) !(tensor [[5, 6]])` and @@ -409,8 +412,8 @@ concat : {auto 0 inBounds : (InBounds axis s, InBounds axis s')} -> {auto 0 shapesConcatenable : deleteAt axis s = deleteAt axis s'} -> Ref $ Tensor (replaceAt axis (index axis s + index axis s') s) dtype -concat axis (MkTensor {shape = _} i env) (MkTensor {shape = _} i' env') = - mergeLeft env env' `end` Concat axis i i' +concat axis (MkTensor {shape = _} i progShape prog) (MkTensor {shape = _} i' progShape' prog') = + extend (mergeLeft progShape progShape') (mergeLeft prog prog') (Concat axis i i') ||| The diagonal of a matrix as a vector. For example, for ||| ``` @@ -422,7 +425,7 @@ concat axis (MkTensor {shape = _} i env) (MkTensor {shape = _} i' env') = ||| `diag !x` is `tensor [0, 4, 8]`. export diag : Primitive dtype => Tensor [n, n] dtype -> Ref (Tensor [n] dtype) -diag $ MkTensor i env = env `end` Diag i +diag $ MkTensor i progShape prog = extend progShape prog $ Diag i ||| Represents the upper- or lower-trinagular component of a matrix. public export @@ -444,14 +447,15 @@ data Triangle = Upper | Lower ||| ``` export triangle : Primitive dtype => Triangle -> Tensor [n, n] dtype -> Ref $ Tensor [n, n] dtype -triangle tri $ MkTensor i env = env `end` Triangle (case tri of Upper => False; Lower => True) i +triangle tri $ MkTensor i progShape prog = + extend progShape prog $ Triangle (case tri of Upper => False; Lower => True) i ||| Tranpose a matrix. For example, `(tensor [[1, 2], [3, 4]]).T` is `tensor [[1, 3], [2, 4]]`. export (.T) : Ref (Tensor [m, n] dtype) -> Ref (Tensor [n, m] dtype) x.T = do - MkTensor i env <- x - env `end` Transpose [1, 0] i + MkTensor i progShape prog <- x + extend progShape prog $ Transpose [1, 0] i ||| Transpose axes of a tensor. This is a more general version of `(.T)`, in which you can ||| transpose any number of axes in a tensor of arbitrary rank. The i'th axis in the resulting @@ -505,7 +509,7 @@ transpose : {auto 0 unique : Sorted Neq ordering} -> {auto 0 inBounds : All (flip InBounds shape) ordering} -> Ref $ Tensor (map (dflip List.index shape) ordering) dtype -transpose ordering $ MkTensor i env = env `end` Transpose ordering i +transpose ordering $ MkTensor i progShape prog = extend progShape prog $ Transpose ordering i ||| The identity tensor, with inferred shape and element type. For example, ||| ``` @@ -520,7 +524,7 @@ transpose ordering $ MkTensor i env = env `end` Transpose ordering i ||| ``` export identity : Primitive.Num dtype => {n : _} -> Ref $ Tensor [n, n] dtype -identity = empty `end` Identity {dtype} n +identity = extend empty empty $ Identity {dtype} n ||| A `DimBroadcastable from to` proves that a dimension of size `from` can be broadcast to a ||| dimension of size `to`. @@ -590,7 +594,7 @@ broadcast : {auto shapesOK : Broadcastable from to} -> Tensor from dtype -> Ref $ Tensor to dtype -broadcast $ MkTensor {shape=_} i env = env `end` Broadcast from to i +broadcast $ MkTensor {shape=_} i progShape prog = extend progShape prog $ Broadcast from to i %hint export @@ -616,12 +620,12 @@ fill xs = broadcast {shapesOK=scalarToAnyOk shape} !(tensor (Scalar xs)) ----------------------------- generic operations ---------------------------- -arg : Primitive dtype => {shape : _} -> Ref (Tensor shape dtype, Nat, ShapeAndType) +arg : Primitive dtype => {shape : _} -> Ref (Tensor shape dtype, Nat, FullShape) arg = do i <- new - pure (MkTensor i (singleton i (Arg i)), (i, MkShapeAndType shape dtype)) + pure (MkTensor i (singleton i shape) (singleton i (Arg i)), (i, shape ### dtype)) -lookup' : Nat -> Env -> Expr +lookup' : Nat -> Program -> Expr lookup' x env = case lookup x env of Just expr => expr Nothing => assert_total $ idris_crash "" @@ -642,11 +646,14 @@ vmap : Primitive a => (Tensor from a -> Ref $ Tensor to b) -> Tensor (n :: from) a -> Ref $ Tensor (n :: to) b -vmap f (MkTensor {shape=n :: from} i env) = do +vmap f (MkTensor {shape=n :: from} i progShape prog) = do + -- rather than separate Program and ProgramShape, just combine them and pass it separately to + -- Transform.vmap j <- new - MkTensor {shape = _} k unVmappedEnv <- f (MkTensor j (singleton j (Arg j))) - (vmappedEnv, l) <- vmap k n j i to unVmappedEnv - pure (MkTensor l (mergeLeft env vmappedEnv)) + MkTensor {shape = _} k unVmappedProgShape unVmappedProg <- + f (MkTensor j (singleton j []) (singleton j (Arg j))) + (vmappedProg, l) <- vmap k n j i to unVmappedProg progShape + pure (MkTensor l ?vmapProgShape (mergeLeft prog vmappedProg)) {- namespace Binary ||| `vmap` for mapping over binary functions. @@ -695,15 +702,17 @@ reduce : {auto 0 axesInBounds : All (flip InBounds shape) axes} -> Tensor shape dtype -> Ref $ Tensor (deleteAt axes shape) dtype -reduce axes $ MkTensor i xEnv = do +reduce axes $ MkTensor i xProgShape xProg = do (a0, p0) <- arg (a1, p1) <- arg let semigroupT : Monoid a -> Semigroup a semigroupT _ = %search - MkTensor j subEnv <- (<+>) @{semigroupT reducer} (pure a0) (pure a1) - MkTensor k neutralEnv <- neutral @{reducer} - mergeLeft xEnv neutralEnv `end` Reduce (MkFn [p0, p1] j subEnv) k axes i + MkTensor j subProgShape subProg <- (<+>) @{semigroupT reducer} (pure a0) (pure a1) + MkTensor k neutralProgShape neutralProg <- neutral @{reducer} + let progShape = mergeLeft xProgShape neutralProgShape + prog = mergeLeft xProg neutralProg + extend progShape prog $ Reduce (MkFn [p0, p1] j subProg) k axes i ||| Sort the elements of a `Tensor` along a specified `dimension` according to a scalar-wise ||| ordering. For sorting function `f`, elements are sorted such that for consecutive sorted @@ -723,11 +732,11 @@ sort : Tensor shape dtype -> {auto 0 dimInBounds : InBounds dimension shape} -> Ref $ Tensor shape dtype -sort comp dimension $ MkTensor i env = do +sort comp dimension $ MkTensor i progShape prog = do (a0, p0) <- arg (a1, p1) <- arg - MkTensor j subEnv <- comp (pure a0) (pure a1) - env `end` Sort (MkFn [p0, p1] j subEnv) dimension False [i] + MkTensor j subProgShape subProg <- comp (pure a0) (pure a1) + extend progShape prog $ Sort (MkFn [p0, p1] j subProg) dimension False [i] ||| Reverse elements along the specified axes. For example, for ||| ``` @@ -764,15 +773,17 @@ reverse : {auto 0 axesInBounds : All (flip InBounds shape) axes} -> Tensor shape dtype -> Ref $ Tensor shape dtype -reverse axes $ MkTensor i env = env `end` Reverse axes i +reverse axes $ MkTensor i progShape prog = extend progShape prog $ Reverse axes i ----------------------------- numeric operations ---------------------------- binaryRef : BinaryOp -> Ref (Tensor s a) -> Ref (Tensor s a') -> Ref (Tensor s a'') binaryRef op x x' = do - MkTensor {shape = _} i env <- x - MkTensor {shape = _} i' env' <- x' - mergeLeft env env' `end` BinaryElementwise {shape = s} op i i' + MkTensor {shape = _} i progShape prog <- x + MkTensor {shape = _} i' progShape' prog' <- x' + let progShape = mergeLeft progShape progShape' + prog = mergeLeft prog prog' + extend progShape prog $ BinaryElementwise {shape = s} op i i' ||| Element-wise equality. For example, `tensor [1, 2] /= tensor [1, 3]` is ||| `tensor [True, False]`. @@ -865,7 +876,8 @@ namespace Monoid neutral = fill False unary : UnaryOp -> Tensor s a -> Ref $ Tensor s a' -unary op $ MkTensor {shape = _} i env = env `end` UnaryElementwise {shape = s} op i +unary op $ MkTensor {shape = _} i progShape prog = + extend progShape prog $ UnaryElementwise {shape = s} op i ||| Element-wise boolean negation. For example, `not !(tensor [True, False])` is ||| `tensor [False, True]`. @@ -897,8 +909,10 @@ select : (onTrue : Tensor shape dtype) -> (onFalse : Tensor shape dtype) -> Ref $ Tensor shape dtype -select (MkTensor p pred) (MkTensor t true) (MkTensor f false) = - mergeLeft (mergeLeft pred true) false `end` Select p t f +select (MkTensor p predShapes pred) (MkTensor t trueShape true) (MkTensor f falseShapes false) = + let progShape = mergeLeft (mergeLeft predShapes trueShape) falseShapes + prog = mergeLeft (mergeLeft pred true) false + in extend progShape prog $ Select p t f ||| Use a scalar predicate to choose which of two functions to evaluate. If the predicte is truthy, ||| evaluate `onTrue` on the corresponding specified argument, otherwise evaluate `onFalse` on the @@ -928,13 +942,18 @@ cond : (onTrue : Tensor ts tt -> Ref $ Tensor shape dtype) -> Tensor ts tt -> (onFalse : Tensor fs ft -> Ref $ Tensor shape dtype) -> Tensor fs ft -> Ref $ Tensor shape dtype -cond (MkTensor pred envPred) onTrue (MkTensor true envTrue) onFalse (MkTensor false envFalse) = do +cond (MkTensor pred predProgShape predProg) onTrue + (MkTensor true trueProgShape trueProg) onFalse + (MkTensor false falseProgShape falseProg) = do (aTrue, pTrue) <- arg (aFalse, pFalse) <- arg - MkTensor lTrue subEnvTrue <- onTrue aTrue - MkTensor lFalse subEnvFalse <- onFalse aFalse - let env = mergeLeft (mergeLeft envPred envTrue) envFalse - env `end` Cond pred (MkFn [pTrue] lTrue subEnvTrue) true (MkFn [pFalse] lFalse subEnvFalse) false + MkTensor lTrue trueSubProgShape trueSubProg <- onTrue aTrue + MkTensor lFalse falseSubProgShape falseSubProg <- onFalse aFalse + let progShape = mergeLeft (mergeLeft predProgShape trueSubProgShape) falseProgShape + prog = mergeLeft (mergeLeft predProg trueProg) falseProg + expr = + Cond pred (MkFn [pTrue] lTrue trueSubProg) true (MkFn [pFalse] lFalse falseSubProg) false + extend progShape prog expr -- see https://www.python.org/dev/peps/pep-0465/#precedence-and-associativity infixl 9 @@ @@ -949,9 +968,9 @@ namespace Vector Ref (Tensor [S m] dtype) -> Ref (Tensor [] dtype) x @@ x' = do - MkTensor i env <- x - MkTensor i' env' <- x' - mergeLeft env env' `end` Dot i i' + MkTensor i progShape prog <- x + MkTensor i' progShape' prog' <- x' + extend (mergeLeft progShape progShape') (mergeLeft prog prog') (Dot i i') namespace Matrix ||| Matrix multiplication with a matrix or vector. Contraction is along the last axis of the first @@ -982,9 +1001,9 @@ namespace Matrix {auto 0 vectorTail : length tl `LTE` 1} -> Ref (Tensor (n :: tl) dtype) x @@ x' = do - MkTensor i env <- x - MkTensor i' env' <- x' - mergeLeft env env' `end` Dot i i' + MkTensor i progShape prog <- x + MkTensor i' progShape' prog' <- x' + extend (mergeLeft progShape progShape') (mergeLeft prog prog') (Dot i i') ||| Element-wise addition. For example, `tensor [1, 2] + tensor [3, 4]` is ||| `tensor [4, 6]`. @@ -1013,8 +1032,8 @@ namespace Monoid export negate : Primitive.Neg dtype => Ref (Tensor shape dtype) -> Ref (Tensor shape dtype) negate x = do - MkTensor {shape = _} i env <- x - env `end` UnaryElementwise {shape} Neg i + MkTensor {shape = _} i progShape prog <- x + extend progShape prog $ UnaryElementwise {shape} Neg i ||| Element-wise subtraction. For example, `tensor [3, 4] - tensor [4, 2]` is ||| `tensor [-1, 2]`. @@ -1045,7 +1064,7 @@ namespace Scalarwise Ref (Tensor (d :: ds) dtype) -> Ref (Tensor (d :: ds) dtype) l * r = do - MkTensor {shape=_ :: _} _ _ <- r + MkTensor {shape=_ :: _} _ _ _ <- r broadcast {shapesOK=scalarToAnyOk (d :: ds)} !l * r namespace Semigroup @@ -1082,7 +1101,7 @@ namespace Scalarwise Ref (Tensor [] dtype) -> Ref (Tensor (d :: ds) dtype) l / r = do - MkTensor {shape = _ :: _} _ _ <- l + MkTensor {shape = _ :: _} _ _ _ <- l l / broadcast {shapesOK=scalarToAnyOk (d :: ds)} !r ||| Element-wise division of natural numbers. For example, @@ -1242,9 +1261,11 @@ sqrt = unary Sqrt ||| `min !(tensor [-3, -1, 3]) !(tensor [-1, 0, 1])` is `tensor [-3, -1, 1]`. export min : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Ref $ Tensor shape dtype -min (MkTensor {shape = _} i env) x'@(MkTensor i' env') = do - let x = MkTensor i env - op = mergeLeft env env' `end` BinaryElementwise {shape} Min i i' +min (MkTensor {shape = _} i progShape prog) x'@(MkTensor i' progShape' prog') = do + let x = MkTensor i progShape prog + progShape = mergeLeft progShape progShape' + prog = mergeLeft prog prog' + op = extend progShape prog $ BinaryElementwise {shape} Min i i' select !(pure x == pure x) !(select !(pure x' == pure x') !op x') x namespace Semigroup @@ -1265,9 +1286,11 @@ namespace Monoid ||| `max !(tensor [-3, -1, 3]) !(tensor [-1, 0, 1])` is `tensor [-1, 0, 3]`. export max : Primitive.Ord dtype => Tensor shape dtype -> Tensor shape dtype -> Ref $ Tensor shape dtype -max (MkTensor {shape = _} i env) x'@(MkTensor i' env') = do - let x = MkTensor i env - op = mergeLeft env env' `end` BinaryElementwise {shape} Max i i' +max (MkTensor {shape = _} i progShape prog) x'@(MkTensor i' progShape' prog') = do + let x = MkTensor i progShape prog + progShape = mergeLeft progShape progShape' + prog = mergeLeft prog prog' + op = extend progShape prog $ BinaryElementwise {shape} Max i i' select !(pure x == pure x) !(select !(pure x' == pure x') !op x') x namespace Semigroup @@ -1286,7 +1309,7 @@ namespace Monoid highlightNan : Primitive.Ord dtype => Bool -> Tensor [S n] dtype -> Ref $ Tensor [S n] dtype highlightNan minimize x with (x) - _ | (MkTensor {shape = _} _ _) = + _ | (MkTensor {shape = _} _ _ _) = cond !(reduce @{All} [0] !(pure x == pure x)) pure x extremizeNan x where @@ -1304,8 +1327,8 @@ highlightNan minimize x with (x) export argmin : Primitive.Ord dtype => Tensor [S n] dtype -> Ref $ Tensor [] U64 argmin x = do - MkTensor i env <- highlightNan True x - env `end` Argmin {out=U64} 0 i + MkTensor i progShape prog <- highlightNan True x + extend progShape prog $ Argmin {out=U64} 0 i ||| The first index of the maximum value in a vector. For example, ||| `argmax !(tensor [-1, 3, -2, -2, 3])` is `tensor 1`. If the vector contains NaN values, @@ -1313,8 +1336,8 @@ argmin x = do export argmax : Primitive.Ord dtype => Tensor [S n] dtype -> Ref $ Tensor [] U64 argmax x = do - MkTensor i env <- highlightNan False x - env `end` Argmax {out=U64} 0 i + MkTensor i progShape prog <- highlightNan False x + extend progShape prog $ Argmax {out=U64} 0 i ---------------------------- other ---------------------------------- @@ -1324,7 +1347,7 @@ argmax x = do ||| diagonal - will always be zero. export cholesky : Tensor [S n, S n] F64 -> Ref $ Tensor [S n, S n] F64 -cholesky $ MkTensor i env = triangle Lower !(env `end` Cholesky i) +cholesky $ MkTensor i progShape prog = triangle Lower !(extend progShape prog $ Cholesky i) infix 9 |\, \| @@ -1339,9 +1362,9 @@ namespace Matrix export (|\) : Ref (Tensor [m, m] F64) -> Ref (Tensor [m, n] F64) -> Ref (Tensor [m, n] F64) x |\ x' = do - MkTensor i env <- x - MkTensor i' env' <- x' - mergeLeft env env' `end` TriangularSolve i i' True + MkTensor i progShape prog <- x + MkTensor i' progShape' prog' <- x' + extend (mergeLeft progShape progShape') (mergeLeft prog prog') (TriangularSolve i i' True) ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the @@ -1353,9 +1376,9 @@ namespace Matrix export (\|) : Ref (Tensor [m, m] F64) -> Ref (Tensor [m, n] F64) -> Ref (Tensor [m, n] F64) x \| x' = do - MkTensor i env <- x - MkTensor i' env' <- x' - mergeLeft env env' `end` TriangularSolve i i' False + MkTensor i progShape prog <- x + MkTensor i' progShape' prog' <- x' + extend (mergeLeft progShape progShape') (mergeLeft prog prog') (TriangularSolve i i' False) namespace Vector ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is a lower-triangular matrix. @@ -1368,8 +1391,8 @@ namespace Vector export (|\) : Ref (Tensor [m, m] F64) -> Ref (Tensor [m] F64) -> Ref (Tensor [m] F64) a |\ b = do - MkTensor {shape=[_]} i env <- b - squeeze !(a |\ expand 1 (MkTensor {shape = [m]} i env)) + MkTensor {shape=[_]} i progShape prog <- b + squeeze !(a |\ expand 1 (MkTensor {shape = [m]} i progShape prog)) ||| Solve the set of linear equations `a @@ x = b` for `x` where `a` is an upper-triangular ||| matrix. `a` is given by the upper-triangular elements of the first argument. Values in the @@ -1381,8 +1404,8 @@ namespace Vector export (\|) : Ref (Tensor [m, m] F64) -> Ref (Tensor [m] F64) -> Ref (Tensor [m] F64) a \| b = do - MkTensor {shape=[_]} i env <- b - squeeze !(a \| expand 1 (MkTensor {shape = [m]} i env)) + MkTensor {shape=[_]} i progShape prog <- b + squeeze !(a \| expand 1 (MkTensor {shape = [m]} i progShape prog)) ||| Sum the elements along the diagonal of the input. For example, ||| `trace !(tensor [[-1, 5], [1, 4]])` is `3`. @@ -1392,7 +1415,7 @@ trace : (Primitive.Num dtype, Prelude.Num a) => Tensor [S n, S n] dtype -> Ref (Tensor [] dtype) trace x with (x) - _ | MkTensor {shape=[_, _]} _ _ = reduce @{Sum} [0, 1] !(Tensor.(*) (pure x) identity) + _ | MkTensor {shape=[_, _]} _ _ _ = reduce @{Sum} [0, 1] !(Tensor.(*) (pure x) identity) ||| A `Rand a` produces a pseudo-random value of type `a` from a `Tensor [1] U64` state. ||| The state is updated each time a new value is generated. @@ -1431,17 +1454,19 @@ uniform : (key : Tensor [] U64) -> (bound, bound' : Tensor shape F64) -> Ref $ Rand $ Tensor shape F64 -uniform (MkTensor iKey envKey) bound bound' = do - minval@(MkTensor iMinval envMinval) <- min bound bound' - maxval@(MkTensor iMaxval envMaxval) <- max bound bound' +uniform (MkTensor iKey progShapeKey progKey) bound bound' = do + minval@(MkTensor iMinval progShapeMinval progMinval) <- min bound bound' + maxval@(MkTensor iMaxval progShapeMaxval progMaxval) <- max bound bound' let inf = broadcast !inf - let env = mergeLeft (mergeLeft envKey envMinval) envMaxval - pure $ ST $ \(MkTensor iState envState) => do + let progShape = mergeLeft (mergeLeft progShapeKey progShapeMinval) progShapeMaxval + prog = mergeLeft (mergeLeft progKey progMinval) progMaxval + pure $ ST $ \(MkTensor iState progShapeState progState) => do i <- new - let env = mergeLeft envState env - env = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) env - state = env `end` GetTupleElement 1 i - value = env `end` GetTupleElement 0 i + let progShape = insert i ?progShapeNormalTupleUniform (mergeLeft progShapeState progShape) + prog = mergeLeft progState prog + prog = insert i (UniformFloatingPoint iKey iState iMinval iMaxval shape) prog + state = extend progShape prog $ GetTupleElement 1 i + value = extend progShape prog $ GetTupleElement 0 i -- workaround for XLA bug https://github.com/tensorflow/tensorflow/issues/56663 -- samples between -inf and 0 should be at -inf, but XLA produces nan -- similarly, samples in (inf, inf) should be at inf and respectively for -inf @@ -1468,10 +1493,11 @@ uniform (MkTensor iKey envKey) bound bound' = do ||| @key Determines the stream of generated samples. export normal : {shape : _} -> (key : Tensor [] U64) -> Rand $ Tensor shape F64 -normal $ MkTensor iKey envKey = - ST $ \(MkTensor iState envState) => do +normal $ MkTensor iKey progShapeKey progKey = + ST $ \(MkTensor iState progShapeState progState) => do i <- new - let env = insert i (NormalFloatingPoint iKey iState shape) $ mergeLeft envKey envState - state <- env `end` GetTupleElement 1 i - value <- env `end` GetTupleElement 0 i + let progShape = insert i ?progShapeNormalTupleNormal (mergeLeft progShapeKey progShapeState) + prog = insert i (NormalFloatingPoint iKey iState shape) $ mergeLeft progKey progState + state <- extend progShape prog $ GetTupleElement 1 i + value <- extend progShape prog $ GetTupleElement 0 i pure (state, value)