Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Jul 11, 2024
1 parent 7206d72 commit 9593c96
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 114 deletions.
15 changes: 5 additions & 10 deletions spidr/src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ namespace S32
fromInteger : Integer -> Tensor [] S32
fromInteger = tensor . Scalar . fromInteger

partial
try : Show e => EitherT e IO a -> IO a
try = eitherT (idris_crash . show) pure
try = eitherT (assert_total . idris_crash . show) pure

namespace Tag
||| Evaluate a `Tensor`, returning its value as a `Literal`. This function builds and executes the
Expand All @@ -143,7 +142,7 @@ namespace Tag
||| **Note:** Each call to `eval` will rebuild and execute the graph; multiple calls to `eval` on
||| different tensors, even if they are in the same computation, will be treated independently.
||| To efficiently evaluate multiple tensors at once, use `TensorList.Tag.eval`.
export covering -- is this true?
export covering
eval : Device -> PrimitiveRW dtype ty => Tag (Tensor shape dtype) -> IO (Literal shape ty)
eval device (MkTagT x) =
let (env, MkTensor root) = runState empty x
Expand All @@ -152,12 +151,8 @@ namespace Tag
[lit] <- execute device (MkFn [] root env) [shape]
read {dtype} [] lit

-- is it safe to use this within a `Tag` context? It's probably not with `unsafePerformIO`, but
-- that's kind of expected. What about w/o `unsafePerformIO`?
||| A convenience wrapper for `Tag.eval`, for use with a bare `Tensor`.
|||
||| **Note:** It is not safe to use
export covering -- is this true?
export covering
eval : Device -> PrimitiveRW dtype ty => Tensor shape dtype -> IO (Literal shape ty)
eval device x = eval device (pure x)

