Skip to content

Commit

Permalink
Arithmetic simplification (#2454)
Browse files Browse the repository at this point in the history
Simplifies arithmetic expressions in the Core optimization phase,
changing e.g. `(x - 1) + 1` to `x`. Such expressions appear as a result
of compiling pattern matching on natural numbers.
  • Loading branch information
lukaszcz authored Oct 23, 2023
1 parent c1c2a06 commit cdfb35a
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyArithmetic
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
import Juvix.Compiler.Core.Transformation.Optimize.SpecializeArgs
Expand Down Expand Up @@ -53,7 +54,8 @@ optimize' CoreOptions {..} tab =

doSimplification :: Int -> InfoTable -> InfoTable
doSimplification n =
simplifyIfs' (_optOptimizationLevel <= 1)
simplifyArithmetic
. simplifyIfs' (_optOptimizationLevel <= 1)
. simplifyComparisons
. caseFolding
. casePermutation
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
module Juvix.Compiler.Core.Transformation.Optimize.SimplifyArithmetic (simplifyArithmetic) where

import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Transformation.Base

convertNode :: Node -> Node
convertNode = dmap go
where
go :: Node -> Node
go node = case node of
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntAdd,
[NBlt blt', n] <- _builtinAppArgs,
blt' ^. builtinAppOp == OpIntSub,
[x, m] <- blt' ^. builtinAppArgs,
m == n ->
x
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntSub,
[NBlt blt', n] <- _builtinAppArgs,
blt' ^. builtinAppOp == OpIntAdd,
[x, m] <- blt' ^. builtinAppArgs ->
if
| m == n ->
x
| x == n ->
m
| otherwise ->
node
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntAdd,
[n, NBlt blt'] <- _builtinAppArgs,
blt' ^. builtinAppOp == OpIntSub,
[x, m] <- blt' ^. builtinAppArgs,
m == n ->
x
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntAdd || _builtinAppOp == OpIntSub,
[x, NCst (Constant _ (ConstInteger 0))] <- _builtinAppArgs ->
x
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntAdd,
[NCst (Constant _ (ConstInteger 0)), x] <- _builtinAppArgs ->
x
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntMul,
[_, c@(NCst (Constant _ (ConstInteger 0)))] <- _builtinAppArgs ->
c
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntMul,
[c@(NCst (Constant _ (ConstInteger 0))), _] <- _builtinAppArgs ->
c
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntMul,
[x, NCst (Constant _ (ConstInteger 1))] <- _builtinAppArgs ->
x
NBlt BuiltinApp {..}
| _builtinAppOp == OpIntMul,
[NCst (Constant _ (ConstInteger 1)), x] <- _builtinAppArgs ->
x
_ -> node

simplifyArithmetic :: InfoTable -> InfoTable
simplifyArithmetic = mapAllNodes convertNode
7 changes: 6 additions & 1 deletion test/Compilation/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -382,5 +382,10 @@ tests =
"Test064: Constant folding"
$(mkRelDir ".")
$(mkRelFile "test064.juvix")
$(mkRelFile "out/test064.out")
$(mkRelFile "out/test064.out"),
posTest
"Test065: Arithmetic simplification"
$(mkRelDir ".")
$(mkRelFile "test065.juvix")
$(mkRelFile "out/test065.out")
]
1 change: 1 addition & 0 deletions tests/Compilation/positive/out/test065.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
42
15 changes: 15 additions & 0 deletions tests/Compilation/positive/test065.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- Arithmetic simplification
module test065;

import Stdlib.Prelude open;

{-# inline: false #-}
f (x : Int) : Int :=
(x + fromNat 1 - fromNat 1) * fromNat 1
+ fromNat 0 * x
+ (fromNat 10 + (x - fromNat 10))
+ (fromNat 10 + x - fromNat 10)
+ (fromNat 11 + (fromNat 11 - x))
+ fromNat 1 * x * fromNat 0 * fromNat 1;

main : Int := f (fromNat 10);

0 comments on commit cdfb35a

Please sign in to comment.