diff --git a/src/Language/Futhark/TypeChecker/Constraints.hs b/src/Language/Futhark/TypeChecker/Constraints.hs index 1ad98997ad..b6d186b195 100644 --- a/src/Language/Futhark/TypeChecker/Constraints.hs +++ b/src/Language/Futhark/TypeChecker/Constraints.hs @@ -110,8 +110,6 @@ data TyVarInfo TyVarRecord Loc (M.Map Name Type) | -- | Must be a sum type with these fields. TyVarSum Loc (M.Map Name [Type]) - | -- | Must be a type that supports equality. - TyVarEql Loc deriving (Show, Eq) instance Pretty TyVarInfo where @@ -119,14 +117,12 @@ instance Pretty TyVarInfo where pretty (TyVarPrim _ pts) = "∈" <+> pretty pts pretty (TyVarRecord _ fs) = pretty $ Scalar $ Record fs pretty (TyVarSum _ cs) = pretty $ Scalar $ Sum cs - pretty (TyVarEql _) = "equality" instance Located TyVarInfo where locOf (TyVarFree loc _) = loc locOf (TyVarPrim loc _) = loc locOf (TyVarRecord loc _) = loc locOf (TyVarSum loc _) = loc - locOf (TyVarEql loc) = loc type TyVar = VName @@ -285,9 +281,6 @@ unifySharedFields reason bcs fs1 fs2 = forM_ (M.toList $ M.intersectionWith (,) fs1 fs2) $ \(f, (ts1, ts2)) -> solveEq reason (matchingField f <> bcs) ts1 ts2 -mustSupportEql :: Reason -> Type -> SolveM () -mustSupportEql _reason _t = pure () - scopeViolation :: Reason -> VName -> Type -> VName -> SolveM () scopeViolation reason v1 ty v2 = typeError (locOf reason) mempty $ @@ -440,8 +433,6 @@ subTyVar reason bcs v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty t) - (Just (Right (TyVarUnsol (TyVarEql _))), _) -> - mustSupportEql reason t -- -- Internal error cases (Just (Right TyVarSol {}), _) -> @@ -481,10 +472,6 @@ unionTyVars reason bcs v t = do setInfo t (TyVarUnsol info) -- -- TyVarPrim cases - ( TyVarUnsol info@TyVarPrim {}, - TyVarEql {} - ) -> - setInfo t (TyVarUnsol info) ( TyVarUnsol (TyVarPrim _ v_pts), TyVarPrim t_loc t_pts ) -> @@ -533,10 +520,6 @@ unionTyVars reason bcs v t = do indent 2 (pretty (Sum cs1)) "with type" indent 2 (pretty (Scalar (Record fs))) - ( TyVarUnsol (TyVarSum _ cs1), - TyVarEql _ - ) -> - mapM_ (mapM_ (mustSupportEql reason)) cs1 -- -- TyVarRecord cases ( TyVarUnsol (TyVarRecord _ fs1), @@ -559,20 +542,6 @@ unionTyVars reason bcs v t = do indent 2 (pretty (Record fs1)) "with type" indent 2 (pretty (Scalar (Sum cs))) - ( TyVarUnsol (TyVarRecord _ fs1), - TyVarEql _ - ) -> - mapM_ (mustSupportEql reason) fs1 - -- - -- TyVarEql cases - (TyVarUnsol (TyVarEql _), TyVarPrim {}) -> - pure () - (TyVarUnsol (TyVarEql _), TyVarEql {}) -> - pure () - (TyVarUnsol (TyVarEql _), TyVarRecord _ fs) -> - mustSupportEql reason $ Scalar $ Record fs - (TyVarUnsol (TyVarEql _), TyVarSum _ cs) -> - mustSupportEql reason $ Scalar $ Sum cs -- -- Internal error cases (TyVarSol {}, _) -> @@ -750,21 +719,6 @@ solveTyVar (tv, (_, TyVarSum loc cs1)) = do "Must be a sum type with constructors" indent 2 (pretty (Scalar (Sum cs1))) Right _ -> pure () -solveTyVar (tv, (_, TyVarEql loc)) = do - tv_t <- lookupTyVar tv - case tv_t of - Left TyVarEql {} -> - typeError loc mempty $ - "Type is ambiguous (must be equality type)" - "Add a type annotation to disambiguate the type." - Left _ -> pure () - Right ty - | orderZero ty -> pure () - | otherwise -> - typeError loc mempty $ - "Type" - indent 2 (align (pretty ty)) - "does not support equality (may contain function)." solveTyVar (tv, (lvl, TyVarFree loc l)) = do tv_t <- lookupTyVar tv case tv_t of diff --git a/src/Language/Futhark/TypeChecker/Rank.hs b/src/Language/Futhark/TypeChecker/Rank.hs index 24254d7392..deb578facb 100644 --- a/src/Language/Futhark/TypeChecker/Rank.hs +++ b/src/Language/Futhark/TypeChecker/Rank.hs @@ -167,8 +167,6 @@ addTyVarInfo tv (_, TyVarRecord {}) = addConstraint $ rank tv ~==~ constant 0 addTyVarInfo tv (_, TyVarSum {}) = addConstraint $ rank tv ~==~ constant 0 -addTyVarInfo tv (_, TyVarEql {}) = - addConstraint $ rank tv ~==~ constant 0 mkLinearProg :: [Ct] -> TyVars -> LinearProg mkLinearProg cs tyVars = @@ -454,7 +452,6 @@ instance SubstRanks TyVarInfo where TyVarRecord loc <$> traverse substRanks fs substRanks (TyVarSum loc cs) = TyVarSum loc <$> (traverse . traverse) substRanks cs - substRanks tv@TyVarEql {} = pure tv instance SubstRanks (Int, TyVarInfo) where substRanks (lvl, tv) = (lvl,) <$> substRanks tv diff --git a/src/Language/Futhark/TypeChecker/Terms.hs b/src/Language/Futhark/TypeChecker/Terms.hs index 48f82ad5ea..b6009887ed 100644 --- a/src/Language/Futhark/TypeChecker/Terms.hs +++ b/src/Language/Futhark/TypeChecker/Terms.hs @@ -1143,6 +1143,14 @@ mustBeIrrefutable p = do "Refutable pattern not allowed here.\nUnmatched cases:" indent 2 (stack (map pretty ps')) +supportsEquality :: TypeBase dim u -> Bool +supportsEquality (Array _ _ t) = supportsEquality $ Scalar t +supportsEquality (Scalar Prim {}) = True +supportsEquality (Scalar TypeVar {}) = False +supportsEquality (Scalar (Record fs)) = all supportsEquality fs +supportsEquality (Scalar (Sum fs)) = all (all supportsEquality) fs +supportsEquality (Scalar Arrow {}) = False + -- | Traverse the expression, emitting warnings and errors for various -- problems: -- @@ -1163,6 +1171,12 @@ localChecks = void . check indent 2 (stack (map pretty ps')) check e@(AppExp (LetPat _ p _ _ _) _) = mustBeIrrefutable p *> recurse e + check e@(AppExp (BinOp (v, loc) _ (x, _) _ _) _) + | qualLeaf v == intrinsicVar "==" = + checkEquality loc (typeOf x) *> recurse e + check e@(Var v (Info t) loc) + | qualLeaf v == intrinsicVar "==" = + checkEquality loc t *> recurse e check e@(Lambda ps _ _ _ _) = mapM_ (mustBeIrrefutable . fmap toStruct) ps *> recurse e check e@(AppExp (LetFun _ (_, ps, _, _, _) _ _) _) = @@ -1188,6 +1202,13 @@ localChecks = void . check check e = recurse e recurse = astMap identityMapper {mapOnExp = check} + checkEquality loc t = + unless (supportsEquality t) $ + typeError loc mempty $ + "Comparing equality of values of type" + indent 2 (pretty t) + "which does not support equality." + bitWidth ty = 8 * intByteSize ty :: Int inBoundsI x (Signed t) = x >= -2 ^ (bitWidth t - 1) && x < 2 ^ (bitWidth t - 1) diff --git a/src/Language/Futhark/TypeChecker/Terms2.hs b/src/Language/Futhark/TypeChecker/Terms2.hs index 975b7ae9e1..679dc6317b 100644 --- a/src/Language/Futhark/TypeChecker/Terms2.hs +++ b/src/Language/Futhark/TypeChecker/Terms2.hs @@ -438,7 +438,7 @@ lookupVar loc qn@(QualName qs name) = do -- TODO - qualify type names, like in the old type checker. pure t' Just EqualityF -> do - argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarEql (locOf loc)) + argtype <- tyVarType Observe <$> newTyVarWith "t" (TyVarFree (locOf loc) Unlifted) pure $ foldFunType [argtype, argtype] $ RetType [] $ Scalar $ Prim Bool Just (OverloadedF ts pts rt) -> do argtype <- newTypeOverloaded loc "t" ts diff --git a/tests/types/inference-error9.fut b/tests/types/inference-error9.fut index 0c6886c881..6cead89cec 100644 --- a/tests/types/inference-error9.fut +++ b/tests/types/inference-error9.fut @@ -2,4 +2,4 @@ -- == -- error: equality -def main 't (x: t) = x == x +def f 't (x: t) = x == x