diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs index 7dc3d0d585..f5ce4c00a5 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs @@ -28,16 +28,22 @@ isSpecializable tab node = -- | Check for `h a1 .. an` where `h` is an identifier explicitly marked for -- specialisation with `specialize: true`. isMarkedSpecializable :: InfoTable -> Node -> Bool -isMarkedSpecializable tab node = - let (h, _) = unfoldApps' node - in case h of - NIdt Ident {..} - | Just (PragmaSpecialise True) <- - lookupIdentifierInfo tab _identSymbol - ^. identifierPragmas . pragmasSpecialise -> - True - _ -> - False +isMarkedSpecializable tab = \case + NTyp TypeConstr {..} + | Just (PragmaSpecialise True) <- + lookupInductiveInfo tab _typeConstrSymbol + ^. inductivePragmas . pragmasSpecialise -> + True + node -> + let (h, _) = unfoldApps' node + in case h of + NIdt Ident {..} + | Just (PragmaSpecialise True) <- + lookupIdentifierInfo tab _identSymbol + ^. identifierPragmas . pragmasSpecialise -> + True + _ -> + False -- | Checks if an argument is passed without modification to recursive calls. isArgSpecializable :: InfoTable -> Symbol -> Int -> Bool @@ -87,6 +93,7 @@ convertNode = dmapLRM go goIdentApp :: BinderList Binder -> Ident -> [Node] -> Sem r Recur goIdentApp bl idt@Ident {..} args = do + args' <- mapM (dmapLRM' (bl, go)) args tab <- getInfoTable let ii = lookupIdentifierInfo tab _identSymbol pspec = ii ^. identifierPragmas . pragmasSpecialiseArgs @@ -97,16 +104,21 @@ convertNode = dmapLRM go (lams, body) = unfoldLambdas def argnames = map (^. lambdaLhsBinder . binderName) lams + -- arguments marked for specialisation with `specialize: true` + psargs0 = + map fst3 $ + filter (\(_, arg, ty) -> isMarkedSpecializable tab arg || isMarkedSpecializable tab ty) $ + zip3 [1 .. argsNum] args' tyargs + getArgIndex :: PragmaSpecialiseArg -> Maybe Int getArgIndex = \case SpecialiseArgNum i -> Just i SpecialiseArgNamed x -> fmap (+ 1) $ x `elemIndex` argnames if - | (isJust pspec || isJust pspecby) && length args == argsNum -> do - args' <- mapM (dmapLRM' (bl, go)) args + | (isJust pspec || isJust pspecby || not (null psargs0)) && length args == argsNum -> do let psargs1 = mapMaybe getArgIndex $ maybe [] (^. pragmaSpecialiseArgs) pspec psargs2 = maybe [] (map (+ 1) . mapMaybe (`elemIndex` argnames) . (^. pragmaSpecialiseBy)) pspecby - psargs = nubSort (psargs1 ++ psargs2) + psargs = nubSort (psargs0 ++ psargs1 ++ psargs2) -- assumption: all type variables are at the front let specargs0 = filter @@ -223,7 +235,7 @@ convertNode = dmapLRM go node'' <- lambdaLiftNode' True bl node' return $ End node'' | otherwise -> - return $ Recur $ mkApps' (NIdt idt) args + return $ End $ mkApps' (NIdt idt) args' -- assumption: all type arguments are substituted, so no binders in the type -- list refer to other elements in the list diff --git a/tests/Compilation/positive/test056.juvix b/tests/Compilation/positive/test056.juvix index 2bb06fc7c1..2981ae3b5d 100644 --- a/tests/Compilation/positive/test056.juvix +++ b/tests/Compilation/positive/test056.juvix @@ -37,6 +37,25 @@ funa : {A : Type} -> (A -> A) -> A -> A | (suc n) := f (go n); in go 10; +{-# specialize: true #-} +type Additive A := mkAdditive {add : A -> A -> A}; + +type Multiplicative A := + mkMultiplicative {mul : A -> A -> A}; + +addNat : Additive Nat := mkAdditive (+); + +{-# specialize: true #-} +mulNat : Multiplicative Nat := mkMultiplicative (*); + +{-# inline: false #-} +fadd {A} (a : Additive A) (x y : A) : A := + Additive.add a x y; + +{-# inline: false #-} +fmul {A} (m : Multiplicative A) (x y : A) : A := + Multiplicative.mul m x y; + main : Nat := sum (mymap λ {x := x + 3} (1 :: 2 :: 3 :: 4 :: nil)) + sum @@ -47,4 +66,6 @@ main : Nat := + myf 3 (*) 2 5 true + myf 1 (+) 2 0 false + myf' 7 (const (+)) 2 0 - + funa ((+) 1) 5; + + funa ((+) 1) 5 + + fadd addNat 1 2 + + fmul mulNat 1 2;