Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type inference for record projection #2172

Merged
merged 2 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/swarm-lang/Swarm/Language/Typecheck.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import Control.Effect.Catch (Catch, catchError)
import Control.Effect.Error (Error)
import Control.Effect.Reader
import Control.Effect.Throw
import Control.Lens ((^.))
import Control.Lens (view, (^.))
import Control.Lens.Indexed (itraverse)
import Control.Monad (forM_, void, when, (<=<), (>=>))
import Control.Monad.Free (Free (..))
Expand Down Expand Up @@ -444,6 +444,9 @@ data TypeErr
-- expected to have a certain type, but has a different type
-- instead.
Mismatch (Maybe Syntax) TypeJoin
| -- | Record type mismatch. The given term was expected to have a
-- record type, but has a different type instead.
MismatchRcd (Maybe Syntax) UType
| -- | Lambda argument type mismatch.
LambdaArgMismatch TypeJoin
| -- | Record field mismatch, i.e. based on the expected type we
Expand Down Expand Up @@ -479,6 +482,14 @@ instance PrettyPrec TypeErr where
, "From context, expected" <+> pprCode t <+> "to" <+> typeDescription Expected ty1 <> ","
, "but it" <+> typeDescription Actual ty2
]
MismatchRcd Nothing ty ->
"Type mismatch: expected a record type, but got" <+> ppr ty
MismatchRcd (Just t) ty ->
nest 2 . vcat $
[ "Type mismatch:"
, "From context, expected" <+> pprCode t <+> "to have a record type,"
, "but it" <+> typeDescription Actual ty
]
LambdaArgMismatch (getJoin -> (ty1, ty2)) ->
"Lambda argument has type annotation" <+> pprCode ty2 <> ", but expected argument type" <+> pprCode ty1
FieldsMismatch (getJoin -> (expFs, actFs)) ->
Expand All @@ -495,7 +506,7 @@ instance PrettyPrec TypeErr where
, reportBug
]
CantInferProj t ->
"Can't infer the type of a record projection:" <+> pprCode t
"In the record projection" <+> pprCode t <> ", can't infer whether the LHS has a record type. Try adding a type annotation."
UnknownProj x t ->
"Record does not have a field with name" <+> pretty x <> ":" <+> pprCode t
InvalidAtomic reason t ->
Expand Down Expand Up @@ -706,6 +717,42 @@ decomposeCmdTy
decomposeCmdTy = decomposeTyConApp1 TCCmd
decomposeDelayTy = decomposeTyConApp1 TCDelay

-- | Decompose a type which is expected to be a record type. There
-- are three possible outcomes:
--
-- * If the type is definitely a record type, return its mapping
-- from field names to types.
--
-- * If the type is definitely not a record type, throw a type error.
--
-- * Otherwise, return @Nothing@.
--
-- This is the best we can do, and different than the way the other
-- @decompose...Ty@ functions work, because we can't solve for record
-- types via unification.
decomposeRcdTy ::
( Has (Reader TDCtx) sig m
, Has (Reader TCStack) sig m
, Has (Throw ContextualTypeErr) sig m
) =>
Maybe Syntax ->
UType ->
m (Maybe (Map Var UType))
decomposeRcdTy ms = \case
ty@(UTyConApp tc as) -> case tc of
-- User-defined type: expand it
TCUser u -> do
ty2 <- expandTydef u as
decomposeRcdTy ms ty2
-- Any other type constructor application is definitely not a record type
_ -> throwTypeErr (maybe NoLoc (view sLoc) ms) $ MismatchRcd ms ty
-- Recursive type: expand it
UTyRec x t -> decomposeRcdTy ms (unfoldRec x t)
-- Record type
UTyRcd m -> pure (Just m)
-- With anything else (type variables, etc.) we're not sure
_ -> pure Nothing

-- | Decompose a type that is supposed to be the application of a
-- given type constructor to two type arguments. Also take the term
-- which is supposed to have that type, for use in error messages.
Expand Down Expand Up @@ -899,11 +946,12 @@ infer s@(CSyntax l t cs) = addLocToTypeErr l $ case t of
-- first anyway.
SProj t1 x -> do
t1' <- infer t1
case t1' ^. sType of
UTyRcd m -> case M.lookup x m of
mm <- decomposeRcdTy (Just t1) (t1' ^. sType)
case mm of
Just m -> case M.lookup x m of
Just xTy -> return $ Syntax' l (SProj t1' x) cs xTy
Nothing -> throwTypeErr l $ UnknownProj x (SProj t1 x)
_ -> throwTypeErr l $ CantInferProj (SProj t1 x)
Nothing -> throwTypeErr l $ CantInferProj (SProj t1 x)

-- See Note [Checking and inference for record literals]
SRcd m -> do
Expand Down
2 changes: 1 addition & 1 deletion src/swarm-lang/Swarm/Language/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ tcArity tydefs =
-- Reducing types to WHNF
------------------------------------------------------------

-- | Reduce a type to weak head normal form, i.e. keep unfold type
-- | Reduce a type to weak head normal form, i.e. keep unfolding type
-- aliases and recursive types just until the top-level constructor
-- of the type is neither @rec@ nor an application of a type alias.
whnfType :: TDCtx -> Type -> Type
Expand Down
30 changes: 30 additions & 0 deletions test/unit/TestLanguagePipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,36 @@ testLanguagePipeline =
"(\\r:[x:Int, y:Int]. r.x) [x = 3, z = 5]"
"1:26: Field mismatch; record literal has:\n - Extra field(s) `z`\n - Missing field(s) `y`"
)
, testCase
"type mismatch with record projection"
( process
"\\x:Int. x.y"
"1:9: Type mismatch:\n From context, expected `x` to have a record type,\n but it actually has type `Int`"
)
, testCase
"inference failure with record projection"
( process
"\\x. x.y"
"1:5: In the record projection `x.y`, can't infer whether the LHS has a record type. Try adding a type annotation."
)
, testCase
"infer record projection with tydef"
(valid "tydef R = [x:Int] end; def f : R -> Int = \\r. r.x end")
, testCase
"infer record projection with nested tydef"
(valid "tydef B = [x:Int] end; tydef A = B end; def f : A -> Int = \\r. r.x end")
, testCase
"infer record projection with tydef and recursive type"
(valid "tydef S = rec s. [hd:Int, tl:s] end; def two : S -> Int = \\s. s.tl.hd end")
, testCase
"infer record projection with tydef - RBTree"
(valid "tydef Color = Bool end; tydef RBTree k v = rec b. Unit + [c: Color, k: k, v: v, l: b, r: b] end; def balanceLR : RBTree k v -> RBTree k v = \\ln. case ln undefined (\\ne. ne.r) end")
, testCase
"infer record projection with tydef - RBTree, error"
( process
"tydef Color = Bool end; tydef RBTree k v = rec b. Unit + [c: Color, k: k, v: v, l: b, r: b] end; def balanceLR : RBTree k v -> RBTree k v = \\ln. case ln.r undefined undefined end"
"1:151: Type mismatch:\n From context, expected `ln` to have a record type,\n but it actually has type `Unit +"
)
]
, testGroup
"type annotations"
Expand Down