Skip to content

Commit

Permalink
specialize: bool
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Oct 3, 2023
1 parent b646b21 commit c568503
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
40 changes: 26 additions & 14 deletions src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,22 @@ isSpecializable tab node =
-- | Check for `h a1 .. an` where `h` is an identifier explicitly marked for
-- specialisation with `specialize: true`.
isMarkedSpecializable :: InfoTable -> Node -> Bool
isMarkedSpecializable tab node =
let (h, _) = unfoldApps' node
in case h of
NIdt Ident {..}
| Just (PragmaSpecialise True) <-
lookupIdentifierInfo tab _identSymbol
^. identifierPragmas . pragmasSpecialise ->
True
_ ->
False
isMarkedSpecializable tab = \case
NTyp TypeConstr {..}
| Just (PragmaSpecialise True) <-
lookupInductiveInfo tab _typeConstrSymbol
^. inductivePragmas . pragmasSpecialise ->
True
node ->
let (h, _) = unfoldApps' node
in case h of
NIdt Ident {..}
| Just (PragmaSpecialise True) <-
lookupIdentifierInfo tab _identSymbol
^. identifierPragmas . pragmasSpecialise ->
True
_ ->
False

-- | Checks if an argument is passed without modification to recursive calls.
isArgSpecializable :: InfoTable -> Symbol -> Int -> Bool
Expand Down Expand Up @@ -87,6 +93,7 @@ convertNode = dmapLRM go

goIdentApp :: BinderList Binder -> Ident -> [Node] -> Sem r Recur
goIdentApp bl idt@Ident {..} args = do
args' <- mapM (dmapLRM' (bl, go)) args
tab <- getInfoTable
let ii = lookupIdentifierInfo tab _identSymbol
pspec = ii ^. identifierPragmas . pragmasSpecialiseArgs
Expand All @@ -97,16 +104,21 @@ convertNode = dmapLRM go
(lams, body) = unfoldLambdas def
argnames = map (^. lambdaLhsBinder . binderName) lams

-- arguments marked for specialisation with `specialize: true`
psargs0 =
map fst3 $
filter (\(_, arg, ty) -> isMarkedSpecializable tab arg || isMarkedSpecializable tab ty) $
zip3 [1 .. argsNum] args' tyargs

getArgIndex :: PragmaSpecialiseArg -> Maybe Int
getArgIndex = \case
SpecialiseArgNum i -> Just i
SpecialiseArgNamed x -> fmap (+ 1) $ x `elemIndex` argnames
if
| (isJust pspec || isJust pspecby) && length args == argsNum -> do
args' <- mapM (dmapLRM' (bl, go)) args
| (isJust pspec || isJust pspecby || not (null psargs0)) && length args == argsNum -> do
let psargs1 = mapMaybe getArgIndex $ maybe [] (^. pragmaSpecialiseArgs) pspec
psargs2 = maybe [] (map (+ 1) . mapMaybe (`elemIndex` argnames) . (^. pragmaSpecialiseBy)) pspecby
psargs = nubSort (psargs1 ++ psargs2)
psargs = nubSort (psargs0 ++ psargs1 ++ psargs2)
-- assumption: all type variables are at the front
let specargs0 =
filter
Expand Down Expand Up @@ -223,7 +235,7 @@ convertNode = dmapLRM go
node'' <- lambdaLiftNode' True bl node'
return $ End node''
| otherwise ->
return $ Recur $ mkApps' (NIdt idt) args
return $ End $ mkApps' (NIdt idt) args'

-- assumption: all type arguments are substituted, so no binders in the type
-- list refer to other elements in the list
Expand Down
23 changes: 22 additions & 1 deletion tests/Compilation/positive/test056.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ funa : {A : Type} -> (A -> A) -> A -> A
| (suc n) := f (go n);
in go 10;

{-# specialize: true #-}
type Additive A := mkAdditive {add : A -> A -> A};

type Multiplicative A :=
mkMultiplicative {mul : A -> A -> A};

addNat : Additive Nat := mkAdditive (+);

{-# specialize: true #-}
mulNat : Multiplicative Nat := mkMultiplicative (*);

{-# inline: false #-}
fadd {A} (a : Additive A) (x y : A) : A :=
Additive.add a x y;

{-# inline: false #-}
fmul {A} (m : Multiplicative A) (x y : A) : A :=
Multiplicative.mul m x y;

main : Nat :=
sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil))
+ sum
Expand All @@ -47,4 +66,6 @@ main : Nat :=
+ myf 3 (*) 2 5 true
+ myf 1 (+) 2 0 false
+ myf' 7 (const (+)) 2 0
+ funa ((+) 1) 5;
+ funa ((+) 1) 5
+ fadd addNat 1 2
+ fmul mulNat 1 2;

0 comments on commit c568503

Please sign in to comment.