Skip to content

Commit

Permalink
Added lots of index sets to prelude.
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Sep 24, 2023
1 parent f456031 commit 8575c17
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 11 deletions.
10 changes: 9 additions & 1 deletion examples/ctc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ instance Ix(FenceAndPosts n) given (n|Ix)
False -> Posts $ unsafe_from_ordinal (intdiv2 o)

instance NonEmpty(FenceAndPosts n) given (n|Ix)
first_ix = unsafe_from_ordinal 0
pass

instance Eq(FenceAndPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
Expand Down Expand Up @@ -220,3 +220,11 @@ or the paper.
sum for i:(Fin 3=>Vocab).
ls_to_f $ ctc blank logits i
> 0.5653746


'One major advantage of Dex is its parallelism-preserving autodiff.
The original CTC paper, and most CUDA implementations, used hand-written
reverse-mode derivatives. Dex should be able to
prodice an efficient one automatically. Let's check:

-- grad (\logits. ls_to_f $ ctc blank logits labels) logits
110 changes: 100 additions & 10 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -874,27 +874,27 @@ instance Ix(Maybe a) given (a|Ix)
True -> Nothing

interface NonEmpty(n|Ix)
first_ix : n
pass

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)
pass

instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
first_ix = unsafe_from_ordinal 0
pass

-- The below instance is valid, but causes "multiple candidate dictionaries"
-- errors if both Left and Right are NonEmpty.
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
-- pass

instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0
pass

'## Fencepost index sets

Expand Down Expand Up @@ -924,11 +924,14 @@ def right_fence(p:Post n) -> Maybe n given (n|Ix) =
then Nothing
else Just $ unsafe_from_ordinal ix

def first_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(0)

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)
pass

def scan(
init:a,
Expand Down Expand Up @@ -1704,7 +1707,7 @@ def from_ordinal(i:Nat) -> n given (n|Ix) =
False -> error $ from_ordinal_error(i, size n)

-- TODO: should this be called `from_ordinal`?
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
case i < size n of
True -> Just $ unsafe_from_ordinal i
False -> Nothing
Expand Down Expand Up @@ -2266,6 +2269,93 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data)
Left( x) -> error "Can't project Left branch to Right branch"
Right(x) -> x

instance Subset(n=>a, n=>b) given (n|Ix, a|Data, b|Data) (Subset a b)
def inject'(xs) = for i. inject xs[i]
def project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> Nothing
False -> Just $ each xs from_just
def unsafe_project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> error "Couldn't project table."
False -> each xs from_just

-- add instance for subset n=>a m=>a given subset n m

instance Subset(List a, List b) given (a|Data, b|Data) (Subset a b)
def inject'(xs') =
AsList(n, xs) = xs'
AsList(n, inject xs)
def project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> Nothing
Just xs -> Just AsList(n, xs)
def unsafe_project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> error "Couldn't project list."
Just xs -> AsList(n, xs)

'### All but Last Index set
All the indices of the original index set except the last one.

struct AllButLast(n:Nat, a|Ix) =
val : a

instance Ix(AllButLast n a) given (n:Nat, a|Ix|Data)
def size'() = (size a) -| n
def ordinal(i) = ordinal i.val
def unsafe_from_ordinal(o) = AllButLast $ unsafe_from_ordinal o

instance Subset(AllButLast n a, a) given (n:Nat, a|Ix)
def inject'(x) = x.val
def project'(x) = case (ordinal x) < ((size a) -| n) of
True -> Just (AllButLast x)
False -> Nothing
def unsafe_project'(x) = (AllButLast x)

instance Eq(AllButLast n a) given (n:Nat, a|Eq|Ix)
def (==)(x, y) = x.val == y.val

def unsafe_increment(i:n) -> n given (n|Ix) = from_ordinal (ordinal i + 1)
def next(i: AllButLast 1 n) -> n given (n|Ix) = unsafe_increment i.val
def get_next_m(tab:n=>a, i:AllButLast m n) -> List a given (n|Ix, m:Nat, a) =
-- The list returned always has size (n - m), but can't spell that yet.
to_list $ for j:(Fin m). tab[unsafe_from_ordinal (ordinal i + ordinal j)]

'### Fence and Inner Posts
A custom datatype and index set
that interleaves the elements of a table with another set
of values representing all the spaces in between those elements,
not including the 2 ends.

data FenceAndInnerPosts(n|Ix) =
Fence(n)
InnerPost(AllButLast 1 n)

instance Ix(FenceAndInnerPosts n) given (n|Ix)
def size'() = 2 * size n -| 1
def ordinal(i) = case i of
Fence j -> 2 * ordinal j
InnerPost j -> 2 * ordinal j + 1
def unsafe_from_ordinal(o) =
case is_odd o of
False -> Fence $ unsafe_from_ordinal (intdiv2 o)
True -> InnerPost $ unsafe_from_ordinal (intdiv2 o)

instance Eq(FenceAndInnerPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
Fence x -> case y of
Fence y -> x == y
InnerPost y -> False
InnerPost x -> case y of
Fence y -> False
InnerPost y -> x == y


'### Index set for tables

def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) =
Expand All @@ -2291,7 +2381,7 @@ instance Ix(a=>b) given (a|Ix, b|Ix)
def unsafe_from_ordinal(i) = int_to_reversed_digits i

instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

'### Stack

Expand Down

0 comments on commit 8575c17

Please sign in to comment.