Skip to content

Commit

Permalink
Eliminate TyVarEql.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Dec 23, 2024
1 parent 6d541a0 commit 4f88fef
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 51 deletions.
46 changes: 0 additions & 46 deletions src/Language/Futhark/TypeChecker/Constraints.hs
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,19 @@ data TyVarInfo
TyVarRecord Loc (M.Map Name Type)
| -- | Must be a sum type with these fields.
TyVarSum Loc (M.Map Name [Type])
| -- | Must be a type that supports equality.
TyVarEql Loc
deriving (Show, Eq)

instance Pretty TyVarInfo where
pretty (TyVarFree _ l) = "free" <+> pretty l
pretty (TyVarPrim _ pts) = "" <+> pretty pts
pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs
pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs
pretty (TyVarEql _) = "equality"

instance Located TyVarInfo where
locOf (TyVarFree loc _) = loc
locOf (TyVarPrim loc _) = loc
locOf (TyVarRecord loc _) = loc
locOf (TyVarSum loc _) = loc
locOf (TyVarEql loc) = loc

type TyVar = VName

Expand Down Expand Up @@ -285,9 +281,6 @@ unifySharedFields reason bcs fs1 fs2 =
forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) ->
solveEq reason (matchingField f <> bcs) ts1 ts2

mustSupportEql :: Reason -> Type -> SolveM ()
mustSupportEql _reason _t = pure ()