Expand Down Expand Up @@ -185,7 +180,7 @@ namespace TensorList
||| ```
||| In contrast to `Tensor.eval` when called on multiple tensors, this function constructs and
||| compiles the graph just once.
export covering -- is this true?
export covering
eval : Device -> Tag (TensorList shapes tys) -> IO (All2 Literal shapes tys)
eval device (MkTagT xs) =
let (env, xs) = runState empty xs
Expand Down Expand Up @@ -215,7 +210,7 @@ namespace TensorList
readAll (MkTensor {dtype} _ :: ts) (l :: ls) = [| read {dtype} [] l :: readAll ts ls |]

||| A convenience wrapper for `TensorList.Tag.eval`, for use with a bare `TensorList`.
export covering -- is this true?
export covering
eval : Device -> TensorList shapes tys -> IO (All2 Literal shapes tys)
eval device xs = eval device (pure xs)

Expand Down
2 changes: 1 addition & 1 deletion test/runner/TestRunner.idr
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import Unit.TestTensor
import Unit.TestLiteral
import Unit.TestUtil

export partial
export
run : Device -> IO ()
run device = test [
Utils.TestComparison.group
Expand Down
3 changes: 1 addition & 2 deletions test/runner/Unit/Model/TestKernel.idr
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import Model.Kernel
import Utils.Comparison
import Utils.Cases

partial
rbfMatchesTFP : Device => Property
rbfMatchesTFP = fixedProperty $ do
let lengthScale = tensor 0.4
Expand All @@ -42,7 +41,7 @@ rbfMatchesTFP = fixedProperty $ do
]
rbf lengthScale x x' ===# pure expected

export partial
export
group : Device => Group
group = MkGroup "Kernel" $ [
("rbf matches tfp", rbfMatchesTFP)
Expand Down
5 changes: 1 addition & 4 deletions test/runner/Unit/TestDistribution.idr
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import Distribution
import Utils.Comparison
import Utils.Cases

partial
gaussianUnivariatePDF : Device => Property
gaussianUnivariatePDF = property $ do
let doubles = literal [] doubles
Expand All @@ -37,15 +36,13 @@ gaussianUnivariatePDF = property $ do
univariate : Double -> Double -> Double -> Double
univariate x mean cov = exp (- (x - mean) * (x - mean) / (2 * cov)) / sqrt (2 * pi * cov)

partial
gaussianMultivariatePDF : Device => Property
gaussianMultivariatePDF = fixedProperty $ do
let mean = tensor [[-0.2], [0.3]]
cov = tensor [[[1.2], [0.5]], [[0.5], [0.7]]]
x = tensor [[1.1], [-0.5]]
pdf (MkGaussian mean cov) x ===# pure 0.016427375

partial
gaussianCDF : Device => Property
gaussianCDF = fixedProperty $ do
let gaussian = MkGaussian (tensor [[0.5]]) (tensor [[[1.44]]])
Expand All @@ -55,7 +52,7 @@ gaussianCDF = fixedProperty $ do
cdf gaussian (tensor [[0.5]]) ===# pure 0.5
cdf gaussian (tensor [[1.5]]) ===# pure 0.7976716

export partial
export
group : Device => Group
group = MkGroup "Distribution" $ [
("Gaussian univariate pdf", gaussianUnivariatePDF)
Expand Down
25 changes: 3 additions & 22 deletions test/runner/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import Utils.Comparison
import Utils.Cases
import Utils.Proof

partial
tensorThenEval : Device => Property
tensorThenEval @{device} = withTests 20 . property $ do
shape <- forAll shapes
Expand All @@ -54,7 +53,6 @@ tensorThenEval @{device} = withTests 20 . property $ do
x <- forAll (literal shape bool)
x === unsafePerformIO (eval device $ pure $ tensor {dtype = PRED} x)

partial
evalTuple : Device => Property
evalTuple @{device} = property $ do
s0 <- forAll shapes
Expand Down Expand Up @@ -86,7 +84,6 @@ evalTuple @{device} = property $ do
x1' === x1
x2' === x2

partial
evalTupleNonTrivial : Device => Property
evalTupleNonTrivial @{device} = property $ do
let xs = do let y0 = tensor [1.0, -2.0, 0.4]
Expand All @@ -101,7 +98,6 @@ evalTupleNonTrivial @{device} = property $ do
v ==~ Scalar (exp (-2.0) + 3.0)
w ==~ [| exp [1.0, -2.0] |]

partial
canConvertAtXlaNumericBounds : Device => Property
canConvertAtXlaNumericBounds = fixedProperty $ do
let f64min : Literal [] Double = min @{Finite}
Expand Down Expand Up @@ -140,7 +136,6 @@ canConvertAtXlaNumericBounds = fixedProperty $ do
unsafeEval (tensor u64min == min') === True
unsafeEval (tensor u64max == max') === True

partial
boundedNonFinite : Device => Property
boundedNonFinite = fixedProperty $ do
let min' : Tensor [] S32 = Types.min @{NonFinite}
Expand All @@ -163,7 +158,6 @@ boundedNonFinite = fixedProperty $ do
unsafeEval {dtype = F64} (Types.min @{NonFinite}) === -inf
unsafeEval {dtype = F64} (Types.max @{NonFinite}) === inf

partial
iota : Device => Property
iota = withTests 20 . property $ do
init <- forAll shapes
Expand All @@ -190,7 +184,6 @@ iota = withTests 20 . property $ do

actual ===# castDtype rangeFull

partial
iotaExamples : Device => Property
iotaExamples = fixedProperty $ do
iota 0 ===# tensor {dtype = S32} [0, 1, 2, 3]
Expand All @@ -212,7 +205,6 @@ iotaExamples = fixedProperty $ do
[1.0, 1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0, 2.0]]

partial
show : Device => Property
show = fixedProperty $ do
let x : Tag $ Tensor [] S32 = pure 1
Expand Down Expand Up @@ -281,7 +273,6 @@ show = fixedProperty $ do
"""
-- x ===# pure 24 -- bug in XLA? https://github.com/openxla/xla/issues/14299

partial
cast : Device => Property
cast = property $ do
shape <- forAll shapes
Expand All @@ -298,7 +289,6 @@ cast = property $ do
let x : Tensor shape F64 = castDtype $ tensor {dtype = S32} lit
x ===# tensor (map (cast {to = Double}) lit)

partial
identity : Device => Property
identity = fixedProperty $ do
identity ===# tensor {dtype = S32} []
Expand All @@ -314,15 +304,15 @@ identity = fixedProperty $ do
]

namespace Vector
export partial
export
(@@) : Device => Property
(@@) = fixedProperty $ do
let l = tensor {dtype = S32} [-2, 0, 1]
r = tensor {dtype = S32} [3, 1, 2]
l @@ r ===# -4

namespace Matrix
export partial
export
(@@) : Device => Property
(@@) = fixedProperty $ do
let l = tensor {dtype = S32} [[-2, 0, 1], [1, 3, 4]]
Expand All @@ -333,7 +323,6 @@ namespace Matrix
r = tensor {dtype = S32} [[3, -1], [3, 2], [-1, -4]]
l @@ r ===# tensor [[ -7, -2], [ 8, -11]]

partial
dotGeneral : Device => Property
dotGeneral = fixedProperty $ do
dotGeneral [] [] [] [] 2 3 ===# tensor {dtype = S32} 6
Expand Down Expand Up @@ -390,23 +379,20 @@ dotGeneral = fixedProperty $ do
[1.1037, 1.5626]]]
dotGeneral [0] [0] [2] [1] l r ===# expected

partial
argmin : Device => Property
argmin = property $ do
d <- forAll dims
xs <- forAll (literal [S d] doubles)
let xs = tensor xs
(do pure $ slice [at !(argmin xs)] xs) ===# reduce [0] @{Min} xs

partial
argmax : Device => Property
argmax = property $ do
d <- forAll dims
xs <- forAll (literal [S d] doubles)
let xs = tensor xs
(do pure $ slice [at !(argmax xs)] xs) ===# reduce [0] @{Max} xs

partial
select : Device => Property
select = fixedProperty $ do
let onTrue = tensor {dtype = S32} 1
Expand All @@ -420,14 +406,12 @@ select = fixedProperty $ do
expected = tensor [[6, 1, 2], [3, 10, 11]]
select pred onTrue onFalse ===# expected

partial
erf : Device => Property
erf = fixedProperty $ do
let x = tensor [-1.5, -0.5, 0.5, 1.5]
expected = tensor [-0.96610516, -0.5204998, 0.5204998, 0.9661051]
erf x ===# expected

partial
cholesky : Device => Property
cholesky = fixedProperty $ do
let x = tensor [[1.0, 0.0], [2.0, 0.0]]
Expand All @@ -447,7 +431,6 @@ cholesky = fixedProperty $ do
]
cholesky x ===# expected

partial
triangularSolveResultAndInverse : Device => Property
triangularSolveResultAndInverse = fixedProperty $ do
let a = tensor [
Expand Down Expand Up @@ -478,7 +461,6 @@ triangularSolveResultAndInverse = fixedProperty $ do
actual ===# expected
a.T @@ actual ===# b

partial
triangularSolveIgnoresOppositeElems : Device => Property
triangularSolveIgnoresOppositeElems = fixedProperty $ do
let a = tensor [[1.0, 2.0], [3.0, 4.0]]
Expand All @@ -489,12 +471,11 @@ triangularSolveIgnoresOppositeElems = fixedProperty $ do
let aUpper = tensor [[1.0, 2.0], [0.0, 4.0]]
a \| b ===# aUpper \| b

partial
trace : Device => Property
trace = fixedProperty $
trace (tensor {dtype = S32} [[-1, 5], [1, 4]]) ===# pure 3

export partial
export
group : Device => Group
group = MkGroup "Tensor" $ [
("eval . tensor", tensorThenEval)
Expand Down
Loading

0 comments on commit 9593c96

Please sign in to comment.