Skip to content

Commit

Permalink
feat: Vectorisation - of and vectorised application (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor authored Oct 31, 2024
1 parent be2b27c commit c8039aa
Show file tree
Hide file tree
Showing 28 changed files with 377 additions and 28 deletions.
108 changes: 105 additions & 3 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ module Brat.Checker (checkBody
) where

import Control.Arrow (first)
import Control.Monad (foldM)
import Control.Exception (assert)
import Control.Monad (foldM, zipWithM)
import Control.Monad.Freer
import Data.Bifunctor (second)
import Data.Functor (($>), (<&>))
Expand All @@ -17,6 +18,7 @@ import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import Data.Maybe (fromJust)
import Data.Traversable (for)
import Data.Type.Equality ((:~:)(..))
import Prelude hiding (filter)

Expand Down Expand Up @@ -385,8 +387,12 @@ check' (Arith op l r) ((), u@(hungry, ty):unders) = case (?my, ty) of
check' (fun :$: arg) (overs, unders) = do
((ins, outputs), ((), leftUnders)) <- check fun ((), unders)
((argIns, ()), (leftOvers, argUnders)) <- check arg (overs, ins)
ensureEmpty "leftover function args" argUnders
pure ((argIns, outputs), (leftOvers, leftUnders))
if null argUnders
then pure ((argIns, outputs), (leftOvers, leftUnders))
else typeErr $ unwords ["Expected function", show fun
,"to consume all of its arguments (" ++ show arg ++ ")\n"
,"but found leftovers:", showRow argUnders
]
check' (Let abs x y) conn = do
(((), dangling), ((), ())) <- check x ((), ())
env <- abstractAll dangling (unWC abs)
Expand Down Expand Up @@ -553,6 +559,102 @@ check' FanIn (overs, ((tgt, ty):unders)) = do
wire (danglingResult, binderToValue ?my ty, hungry)
faninNodes my (n - 1) (hungryTail, tailTy) elTy overs
check' Identity ((this:leftovers), ()) = pure (((), [this]), (leftovers, ()))
check' (Of n e) ((), unders) = case ?my of
Kerny -> typeErr $ "`of` not supported in kernel contexts"
Braty -> do
-- TODO: Our expectations about Id nodes in compilation might need updated?
(_, [(natUnder,Left k)], [(natOver, _)], _) <- anext "Of_len" Id (S0, Some (Zy :* S0))
(REx ("value", Nat) R0)
(REx ("value", Nat) R0)
([n], leftovers) <- kindCheck [(natUnder, k)] (unWC n)
defineSrc natOver n
ensureEmpty "" leftovers
case diry @d of
-- Get the largest prefix of unders whose types are vectors of the right length
Chky -> getVecs n unders >>= \case
-- If none of the unders have the right type, we should fail
([], [], _) -> let expected = if null unders then "empty row" else showRow unders in
typeErr $ unlines ["Got: Vector of length " ++ show n
,"Expected: " ++ expected]
(elemUnders, vecUnders, rightUnders) -> do
(Some (_ :* stk)) <- rowToRo ?my [ (portName tgt, tgt, Right ty) | (tgt, ty) <- elemUnders ] S0
case stk of
S0 -> do
(repConns, tgtMap) <- mkReplicateNodes n elemUnders
let (lenIns, repUnders, repOvers) = unzip3 repConns
-- Wire the length into all the replicate nodes
for lenIns $ \(tgt, _) -> do
wire (natOver, kindType Nat, tgt)
defineTgt tgt n
(((), ()), ((), elemRightUnders)) <- check e ((), repUnders)
-- If `elemRightUnders` isn't empty, it means we were too greedy
-- in the call to getVecs, so we should work out which elements of
-- the original unders weren't used, and make sure they prefix the
-- unders returned from here.
let unusedVecTgts :: [Tgt] = (fromJust . flip lookup tgtMap . fst) <$> elemRightUnders
let (usedVecUnders, unusedVecUnders) = splitAt (length vecUnders - length unusedVecTgts) vecUnders
-- Wire up the outputs of the replicate nodes to the _used_ vec
-- unders. The remainder of the replicate nodes don't get used.
-- (their inputs live in `elemRightUnders`)
assert (length repOvers >= length usedVecUnders) $ do
zipWithM (\(dangling, _) (hungry, ty) -> wire (dangling, ty, hungry)) repOvers usedVecUnders
pure (((), ()), ((), (second Right <$> unusedVecUnders) ++ rightUnders))

_ -> localFC (fcOf e) $ typeErr "No type dependency allowed when using `of`"
Syny -> do
(((), outputs), ((), ())) <- check e ((), ())
Some (_ :* stk) <- rowToRo ?my [(portName src, src, ty) | (src, ty) <- outputs] S0
case stk of
S0 -> do
-- Use of `outputs` and the map returned here are nonsensical, but we're
-- ignoring the map anyway
outputs <- getVals outputs
(conns, _) <- mkReplicateNodes n outputs
let (lenIns, elemIns, vecOuts) = unzip3 conns
for lenIns $ \(tgt,_) -> do
wire (natOver, kindType Nat, tgt)
defineTgt tgt n
zipWithM (\(dangling, ty) (hungry, _) -> wire (dangling, ty, hungry)) outputs elemIns
pure (((), vecOuts), ((), ()))
_ -> localFC (fcOf e) $ typeErr "No type dependency allowed when using `of`"
where
getVals :: [(t, BinderType Brat)] -> Checking [(t, Val Z)]
getVals [] = pure []
getVals ((t, Right ty):rest) = ((t, ty):) <$> getVals rest
getVals ((_, Left _):_) = localFC (fcOf e) $ typeErr "No type dependency allowed when using `of`"

mkReplicateNodes :: forall t
. ToEnd t
=> Val Z
-> [(t, Val Z)] -- The unders from getVec, only used for building the map
-> Checking ([((Tgt, BinderType Brat) -- The Tgt for the vector length
,(Tgt, BinderType Brat) -- The Tgt for the element
,(Src, BinderType Brat) -- The vectorised element output
)]
,[(Tgt, t)] -- A map from element tgts to the original vector tgts
)
mkReplicateNodes _ [] = pure ([], [])
mkReplicateNodes len ((t, ty):unders) = do
let weakTy = changeVar (Thinning (ThDrop ThNull)) ty
(_, [lenUnder, repUnder], [repOver], _) <- anext "replicate" Replicate (S0, Some (Zy :* S0))
(REx ("n", Nat) (RPr ("elem", weakTy) R0)) -- the type of e
(RPr ("vec", TVec weakTy (VApp (VInx VZ) B0)) R0) -- a vector of e's of length n??
(conns, tgtMap) <- mkReplicateNodes len unders
pure ((lenUnder, repUnder, repOver):conns, ((fst repUnder), t):tgtMap)

getVecs :: Val Z -- The length of vectors we're looking for
-> [(Tgt, BinderType Brat)]
-> Checking ([(Tgt, Val Z)] -- element types for which we need vecs of the given length
,[(Tgt, Val Z)] -- The vector type unders which we'll wire to
,[(Tgt, BinderType Brat)] -- Rightunders
)
getVecs len ((tgt, Right ty@(TVec el n)):unders) = eqTest "" Nat len n >>= \case
Left _ -> pure ([], [], (tgt, Right ty):unders)
Right () -> do
(elems, unders, rightUnders) <- getVecs len unders
pure ((tgt, el):elems, (tgt, ty):unders, rightUnders)
getVecs _ unders = pure ([], [], unders)

check' tm _ = error $ "check' " ++ show tm


Expand Down
110 changes: 104 additions & 6 deletions brat/Brat/Checker/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,21 @@ getThunks :: Modey m
,Overs m UVerb
)
getThunks _ [] = pure ([], [], [])
getThunks Braty row@((src, Right ty):rest) = eval S0 ty >>= \case
(VFun Braty (ss :->> ts)) -> do
getThunks Braty row@((src, Right ty):rest) = ((src,) <$> eval S0 ty) >>= vectorise >>= \case
(src, VFun Braty (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Braty in
anext "" (Eval (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Braty rest
pure (node:nodes, unders <> unders', overs <> overs')
(VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
-- These shouldn't happen
(_, VFun _ _) -> err $ ExpectedThunk (showMode Braty) (showRow row)
v -> typeErr $ "Force called on non-thunk: " ++ show v
getThunks Kerny row@((src, Right ty):rest) = eval S0 ty >>= \case
(VFun Kerny (ss :->> ts)) -> do
getThunks Kerny row@((src, Right ty):rest) = ((src,) <$> eval S0 ty) >>= vectorise >>= \case
(src, VFun Kerny (ss :->> ts)) -> do
(node, unders, overs, _) <- let ?my = Kerny in anext "" (Splice (end src)) (S0, Some (Zy :* S0)) ss ts
(nodes, unders', overs') <- getThunks Kerny rest
pure (node:nodes, unders <> unders', overs <> overs')
(VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
(_, VFun _ _) -> err $ ExpectedThunk (showMode Kerny) (showRow row)
v -> typeErr $ "Force called on non-(kernel)-thunk: " ++ show v
getThunks Braty ((src, Left (Star args)):rest) = do
(node, unders, overs) <- case bwdStack (B0 <>< args) of
Expand All @@ -283,6 +284,103 @@ getThunks Braty ((src, Left (Star args)):rest) = do
pure (node:nodes, unders <> unders', overs <> overs')
getThunks m ro = err $ ExpectedThunk (showMode m) (showRow ro)

-- The type given here should be normalised
vecLayers :: Val Z -> Checking ([(Src, NumVal (VVar Z))] -- The sizes of the vector layers
,Some (Modey :* Flip CTy Z) -- The function type at the end
)
vecLayers (TVec ty (VNum n)) = do
src <- mkStaticNum n
(layers, fun) <- vecLayers ty
pure ((src, n):layers, fun)
vecLayers (VFun my cty) = pure ([], Some (my :* Flip cty))
vecLayers ty = typeErr $ "Expected a function or vector of functions, got " ++ show ty

mkStaticNum :: NumVal (VVar Z) -> Checking Src
mkStaticNum n@(NumValue c gro) = do
(_, [], [(constSrc,_)], _) <- next "const" (Const (Num (fromIntegral c))) (S0, Some (Zy :* S0)) R0 (RPr ("value", TNat) R0)
src <- case gro of
Constant0 -> pure constSrc
StrictMonoFun sm -> do
(_, [(lhs,_),(rhs,_)], [(src,_)], _) <- next "add_const" (ArithNode Add) (S0, Some (Zy :* S0))
(RPr ("lhs", TNat) (RPr ("rhs", TNat) R0))
(RPr ("value", TNat) R0)
smSrc <- mkStrictMono sm
wire (constSrc, TNat, lhs)
wire (smSrc, TNat, rhs)
pure src
defineSrc src (VNum n)
pure src
where
mkStrictMono :: StrictMono (VVar Z) -> Checking Src
mkStrictMono (StrictMono k mono) = do
(_, [], [(constSrc,_)], _) <- next "2^k" (Const (Num (2^k))) (S0, Some (Zy :* S0)) R0 (RPr ("value", TNat) R0)
(_, [(lhs,_),(rhs,_)], [(src,_)], _) <- next "mult_const" (ArithNode Mul) (S0, Some (Zy :* S0))
(RPr ("lhs", TNat) (RPr ("rhs", TNat) R0))
(RPr ("value", TNat) R0)
monoSrc <- mkMono mono
wire (constSrc, TNat, lhs)
wire (monoSrc, TNat, rhs)
pure src

mkMono :: Monotone (VVar Z) -> Checking Src
mkMono (Linear (VPar (ExEnd e))) = pure (NamedPort e "mono")

Check warning on line 326 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 326 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 326 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive
mkMono (Full sm) = do
(_, [], [(twoSrc,_)], _) <- next "2" (Const (Num 2)) (S0, Some (Zy :* S0)) R0 (RPr ("value", TNat) R0)
(_, [(lhs,_),(rhs,_)], [(powSrc,_)], _) <- next "2^" (ArithNode Pow) (S0, Some (Zy :* S0))
(RPr ("lhs", TNat) (RPr ("rhs", TNat) R0))
(RPr ("value", TNat) R0)
smSrc <- mkStrictMono sm
wire (twoSrc, TNat, lhs)
wire (smSrc, TNat, rhs)

(_, [], [(oneSrc,_)], _) <- next "1" (Const (Num 1)) (S0, Some (Zy :* S0)) R0 (RPr ("value", TNat) R0)
(_, [(lhs,_),(rhs,_)], [(src,_)], _) <- next "n-1" (ArithNode Sub) (S0, Some (Zy :* S0))
(RPr ("lhs", TNat) (RPr ("rhs", TNat) R0))
(RPr ("value", TNat) R0)
wire (powSrc, TNat, lhs)
wire (oneSrc, TNat, rhs)
pure src

vectorise :: (Src, Val Z) -> Checking (Src, Val Z)
vectorise (src, ty) = do
(layers, Some (my :* Flip cty)) <- vecLayers ty
modily my $ mkMapFuns (src, VFun my cty) layers
where
mkMapFuns :: (Src, Val Z) -- The input to the mapfun
-> [(Src, NumVal (VVar Z))] -- Remaining layers
-> Checking (Src, Val Z)
mkMapFuns over [] = pure over
mkMapFuns (valSrc, ty) ((lenSrc, len):layers) = do
(valSrc, ty@(VFun my cty)) <- mkMapFuns (valSrc, ty) layers
let weak1 = changeVar (Thinning (ThDrop ThNull))

Check warning on line 355 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 355 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’

Check warning on line 355 in brat/Brat/Checker/Helpers.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘weak1’
vecFun <- vectorisedFun len my cty
(_, [(lenTgt,_), (valTgt, _)], [(vectorSrc, Right vecTy)], _) <-
next "" MapFun (S0, Some (Zy :* S0))
(REx ("len", Nat) (RPr ("value", weak1 ty) R0))
(RPr ("vector", weak1 vecFun) R0)
defineTgt lenTgt (VNum len)
wire (lenSrc, kindType Nat, lenTgt)
wire (valSrc, ty, valTgt)
pure (vectorSrc, vecTy)

vectorisedFun :: NumVal (VVar Z) -> Modey m -> CTy m Z -> Checking (Val Z)
vectorisedFun nv my (ss :->> ts) = do
(ss', ny) <- vectoriseRo True nv Zy ss
(ts', _) <- vectoriseRo False nv ny ts
pure $ modily my $ VFun my (ss' :->> ts')

-- We don't allow existentials in vectorised functions, so the boolean says
-- whether we are in the input row and can allow binding
vectoriseRo :: Bool -> NumVal (VVar Z) -> Ny i -> Ro m i j -> Checking (Ro m i j, Ny j)
vectoriseRo _ _ ny R0 = pure (R0, ny)
vectoriseRo True n ny (REx k ro) = do (ro', ny') <- vectoriseRo True n (Sy ny) ro
pure (REx k ro', ny')
vectoriseRo False _ _ (REx _ _) =
typeErr "Type variable binding not allowed in the output type of a vectorised function"
vectoriseRo b n ny (RPr (p, ty) ro) = do
(ro', ny') <- vectoriseRo b n ny ro
pure (RPr (p, TVec ty (VNum (changeVar (Thinning (thEmpty ny)) <$> n))) ro', ny')

binderToValue :: Modey m -> BinderType m -> Val Z
binderToValue Braty (Left k) = kindType k
binderToValue Braty (Right ty) = ty
Expand Down
6 changes: 6 additions & 0 deletions brat/Brat/Compile/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ compileWithInputs parent name = gets compiled <&> M.lookup name >>= \case
pure dfgId
ArithNode op -> default_edges <$> compileArithNode parent op (snd $ head ins)
Selector _c -> error "Todo: selector"
Replicate -> default_edges <$> do
ins <- compilePorts ins
let [_, elemTy] = ins
outs <- compilePorts outs
let sig = FunctionType ins outs
addNode "Replicate" (OpCustom (CustomOp parent "BRAT" "Replicate" sig [TAType elemTy]))
x -> error $ show x ++ " should have been compiled outside of compileNode"

compileConstructor :: NodeId -> UserName -> UserName -> FunctionType -> Compile NodeId
Expand Down
7 changes: 7 additions & 0 deletions brat/Brat/Elaborator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ assertChk s@(WC _ r) = case dir r of
deepEmb (WC fc (a ::-:: b)) = WC fc (a ::-:: deepEmb b)
deepEmb (WC fc (RLambda c cs)) = WC fc (RLambda ((id *** deepEmb) c) cs)
deepEmb (WC fc (RLet abs a b)) = WC fc (RLet abs a (deepEmb b))
deepEmb (WC fc (ROf num exp)) = WC fc (ROf num (deepEmb exp))
-- We like to avoid RTypedTh because the body doesn't know whether it's Brat or Kernel
deepEmb (WC fc (RTypedTh bdy)) = WC fc (RTh (WC fc $ RForget $ deepEmb bdy))
deepEmb (WC fc a) = WC fc (REmb (WC fc a))
Expand Down Expand Up @@ -179,6 +180,12 @@ elaborate' (FAnnotation a ts) = do
a <- assertNoun a
pure $ SomeRaw' (a ::::: ts)
elaborate' (FInto a b) = elaborate' (FApp b a)
elaborate' (FOf n e) = do
SomeRaw n <- elaborate n
n <- assertNoun =<< assertChk n
SomeRaw e <- elaborate e
e <- assertNoun e
pure $ SomeRaw' (ROf n e)
elaborate' (FFn cty) = pure $ SomeRaw' (RFn cty)
elaborate' (FKernel sty) = pure $ SomeRaw' (RKernel sty)
elaborate' FIdentity = pure $ SomeRaw' RIdentity
Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Graph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ data NodeType :: Mode -> Type where
Constructor :: UserName -> NodeType a
Selector :: UserName -> NodeType a -- TODO: Get rid of this in favour of pattern matching
ArithNode :: ArithOp -> NodeType Brat
Replicate :: NodeType Brat
MapFun :: NodeType a

deriving instance Show (NodeType a)

Expand Down
1 change: 1 addition & 0 deletions brat/Brat/Lexer/Flat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ keyword
<|> string "import" $> KImport
<|> string "let" $> KLet
<|> string "in" $> KIn
<|> string "of" $> KOf
) <* notFollowedBy identChar

identChar :: Lexer Char
Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Lexer/Token.hs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ data Keyword
| KImport
| KLet
| KIn
| KOf
deriving Eq

instance Show Keyword where
Expand All @@ -122,6 +123,7 @@ instance Show Keyword where
show KImport = "import"
show KLet = "let"
show KIn = "in"
show KOf = "of"

tokLen :: Tok -> Int
tokLen = length . show
Expand Down
9 changes: 9 additions & 0 deletions brat/Brat/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ cthunk = try bratFn <|> try kernel <|> thunk
; (left-assoc)
, & port-pull
-, ,- =,= =,_,= =%= (vector builders) (all right-assoc (for now!) and same precedence)
_of_ (right-assoc)
+ - (left-assoc)
* / (left-assoc)
^ (left-assoc)
Expand All @@ -386,6 +387,7 @@ expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr]
PComp -> composition <?> "composition"
PJuxtPull -> pullAndJuxt <?> "juxtaposition"
PVecPat -> vectorBuild <?> "vector pattern"
POf -> ofExpr <?> "vectorisation"
PAddSub -> addSub <?> "addition or subtraction"
PMulDiv -> mulDiv <?> "multiplication or division"
PPow -> pow <?> "power"
Expand Down Expand Up @@ -414,6 +416,13 @@ expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr]
pure (FCon c (mkJuxt (args ++ [rhs])))
Nothing -> pure (unWC lhs)

ofExpr :: Parser Flat
ofExpr = do
lhs <- withFC (subExpr POf)
optional (kmatch KOf) >>= \case
Nothing -> pure (unWC lhs)
Just () -> FOf lhs <$> (withFC ofExpr)

mkJuxt [x] = x

Check warning on line 426 in brat/Brat/Parser.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 426 in brat/Brat/Parser.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive

Check warning on line 426 in brat/Brat/Parser.hs

View workflow job for this annotation

GitHub Actions / build

Pattern match(es) are non-exhaustive
mkJuxt (x:xs) = let rest = mkJuxt xs in WC (FC (start (fcOf x)) (end (fcOf rest))) (FJuxt x rest)

Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Syntax/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Brat.Syntax.Common (PortName,
Dir(..),
Kind(..),
Diry(..),
DIRY(..),
Kindy(..),
CType'(..),
Import(..),
Expand Down Expand Up @@ -208,6 +209,7 @@ data Precedence
| PComp
| PJuxtPull -- Juxtaposition has the same precedence as port pulling
| PVecPat
| POf
| PAddSub
| PMulDiv
| PPow
Expand Down
1 change: 1 addition & 0 deletions brat/Brat/Syntax/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ data Flat
| FFanOut
| FFanIn
| FIdentity
| FOf ({- number :: -}WC Flat) {- of -} ({- expr -}WC Flat)
deriving Show
Loading

0 comments on commit c8039aa

Please sign in to comment.