scopeViolation :: Reason -> VName -> Type -> VName -> SolveM ()
scopeViolation reason v1 ty v2 =
typeError (locOf reason) mempty $
Expand Down Expand Up @@ -440,8 +433,6 @@ subTyVar reason bcs v t = do
</> indent 2 (pretty (Record fs1))
</> "with type"
</> indent 2 (pretty t)
(Just (Right (TyVarUnsol (TyVarEql _))), _) ->
mustSupportEql reason t
--
-- Internal error cases
(Just (Right TyVarSol {}), _) ->
Expand Down Expand Up @@ -481,10 +472,6 @@ unionTyVars reason bcs v t = do
setInfo t (TyVarUnsol info)
--
-- TyVarPrim cases
( TyVarUnsol info@TyVarPrim {},
TyVarEql {}
) ->
setInfo t (TyVarUnsol info)
( TyVarUnsol (TyVarPrim _ v_pts),
TyVarPrim t_loc t_pts
) ->
Expand Down Expand Up @@ -533,10 +520,6 @@ unionTyVars reason bcs v t = do
</> indent 2 (pretty (Sum cs1))
</> "with type"
</> indent 2 (pretty (Scalar (Record fs)))
( TyVarUnsol (TyVarSum _ cs1),
TyVarEql _
) ->
mapM_ (mapM_ (mustSupportEql reason)) cs1
--
-- TyVarRecord cases
( TyVarUnsol (TyVarRecord _ fs1),
Expand All @@ -559,20 +542,6 @@ unionTyVars reason bcs v t = do
</> indent 2 (pretty (Record fs1))
</> "with type"
</> indent 2 (pretty (Scalar (Sum cs)))
( TyVarUnsol (TyVarRecord _ fs1),
TyVarEql _
) ->
mapM_ (mustSupportEql reason) fs1
--
-- TyVarEql cases
(TyVarUnsol (TyVarEql _), TyVarPrim {}) ->
pure ()
(TyVarUnsol (TyVarEql _), TyVarEql {}) ->
pure ()
(TyVarUnsol (TyVarEql _), TyVarRecord _ fs) ->
mustSupportEql reason $ Scalar $ Record fs
(TyVarUnsol (TyVarEql _), TyVarSum _ cs) ->
mustSupportEql reason $ Scalar $ Sum cs
--
-- Internal error cases
(TyVarSol {}, _) ->
Expand Down Expand Up @@ -750,21 +719,6 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do
</> "Must be a sum type with constructors"
</> indent 2 (pretty (Scalar (Sum cs1)))
Right _ -> pure ()
solveTyVar (tv, (_, TyVarEql loc)) = do
tv_t <- lookupTyVar tv
case tv_t of
Left TyVarEql {} ->
typeError loc mempty $
"Type is ambiguous (must be equality type)"
</> "Add a type annotation to disambiguate the type."
Left _ -> pure ()
Right ty
| orderZero ty -> pure ()
| otherwise ->
typeError loc mempty $
"Type"
</> indent 2 (align (pretty ty))
</> "does not support equality (may contain function)."
solveTyVar (tv, (lvl, TyVarFree loc l)) = do
tv_t <- lookupTyVar tv
case tv_t of
Expand Down
3 changes: 0 additions & 3 deletions src/Language/Futhark/TypeChecker/Rank.hs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ addTyVarInfo tv (_, TyVarRecord {}) =
addConstraint $ rank tv ~==~ constant 0
addTyVarInfo tv (_, TyVarSum {}) =
addConstraint $ rank tv ~==~ constant 0
addTyVarInfo tv (_, TyVarEql {}) =
addConstraint $ rank tv ~==~ constant 0

mkLinearProg :: [Ct] -> TyVars -> LinearProg
mkLinearProg cs tyVars =
Expand Down Expand Up @@ -454,7 +452,6 @@ instance SubstRanks TyVarInfo where
TyVarRecord loc <$> traverse substRanks fs
substRanks (TyVarSum loc cs) =
TyVarSum loc <$> (traverse . traverse) substRanks cs
substRanks tv@TyVarEql {} = pure tv

instance SubstRanks (Int, TyVarInfo) where
substRanks (lvl, tv) = (lvl,) <$> substRanks tv
Expand Down
21 changes: 21 additions & 0 deletions src/Language/Futhark/TypeChecker/Terms.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,14 @@ mustBeIrrefutable p = do
"Refutable pattern not allowed here.\nUnmatched cases:"
</> indent 2 (stack (map pretty ps'))

supportsEquality :: TypeBase dim u -> Bool
supportsEquality (Array _ _ t) = supportsEquality $ Scalar t
supportsEquality (Scalar Prim {}) = True
supportsEquality (Scalar TypeVar {}) = False
supportsEquality (Scalar (Record fs)) = all supportsEquality fs
supportsEquality (Scalar (Sum fs)) = all (all supportsEquality) fs
supportsEquality (Scalar Arrow {}) = False

-- | Traverse the expression, emitting warnings and errors for various
-- problems:
--
Expand All @@ -1163,6 +1171,12 @@ localChecks = void . check
</> indent 2 (stack (map pretty ps'))
check e@(AppExp (LetPat _ p _ _ _) _) =
mustBeIrrefutable p *> recurse e
check e@(AppExp (BinOp (v, loc) _ (x, _) _ _) _)
| qualLeaf v == intrinsicVar "==" =
checkEquality loc (typeOf x) *> recurse e
check e@(Var v (Info t) loc)
| qualLeaf v == intrinsicVar "==" =
checkEquality loc t *> recurse e
check e@(Lambda ps _ _ _ _) =
mapM_ (mustBeIrrefutable . fmap toStruct) ps *> recurse e
check e@(AppExp (LetFun _ (_, ps, _, _, _) _ _) _) =
Expand All @@ -1188,6 +1202,13 @@ localChecks = void . check
check e = recurse e
recurse = astMap identityMapper {mapOnExp = check}

checkEquality loc t =
unless (supportsEquality t) $
typeError loc mempty $
"Comparing equality of values of type"
</> indent 2 (pretty t)
</> "which does not support equality."

bitWidth ty = 8 * intByteSize ty :: Int

inBoundsI x (Signed t) = x >= -2 ^ (bitWidth t - 1) && x < 2 ^ (bitWidth t - 1)
Expand Down
2 changes: 1 addition & 1 deletion src/Language/Futhark/TypeChecker/Terms2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ lookupVar loc qn@(QualName qs name) = do
-- TODO - qualify type names, like in the old type checker.
pure t'
Just EqualityF -> do
argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc))
argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarFree (locOf loc) Unlifted)
pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool
Just (OverloadedF ts pts rt) -> do
argtype <- newTypeOverloaded loc "t" ts
Expand Down
2 changes: 1 addition & 1 deletion tests/types/inference-error9.fut
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
-- ==
-- error: equality

def main 't (x: t) = x == x
def f 't (x: t) = x == x

0 comments on commit 4f88fef

Please sign in to comment.