Skip to content

Commit

Permalink
feat: Add fanin and fanout operators (#36)
Browse files Browse the repository at this point in the history
[See discussion here
](CQCL-DEV/brat#312)

I've added tests that we can use these operators as a function (i.e.
`[\/](x, y, z)`) and a bunch more tests, as well as some drive-by error
formatting
  • Loading branch information
croyzor authored Oct 14, 2024
1 parent e05a15e commit 025ee06
Show file tree
Hide file tree
Showing 34 changed files with 237 additions and 9 deletions.
79 changes: 73 additions & 6 deletions brat/Brat/Checker.hs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ checkInputs tm@(WC fc _) (o:overs) (u:unders) = localFC fc $ do
addRowContext _ as bs (Err fc (TypeMismatch tm _ _))
= Err fc $ TypeMismatch tm (showRow as) (showRow bs)
addRowContext _ _ _ e = e
checkInputs tm [] unders = typeErr $ "No overs but unders: " ++ show unders ++ " for " ++ show tm
checkInputs tm [] unders = typeErr $ "No overs but unders: " ++ showRow unders ++ " for " ++ show tm

checkOutputs :: (CheckConstraints m k, ?my :: Modey m)
=> WC (Term Syn k)
Expand All @@ -160,7 +160,7 @@ checkOutputs tm@(WC fc _) (u:unders) (o:overs) = localFC fc $ do
addRowContext _ as bs (Err fc (TypeMismatch tm _ _))
= Err fc $ TypeMismatch tm (showRow as) (showRow bs)
addRowContext _ _ _ e = e
checkOutputs tm [] overs = typeErr $ "No unders but overs: " ++ show overs ++ " for " ++ show tm
checkOutputs tm [] overs = typeErr $ "No unders but overs: " ++ showRow overs ++ " for " ++ show tm

checkThunk :: (CheckConstraints m UVerb, EvMode m)
=> Modey m
Expand All @@ -171,9 +171,13 @@ checkThunk :: (CheckConstraints m UVerb, EvMode m)
checkThunk m name cty tm = do
((dangling, _), ()) <- let ?my = m in makeBox name cty $
\(thOvers, thUnders) -> do
(((), ()), (emptyOvers, emptyUnders)) <- check tm (thOvers, thUnders)
ensureEmpty "thunk leftovers" emptyOvers
ensureEmpty "thunk leftunders" emptyUnders
(((), ()), leftovers) <- check tm (thOvers, thUnders)
case leftovers of
([], []) -> pure ()
([], unders) -> err (ThunkLeftUnders (showRow unders))
-- If there are leftovers and leftunders, complain about the leftovers
-- Until we can report multiple errors!
(overs, _) -> err (ThunkLeftOvers (showRow overs))
pure dangling

check :: (CheckConstraints m k
Expand Down Expand Up @@ -252,7 +256,7 @@ check' (Lambda c@(WC abstFC abst, body) cs) (overs, unders) = do
solve ?my >>=
(solToEnv . snd)
(((), synthOuts), ((), ())) <- localEnv env $ check body ((), ())
pure synthOuts
pure synthOuts

sig <- mkSig usedOvers synthOuts
patOuts <- checkClauses sig usedOvers ((fst c, WC (fcOf body) (Emb body)) :| cs)
Expand Down Expand Up @@ -485,6 +489,69 @@ check' (Simple tm) ((), ((hungry, ty):unders)) = do
R0 (RPr ("value", vty) R0)
wire (dangling, vty, hungry)
pure (((), ()), ((), unders))
check' FanOut ((p, ty):overs, ()) = do
ty <- eval S0 (binderToValue ?my ty)
case ty of
TVec elTy n
| VNum n <- n
, Just n <- numValIsConstant n ->
if n < 0
then err (InternalError $ "Vector of negative length (" ++ show n ++ ")")
else do
wires <- fanoutNodes ?my n (p, valueToBinder ?my ty) elTy
pure (((), wires), (overs, ()))
| otherwise -> typeErr $ "Can't fanout a Vec with non-constant length: " ++ show n
_ -> typeErr "Fanout ([/\\]) only applies to Vec"
where
fanoutNodes :: Modey m -> Integer -> (Src, BinderType m) -> Val Z -> Checking [(Src, BinderType m)]
fanoutNodes _ 0 _ _ = pure []
fanoutNodes my n (dangling, ty) elTy = do
(_, [(hungry, _)], [danglingHead, danglingTail], _) <- anext "fanoutNodes" (Selector (plain "cons")) (S0, Some (Zy :* S0))
(RPr ("value", binderToValue my ty) R0)
((RPr ("head", elTy) (RPr ("tail", TVec elTy (VNum (nConstant (n - 1)))) R0)) :: Ro m Z Z)
-- Wire the input into the selector node
wire (dangling, binderToValue my ty, hungry)
(danglingHead:) <$> fanoutNodes my (n - 1) danglingTail elTy

check' FanIn (overs, ((tgt, ty):unders)) = do
ty <- eval S0 (binderToValue ?my ty)
case ty of
TVec elTy n
| VNum n <- n
, Just n <- numValIsConstant n ->
if n < 0
then err (InternalError $ "Vector of negative length (" ++ show n ++ ")")
else faninNodes ?my n (tgt, valueToBinder ?my ty) elTy overs >>= \case
Just overs -> pure (((), ()), (overs, unders))
Nothing -> typeErr ("Not enough inputs to make a vector of size " ++ show n)
| otherwise -> typeErr $ "Can't fanout a Vec with non-constant length: " ++ show n
_ -> typeErr "Fanin ([\\/]) only applies to Vec"
where
faninNodes :: Modey m
-> Integer -- The number of things left to pack up
-> (Tgt, BinderType m) -- The place to wire the resulting vector to
-> Val Z -- Element type
-> [(Src, BinderType m)] -- Overs
-> Checking (Maybe [(Src, BinderType m)]) -- Leftovers
faninNodes my 0 (tgt, ty) elTy overs = do
(_, _, [(dangling, _)], _) <- anext "nil" (Constructor (plain "nil")) (S0, Some (Zy :* S0))
(R0 :: Ro m Z Z)
(RPr ("value", TVec elTy (VNum nZero)) R0)
wire (dangling, binderToValue my ty, tgt)
pure (Just overs)
faninNodes _ _ _ _ [] = pure Nothing
faninNodes my n (hungry, ty) elTy ((over, overTy):overs) = do
let k = case my of
Kerny -> Dollar []
Braty -> Star []
typeEq (show FanIn) k elTy (binderToValue my overTy)
let tailTy = TVec elTy (VNum (nConstant (n - 1)))
(_, [(hungryHead, _), (hungryTail, tailTy)], [(danglingResult, _)], _) <- anext "faninNodes" (Constructor (plain "cons")) (S0, Some (Zy :* S0))
((RPr ("head", elTy) (RPr ("tail", tailTy) R0)) :: Ro m Z Z)
(RPr ("value", binderToValue my ty) R0)
wire (over, elTy, hungryHead)
wire (danglingResult, binderToValue ?my ty, hungry)
faninNodes my (n - 1) (hungryTail, tailTy) elTy overs
check' Identity ((this:leftovers), ()) = pure (((), [this]), (leftovers, ()))
check' tm _ = error $ "check' " ++ show tm

Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Elaborator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ elaborate' FIdentity = pure $ SomeRaw' RIdentity
-- We catch underscores in the top-level elaborate so this case
-- should never be triggered
elaborate' FUnderscore = Left (dumbErr (InternalError "Unexpected '_'"))
elaborate' FFanOut = pure $ SomeRaw' RFanOut
elaborate' FFanIn = pure $ SomeRaw' RFanIn

elabBody :: FBody -> FC -> Either Error (FunBody Raw Noun)
elabBody (FClauses cs) fc = ThunkOf . WC fc . Clauses <$> traverse elab1Clause cs
Expand Down
7 changes: 7 additions & 0 deletions brat/Brat/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ data ErrorMsg
| WrongModeForType String
-- TODO: Add file context here
| CompilingHoles [String]
-- For thunks which don't address enough inputs, or produce enough outputs.
-- The argument is the row of unused connectors
| ThunkLeftOvers String
| ThunkLeftUnders String

instance Show ErrorMsg where
show (TypeErr x) = "Type error: " ++ x
Expand Down Expand Up @@ -164,6 +168,9 @@ instance Show ErrorMsg where
show (CompilingHoles hs) = unlines ("Can't compile file with remaining holes": indent hs)
where
indent = fmap (" " ++)

Check warning on line 170 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’

Check warning on line 170 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’

Check warning on line 170 in brat/Brat/Error.hs

View workflow job for this annotation

GitHub Actions / build

• The Monomorphism Restriction applies to the binding for ‘indent’
show (ThunkLeftOvers overs) = "Expected function to address all inputs, but " ++ overs ++ " wasn't used"
show (ThunkLeftUnders unders) = "Expected function to return additional values of type: " ++ unders


data Error = Err { fc :: Maybe FC
, msg :: ErrorMsg
Expand Down
5 changes: 4 additions & 1 deletion brat/Brat/Eval.hs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,10 @@ eqWorker tm lvkz (TypeFor _ []) (SSum m0 stk0 rs0) (SSum m1 stk1 rs1)
Just rs -> traverse eqVariant rs <&> sequence_
where
eqVariant (Some r0, Some r1) = eqRowTest m0 tm lvkz (stk0,r0) (stk1,r1) <&> dropRight
eqWorker tm _ _ v0 v1 = pure . Left $ TypeMismatch tm (show v0) (show v1)
eqWorker tm _ _ s0 s1 = do
v0 <- quote Zy s0
v1 <- quote Zy s1
pure . Left $ TypeMismatch tm (show v0) (show v1)

-- Type rows have bot0,bot1 dangling de Bruijn indices, which we instantiate with
-- de Bruijn levels. As we go under binders in these rows, we add to the scope's
Expand Down
4 changes: 2 additions & 2 deletions brat/Brat/Lexer/Flat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ space = (many $ (satisfy isSpace >> return ()) <|> comment) >> return ()
comment = string "--" *> ((printChar `manyTill` lookAhead (void newline <|> void eof)) >> return ())

tok :: Lexer Tok
tok = ( try (char '(' $> LParen)
tok = try (char '(' $> LParen)
<|> try (char ')' $> RParen)
<|> try (char '{' $> LBrace)
<|> try (char '}' $> RBrace)
Expand All @@ -62,6 +62,7 @@ tok = ( try (char '(' $> LParen)
<|> try (Number <$> number)
<|> try (string "+" $> Plus)
<|> try (string "/" $> Slash)
<|> try (string "\\" $> Backslash)
<|> try (string "^" $> Caret)
<|> try (string "->") $> Arrow
<|> try (string "=>") $> FatArrow
Expand Down Expand Up @@ -89,7 +90,6 @@ tok = ( try (char '(' $> LParen)
<|> try (K <$> try keyword)
<|> try qualified
<|> Ident <$> ident
)
where
float :: Lexer Double
float = label "float literal" $ do
Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Lexer/Token.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ data Tok
| Plus
| Minus
| Asterisk
| Backslash
| Slash
| Caret
| Hash
Expand Down Expand Up @@ -80,6 +81,7 @@ instance Show Tok where
show Plus = "+"
show Minus = "-"
show Asterisk = "*"
show Backslash = "\\"
show Slash = "/"
show Caret = "^"
show Hash = "#"
Expand Down
5 changes: 5 additions & 0 deletions brat/Brat/Parser.hs
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,17 @@ expr' p = choice $ (try . getParser <$> enumFrom p) ++ [atomExpr]
Nothing -> unWC expr
Just rest -> FJuxt expr rest

fanout = square (FFanOut <$ match Slash <* match Backslash)
fanin = square (FFanIn <$ match Backslash <* match Slash)

-- Expressions which don't contain juxtaposition or operators
atomExpr :: Parser Flat
atomExpr = simpleExpr <|> round expr
where
simpleExpr = FHole <$> hole
<|> try (FSimple <$> simpleTerm)
<|> try fanout
<|> try fanin
<|> vec
<|> cthunk
<|> try (match DotDot $> FPass)
Expand Down
2 changes: 2 additions & 0 deletions brat/Brat/Syntax/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,7 @@ data Flat
| FKernel RawKType
| FUnderscore
| FPass
| FFanOut
| FFanIn
| FIdentity
deriving Show
5 changes: 5 additions & 0 deletions brat/Brat/Syntax/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ data Term :: Dir -> Kind -> Type where
C :: CType' (PortName, KindOr (Term Chk Noun)) -> Term Chk Noun
-- Kernel types
K :: CType' (PortName, Term Chk Noun) -> Term Chk Noun
FanOut :: Term Syn UVerb
FanIn :: Term Chk UVerb

deriving instance Eq (Term d k)

Expand Down Expand Up @@ -128,6 +130,9 @@ instance Show (Term d k) where

show (C f) = "{" ++ show f ++ "}"
show (K (ss :-> ts)) = "{" ++ showSig ss ++ " -o " ++ showSig ts ++ "}"
show FanOut = "[/\\]"
show FanIn = "[\\/]"


-- Wrap a term in brackets if its `precedence` is looser than `n`
bracket :: Precedence -> WC (Term d k) -> String
Expand Down
6 changes: 6 additions & 0 deletions brat/Brat/Syntax/Raw.hs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ data Raw :: Dir -> Kind -> Type where
RFn :: RawCType -> Raw Chk Noun
-- Kernel types
RKernel :: RawKType -> Raw Chk Noun
RFanOut :: Raw Syn UVerb
RFanIn :: Raw Chk UVerb

class Dirable d where
dir :: Raw d k -> Diry d
Expand Down Expand Up @@ -123,6 +125,8 @@ instance Show (Raw d k) where
show (RCon c xs) = "Con(" ++ show c ++ "(" ++ show xs ++ "))"
show (RFn cty) = show cty
show (RKernel cty) = show cty
show RFanOut = "[/\\]"
show RFanIn = "[\\/]"

type Desugar = StateT Namespace (ReaderT (RawEnv, Bwd UserName) (Except Error))

Expand Down Expand Up @@ -238,6 +242,8 @@ instance (Kindable k) => Desugarable (Raw d k) where
desugar' (RCon c arg) = Con c <$> desugar arg
desugar' (RFn cty) = C <$> desugar' cty
desugar' (RKernel cty) = K <$> desugar' cty
desugar' RFanOut = pure FanOut
desugar' RFanIn = pure FanIn

instance Desugarable ty => Desugarable (PortName, ty) where
type Desugared (PortName, ty) = (PortName, Desugared ty)
Expand Down
8 changes: 8 additions & 0 deletions brat/Brat/Syntax/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,11 @@ copyable (VApp _ _) = Just False
copyable (TVec elem _) = copyable elem
copyable TBit = Just True
copyable _ = Nothing

stkLen :: Stack Z t tot -> Ny tot
stkLen S0 = Zy
stkLen (zx :<< _) = Sy (stkLen zx)

numValIsConstant :: NumVal (VVar Z) -> Maybe Integer
numValIsConstant (NumValue up Constant0) = pure up
numValIsConstant _ = Nothing
4 changes: 4 additions & 0 deletions brat/Brat/Unelaborator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ unelab dy _ (Lambda (abs,rhs) cs) = FLambda ((abs, unelab dy Nouny <$> rhs) :| (
unelab _ _ (Con c args) = FCon c (unelab Chky Nouny <$> args)
unelab _ _ (C (ss :-> ts)) = FFn (toRawRo ss :-> toRawRo ts)
unelab _ _ (K cty) = FKernel $ fmap (\(p, ty) -> Named p (toRaw ty)) cty
unelab _ _ FanIn = FFanIn
unelab _ _ FanOut = FFanOut

-- This is needed for concrete terms which embed a type as a list of `Raw` things
toRaw :: Term d k -> Raw d k
Expand All @@ -61,6 +63,8 @@ toRaw (Lambda (abs,rhs) cs) = RLambda (abs, toRaw <$> rhs) (second (fmap toRaw)
toRaw (Con c args) = RCon c (toRaw <$> args)
toRaw (C (ss :-> ts)) = RFn (toRawRo ss :-> toRawRo ts)
toRaw (K cty) = RKernel $ (\(p, ty) -> Named p (toRaw ty)) <$> cty
toRaw FanIn = RFanIn
toRaw FanOut = RFanOut

toRawRo :: [(PortName, KindOr (Term Chk Noun))] -> [TypeRowElem (KindOr RawVType)]
toRawRo = fmap (\(p, bty) -> Named p (second toRaw bty))
30 changes: 30 additions & 0 deletions brat/examples/fanout.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
open import lib.kernel (CX)

fanout(Vec(Nat, 3)) -> Nat, Nat, Nat
fanout = { [/\] }

swap(X :: $, Y :: $) -> { X, Y -o Y, X }
swap(_, _) = { x, y => y, x }

test(Vec(Qubit, 2)) -o Vec(Qubit, 2)
test = { [/\]; CX; [\/] }

as_fn :: Vec(Nat, 3)
as_fn = [\/](fanout([1,2,3]))

swap_vec(Vec(Qubit, 2)) -o Vec(Qubit, 2)
swap_vec(qs) = [\/](swap(Qubit, Qubit)([/\](qs)))

pack_mid(Nat, Nat, Nat, Nat) -> Nat, Vec(Nat, 2), Nat
pack_mid = { (x=>x), [\/], .. }

ext "" f :: { Vec(Qubit, 2) -o Bit }

g(Qubit, Qubit) -o Bit
g = [\/]; f

poly(X :: *) -> { Vec(X, 3) -> X, X, X }
poly(_) = { [/\] }

poly2(X :: $) -> { X, X, X -o Vec(X, 3) }
poly2(_) = { [\/] }
1 change: 1 addition & 0 deletions brat/test/Test/Compile/Hugr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ nonCompilingExamples = (expectedCheckingFails ++ expectedParsingFails ++
,"qft"
,"test"
,"vector"
,"fanout" -- Contains Selectors
-- Conjecture: These examples don't compile because number patterns in type
-- signatures causes `kindCheck` to call `abstract`, creating "Selector"
-- nodes, which we don't attempt to compile because we want to get rid of them
Expand Down
2 changes: 2 additions & 0 deletions brat/test/golden/error/fanin-diff-types.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(Qubit, Bit) -o Vec(Qubit, 2)
f = { [\/] }
9 changes: 9 additions & 0 deletions brat/test/golden/error/fanin-diff-types.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Error in test/golden/error/fanin-diff-types.brat@FC {start = Pos {line = 2, col = 5}, end = Pos {line = 2, col = 13}}:
f = { [\/] }
^^^^^^^^

Type mismatch when checking [\/]
Expected: Qubit
But got: Bit


2 changes: 2 additions & 0 deletions brat/test/golden/error/fanin-dynamic-length.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(n :: #) -> { Qubit, Qubit -o Vec(Qubit, n) }
f(n) = { [\/] }
6 changes: 6 additions & 0 deletions brat/test/golden/error/fanin-dynamic-length.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Error in test/golden/error/fanin-dynamic-length.brat@FC {start = Pos {line = 2, col = 8}, end = Pos {line = 2, col = 16}}:
f(n) = { [\/] }
^^^^^^^^

Type error: Can't fanout a Vec with non-constant length: VPar Ex checking_check_defs_1_f_f.box_2_lambda_fake_source 0

2 changes: 2 additions & 0 deletions brat/test/golden/error/fanin-list.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(Bool, Bool) -> List(Bool)
f = { [\/] }
6 changes: 6 additions & 0 deletions brat/test/golden/error/fanin-list.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Error in test/golden/error/fanin-list.brat@FC {start = Pos {line = 2, col = 5}, end = Pos {line = 2, col = 13}}:
f = { [\/] }
^^^^^^^^

Type error: Fanin ([\/]) only applies to Vec

2 changes: 2 additions & 0 deletions brat/test/golden/error/fanin-not-enough-overs.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(Nat, Nat) -> Vec(Nat, 3)
f = { [\/] }
6 changes: 6 additions & 0 deletions brat/test/golden/error/fanin-not-enough-overs.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Error in test/golden/error/fanin-not-enough-overs.brat@FC {start = Pos {line = 2, col = 5}, end = Pos {line = 2, col = 13}}:
f = { [\/] }
^^^^^^^^

Type error: Not enough inputs to make a vector of size 3

2 changes: 2 additions & 0 deletions brat/test/golden/error/fanin-too-many-overs.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(x :: Nat, y :: Nat, z :: Nat) -> Vec(Nat, 2)
f = { [\/] }
6 changes: 6 additions & 0 deletions brat/test/golden/error/fanin-too-many-overs.brat.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Error in test/golden/error/fanin-too-many-overs.brat@FC {start = Pos {line = 2, col = 5}, end = Pos {line = 2, col = 13}}:
f = { [\/] }
^^^^^^^^

Expected function to address all inputs, but (z :: Nat) wasn't used

2 changes: 2 additions & 0 deletions brat/test/golden/error/fanout-diff-types.brat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
f(Vec(Qubit, 2)) -o Qubit, Bit
f = { [/\] }
Loading

0 comments on commit 025ee06

Please sign in to comment.