Skip to content

Commit

Permalink
Fix JuvixTree type unification (#2972)
Browse files Browse the repository at this point in the history
* Closes #2954 
* The problem was that the type validation algorithm was too strict for
higher-order functions with a dynamic (unknown) target.
  • Loading branch information
lukaszcz authored Aug 27, 2024
1 parent 9c980d1 commit eb5b2e4
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 123 deletions.
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ checkValueStack' loc tab tys mem = do
mapM_
( \(ty, idx) -> do
let ty' = fromJust $ topValueStack idx mem
unless (isSubtype' ty' ty) $
unless (isSubtype ty' ty) $
throw $
AsmError loc $
"type mismatch on value stack cell "
Expand Down
238 changes: 119 additions & 119 deletions src/Juvix/Compiler/Tree/Extra/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,127 +39,127 @@ curryType ty = case typeArgs ty of
in foldr (\tyarg ty'' -> mkTypeFun [tyarg] ty'') (typeTarget ty') tyargs

isSubtype :: Type -> Type -> Bool
isSubtype ty1 ty2 = case (ty1, ty2) of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False

isSubtype' :: Type -> Type -> Bool
isSubtype' ty1 ty2
-- The guard is to ensure correct behaviour with dynamic type targets. E.g.
-- `A -> B -> C -> D` should be a subtype of `(A, B) -> *`.
| tgt1 == TyDynamic || tgt2 == TyDynamic =
isSubtype
(curryType ty1)
(curryType ty2)
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2
isSubtype ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, _) -> True
(_, TyDynamic) -> True
(TyConstr TypeConstr {..}, TyInductive TypeInductive {..}) ->
_typeConstrInductive == _typeInductiveSymbol
(TyConstr c1, TyConstr c2) ->
c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag
&& all (uncurry isSubtype) (zip (c1 ^. typeConstrFields) (c2 ^. typeConstrFields))
(TyFun t1, TyFun t2) ->
let l1 = toList (t1 ^. typeFunArgs)
l2 = toList (t2 ^. typeFunArgs)
r1 = t1 ^. typeFunTarget
r2 = t2 ^. typeFunTarget
in length l1 == length l2 && all (uncurry isSubtype) (zip l2 l1) && isSubtype r1 r2
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
checkBounds (>=) l1 l2 && checkBounds (<=) u1 u2
where
checkBounds :: (Integer -> Integer -> Bool) -> Maybe Integer -> Maybe Integer -> Bool
checkBounds _ Nothing Nothing = True
checkBounds _ Nothing (Just _) = False
checkBounds _ (Just _) Nothing = True
checkBounds cmp (Just x) (Just y) = cmp x y
(TyBool {}, TyBool {}) -> True
(TyString, TyString) -> True
(TyField, TyField) -> True
(TyByteArray, TyByteArray) -> True
(TyUnit, TyUnit) -> True
(TyVoid, TyVoid) -> True
(TyInductive {}, TyInductive {}) -> ty1 == ty2
(TyUnit, _) -> False
(_, TyUnit) -> False
(TyVoid, _) -> False
(_, TyVoid) -> False
(TyInteger {}, _) -> False
(_, TyInteger {}) -> False
(TyString, _) -> False
(_, TyString) -> False
(TyField, _) -> False
(_, TyField) -> False
(TyByteArray, _) -> False
(_, TyByteArray) -> False
(TyBool {}, _) -> False
(_, TyBool {}) -> False
(TyFun {}, _) -> False
(_, TyFun {}) -> False
(_, TyConstr {}) -> False

unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes ty1 ty2 =
let (ty1', ty2') =
if
| typeTarget (uncurryType ty1) == TyDynamic || typeTarget (uncurryType ty2) == TyDynamic ->
(curryType ty1, curryType ty2)
| otherwise ->
(ty1, ty2)
in case (ty1', ty2') of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyField, TyField) -> return TyField
(TyByteArray, TyByteArray) -> return TyByteArray
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyField, _) -> err
(_, TyField) -> err
(TyByteArray, _) -> err
(_, TyByteArray) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do
Expand Down
8 changes: 6 additions & 2 deletions test/Tree/Asm/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Tree.Asm.Base where
import Asm.Run.Base qualified as Asm
import Base
import Juvix.Compiler.Asm.Translation.FromTree qualified as Asm
import Juvix.Compiler.Tree.Pipeline qualified as Tree
import Juvix.Compiler.Tree.Translation.FromSource
import Juvix.Data.PPOutput

Expand All @@ -18,5 +19,8 @@ treeAsmAssertion mainFile expectedFile step = do
Left err -> assertFailure (prettyString err)
Right tabIni -> do
step "Translate"
let tab = Asm.fromTree tabIni
Asm.asmRunAssertion' tab expectedFile step
case run $ runError @JuvixError $ Tree.toAsm tabIni of
Left err -> assertFailure (prettyString (fromJuvixError @GenericError err))
Right tab -> do
let tab' = Asm.fromTree tab
Asm.asmRunAssertion' tab' expectedFile step
7 changes: 6 additions & 1 deletion test/Tree/Eval/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -239,5 +239,10 @@ tests =
"Test040: ByteArray"
$(mkRelDir ".")
$(mkRelFile "test040.jvt")
$(mkRelFile "out/test040.out")
$(mkRelFile "out/test040.out"),
PosTest
"Test041: Type unification"
$(mkRelDir ".")
$(mkRelFile "test041.jvt")
$(mkRelFile "out/test041.out")
]
1 change: 1 addition & 0 deletions tests/Tree/positive/out/test041.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
41 changes: 41 additions & 0 deletions tests/Tree/positive/test041.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
type Foldable {
mkFoldable : ((* → * → *) → * → * → *) → Foldable;
}

type Box {
mkBox : * → Box;
}

function lambda_16(integer, integer) : integer;
function lambda_18((integer, integer) → integer, integer, Box) : integer;
function foldableBoxintegerI() : Foldable;
function go_17(integer) : integer;
function main() : integer;

function lambda_16(_X : integer, _X' : integer) : integer {
_X'
}

function lambda_18(f : (integer, integer) → integer, ini : integer, _X : Box) : integer {
case[Box](_X) {
mkBox: save {
call[go_17](tmp[0].mkBox[0])
}
}
}

function foldableBoxintegerI() : Foldable {
alloc[mkFoldable](calloc[lambda_18]())
}

function go_17(x' : integer) : integer {
x'
}

function main() : integer {
case[Foldable](call[foldableBoxintegerI]()) {
mkFoldable: save {
ccall(tmp[0].mkFoldable[0], calloc[lambda_16](), 0, alloc[mkBox](0))
}
}
}

0 comments on commit eb5b2e4

Please sign in to comment.