From 41a81f44b578c4dd52767f05cd1946a6e22ec745 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 16 Aug 2023 14:22:05 -0400 Subject: [PATCH 01/41] Merge SolverM and InfererM as part of an effort to simplify Inference. --- src/lib/Err.hs | 4 + src/lib/Imp.hs | 2 +- src/lib/Inference.hs | 301 +++++++++++++++++-------------------------- src/lib/Name.hs | 8 +- src/lib/Subst.hs | 3 +- 5 files changed, 130 insertions(+), 188 deletions(-) diff --git a/src/lib/Err.hs b/src/lib/Err.hs index b575c048e..426250884 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -347,6 +347,10 @@ instance Alternative SearcherM where Just ans -> return $ Just ans Nothing -> m2 +instance Catchable SearcherM where + SearcherM (MaybeT m) `catchErr` handler = SearcherM $ MaybeT $ + m `catchErr` \errs -> runMaybeT $ runSearcherM' $ handler errs + instance Searcher SearcherM where () = (<|>) {-# INLINE () #-} diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d74333f38..56cfa95de 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -1235,7 +1235,7 @@ buildImpFunction cc argHintsTys body = do return $ ImpFunction impFun $ Abs bs $ ImpBlock decls results buildImpNaryAbs - :: (SinkableE e, HasNamesE e, RenameE e, HoistableE e) + :: HasNamesE e => [(NameHint, IType)] -> (forall l. (Emits l, DExt n l) => [(Name ImpNameC l, BaseType)] -> SubstImpM i l (e l)) -> SubstImpM i n (Abs (Nest IBinder) (Abs (Nest ImpDecl) e) n) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index cd28d7c5e..d4384ddb9 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -163,51 +163,32 @@ getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM d -- === Inferer interface === -class ( MonadFail1 m, Fallible1 m, Catchable1 m, CtxReader1 m, Builder CoreIR m ) - => InfBuilder (m::MonadKind1) where - - -- XXX: we should almost always used the zonking `buildDeclsInf` , - -- except where it's not possible because the result isn't atom-substitutable, - -- such as the source map at the top level. - buildDeclsInfUnzonked - :: (SinkableE e, HoistableE e, RenameE e) - => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => m l (e l)) - -> m n (Abs (Nest CDecl) e n) - - buildAbsInf - :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) - => EmitsInf n - => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs CBinder e n) - buildAbsInfWithExpl - :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) + :: (HasNamesE e, SubstE AtomSubstVal e) => EmitsInf n => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs (WithExpl CBinder) e n) + -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> InfererM i l (e l)) + -> InfererM i n (Abs (WithExpl CBinder) e n) buildAbsInfWithExpl hint expl ty cont = do Abs b e <- buildAbsInf hint expl ty cont return $ Abs (WithAttrB expl b) e buildNaryAbsInfWithExpl - :: (Inferer m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) + :: (HasNamesE e, SubstE AtomSubstVal e) => EmitsInf n => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest (WithExpl CBinder)) e n) + -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> InfererM i l (e l)) + -> InfererM i n (Abs (Nest (WithExpl CBinder)) e n) buildNaryAbsInfWithExpl expls bs cont = do Abs bs' e <- buildNaryAbsInf expls bs cont return $ Abs (zipAttrs expls bs') e buildNaryAbsInf - :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) + :: (HasNamesE e, SubstE AtomSubstVal e) => EmitsInf n => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest CBinder) e n) + -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> InfererM i l (e l)) + -> InfererM i n (Abs (Nest CBinder) e n) buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do @@ -216,21 +197,12 @@ buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = buildNaryAbsInf _ _ _ = error "zip error" buildDeclsInf - :: (SubstE AtomSubstVal e, RenameE e, Solver m, InfBuilder m) - => (SinkableE e, HoistableE e) + :: (HasNamesE e, SubstE AtomSubstVal e) => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => m l (e l)) - -> m n (Abs (Nest CDecl) e n) + => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) + -> InfererM i n (Abs (Nest CDecl) e n) buildDeclsInf cont = buildDeclsInfUnzonked $ cont >>= zonk -type InfBuilder2 (m::MonadKind2) = forall i. InfBuilder (m i) - -class (SubstReader Name m, InfBuilder2 m, Solver2 m) - => Inferer (m::MonadKind2) where - liftSolverMInf :: EmitsInf o => SolverM o a -> m i o a - addDefault :: CAtomName o -> DefaultType -> m i o () - getDefaults :: m i o (Defaults o) - applyDefaults :: EmitsInf o => InfererM i o () applyDefaults = do defaults <- getDefaults @@ -429,8 +401,8 @@ extendOutMapWithConstraints env us ss (Constraints allCs) = case tryUnsnoc allCs (env'', us'', ss''') newtype InfererM (i::S) (o::S) (a:: *) = InfererM - { runInfererM' :: SubstReaderT Name (InplaceT InfOutMap InfOutFrag FallibleM) i o a } - deriving (Functor, Applicative, Monad, MonadFail, + { runInfererM' :: SubstReaderT Name (InplaceT InfOutMap InfOutFrag SearcherM) i o a } + deriving (Functor, Applicative, Monad, MonadFail, Alternative, Searcher, ScopeReader, Fallible, Catchable, CtxReader, SubstReader Name) liftInfererMSubst :: (Fallible2 m, SubstReader Name m, EnvReader2 m) @@ -440,7 +412,7 @@ liftInfererMSubst cont = do subst <- getSubst Distinct <- getDistinct (InfOutFrag REmpty _ _, result) <- - liftExcept $ runFallibleM $ runInplaceT (initInfOutMap env) $ + liftExcept $ liftM fromJust $ runSearcherM $ runInplaceT (initInfOutMap env) $ runSubstReaderT subst $ runInfererM' $ cont return result @@ -481,29 +453,34 @@ emitInfererM hint emission = do return $ AtomVar v $ getType emission {-# INLINE emitInfererM #-} -instance Solver (InfererM i) where - extendSolverSubst v ty = do - InfererM $ SubstReaderT $ lift $ - void $ extendTrivialInplaceT $ - InfOutFrag REmpty mempty (singleConstraint v ty) - {-# INLINE extendSolverSubst #-} - - zonk e = InfererM $ SubstReaderT $ lift do - Distinct <- getDistinct - solverOutMap <- getOutMapInplaceT - return $ zonkWithOutMap solverOutMap e - {-# INLINE zonk #-} - - emitSolver binding = emitInfererM (getNameHint @String "?") $ RightE binding - {-# INLINE emitSolver #-} +extendSolverSubst :: CAtomName n -> CAtom n -> InfererM i n () +extendSolverSubst v ty = do + InfererM $ SubstReaderT $ lift $ + void $ extendTrivialInplaceT $ + InfOutFrag REmpty mempty (singleConstraint v ty) +{-# INLINE extendSolverSubst #-} - solveLocal cont = do - Abs (InfOutFrag unsolvedInfVars _ _) result <- dceInfFrag =<< runLocalInfererM cont - case unRNest unsolvedInfVars of - Empty -> return result - Nest (b:>RightE (InfVarBound ty (ctx, desc))) _ -> addSrcContext ctx $ - throw TypeErr $ formatAmbiguousVarErr (binderName b) ty desc - _ -> error "shouldn't be possible" +zonk :: (SubstE AtomSubstVal e, SinkableE e) => e n -> InfererM i n (e n) +zonk e = InfererM $ SubstReaderT $ lift do + Distinct <- getDistinct + solverOutMap <- getOutMapInplaceT + return $ zonkWithOutMap solverOutMap e +{-# INLINE zonk #-} + +emitSolver :: EmitsInf n => SolverBinding n -> InfererM i n (CAtomVar n) +emitSolver binding = emitInfererM (getNameHint @String "?") $ RightE binding +{-# INLINE emitSolver #-} + +solveLocal :: HasNamesE e + => (forall l. (EmitsInf l, Ext n l, Distinct l) => InfererM i l (e l)) + -> InfererM i n (e n) +solveLocal cont = do + Abs (InfOutFrag unsolvedInfVars _ _) result <- dceInfFrag =<< runLocalInfererM cont + case unRNest unsolvedInfVars of + Empty -> return result + Nest (b:>RightE (InfVarBound ty (ctx, desc))) _ -> addSrcContext ctx $ + throw TypeErr $ formatAmbiguousVarErr (binderName b) ty desc + _ -> error "shouldn't be possible" formatAmbiguousVarErr :: CAtomName n -> CType n' -> InfVarDesc -> String formatAmbiguousVarErr infVar ty = \case @@ -516,39 +493,52 @@ formatAmbiguousVarErr infVar ty = \case MiscInfVar -> "Ambiguous type variable: " ++ pprint infVar ++ ": " ++ pprint ty -instance InfBuilder (InfererM i) where - buildDeclsInfUnzonked cont = do - InfererM $ SubstReaderT $ ReaderT \env -> do - Abs frag result <- locallyMutableInplaceT (do - Emits <- fabricateEmitsEvidenceM - EmitsInf <- fabricateEmitsInfEvidenceM - runSubstReaderT (sink env) $ runInfererM' cont) +-- XXX: we should almost always used the zonking `buildDeclsInf` , +-- except where it's not possible because the result isn't atom-substitutable, +-- such as the source map at the top level. +buildDeclsInfUnzonked + :: (SinkableE e, HoistableE e, RenameE e) + => EmitsInf n + => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) + -> InfererM i n (Abs (Nest CDecl) e n) +buildDeclsInfUnzonked cont = do + InfererM $ SubstReaderT $ ReaderT \env -> do + Abs frag result <- locallyMutableInplaceT (do + Emits <- fabricateEmitsEvidenceM + EmitsInf <- fabricateEmitsInfEvidenceM + runSubstReaderT (sink env) $ runInfererM' cont) + (\d e -> return $ Abs d e) + extendInplaceT =<< hoistThroughDecls frag result + +buildAbsInf + :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) + => EmitsInf n + => NameHint -> Explicitness -> CType n + -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> InfererM i l (e l)) + -> InfererM i n (Abs CBinder e n) +buildAbsInf hint expl ty cont = do + ab <- InfererM $ SubstReaderT $ ReaderT \env -> do + extendInplaceT =<< withFreshBinder hint ty \bWithTy@(b:>_) -> do + ab <- locallyMutableInplaceT (do + v <- sinkM $ binderVar bWithTy + extendInplaceTLocal (extendSynthCandidatesInf expl $ atomVarName v) do + EmitsInf <- fabricateEmitsInfEvidenceM + -- zonking is needed so that dceInfFrag works properly + runSubstReaderT (sink env) (runInfererM' $ cont v >>= zonk)) (\d e -> return $ Abs d e) - extendInplaceT =<< hoistThroughDecls frag result - - buildAbsInf hint expl ty cont = do - ab <- InfererM $ SubstReaderT $ ReaderT \env -> do - extendInplaceT =<< withFreshBinder hint ty \bWithTy@(b:>_) -> do - ab <- locallyMutableInplaceT (do - v <- sinkM $ binderVar bWithTy - extendInplaceTLocal (extendSynthCandidatesInf expl $ atomVarName v) do - EmitsInf <- fabricateEmitsInfEvidenceM - -- zonking is needed so that dceInfFrag works properly - runSubstReaderT (sink env) (runInfererM' $ cont v >>= zonk)) - (\d e -> return $ Abs d e) - ab' <- dceInfFrag ab - refreshAbs ab' \infFrag result -> do - case exchangeBs $ PairB b infFrag of - HoistSuccess (PairB infFrag' b') -> do - return $ withSubscopeDistinct b' $ - Abs infFrag' $ Abs b' result - HoistFailure vs -> do - throw EscapedNameErr $ (pprint vs) - ++ "\nFailed to exchange binders in buildAbsInf" - ++ "\n" ++ pprint infFrag - Abs b e <- return ab - ty' <- zonk ty - return $ Abs (b:>ty') e + ab' <- dceInfFrag ab + refreshAbs ab' \infFrag result -> do + case exchangeBs $ PairB b infFrag of + HoistSuccess (PairB infFrag' b') -> do + return $ withSubscopeDistinct b' $ + Abs infFrag' $ Abs b' result + HoistFailure vs -> do + throw EscapedNameErr $ (pprint vs) + ++ "\nFailed to exchange binders in buildAbsInf" + ++ "\n" ++ pprint infFrag + Abs b e <- return ab + ty' <- zonk ty + return $ Abs (b:>ty') e dceInfFrag :: (EnvReader m, EnvExtender m, Fallible1 m, RenameE e, HoistableE e) @@ -560,23 +550,19 @@ dceInfFrag ab@(Abs frag@(InfOutFrag bs _ _) e) = Abs frag' (Abs Empty e') -> return $ Abs frag' e' _ -> error "Shouldn't have any decls without `Emits` constraint" -instance Inferer InfererM where - liftSolverMInf m = InfererM $ SubstReaderT $ lift $ - liftBetweenInplaceTs (liftExcept . liftM fromJust . runSearcherM) id liftSolverOutFrag $ - runSolverM' m - {-# INLINE liftSolverMInf #-} - - addDefault v defaultType = - InfererM $ SubstReaderT $ lift $ - extendTrivialInplaceT $ InfOutFrag REmpty defaults mempty - where - defaults = case defaultType of - IntDefault -> Defaults (freeVarsE v) mempty - NatDefault -> Defaults mempty (freeVarsE v) +addDefault :: CAtomName o -> DefaultType ->InfererM i o () +addDefault v defaultType = + InfererM $ SubstReaderT $ lift $ + extendTrivialInplaceT $ InfOutFrag REmpty defaults mempty + where + defaults = case defaultType of + IntDefault -> Defaults (freeVarsE v) mempty + NatDefault -> Defaults mempty (freeVarsE v) - getDefaults = InfererM $ SubstReaderT $ lift do - InfOutMap _ _ defaults _ _ <- getOutMapInplaceT - return defaults +getDefaults :: InfererM i o (Defaults o) +getDefaults = InfererM $ SubstReaderT $ lift do + InfOutMap _ _ defaults _ _ <- getOutMapInplaceT + return defaults instance Builder CoreIR (InfererM i) where rawEmitDecl hint ann expr = do @@ -1800,7 +1786,7 @@ identifySuperclasses ab = do return $ Abs bs' e withUBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) + :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) @@ -1815,7 +1801,7 @@ withUBinders bs cont = case bs of _ -> error "zip error" withConstraintBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, RenameE e, SinkableE e) + :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) => [UConstraint i] -> CAtomVar o -> (forall o'. (EmitsInf o', DExt o o') => InfererM i o' (e o')) @@ -1829,7 +1815,7 @@ withConstraintBinders (c:cs) v cont = do withConstraintBinders cs (sink v) cont withRoleUBinders - :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) + :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) => UAnnExplBinders req i i' -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) @@ -1918,7 +1904,7 @@ checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) coreLamExpr piAppExpl expls' $ Abs bs' body' checkLamBinders - :: (EmitsInf o, SinkableE e, HoistableE e, SubstE AtomSubstVal e, RenameE e) + :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) => [Explicitness] -> Nest CBinder o any -> Nest UOptAnnBinder i i' -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) @@ -2056,14 +2042,14 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat bindLetPats ps args $ cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" -inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) +inferParams :: (EmitsBoth o, HasNamesE e, SubstE AtomSubstVal e) => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) inferParams sourceName roleExpls (Abs paramBs bodyTop) = do let expls = snd <$> roleExpls (params, e') <- go expls (Abs paramBs bodyTop) return (TyConParams expls params, e') where - go :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) + go :: (EmitsBoth o, HasNamesE e, SubstE AtomSubstVal e) => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) go [] (Abs Empty body) = return ([], body) go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do @@ -2236,14 +2222,6 @@ newtype SolverSubst n = SolverSubst (M.Map (CAtomName n) (CAtom n)) instance Pretty (SolverSubst n) where pretty (SolverSubst m) = pretty $ M.toList m -class (CtxReader1 m, EnvReader m) => Solver (m::MonadKind1) where - zonk :: (SubstE AtomSubstVal e, SinkableE e) => e n -> m n (e n) - extendSolverSubst :: CAtomName n -> CAtom n -> m n () - emitSolver :: EmitsInf n => SolverBinding n -> m n (CAtomVar n) - solveLocal :: (SinkableE e, HoistableE e, RenameE e) - => (forall l. (EmitsInf l, Ext n l, Distinct l) => m l (e l)) - -> m n (e n) - type SolverOutMap = InfOutMap data SolverOutFrag (n::S) (l::S) = @@ -2285,10 +2263,7 @@ instance ExtOutMap InfOutMap SolverOutFrag where extendOutMap infOutMap outFrag = extendOutMap infOutMap $ liftSolverOutFrag outFrag -newtype SolverM (n::S) (a:: *) = - SolverM { runSolverM' :: InplaceT SolverOutMap SolverOutFrag SearcherM n a } - deriving (Functor, Applicative, Monad, MonadFail, Alternative, Searcher, - ScopeReader, Fallible, CtxReader) +type SolverM = InfererM VoidS liftSolverM :: EnvReader m => SolverM n a -> m n (Except a) liftSolverM cont = do @@ -2296,18 +2271,12 @@ liftSolverM cont = do Distinct <- getDistinct return do maybeResult <- runSearcherM $ runInplaceT (initInfOutMap env) $ - runSolverM' $ cont + runSubstReaderT (newSubst absurdNameFunction) $ runInfererM' cont case maybeResult of Nothing -> throw TypeErr "No solution" Just (_, result) -> return result {-# INLINE liftSolverM #-} -instance EnvReader SolverM where - unsafeGetEnv = SolverM do - InfOutMap env _ _ _ _ <- getOutMapInplaceT - return env - {-# INLINE unsafeGetEnv #-} - newtype SolverEmission (n::S) (l::S) = SolverEmission (BinderP (AtomNameC CoreIR) SolverBinding n l) instance ExtOutMap SolverOutMap SolverEmission where extendOutMap env (SolverEmission e) = env `extendOutMap` toEnvFrag e @@ -2315,52 +2284,16 @@ instance ExtOutFrag SolverOutFrag SolverEmission where extendOutFrag (SolverOutFrag es substs) (SolverEmission e) = withSubscopeDistinct e $ SolverOutFrag (RNest es e) (sink substs) -instance Solver SolverM where - extendSolverSubst v ty = SolverM $ - void $ extendTrivialInplaceT $ - SolverOutFrag REmpty (singleConstraint v ty) - {-# INLINE extendSolverSubst #-} - - zonk e = SolverM do - Distinct <- getDistinct - solverOutMap <- getOutMapInplaceT - return $ zonkWithOutMap solverOutMap $ sink e - {-# INLINE zonk #-} - - emitSolver binding = do - v <- SolverM $ freshExtendSubInplaceT (getNameHint @String "?") \b -> - (SolverEmission (b:>binding), binderName b) - toAtomVar v - {-# INLINE emitSolver #-} - - solveLocal cont = SolverM do - results <- locallyMutableInplaceT (do - Distinct <- getDistinct - EmitsInf <- fabricateEmitsInfEvidenceM - runSolverM' cont) (\d e -> return $ Abs d e) - Abs (SolverOutFrag unsolvedInfNames _) result <- return results - case unsolvedInfNames of - REmpty -> return result - _ -> case hoist unsolvedInfNames result of - HoistSuccess result' -> return result' - HoistFailure vs -> - throw TypeErr $ "Ambiguous type variables: " ++ pprint vs - {-# INLINE solveLocal #-} - -instance Unifier SolverM - -freshInferenceName :: (EmitsInf n, Solver m) => InfVarDesc -> Kind CoreIR n -> m n (CAtomVar n) +freshInferenceName :: EmitsInf n => InfVarDesc -> Kind CoreIR n -> InfererM i n (CAtomVar n) freshInferenceName desc k = do ctx <- srcPosCtx <$> getErrCtx emitSolver $ InfVarBound k (ctx, desc) {-# INLINE freshInferenceName #-} -freshSkolemName :: (EmitsInf n, Solver m) => Kind CoreIR n -> m n (CAtomVar n) +freshSkolemName :: EmitsInf n => Kind CoreIR n -> InfererM i n (CAtomVar n) freshSkolemName k = emitSolver $ SkolemBound k {-# INLINE freshSkolemName #-} -type Solver2 (m::MonadKind2) = forall i. Solver (m i) - emptySolverSubst :: SolverSubst n emptySolverSubst = SolverSubst mempty @@ -2428,9 +2361,7 @@ constrainEq t1 t2 = do ++ (case infVars of Empty -> "" _ -> "\n(Solving for: " ++ pprint (nestToList pprint infVars) ++ ")") - void $ addContext msg $ liftSolverMInf $ unify t1' t2' - -class (Alternative1 m, Searcher1 m, Fallible1 m, Solver m) => Unifier m + void $ addContext msg $ withSubst (newSubst absurdNameFunction) $ unify t1' t2' class (AlphaEqE e, SinkableE e, SubstE AtomSubstVal e) => Unifiable (e::E) where unifyZonked :: EmitsInf n => e n -> e n -> SolverM n () @@ -2585,19 +2516,19 @@ isSkolemName v = lookupEnv v >>= \case _ -> return False {-# INLINE isSkolemName #-} -freshType :: (EmitsInf n, Solver m) => m n (CType n) +freshType :: EmitsInf n => InfererM i n (CType n) freshType = TyVar <$> freshInferenceName MiscInfVar TyKind {-# INLINE freshType #-} -freshAtom :: (EmitsInf n, Solver m) => Type CoreIR n -> m n (CAtom n) +freshAtom :: EmitsInf n => Type CoreIR n -> InfererM i n (CAtom n) freshAtom t = Var <$> freshInferenceName MiscInfVar t {-# INLINE freshAtom #-} -freshEff :: (EmitsInf n, Solver m) => m n (EffectRow CoreIR n) +freshEff :: EmitsInf n => InfererM i n (EffectRow CoreIR n) freshEff = EffectRow mempty . EffectRowTail <$> freshInferenceName MiscInfVar EffKind {-# INLINE freshEff #-} -renameForPrinting :: (EnvReader m, HoistableE e, SinkableE e, RenameE e) +renameForPrinting :: (EnvReader m, HasNamesE e) => e n -> m n (Abs (Nest (AtomNameBinder CoreIR)) e n) renameForPrinting e = do infVars <- filterM isInferenceVar $ freeAtomVarsList e @@ -2889,7 +2820,7 @@ coreLamExpr appExpl expls ab = liftEnvReaderM do return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') withGivenBinders - :: (SinkableE e, RenameE e) => [Explicitness] -> Abs (Nest CBinder) e n + :: HasNamesE e => [Explicitness] -> Abs (Nest CBinder) e n -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) -> SyntherM n a withGivenBinders explsTop (Abs bsTop e) contTop = @@ -3123,7 +3054,7 @@ buildBlockInf cont = do {-# INLINE buildBlockInf #-} buildBlockInfWithRecon - :: (EmitsInf n, RenameE e, HoistableE e, SinkableE e) + :: (EmitsInf n, HasNamesE e) => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do diff --git a/src/lib/Name.hs b/src/lib/Name.hs index 94b14d98c..cdd5aa3e5 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -228,7 +228,7 @@ class SinkableB b => RenameB (b::B) where class (SinkableV v , forall c. Color c => RenameE (v c)) => RenameV (v::V) -type HasNamesE e = (RenameE e, HoistableE e) +type HasNamesE e = (RenameE e, SinkableE e, HoistableE e) type HasNamesB = RenameB instance RenameV Name @@ -3051,6 +3051,12 @@ toSubstPairs (UnsafeMakeSubst m) = data WithRenamer e i o where WithRenamer :: SubstFrag Name i i' o -> e i' -> WithRenamer e i o +instance Category UnitB where + id = UnitB + {-# INLINE id #-} + UnitB . UnitB = UnitB + {-# INLINE (.) #-} + instance Category (Nest b) where id = Empty {-# INLINE id #-} diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 5b13ef624..bf5036346 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -302,7 +302,8 @@ instance (forall n. Monad (m n)) => Monad (SubstReaderT v m i o) where deriving instance (Monad1 m, MonadFail1 m) => MonadFail (SubstReaderT v m i o) deriving instance (Monad1 m, Alternative1 m) => Alternative (SubstReaderT v m i o) -deriving instance (Fallible1 m) => Fallible (SubstReaderT v m i o) +deriving instance Fallible1 m => Fallible (SubstReaderT v m i o) +deriving instance Searcher1 m => Searcher (SubstReaderT v m i o) deriving instance Catchable1 m => Catchable (SubstReaderT v m i o) deriving instance CtxReader1 m => CtxReader (SubstReaderT v m i o) From d3705b03bad0880ed4131042beee002f9d321113 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 17 Aug 2023 23:29:19 -0400 Subject: [PATCH 02/41] Defang type inference. So far the prelude and some of the libraries work but I haven't updated the tests or the examples. I'm going to keep working in this state as I add expressions-in-types. My goal here is to make the type inference pass simpler so that we can modify it more easily, particularly adding expressions in types. I also want the type inference process to be more legible to users. It should be easy to reason about what type inference can do and we should be able to provide feedback about what it did. The main principle I'm trying out is that type inference should be (1) local and (2) feed-forward. "Local" means I don't want inference variables spanning multiple expressions. "Feed-forward" means that type information should flow in the same direction as ordinary data - you get the type of an expression, given the types of its free variables, by looking at the types of its subexpressions. You don't figure out the type of a binder based on its occurrences. Here are the main changes to type inference. "Fully known" means "contains no inference variables, including in the definitions and types of its free variables". 1. All variables have fully-known types, and let-bound variables have fully-known definitions too. 2. Each UExpr is either processed "top-down" (i.e. checking it against a fully-known type) or "bottom-up" (inferring a fully-known type). There's no "partial checking" where you check a term against, say `?a -> ?b` where `?a` and `?b` are inference variables. 3. Inference variables only exist for the duration of a single (n-ary) function application. 4. Lambda-like binders (including `for` binders) must either be explicitly annotated or their type must be evident from the top-down type we're checking against. The latter case lets us still write things like `each xs \x. x + x` without annotating the `x` binder. The `\x` behaves more like a `let` binder than a lambda binder in this context. This is also the reason we can't adopt truly brutalist type inference and do *everything* bottom-up. Some things are better now: * ~1000 loc removed * "cheap reduction" is now well-typed because there are no inference variables in either the term we're reducing or its environment. * We now tend to ask "are these types equal?" rather than "what would these variables need to be for these types to be equal?". The latter question doesn't work so well when we move beyond syntactic equality and start talking about equivalence relations. * We can now reliably use features that require fully-known types, like dot method resolution. Other things look worse but I think they're actually better: * To support inference, functions should appear last in an argument list. For example, we should prefer `each` to its flipped version, `map` (I actually deleted `map` from the prelude). This is so that we see the table argument first, which tells us the element type, which tells us the type of the binder of the per-element function. But function arguments should go last anyway, for syntactic reasons. And even before this change, if you put a function argument first, it might type check but you wouldn't be able to use type-dispatch features (like dot method resolution). This just forces us to be more consistent. * No more defaults. I got rid of these because they require non-local inference. But I think they were confusing anyway for exactly that reason. Now you have to be explicit, writing `n = 10 :: Nat` or `n : Nat = 10` (we should standardize on one or the other) instead of `n = 10`. * `for i.` now usually requries an annotated `i`. But I think it ends up being more readable. And for the very common case where you're just using `i` to index into a single table, it's better to use `each` anyway. Other things are actually bad * `Nothing`. Sometimes we now have to write `Nothing :: Maybe Nat`. (We should be able to write it as `Nothing(a=Nat)` but that doesn't work yet.) * `0 + x` will fail now. You can write `x + 0` just fine, because the type of `0` is obtained from the type of `x` but if you want to write it in the other order you have to write `(0::Nat) + x` or something. It's frustrating because `+` is such an obviously symmetric operator. There's a similar thing with `case`, where the order of the cases influences whether inference succeeds or fails (it wasn't perfect before either - the order could still influence whether dot-notation would work). I'm not sure how to fix these. "Backwards" type inference is just really useful in handling polymorphic constants like `0` and `Nothing`. (In my opinion this is the *only* place it's useful.) I'm tempted to just say that we should get in the habit of annotating these constants or supplying explicit type arguments. Maybe `Nothing` should take its type argument explicitly so by default so you write `Nothing(Nat)` or `Nothing(_)` if you're sure you're in top-down checking mode. Then you can at least avoid the redundant `Maybe` in `Nothing :: Maybe Nat`. --- examples/raytrace.dx | 35 +- lib/diagram.dx | 37 +- lib/plot.dx | 46 +- lib/png.dx | 49 +- lib/prelude.dx | 536 +++--- src/lib/AbstractSyntax.hs | 200 ++- src/lib/Algebra.hs | 10 +- src/lib/Builder.hs | 28 +- src/lib/CheapReduction.hs | 15 +- src/lib/CheckType.hs | 3 - src/lib/ConcreteSyntax.hs | 23 +- src/lib/Core.hs | 17 +- src/lib/Err.hs | 217 +-- src/lib/Imp.hs | 9 +- src/lib/ImpToLLVM.hs | 2 +- src/lib/Inference.hs | 3399 ++++++++++++++----------------------- src/lib/Inference.hs-boot | 14 - src/lib/MTL1.hs | 125 +- src/lib/Name.hs | 76 +- src/lib/PPrint.hs | 41 +- src/lib/QueryType.hs | 4 +- src/lib/QueryTypePure.hs | 2 +- src/lib/Simplify.hs | 2 - src/lib/SourceRename.hs | 59 +- src/lib/Subst.hs | 71 +- src/lib/TopLevel.hs | 9 +- src/lib/Types/Core.hs | 35 +- src/lib/Types/Source.hs | 117 +- src/lib/Util.hs | 6 +- src/lib/Vectorize.hs | 27 +- 30 files changed, 2118 insertions(+), 3096 deletions(-) delete mode 100644 src/lib/Inference.hs-boot diff --git a/examples/raytrace.dx b/examples/raytrace.dx index eccfa5695..993d6203c 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -15,7 +15,7 @@ def Vec(n:Nat) -> Type = Fin n => Float def Mat(n:Nat, m:Nat) -> Type = Fin n => Fin m => Float def relu(x:Float) -> Float = max x 0.0 -def length(x: d=>Float) -> Float given (d|Ix) = sqrt $ sum for i. sq x[i] +def length(x: d=>Float) -> Float given (d|Ix) = sqrt $ sum for i:d. sq x[i] -- TODO: make a newtype for normal vectors def normalize(x: d=>Float) -> d=>Float given (d|Ix) = x / (length x) def directionAndLength(x: d=>Float) -> (d=>Float, Float) given (d|Ix) = @@ -68,7 +68,7 @@ def rotateZ(p:Vec 3, angle:Angle) -> Vec 3 = [c*px - s*py, s*px+c*py, pz] def sampleCosineWeightedHemisphere(normal: Vec 3, k:Key) -> Vec 3 = - [k1, k2] = split_key k + [k1, k2] = split_key(n=2, k) u1 = rand k1 u2 = rand k2 uu = normalize $ cross normal [0.0, 1.1, 1.1] @@ -152,21 +152,21 @@ def sdObject(pos:Position, obj:Object) -> Distance = Wall(nor, d) -> d + dot nor pos Block(blockPos, halfWidths, angle) -> pos' = rotateY (pos - blockPos) angle - length $ for i. max ((abs pos'[i]) - halfWidths[i]) 0.0 + length $ for i:(Fin 3). max ((abs pos'[i]) - halfWidths[i]) 0.0 Sphere(spherePos, r) -> pos' = pos - spherePos max (length pos' - r) 0.0 Light(squarePos, hw, _) -> pos' = pos - squarePos halfWidths = [hw, 0.01, hw] - length $ for i. max ((abs pos'[i]) - halfWidths[i]) 0.0 + length $ for i:(Fin 3). max ((abs pos'[i]) - halfWidths[i]) 0.0 def sdScene(scene:Scene n, pos:Position) -> (Object, Distance) given (n|Ix) = - (i, d) = minimum_by snd $ for i. (i, sdObject pos scene.objects[i]) + (i, d) = minimum_by(for i:n. (i, sdObject pos scene.objects[i]), snd) (scene.objects[i], d) def calcNormal(obj:Object, pos:Position) -> Direction = - normalize (grad (\pos. sdObject pos obj) pos) + grad(\p:Position. sdObject(p, obj)) pos | normalize data RayMarchResult = -- incident ray, surface normal, surface properties @@ -176,7 +176,7 @@ data RayMarchResult = HitNothing def raymarch(scene:Scene n, ray:Ray) -> RayMarchResult given (n|Ix) = - maxIters = 100 + maxIters : Nat = 100 tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface with_state (10.0 * tol) \rayLength. @@ -209,7 +209,7 @@ def rayDirectRadiance(scene:Scene n, ray:Ray) -> Radiance given (n|Ix) = HitObj(_, _) -> zero def sampleSquare(hw:Float, k:Key) -> Position = - [kx, kz] = split_key k + [kx, kz] : Fin 2 => Key = split_key k x = randuniform (- hw) hw kx z = randuniform (- hw) hw kz [x, 0.0, z] @@ -220,7 +220,7 @@ def sampleLightRadiance( inRay:Ray, k:Key) -> Radiance given (n|Ix) = yield_accum (AddMonoid Float) \radiance. - for i. case scene.objects[i] of + each scene.objects \obj. case obj of PassiveObject(_, _) -> () Light(lightPos, hw, _) -> (dirToLight, distToLight) = directionAndLength $ @@ -244,7 +244,7 @@ def trace(params:Params, scene:Scene n, initRay:Ray, k:Key) -> Color given (n|Ix if i == 0 then radiance += intensity -- TODO: scale etc Done () HitObj(incidentRay, osurf) -> - [k1, k2] = split_key $ hash k i + [k1, k2] = split_key(n=2, hash k i) lightRadiance = sampleLightRadiance scene osurf incidentRay k1 ray := sampleReflection osurf incidentRay k2 filter := surfaceFilter (get filter) osurf.surface @@ -265,27 +265,24 @@ def cameraRays(n:Nat, camera:Camera) -> Fin n => Fin n => ((Key) -> Ray) = pixHalfWidth = halfWidth / n_to_f n ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth - for i j. \key. - [kx, ky] = split_key key + for i:(Fin n) j:(Fin n). \key. + [kx, ky] = split_key(n=2, key) x = xs[j] + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys[i] + randuniform (-pixHalfWidth) pixHalfWidth ky Ray(camera.pos, normalize [x, y, neg camera.sensorDist]) def takePicture(params:Params, scene:Scene m, camera:Camera) -> Image given (m|Ix) = - n = camera.numPix - rays = cameraRays n camera + rays = cameraRays camera.numPix camera rootKey = new_key 0 - image = for i j. + image = for i:(Fin camera.numPix) j:(Fin camera.numPix). pixKey = if params.shareSeed then rootKey else ixkey (ixkey rootKey i) j def sampleRayColor(k:Key) -> Color = - [k1, k2] = split_key k + [k1, k2] = split_key(n=2, k) trace params scene (rays[i,j] k1) k2 sampleAveraged sampleRayColor params.numSamples pixKey - MkImage _ _ $ image / mean (for ixs. - (i,j,k) = ixs - image[i,j,k]) + MkImage _ _ $ image / mean(flatten3D(image)) '## Define the scene and render it diff --git a/lib/diagram.dx b/lib/diagram.dx index 66beff1ff..46fffd400 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -35,15 +35,16 @@ struct GeomStyle = default_geom_style = GeomStyle Nothing (Just black) 1 -- TODO: consider sharing attributes among a set of objects for efficiency +Object : Type = (GeomStyle, Point, Geom) struct Diagram = - val : (List (GeomStyle, Point, Geom)) + val : (List Object) instance Monoid(Diagram) mempty = Diagram mempty def (<>)(d1, d2) = Diagram $ d1.val <> d2.val def concat_diagrams(diagrams:n=>Diagram) -> Diagram given (n|Ix) = - Diagram $ concat for i. diagrams[i].val + Diagram $ concat $ each diagrams \d. d.val -- TODO: arbitrary affine transformations. Our current representation of -- rectangles and circles means we can only do scale/flip/rotate90. @@ -54,8 +55,8 @@ def apply_transformation( d:Diagram ) -> Diagram = AsList(_, objs) = d.val - Diagram $ to_list for i. - (attr, p, geom) = objs[i] + Diagram $ to_list $ each objs \obj. + (attr, p, geom) = obj (attr, transformPoint p, transformGeom geom) def flip_y(d:Diagram) -> Diagram = @@ -92,8 +93,8 @@ def text(x:String) -> Diagram = singleton_default $ Text x def update_geom(update: (GeomStyle) -> GeomStyle, d:Diagram) -> Diagram = AsList(_, objs) = d.val - Diagram $ to_list for i. - ( attr, point, geoms) = objs[i] + Diagram $ to_list $ each objs \obj. + ( attr, point, geoms) = obj (update attr, point, geoms) -- TODO: these would be better if we had field-access-based ref projections, so we could @@ -149,7 +150,7 @@ def (<=>)(attr:String, val:b) -> String given (b|Show) = attr <.> "=" <.> quote (show val) def html_color(cs:HtmlColor) -> String = - "#" <> (concat $ for i. showHex cs[i]) + "#" <> (concat $ each cs showHex) def optional_html_color(c: Maybe HtmlColor) -> String = case c of @@ -166,7 +167,7 @@ def attr_string(attr:GeomStyle) -> String = def render_geom(attr:GeomStyle, p:Point, geom:Geom) -> String = -- For things that are solid. SVG says they have fill=stroke. solidAttr = GeomStyle attr.strokeColor attr.strokeColor attr.strokeWidth - groupEle = \attr s. tag_brackets_attr "g" (attr_string attr) s + groupEle = \attr:GeomStyle s:String. tag_brackets_attr "g" (attr_string attr) s case geom of PointGeom -> groupEle solidAttr $ self_closing_brackets $ @@ -188,7 +189,7 @@ def render_geom(attr:GeomStyle, p:Point, geom:Geom) -> String = "x" <=> (p.x - (w/2.0)) <.> "y" <=> (p.y - (h/2.0))) Text content -> - textEle = \s. tag_brackets_attr("text", + textEle = \s:String. tag_brackets_attr("text", ("x" <=> p.x <.> "y" <=> p.y <.> "text-anchor" <=> "middle" <.> -- horizontal center @@ -200,8 +201,8 @@ BoundingBox : Type = (Point, Point) @noinline def compute_bounds(d:Diagram) -> BoundingBox = - computeSubBound = \sel op. - \triple. + computeSubBound = \sel:((Point) -> Float) op:((Float) -> Float). + \triple:Object. (_, p, geom) = triple sel p + case geom of PointGeom -> 0.0 @@ -213,12 +214,12 @@ def compute_bounds(d:Diagram) -> BoundingBox = AsList(_, objs) = d.val ( Point( - minimum $ map (computeSubBound (\p. p.x) neg) objs, - minimum $ map (computeSubBound (\p. p.y) neg) objs + minimum $ each objs (computeSubBound (\p. p.x) neg), + minimum $ each objs (computeSubBound (\p. p.y) neg) ), Point( - maximum $ map (computeSubBound (\p. p.x) id) objs, - maximum $ map (computeSubBound (\p. p.y) id) objs + maximum $ each objs (computeSubBound (\p. p.x) id), + maximum $ each objs (computeSubBound (\p. p.y) id) ) ) @@ -235,11 +236,11 @@ def render_svg(d:Diagram, bounds:BoundingBox) -> String = <+> "height" <=> imgHeight <+> "viewBox" <=> (imgXMin <+> imgYMin <+> imgWidth <+> imgHeight)) tag_brackets_attr "svg" svgAttrStr $ - concat for i. - (attr, pos, geom) = objs[i] + concat $ each objs \obj. + (attr, pos, geom) = obj render_geom attr pos geom -render_scaled_svg = \d. render_svg d (compute_bounds d) +render_scaled_svg = \d:Diagram. render_svg d (compute_bounds d) '## Derived convenience methods and combinators diff --git a/lib/plot.dx b/lib/plot.dx index f5f92a028..189ca543b 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -15,6 +15,7 @@ struct ScaledData(n|Ix, a:Type) = scale : Scale a dat : n => a +-- TODO: bundle up the type params into a triple of types struct Plot(n|Ix, a:Type, b:Type, c:Type) = xs : ScaledData n a ys : ScaledData n b @@ -22,7 +23,7 @@ struct Plot(n|Ix, a:Type, b:Type, c:Type) = Color : Type = Fin 3 => Float -def apply_scale(s:Scale a, x:a) -> Maybe Float given (a) = s.mapping x +def apply_scale(s:Scale a, x:a) -> Maybe Float given (a:Type) = s.mapping x unit_type_scale : Scale(()) = Scale (\_. Just 0.0) (AsList _ [Singleton 0.0]) @@ -33,12 +34,12 @@ def project_unit_interval(x:Float) -> Maybe Float = unit_interval_scale : Scale Float = Scale (project_unit_interval) (AsList _ [Interval 0.0 1.0]) -def map_scale(s:Scale a, f: (b) -> a) -> Scale b given (a, b) = Scale (\x. s.mapping (f x)) s.range +def map_scale(s:Scale a, f: (b) -> a) -> Scale b given (a:Type, b:Type) = Scale (\x. s.mapping (f x)) s.range def float_scale(xmin:Float, xmax:Float) -> Scale Float = map_scale unit_interval_scale (\x. (x - xmin) / (xmax - xmin)) -def get_scaled(sd:ScaledData n a, i:n) -> Maybe Float given (n|Ix, a) = +def get_scaled(sd:ScaledData n a, i:n) -> Maybe Float given (n|Ix, a:Type) = apply_scale sd.scale sd.dat[i] low_color = [1.0, 0.5, 0.0] @@ -54,8 +55,8 @@ def make_rgb_color(c: Color) -> HtmlColor = def color_scale(x:Float) -> HtmlColor = make_rgb_color $ interpolate low_color high_color x -def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a, b, c, n|Ix) = - points = concat_diagrams for i. +def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a:Type, b:Type, c:Type, n|Ix) = + points = concat_diagrams for i:n. x = get_scaled plot.xs i y = get_scaled plot.ys i c = get_scaled plot.cs i @@ -70,16 +71,17 @@ def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a, b, c, n|Ix) = boundingBox = move_xy(rect 1.0 1.0, 0.5, 0.5) boundingBox <> points -def show_plot(plot:Plot n a b c) -> String given (a, b, c, n|Ix) = +def show_plot(plot:Plot n a b c) -> String given (a:Type, b:Type, c:Type, n|Ix) = render_svg (plot_to_diagram plot) (Point 0.0 0.0, Point 1.0 1.0) -def blank_data() ->> ScaledData n () given (n|Ix) = - ScaledData unit_type_scale (for i. ()) +def blank_data(n|Ix) -> ScaledData n () = + ScaledData unit_type_scale (for i:n. ()) -def blank_plot() ->> Plot n () () () given (n|Ix) = - Plot blank_data blank_data blank_data +def blank_plot(n|Ix) -> Plot n () () () = + -- TODO: figure out why we need the annotations here. Top-down inference should work. + Plot(blank_data(n), blank_data(n), blank_data(n)) --- -- TODO: generalize beyond Float with a type class for auto scaling +-- TODO: generalize beyond Float with a type class for auto scaling def auto_scale(xs:n=>Float) -> ScaledData n Float given (n|Ix) = max = maximum xs min = minimum xs @@ -88,29 +90,29 @@ def auto_scale(xs:n=>Float) -> ScaledData n Float given (n|Ix) = padding = maximum [space, max * 0.001, 0.000001] ScaledData (float_scale (min - padding) (max + padding)) xs -def set_x_data(plot:Plot n a b c, xs:ScaledData n new) -> Plot n new b c given (n|Ix, a, b, c, new) = +def set_x_data(plot:Plot n a b c, xs:ScaledData n new) -> Plot n new b c given (n|Ix, a:Type, b:Type, c:Type, new:Type) = -- We can't use `setAt` here because we're changing the type Plot xs plot.ys plot.cs -def set_y_data(plot:Plot n a b c, ys:ScaledData n new) -> Plot n a new c given (n|Ix, a, b, c, new) = +def set_y_data(plot:Plot n a b c, ys:ScaledData n new) -> Plot n a new c given (n|Ix, a:Type, b:Type, c:Type, new:Type) = Plot plot.xs ys plot.cs -def set_c_data(plot:Plot n a b c, cs:ScaledData n new) -> Plot n a b new given (n|Ix, a, b, c, new) = +def set_c_data(plot:Plot n a b c, cs:ScaledData n new) -> Plot n a b new given (n|Ix, a:Type, b:Type, c:Type, new:Type) = Plot plot.xs plot.ys cs def xy_plot(xs:n=>Float, ys:n=>Float) -> Plot n Float Float () given (n|Ix) = - blank_plot | + blank_plot(n) | set_x_data (auto_scale xs) | set_y_data (auto_scale ys) def xyc_plot(xs:n=>Float, ys:n=>Float, cs:n=>Float) -> Plot n Float Float Float given (n|Ix) = - blank_plot | + blank_plot(n) | set_x_data (auto_scale xs) | set_y_data (auto_scale ys) | set_c_data (auto_scale cs) def y_plot(ys:n=>Float) -> Plot n Float Float () given (n|Ix) = - xs = for i. n_to_f $ ordinal i + xs = for i:n. n_to_f $ ordinal i xy_plot xs ys -- xs = linspace (Fin 100) 0. 1.0 @@ -120,14 +122,10 @@ def y_plot(ys:n=>Float) -> Plot n Float Float () given (n|Ix) = -- TODO: scales def matshow(img:n=>m=>Float) -> Html given (n|Ix, m|Ix) = - low = minimum $ for p. - (i, j) = p - img[i,j] - high = maximum $ for p. - (i, j) = p - img[i,j] + low = minimum $ flatten2D(img) + high = maximum $ flatten2D(img) range = high - low - img_to_html $ make_png for i j. + img_to_html $ make_png for i:n j:m. x = if range == 0.0 then float_to_8bit $ 0.5 else float_to_8bit $ (img[i,j] - low) / range diff --git a/lib/png.dx b/lib/png.dx index faa6b7db7..3937b0350 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -28,7 +28,7 @@ Base64 = Byte -- first two bits should be zero -- This could go in the prelude, or in a library of array-dicing functions. -- An explicit "view" builder would be good here, to avoid copies def get_chunks(chunkSize:Nat, padVal:a, xs:n=>a) - -> List (Fin chunkSize => a) given (n|Ix, a) = + -> List (Fin chunkSize => a) given (n|Ix, a:Type) = numChunks = idiv_ceil (size n) chunkSize paddedSize = numChunks * chunkSize xsPadded = pad_to (Fin paddedSize) padVal xs @@ -44,29 +44,28 @@ def base64s_to_bytes(chunk : Fin 4 => Base64) -> Fin 3 => Byte = def bytes_to_base64s(chunk : Fin 3 => Byte) -> Fin 4 => Base64 = [a, b, c] = chunk -- '?' is 00111111 - map (\x. x .&. '?') $ - [ a .>>. 2 - , (a .<<. 4) .|. (b .>>. 4) - , (b .<<. 2) .|. (c .>>. 6) - , c ] + tmp = [ a .>>. 2 + , (a .<<. 4) .|. (b .>>. 4) + , (b .<<. 2) .|. (c .>>. 6) + , c ] + each tmp \x. x .&. '?' def base64_to_ascii(x:Base64) -> Char = encoding_table[from_ordinal (w8_to_n x)] def encode_chunk(chunk : Fin 3 => Char) -> Fin 4 => Char = - map base64_to_ascii $ bytes_to_base64s chunk + each (bytes_to_base64s chunk) base64_to_ascii -- TODO: the `AsList` unpacking is very tedious. Daniel's change will help def base64_encode(s:String) -> String = AsList(n, cs) = s AsList(numChunks, chunks) = get_chunks 3 '\NUL' cs - encodedChunks = map encode_chunk chunks - flattened = for pair. - (i, j) = pair - encodedChunks[i, j] + encodedChunks = each chunks encode_chunk + FlatIxType : Type = (Fin numChunks, Fin 4) + flattened = flatten2D(encodedChunks) padChars = rem (unsafe_nat_diff 3 (rem n 3)) 3 validOutputChars = unsafe_nat_diff (numChunks * 4) padChars - to_list for i. case ordinal i < validOutputChars of + to_list for i:FlatIxType. case ordinal i < validOutputChars of True -> flattened[i] False -> '=' @@ -74,7 +73,7 @@ def ascii_to_base64(c:Char) -> Maybe Base64 = decoding_table[from_ordinal (w8_to_n c)] def decode_chunk(chunk : Fin 4 => Char) -> Maybe (Fin 3 => Char) = - case seq_maybes $ map ascii_to_base64 chunk of + case chunk | each(ascii_to_base64) | seq_maybes of Nothing -> Nothing Just base64s -> Just $ base64s_to_bytes base64s @@ -87,16 +86,14 @@ def replace(pair:(a,a), x:a) -> a given (a|Eq) = def base64_decode(s:String) -> Maybe String = AsList(n, cs) = s - numValidInputChars = sum for i. b_to_n $ cs[i] /= '=' + numValidInputChars = sum for i:(Fin n). b_to_n $ cs[i] /= '=' numValidOutputChars = idiv (numValidInputChars * 3) 4 - csZeroed = map (\x. replace(('=', 'A'), x)) cs -- swap padding char with 'zero' char + csZeroed = each cs \c. replace(('=', 'A'), c) -- swap padding char with 'zero' char AsList(_, chunks) = get_chunks 4 '\NUL' csZeroed - case seq_maybes $ map decode_chunk chunks of + case chunks | each(decode_chunk) | seq_maybes of Nothing -> Nothing Just decodedChunks -> - resultPadded = for pair. - (i, j) = pair - decodedChunks[i, j] + resultPadded = flatten2D(decodedChunks) Just $ to_list $ slice resultPadded 0 (Fin numValidOutputChars) '## PNG FFI @@ -108,7 +105,7 @@ Gif : Type = String foreign "encodePNG" encodePNG : (RawPtr, Word32, Word32) -> {IO} (Word32, RawPtr) def make_png(img:n=>m=>(Fin 3)=>Word8) -> Png given (n|Ix, m|Ix) = unsafe_io \. - AsList(_, imgFlat) = to_list for triple. + AsList(_, imgFlat) = to_list for triple:(n,(m,Fin 3)). (i, (j, k)) = triple img[i, j, k] with_table_ptr imgFlat \ptr. @@ -116,13 +113,13 @@ def make_png(img:n=>m=>(Fin 3)=>Word8) -> Png given (n|Ix, m|Ix) = unsafe_io \. (sz, ptr') = encodePNG rawPtr (nat_to_rep $ size m) (nat_to_rep $ size n) AsList((rep_to_nat sz), table_from_ptr(Ptr(ptr'))) -def pngs_to_gif(delay:Int, pngs:t=>Png) -> Gif given (t|Ix) = unsafe_io \. - with_temp_files \pngFiles. - for i. write_file pngFiles[i] pngs[i] +def pngs_to_gif(pngs:t=>Png, delay:Int) -> Gif given (t|Ix) = unsafe_io \. + with_temp_files(t) \pngFiles. + for i:t. write_file pngFiles[i] pngs[i] with_temp_file \gifFile. shell_out $ "convert" <> " -delay " <> show delay <> " " <> - concat (for i. "png:" <> pngFiles[i] <> " ") <> + concat (for i:t. "png:" <> pngFiles[i] <> " ") <> "gif:" <> gifFile read_file gifFile @@ -135,7 +132,7 @@ def float_to_8bit(x:Float) -> Word8 = n_to_w8 $ f_to_n $ 255.0 * clip (0.0, 1.0) x def img_to_png(img:n=>m=>(Fin 3)=>Float) -> Png given (n|Ix, m|Ix) = - make_png for i j k. float_to_8bit img[i, j, k] + make_png for i:n j:m k:(Fin 3). float_to_8bit img[i, j, k] '## API entry point @@ -143,4 +140,4 @@ def imshow(img:n=>m=>(Fin 3)=>Float) -> Html given (n|Ix, m|Ix) = img_to_html $ img_to_png img def imseqshow(imgs:t=>n=>m=>(Fin 3)=>Float) -> Html given (t|Ix, n|Ix, m|Ix) = - img_to_html $ pngs_to_gif 50 $ map img_to_png imgs + imgs | each(img_to_png) | pngs_to_gif(50) | img_to_html diff --git a/lib/prelude.dx b/lib/prelude.dx index c2baec9e6..243e09e03 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -32,7 +32,7 @@ interface Data(a:Type) '### Casting -def internal_cast(x:from) -> to given (from, to) = +def internal_cast(x:from) -> to given (from:Type, to:Type) = %cast(to, x) def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x) @@ -266,7 +266,7 @@ instance Mul(()) '#### Integral Integer-like things. -interface Integral(a) +interface Integral(a:Type) idiv : (a,a)->a rem : (a,a)->a @@ -298,7 +298,7 @@ instance Integral(Nat) Rational-like things. Includes floating point and two field rational representations. -interface Fractional(a) +interface Fractional(a:Type) divide : (a, a) -> a instance Fractional(Float64) @@ -314,7 +314,7 @@ interface Ix(n|Data) ordinal : (n) -> Nat unsafe_from_ordinal : (Nat) -> n -def size(n|Ix) -> Nat = size'(n=n) +def size(n:Type|Ix) -> Nat = size'(n=n) def Fin(n:Nat) -> Type = %Fin(n) @@ -323,42 +323,42 @@ def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt(x', y') - rep_to_nat %select(requires_clamp, 0, (%isub(x', y'))) + rep_to_nat %select(requires_clamp, 0::NatRep, (%isub(x', y'))) def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat %isub(x', y') --- `(i..)` parses as `RangeFrom _ i` +-- `(i..)` parses as `RangeFrom(i)` -- TODO: need to a way to indicate constructor as private -struct RangeFrom(q:Type, i:q) = val : Nat +struct RangeFrom(i:q) given (q:Type) = val : Nat --- `(i<..)` parses as `RangeFromExc _ i` -struct RangeFromExc(q:Type, i:q) = val : Nat +-- `(i<..)` parses as `RangeFromExc i` +struct RangeFromExc(i:q) given (q:Type) = val : Nat --- `(..i)` parses as `RangeTo _ i` -struct RangeTo(q:Type, i:q) = val : Nat +-- `(..i)` parses as `RangeTo i` +struct RangeTo(i:q) given (q:Type) = val : Nat --- `(.. n=>Nat = for i. ordinal i +def iota(n:Type|Ix) -> n=>Nat = for i. ordinal i '## Arithmetic instances for table types @@ -408,10 +408,10 @@ instance Mul(n=>a) given (a|Mul, n|Ix) '## Basic polymorphic functions and types -def fst(pair:(a, b)) -> a given (a, b) = pair.0 -def snd(pair:(a, b)) -> b given (a, b) = pair.1 +def fst(pair:(a, b)) -> a given (a:Type, b:Type) = pair.0 +def snd(pair:(a, b)) -> b given (a:Type, b:Type) = pair.1 -def swap(pair:(a, b)) -> (b, a) given (a, b) = +def swap(pair:(a, b)) -> (b, a) given (a:Type, b:Type) = (x, y) = pair (y, x) @@ -443,7 +443,7 @@ instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix) (i, j, k) = tup ordinal((i,(j,k))) def unsafe_from_ordinal(o) = - (i, (j, k)) = unsafe_from_ordinal o + (i, (j, k)) = unsafe_from_ordinal(n=(a,(b,c))::Type, o) (i, j, k) instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) @@ -452,7 +452,7 @@ instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) (i, j, k, m) = tup ordinal((i,(j,(k,m)))) def unsafe_from_ordinal(o) = - (i, (j, (k, m))) = unsafe_from_ordinal o + (i, (j, (k, m))) = unsafe_from_ordinal(n=(a,(b,(c,d)))::Type, o) (i, j, k, m) '## Vector spaces @@ -519,7 +519,7 @@ TODO: move these with the others? -- Can't use `%select` because it lowers to `ISelect`, which requires -- `a` to be a `BaseTy`. -def select(p:Bool, x:a, y:a) -> a given (a) = +def select(p:Bool, x:a, y:a) -> a given (a:Type) = case p of True -> x False -> y @@ -547,14 +547,14 @@ data Maybe(a:Type) = Nothing Just(a) -def is_nothing(x:Maybe a) -> Bool given (a) = +def is_nothing(x:Maybe a) -> Bool given (a:Type) = case x of Nothing -> True Just(_) -> False -def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x +def is_just(x:Maybe a) -> Bool given (a:Type) = not $ is_nothing x -def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = +def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a:Type, b:Type) = case x of Nothing -> d Just(x') -> f x' @@ -593,7 +593,6 @@ def i_to_n(x:Int) -> Maybe Nat = '### Monoid A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element. -This is a very useful and general calls of things. It includes: - Addition and Multiplication of Numbers - Boolean Logic @@ -626,55 +625,55 @@ named-instance MulMonoid(a|Mul) -> Monoid(a) '## Effects -def Ref(r:Heap, a|Data) -> Type = %Ref(r, a) -def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref) -def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x) +def Ref(r:Heap, a:Type|Data) -> Type = %Ref(r, a) +def get(ref:Ref h s) -> {State h} s given (h:Heap, s|Data) = %get(ref) +def (:=)(ref:Ref h s, x:s) -> {State h} () given (h:Heap, s|Data) = %put(ref, x) -def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref) +def ask(ref:Ref h r) -> {Read h} r given (h:Heap, r|Data) = %ask(ref) data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b) -interface AccumMonoid(h:Heap, w) +interface AccumMonoid(h:Heap, w:Type) getAccumMonoidData : AccumMonoidData(h, w) -instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) +instance AccumMonoid(h, n=>w) given (n|Ix, h:Heap, w:Type) (am:AccumMonoid(h, w)) getAccumMonoidData = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) UnsafeMkAccumMonoidData(b, bm) def (+=)(ref:Ref h w, x:w) -> {Accum h} () - given (h, w) (am:AccumMonoid(h, w)) = + given (h:Heap, w|Data) (am:AccumMonoid(h, w)) = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) empty = %applyMethod0(bm) %mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x) -def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i) -def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0 -def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1 +def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h:Heap) = %indexRef(ref, i) +def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b|Data, a|Data, h:Heap) = ref.0 +def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a|Data, b|Data, h:Heap) = ref.1 def run_reader( init:r, - action:(given (h), Ref h r) -> {Read h|eff} a - ) -> {|eff} a given (r|Data, a, eff) = + action:(given (h:Heap), Ref h r) -> {Read h|eff} a + ) -> {|eff} a given (r|Data, a:Type, eff:Effects) = def explicitAction(h':Heap, ref:Ref h' r) -> {Read h'|eff} a = action ref %runReader(init, explicitAction) def with_reader( init:r, - action: (given (h), Ref(h,r)) -> {Read h|eff} a - ) -> {|eff} a given (r|Data, a, eff) = + action: (given (h:Heap), Ref(h,r)) -> {Read h|eff} a + ) -> {|eff} a given (r|Data, a:Type, eff:Effects) = run_reader(init, action) def MonoidLifter(b:Type, w:Type) -> Type = - (given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) + (given (h:Heap) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) -named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) +named-instance mk_accum_monoid (given (h:Heap, w:Type), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) getAccumMonoidData = d def run_accum( bm:Monoid b, - action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a - ) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) = + action: (given (h:Heap) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a + ) -> {|eff} (a, w) given (a:Type, b:Type, w|Data, eff:Effects) (MonoidLifter(b,w)) = empty = %applyMethod0(bm) def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm @@ -684,32 +683,32 @@ def run_accum( def yield_accum( m:Monoid b, - action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a - ) -> {|eff} w given (a, b, w|Data, eff) (MonoidLifter b w) = + action: (given (h:Heap) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a + ) -> {|eff} w given (a:Type, b:Type, w|Data, eff:Effects) (MonoidLifter b w) = snd $ run_accum(m, action) def run_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} (a,s) given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} (a,s) given (a:Type, s|Data, eff:Effects) = def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref %runState(init, explicitAction) def with_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} a given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} a given (a:Type, s|Data, eff:Effects) = fst $ run_state(init, action) def yield_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} s given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} s given (a:Type, s|Data, eff:Effects) = snd $ run_state(init, action) def unsafe_io( f:()->{IO|eff} a - ) -> {|eff} a given (a, eff) = + ) -> {|eff} a given (a:Type, eff:Effects) = f' : (() -> {IO|eff} a) = \. f() %runIO(f') @@ -849,10 +848,10 @@ instance Ord(Nat) def (<)(x, y) = nat_to_rep x < nat_to_rep y -- TODO: we want Eq and Ord for all index sets, not just `Fin n` -instance Eq(Fin n) given (n) +instance Eq(Fin n) given (n:Nat) def (==)(x, y) = ordinal x == ordinal y -instance Ord(Fin n) given (n) +instance Ord(Fin n) given (n:Nat) def (>)(x, y) = ordinal x > ordinal y def (<)(x, y) = ordinal x < ordinal y @@ -870,7 +869,7 @@ instance Ix(Maybe a) given (a|Ix) Nothing -> size a def unsafe_from_ordinal(o) = case o == size a of - False -> Just $ unsafe_from_ordinal o + False -> Just $ unsafe_from_ordinal(n=a, o) True -> Nothing interface NonEmpty(n|Ix) @@ -916,13 +915,13 @@ def left_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == 0 then Nothing - else Just $ unsafe_from_ordinal $ ix -| 1 + else Just $ unsafe_from_ordinal(n=n, ix -| 1) def right_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == size n then Nothing - else Just $ unsafe_from_ordinal ix + else Just $ unsafe_from_ordinal(n=n, ix) def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1)) @@ -930,19 +929,6 @@ def last_ix() ->> n given (n|NonEmpty) = instance NonEmpty(Post n) given (n|Ix) first_ix = unsafe_from_ordinal(n=Post n, 0) -def scan( - init:a, - body:(n, a)->(a,b) - ) -> (a, n=>b) given (a|Data, b, n|Ix) = - swap $ run_state(init) \s. for i. - c = get s - (c', y) = body(i, c) - s := c' - y - -def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = - fst $ scan init \i x. (body(i, x), ()) - def compare(x:a, y:a) -> Ordering given (a|Ord) = if x < y then LT @@ -961,42 +947,32 @@ instance Monoid(Ordering) instance Eq(n=>a) given (n|Ix, a|Eq) def (==)(xs, ys) = yield_accum AndMonoid \ref. - for i. ref += xs[i] == ys[i] - -instance Ord(n=>a) given (n|Ix, a|Ord) - def (>)(xs, ys) = - f: Ordering = - fold EQ $ \i c. c <> compare(xs[i], ys[i]) - f == GT - def (<)(xs, ys) = - f: Ordering = - fold EQ $ \i c. c <> compare(xs[i], ys[i]) - f == LT + for i:n. ref += xs[i] == ys[i] '## Subset class -interface Subset(subset, superset) +interface Subset(subset:Type, superset:Type) inject' : (subset) -> superset project' : (superset) -> Maybe subset unsafe_project' : (superset) -> subset -- wrappers with more helpful implicit arg names -def inject(x:from) -> to given (to, from) (Subset(from, to)) = inject'(x) -def project(x:from) -> Maybe to given (to, from) (Subset(to, from)) = project'(x) -def unsafe_project(x:from) -> to given (to, from) (Subset(to, from)) = unsafe_project'(x) +def inject(x:from) -> to given (to:Type, from:Type) (Subset(from, to)) = inject'(x) +def project(x:from) -> Maybe to given (to:Type, from:Type) (Subset(to, from)) = project'(x) +def unsafe_project(x:from) -> to given (to:Type, from:Type) (Subset(to, from)) = unsafe_project'(x) -instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c)) +instance Subset(a, c) given (a:Type, b:Type, c:Type) (Subset(a, b), Subset(b, c)) def inject'(x) = inject $ inject(to=b, x) def project'(x) = case project(to=b, x) of Nothing -> Nothing Just(y)-> project y def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x) -def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = +def unsafe_project_rangefrom(j:q) -> RangeFrom(i) given (q|Ix, i:q) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) -instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) +instance Subset(RangeFrom(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i def project'(j) = @@ -1007,7 +983,7 @@ instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) else Just $ RangeFrom $ unsafe_nat_diff(j', i') def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) -instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) +instance Subset(RangeFromExc(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1 def project'(j) = j' = ordinal j @@ -1018,7 +994,7 @@ instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) def unsafe_project'(j) = RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1) -instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) +instance Subset(RangeTo(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1028,7 +1004,7 @@ instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) else Just $ RangeTo j' def unsafe_project'(j) = RangeTo (ordinal j) -instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) +instance Subset(RangeToExc(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1038,7 +1014,7 @@ instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j) -instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q) +instance Subset(RangeToExc(i), RangeTo(i)) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1140,7 +1116,7 @@ instance Floating(Float32) struct Ptr(a:Type) = val : RawPtr -def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val) +def cast_ptr(ptr: Ptr from) -> Ptr to given (from:Type, to:Type) = Ptr(ptr.val) interface Storable(a|Data) store : (Ptr a, a) -> {IO} () @@ -1168,11 +1144,11 @@ instance Storable(Float32) def storage_size() = 4 instance Storable(Nat) - def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x) + def store(ptr, x) = store(cast_ptr(ptr, to=%Word32()), nat_to_rep x) def load(ptr) = rep_to_nat $ load(Ptr(ptr.val)) def storage_size() = storage_size(a=NatRep) -instance Storable(Ptr a) given (a) +instance Storable(Ptr a) given (a:Type) def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) def storage_size() = 8 -- TODO: something more portable? @@ -1183,7 +1159,7 @@ def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n Ptr(%alloc(nat_to_rep numBytes)) -def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val) +def free(ptr:Ptr a) -> {IO} () given (a:Type) = %free(ptr.val) def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) @@ -1191,7 +1167,7 @@ def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = -- TODO: consider making a Storable instance for tables instead def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = - for_ i. store(ptr +>> ordinal i, tab[i]) + for_ i:n. store(ptr +>> ordinal i, tab[i]) def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = for_ i:(Fin n). @@ -1201,10 +1177,11 @@ def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = -- TODO: generalize these brackets to allow other effects -- TODO: make sure that freeing happens even if there are run-time errors def with_alloc( + a|Storable, n:Nat, action: (Ptr a) -> {IO} b - ) -> {IO} b given (a|Storable, b) = - ptr = malloc n + ) -> {IO} b given (b:Type) = + ptr = malloc(a=a, n) result = action ptr free ptr result @@ -1212,9 +1189,9 @@ def with_alloc( def with_table_ptr( xs:n=>a, action: (Ptr a) -> {IO} b - ) -> {IO} b given (a|Storable, b, n|Ix) = - ptr <- with_alloc(size n) - for i. store(ptr +>> ordinal i, xs[i]) + ) -> {IO} b given (a|Storable, b:Type, n|Ix) = + ptr <- with_alloc(a, size n) + for i:n. store(ptr +>> ordinal i, xs[i]) action ptr def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = @@ -1224,22 +1201,31 @@ def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = pi : Float = 3.141592653589793 -def id(x:a) -> a given (a) = x -def dup(x:a) -> (a, a) given (a) = (x, x) -def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = - for i. f xs[i] +def id(x:a) -> a given (a:Type) = x +def dup(x:a) -> (a, a) given (a:Type) = (x, x) -- map, flipped so that the function goes last -def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = +def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a:Type, b:Type, n|Ix, eff:Effects) = for i. f xs[i] -def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i]) -def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd) -def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x +def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a:Type, b:Type, n|Ix) = for i. (xs[i], ys[i]) +def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a:Type, b:Type, n|Ix) = + (each xys \xy. fst(xy), each xys \xy. snd(xy)) +def fanout(x:a) -> n=>a given (n|Ix, a:Type) = for i. x def sq(x:a) -> a given (a|Mul) = x * x -def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x) +def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, (zero::a) - x) def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y) -def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a, b, c) = \x. g(f(x)) -def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a, b, c) = \x. f(g(x)) +def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a:Type, b:Type, c:Type) = \x. g(f(x)) +def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a:Type, b:Type, c:Type) = \x. f(g(x)) + +def flatten2D(mat:n=>m=>a) -> (n,m)=>a given (n|Ix, m|Ix, a:Type) = + for pair. + (i, j) = pair + mat[i,j] + +def flatten3D(array:l=>n=>m=>a) -> (l,n,m)=>a given (l|Ix, n|Ix, m|Ix, a:Type) = + for triple. + (i, j, k) = triple + array[i,j,k] '## Table Operations @@ -1267,23 +1253,34 @@ instance Floating(n=>a) given (a|Floating, n|Ix) '### Reductions +def scan( + init:c, + xs: n=>a, + body:(n, a, c)->(b, c) + ) -> (n=>b, c) given (a:Type, b:Type, c|Data, n|Ix) = + run_state(init) \ref. for i:n. + carry = get ref + (y, carry') = body(i, xs[i], carry) + ref := carry' + y + +def fold(init:c, xs:n=>a, body:(n, a, c)-> c) -> c given (a:Type, n|Ix, c|Data) = + snd $ scan(init, xs) \i x carry. ((), body(i, x, carry)) + -- `combine` should be a commutative and associative, and form a -- commutative monoid with `identity` -def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = +def reduce(xs:n=>a, identity:a, combine:(a,a)->a) -> a given (a|Data, n|Ix) = -- TODO: implement with the accumulator effect - fold identity \i c. combine(c, xs[i]) + fold(identity, xs) \i x c. combine(c, x) --- TODO: call this `scan` and call the current `scan` something else -def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = - snd $ scan init \i x. dup(body(i, x)) def fsum(xs:n=>Float) -> Float given (n|Ix) = - yield_accum(AddMonoid Float) \ref. for i. ref += xs[i] -def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs) -def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs) + yield_accum(AddMonoid Float) \ref. each xs \x. ref += x +def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(xs, zero, (+)) +def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(xs, one , (*)) def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n) def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs) -def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs) -def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs) +def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(xs, False, (||)) +def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(xs, True , (&&)) '### apply_n @@ -1296,15 +1293,15 @@ TODO: Move this to be with reductions? It's a kind of `scan`. def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) = - total <- with_state zero - for i. + total <- with_state (zero::a) + for i:n. newTotal = get total + xs[i] total := newTotal newTotal def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = - total <- with_state zero - for i. + total <- with_state (zero::a) + for i:n. oldTotal = get total total := oldTotal + xs[i] oldTotal @@ -1314,18 +1311,18 @@ def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = '### AD operations -- TODO: add vector space constraints -def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = - %linearize(\x. f x, x) +def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a:Type, b:Type) = + %linearize(\x:a. f x, x) -def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t) -def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct. - %linearTranspose(\x. f x, ct) +def jvp(f:(a)->b, x:a, t:a) -> b given (a:Type, b:Type) = (snd $ linearize(f, x))(t) +def transpose_linear(f:(a)->b) -> (b)->a given (a:Type, b:Type) = \ct. + %linearTranspose(\x:a. f x, ct) -def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = +def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a:Type, b:Type) = (y, df) = linearize(f, x) (y, transpose_linear df) -def grad(f:(a)->Float, x:a) -> a given (a) = (snd vjp(f, x))(1.0) +def grad(f:(a)->Float, x:a) -> a given (a:Type) = (snd vjp(f, x))(1.0) def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0) @@ -1333,7 +1330,7 @@ def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0) -- XXX: Watch out when editing this data type! We depend on its structure -- deep inside the compiler (mostly in linearization and during rule registration). -data SymbolicTangent(a) = +data SymbolicTangent(a:Type) = ZeroTangent SomeTangent(a) @@ -1345,15 +1342,15 @@ def someTangent(x:SymbolicTangent a) -> a given (a|VSpace) = '### Approximate Equality TODO: move this outside the AD section to be with equality? -interface HasAllClose(a) +interface HasAllClose(a:Type) allclose : (a, a, a, a) -> Bool -interface HasDefaultTolerance(a) +interface HasDefaultTolerance(a:Type) default_atol : a default_rtol : a def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = - allclose(default_atol, default_rtol, x, y) + allclose(a=a, default_atol, default_rtol, x, y) instance HasAllClose(Float32) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) @@ -1402,7 +1399,7 @@ def check_deriv(f:(Float)->Float, x:Float) -> Bool = '## Length-erased lists -data List(a)= +data List(a:Type) = AsList(n:Nat, elements:(Fin n => a)) instance Eq(List a) given (a|Eq) @@ -1414,15 +1411,15 @@ instance Eq(List a) given (a|Eq) else all for i:(Fin nx). xs[i] == ys[unsafe_from_ordinal (ordinal i)] -def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = +def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a:Type) = for i. xs[unsafe_from_ordinal (ordinal i)] -def to_list(xs:n=>a) -> List a given (n|Ix, a) = +def to_list(xs:n=>a) -> List a given (n|Ix, a:Type) = n' = size n - AsList(_, unsafe_cast_table(to=Fin n', xs)) + AsList(n', unsafe_cast_table(to=Fin n', xs)) instance Monoid(List a) given (a|Data) - mempty = AsList(_, []) + mempty = to_list([] :: Fin 0 => a) def (<>)(x, y) = AsList(nx,xs) = x AsList(ny,ys) = y @@ -1440,11 +1437,11 @@ named-instance ListMonoid (a|Data) -> Monoid(List a) -- TODO Eliminate or reimplement this operation, since it costs O(n) -- where n is the length of the list held in the reference. def append(list: Ref(h, List a), x:a) -> {Accum h} () - given (a|Data, h) (AccumMonoid(h, List a)) = + given (a|Data, h:Heap) (AccumMonoid(h, List a)) = list += to_list [x] -- TODO: replace `slice` with this? -def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = +def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a:Type) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)] @@ -1466,7 +1463,7 @@ struct CString = def with_c_string( s:String, action: (CString) -> {IO} a - ) -> {IO} a given (a) = + ) -> {IO} a given (a:Type) = AsList(n, s') = s <> "\NUL" with_table_ptr s' \ptr. action CString(ptr.val) @@ -1477,7 +1474,7 @@ No particular promises are made to exactly what that representation will contain In particular it is **not** promised to be parseable. Nor does it promise a particular level of precision for numeric values. -interface Show(a) +interface Show(a:Type) show : (a) -> String instance Show(String) @@ -1535,7 +1532,7 @@ instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show) '### Parse interface For types that can be parsed from a `String`. -interface Parse(a) +interface Parse(a:Type) parseString : (String) -> Maybe a foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float @@ -1544,7 +1541,7 @@ instance Parse(Float) def parseString(str) = unsafe_io \. AsList(str_len, _) = str with_c_string str \cStr. - with_alloc 1 \end_ptr:(Ptr (Ptr Char)). + with_alloc (Ptr Char) 1 \end_ptr. result = strtofFFI(cStr.ptr, end_ptr.val) str_end_ptr = load end_ptr consumed = raw_ptr_to_i64 str_end_ptr.val - raw_ptr_to_i64 cStr.ptr @@ -1585,7 +1582,7 @@ FilePath : Type = String def is_null_raw_ptr(ptr:RawPtr) -> Bool = raw_ptr_to_i64 ptr == 0 -def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) = +def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a:Type) = if is_null_raw_ptr ptr then Nothing else Just $ Ptr ptr @@ -1615,7 +1612,7 @@ def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) = with_c_string modeStr \cMode. Stream $ fopenFFI(cPath.ptr, cMode.ptr) -def fclose(stream:Stream mode) -> {IO} () given (mode) = +def fclose(stream:Stream mode) -> {IO} () given (mode:StreamMode) = fcloseFFI stream.ptr () @@ -1629,7 +1626,7 @@ def fwrite(stream:Stream WriteMode, s:String) -> {IO} () = '### Iteration TODO: move this out of the file-system section -def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) = +def while(body: () -> {|eff} Bool) -> {|eff} () given (eff:Effects) = body' : () -> {|eff} Word8 = \. b_to_w8 $ body() %while(body') @@ -1637,19 +1634,14 @@ data IterResult(a|Data) = Continue Done(a) --- TODO: can we improve effect inference so we don't need this? -def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b - given (a, b, c, h, eff) = - f x - -- A little iteration combinator -def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = - result = yield_state Nothing \resultRef. - i <- with_state 0 +def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff:Effects) = + result = yield_state (Nothing::Maybe a) \resultRef. + i <- with_state (0::Nat) while \. continue = is_nothing $ get resultRef if continue then - case lift_state(resultRef, (\x. lift_state(i, body, x)), get i) of + case body(get(i)) of Continue -> i := get i + 1 Done(result) -> resultRef := Just result continue @@ -1661,7 +1653,7 @@ def bounded_iter( maxIters:Nat, fallback:a, body:(Nat) -> {|eff} IterResult a - ) -> {|eff} a given (a|Data, eff) = iter \i. + ) -> {|eff} a given (a|Data, eff:Effects) = iter \i. if i >= maxIters then Done fallback else body i @@ -1691,7 +1683,6 @@ def error(s:String) -> a given (a|Data) = unsafe_io \. def todo() ->> a given (a|Data) = error "TODO: implement it!" - '### Table operations @noinline @@ -1720,12 +1711,12 @@ def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i def (@)(i:Nat, n|Ix) -> n = from_ordinal i -def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = +def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a:Type) = for i. xs[from_ordinal (ordinal i + start)] -def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_] +def head(xs:n=>a) -> a given (n|Ix, a:Type) = xs[0@_] -def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) = +def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a:Type) = numElts = size n -| start to_list $ slice(xs, start, Fin numElts) @@ -1742,8 +1733,8 @@ Key = Word64 @noinline def threefry_2x32(k:Word64, count:Word64) -> Word64 = -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins - rotations1 = [13, 15, 26, 6] - rotations2 = [17, 29, 16, 24] + rotations1 : Fin 4 => Int32 = [13, 15, 26, 6] + rotations2 : Fin 4 => Int32 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k @@ -1758,9 +1749,9 @@ def threefry_2x32(k:Word64, count:Word64) -> Word64 = rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yield_state (x, y) \ref. for i:(Fin 5). - for j. + for j:(Fin 4). (x, y) = get ref - rotationIndex = unsafe_from_ordinal (ordinal i `mod` 2) + rotationIndex : Fin 2 = unsafe_from_ordinal (ordinal i `mod` 2) rot = rotations[rotationIndex, j] x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) @@ -1777,7 +1768,7 @@ def hash(x:Key, y:Nat) -> Key = y64 = n_to_w64 y threefry_2x32(x, y64) def new_key(x:Nat) -> Key = hash(0, x) -def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i) +def many(f:(Key)->a, k:Key, i:n) -> a given (a:Type, n|Ix) = f hash(k, ordinal i) def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i) def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i) @@ -1786,19 +1777,19 @@ These functions generate samples taken from, different distributions. Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library. def rand(k:Key) -> Float = - exponent_bits = 1065353216 -- 1065353216 = 127 << 23 + exponent_bits : Word32 = 1065353216 -- 1065353216 = 127 << 23 mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0 -def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a) = +def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a:Type) = for i:(Fin n). f ixkey(k, i) -def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = +def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a:Type) = for i j. f ixkey(k, (i, j)) def randn(k:Key) -> Float = - [k1, k2] = split_key k + [k1, k2] = split_key(n=2, k) -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 -- (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) @@ -1823,12 +1814,12 @@ instance InnerProd(Float) def inner_prod(x, y) = x * y instance InnerProd(n=>a) given (a|InnerProd, n|Ix) - def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i]) + def inner_prod(x, y) =sum for i:n. inner_prod(x[i], y[i]) '## Arbitrary Type class for generating example values -interface Arbitrary(a) +interface Arbitrary(a:Type) arb : (Key) -> a instance Arbitrary(Bool) @@ -1860,10 +1851,10 @@ instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) def arb(key) = - [k1, k2] = split_key key + [k1, k2] = split_key(n=2, key) (arb k1, arb k2) -instance Arbitrary(Fin n) given (n) +instance Arbitrary(Fin n) given (n:Nat) def arb(key) = rand_idx key '## Ord on Arrays @@ -1888,7 +1879,7 @@ def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) = else if x < xs[from_ordinal 0] then first_ix else - low <- with_state(0) + low <- with_state(0::Nat) high <- with_state(size n) _ <- iter numLeft = n_to_i (get high) - n_to_i (get low) @@ -1911,40 +1902,41 @@ def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = '### min / max etc -def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y) -def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y) +def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a:Type) = select(f x < f y, x, y) +def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a:Type) = select(f x > f y, x, y) def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2) def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2) -def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = - reduce(xs[0@_], \x y. min_by(f, x, y), xs) -def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = - reduce(xs[0@_], \x y. max_by(f, x, y), xs) +def minimum_by(xs:n=>a, f:(a)->o) -> a given (a|Data, o|Ord, n|Ix) = + reduce(xs, xs[0@_], \x y. min_by(f, x, y)) +def maximum_by(xs:n=>a, f:(a)->o) -> a given (a|Data, o|Ord, n|Ix) = + reduce(xs, xs[0@_], \x y. max_by(f, x, y)) -def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs) -def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs) +def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(xs, id) +def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(xs, id) '### argmin/argmax -- TODO: put in same section as `searchsorted` -def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = - zeroth = (0@_, xs[0@_]) - compare = \p1 p2. +def argscan(xs:n=>a, comp:(a,a)->Bool) -> n given (a|Ord, n|Ix) = + AccumTy : Type = (n, a) + zeroth : AccumTy = (0@_, xs[0@_]) + compare = \p1:AccumTy p2:AccumTy. (idx1, x1) = p1 (idx2, x2) = p2 select(comp(x1, x2), (idx1, x1), (idx2, x2)) - zipped = for i. (i, xs[i]) - fst $ reduce(zeroth, compare, zipped) + zipped = for i:n. (i, xs[i]) + fst $ reduce(zipped, zeroth, compare) -def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs) -def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs) +def argmin(xs:n=>a) -> n given (n|Ix, a|Ord) = argscan(xs, (<)) +def argmax(xs:n=>a) -> n given (n|Ix, a|Ord) = argscan(xs, (>)) def lexical_order( - compareElements:(n,n)->Bool, - compareLengths: (Nat,Nat)->Bool, xList:List n, - yList:List n + yList:List n, + compareElements:(n,n)->Bool, + compareLengths: (Nat,Nat)->Bool ) -> Bool given (n|Ord) = -- Orders Lists according to the order of their elements, -- in the same way a dictionary does. @@ -1973,8 +1965,8 @@ def lexical_order( False -> Done False instance Ord(List n) given (n|Ord) - def (>)(xs, ys) = lexical_order((>), (>), xs, ys) - def (<)(xs, ys) = lexical_order((<), (<), xs, ys) + def (>)(xs, ys) = lexical_order(xs, ys, (>), (>)) + def (<)(xs, ys) = lexical_order(xs, ys, (<), (<)) '### clip @@ -2033,13 +2025,13 @@ TODO: all of these should be in some other section def reflect(i:n) -> n given (n|Ix) = unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1) -def reverse(x:n=>a) -> n=>a given (n|Ix, a) = - for i. x[reflect i] +def reverse(x:n=>a) -> n=>a given (n|Ix, a:Type) = + for i:n. x[reflect i] def wrap_periodic(n|Ix, i:Nat) -> n = unsafe_from_ordinal(n=n, i `mod` size n) -def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = +def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a:Type) = n' = size n for i. i' = ordinal i @@ -2062,7 +2054,7 @@ def is_power_of_2(x:Nat) -> Bool = x' = nat_to_rep x if x' == 0 then False - else 0 == %and(x', (%isub(x', 1::NatRep))) + else %and(x', (%isub(x', 1::NatRep))) == 0 -- This computes the integer part of the binary logarithm of the input. -- TODO: natlog2 0 should do something other than underflow the answer. @@ -2071,8 +2063,8 @@ def is_power_of_2(x:Nat) -> Bool = -- we have with a fixed-point argument. -- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic def natlog2(x:Nat) -> Nat = - tmp = yield_state 0 \ans. - cmp <- run_state 1 + tmp = yield_state (0::Nat) \ans. + cmp <- run_state (1::Nat) while \. if x >= (get cmp) then @@ -2093,7 +2085,7 @@ def general_integer_power( one:a, base:a, power:Nat ) -> a given (a|Data) = - iters = if power == 0 then 0 else 1 + natlog2 power + iters : Nat = if power == 0 then 0 else 1 + natlog2 power -- Implements exponentiation by squaring. -- This could be nicer if there were a way to explicitly -- specify which typelcass instance to use for Mul. @@ -2109,33 +2101,33 @@ def general_integer_power( def intpow(base:a, power:Nat) -> a given (a|Mul) = general_integer_power((*), one, base, power) -def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x' +def from_just(x:Maybe a) -> a given (a:Type) = case x of Just(x') -> x' -def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f) +def any_sat(xs:n=>a, f:(a)->Bool) -> Bool given (a:Type, n|Ix) = any(each xs f) -def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = +def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a:Type) = -- is it possible to implement this safely? (i.e. without using partial -- functions) - case any_sat(is_nothing, xs) of + case any_sat(xs, is_nothing) of True -> Nothing False -> Just $ each xs from_just def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = - yield_state Nothing \ref. for i. + yield_state Nothing \ref. for i:n. case xs[i] == query of True -> ref := Just i False -> () -def list_length(l:List a) -> Nat given (a) = +def list_length(l:List a) -> Nat given (a:Type) = AsList(n, _) = l n -- This is for efficiency (rather than using `<>` repeatedly) -- TODO: we want this for any monoid but this implementation won't work. -def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = - totalSize = sum for i. list_length lists[i] - to_list $ with_state 0 \listIdx. - eltIdx <- with_state 0 +def concat(lists:n=>(List a)) -> List a given (a:Type, n|Ix) = + totalSize = sum for i:n. list_length lists[i] + to_list $ with_state (0::Nat) \listIdx. + eltIdx <- with_state (0::Nat) for i:(Fin totalSize). while \. continue = get eltIdx >= list_length (lists[(get listIdx)@_]) @@ -2151,8 +2143,10 @@ def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = xs[eltIdxVal@_] def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = - (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. - for i. case xs[i] of + StateTy : Type = (Nat, n=>Maybe n) + init_state : StateTy = (0, for i. Nothing) + (num_res, res_inds) = yield_state init_state \ref. + for i:n. case xs[i] of Just(_) -> ix = get ref.0 ref.1 ! (unsafe_from_ordinal ix) := Just i @@ -2166,10 +2160,10 @@ def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = Nothing -> todo -- Impossible def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = - cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing + cat_maybes $ for i:n. if condition xs[i] then Just xs[i] else Nothing def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = - cat_maybes $ for i. if condition xs[i] then Just i else Nothing + cat_maybes $ for i:n. if condition xs[i] then Just i else Nothing -- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead def prev_ix(i:n) -> Maybe n given (n|Ix) = @@ -2178,15 +2172,15 @@ def prev_ix(i:n) -> Maybe n given (n|Ix) = Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just def lines(source:String) -> List String = - AsList(_, s) = source - AsList(num_lines, newline_ixs) = cat_maybes for i_char. + AsList(num_chars, s) = source + AsList(num_lines, newline_ixs) = cat_maybes for i_char:(Fin num_chars). if s[i_char] == '\n' then Just(i_char) else Nothing to_list for i_line:(Fin num_lines). start = case prev_ix i_line of - Nothing -> first_ix Just(i) -> right_post newline_ixs[i] + Nothing -> first_ix end = left_post newline_ixs[i_line] post_slice(s, start, end) @@ -2201,7 +2195,7 @@ def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = maxLogProb = maximum logprobs - cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb) + cumsum_low $ normalize_pdf $ for i:n. exp(logprobs[i] - maxLogProb) def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key) @@ -2210,11 +2204,11 @@ def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = -- (alternatively we could rely on hoisting of loop constants) def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs - for i. categorical_from_cdf(cdf, ixkey(key, i)) + for i:m. categorical_from_cdf(cdf, ixkey(key, i)) def logsumexp(x: n=>Float) -> Float given (n|Ix) = m = maximum x - m + (log $ sum for i. exp (x[i] - m)) + m + (log $ sum for i:n. exp (x[i] - m)) def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = lse = logsumexp x @@ -2222,25 +2216,25 @@ def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = def softmax(x: n=>Float) -> n=>Float given (n|Ix) = m = maximum x - e = for i. exp (x[i] - m) + e = for i:n. exp (x[i] - m) s = sum e for i. e[i] / s '## Polynomials TODO: Move this somewhere else -def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = +def evalpoly(coeffs:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = -- Evaluate a polynomial at x. Same as Numpy's polyval. - fold zero \i c. coefficients[i] + x .* c + fold zero coeffs \i coeff c. coeff + x .* c '## Exception effect -- TODO: move `error` and `todo` to here. -def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff)= +def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a:Type, eff:Effects)= f' : (() -> {Except|eff} a) = \. f() %catchException(f') -def throw() -> {Except} a given (a) = +def throw() -> {Except} a given (a:Type) = %throwException(a) def assert(b:Bool) -> {Except} () = @@ -2270,16 +2264,16 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data) def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = base = size b - snd $ scan k \_ cur_k. + fst $ scan k (for i:a. ()) \_ _ cur_k. next_k = cur_k `idiv` base digit = cur_k `mod` base - (next_k, unsafe_from_ordinal(n=b, digit)) + (unsafe_from_ordinal(n=b, digit), next_k) def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = base = size b - fst $ fold (0, 1) \j pair. + fst $ fold (0::Nat, 1::Nat) digits \j digit pair. (cur_k, cur_base) = pair - next_k = cur_k + ordinal digits[j] * cur_base + next_k = cur_k + ordinal digit * cur_base next_base = cur_base * base (next_k, next_base) @@ -2352,18 +2346,19 @@ struct Stack(h:Heap, a|Data) = self.size_ref := n_new Just $ get buf!(unsafe_from_ordinal n_new) -stack_init_size = 16 +stack_init_size : Nat = 16 + def with_stack( a|Data, action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r - ) -> {|eff} r given (eff, r) = - init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() - with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1)) + ) -> {|eff} r given (eff:Effects, r:Type) = + init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() :: a + with_state (0::Nat, init_stack) \ref . action(Stack(ref.0, ref.1)) -def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = +def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n:Nat, h:Heap) = stack.extend(x) -def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) = +def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h:Heap) = stack.push(x) def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = @@ -2387,7 +2382,7 @@ def from_c_string(s:CString) -> {IO} (Maybe String) = stack.push(c) Continue -def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x)) +def show_any(x:a) -> String given (a:Type) = unsafe_coerce(to=String, %showAny(x)) def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = if size m == size n @@ -2423,19 +2418,18 @@ def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" def fread(stream:Stream ReadMode) -> {IO} String = -- TODO: allow reading longer files! - n = 4096 - ptr:(Ptr Char) <- with_alloc n + n : Nat = 4096 + ptr <- with_alloc(Char, n) stack <- with_stack Char iter \_. numRead = i_to_w32 $ i64_to_i $ freadFFI(ptr.val, 1, n_to_i64 n, stream.ptr) AsList(_, new_chars) = string_from_char_ptr(numRead, ptr) stack.extend(new_chars) - if numRead == n_to_w32 n - then Continue - else Done () + case numRead == n_to_w32 n of + True -> Continue :: IterResult () + False -> Done () stack.read() - '### Shelling Out foreign "popen" popenFFI : (RawPtr, RawPtr) -> {IO} RawPtr @@ -2447,8 +2441,7 @@ def shell_out(command:String) -> {IO} String = modeStr = "r" with_c_string command \command'. with_c_string modeStr \modeStr'. - pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr) - fread pipe + fread $ Stream $ popenFFI(command'.ptr, modeStr'.ptr) '### File Operations @@ -2491,16 +2484,16 @@ def new_temp_file() -> {IO} FilePath = closeFFI fd string_from_char_ptr(15, (Ptr s.ptr)) -def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) = +def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a:Type) = tmpFile = new_temp_file() result = action tmpFile delete_file tmpFile result -def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = - tmpFiles = for i. new_temp_file() +def with_temp_files(n|Ix, action: (n=>FilePath) -> {IO} a) -> {IO} a given (a:Type) = + tmpFiles = for i:n. new_temp_file() result = action tmpFiles - for i. delete_file tmpFiles[i] + each tmpFiles delete_file result '### Linear Algebra @@ -2509,12 +2502,12 @@ def linspace(n|Ix, low:Float, high:Float) -> n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx -def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i] -def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i] -def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j] +def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a:Type) = for i j. x[j,i] +def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i:n. x[i] * y[i] +def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j:n. s[j] .* vs[j] def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = - for i k. fsum for j. x[i,j] * y[j,k] + for i k. fsum for j:m. x[i,j] * y[j,k] -- A `FullTileIx` type represents `tile_ix`th full tile (of size -- `tile_size`) iterating over the index set `n`. @@ -2552,7 +2545,7 @@ def tile( n|Ix, tile_size: Nat, body:(m:Type, given () (Ix m, Subset(m, n))) -> {|eff} () - ) -> {|eff} () given (eff) = + ) -> {|eff} () given (eff:Effects) = num_tiles = size n `idiv` tile_size coda_size = size n `rem` tile_size coda_offset = num_tiles * tile_size @@ -2567,9 +2560,9 @@ def tiled_matmul( y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = -- Tile sizes picked for axch's laptop - l_tile_size = 32 - n_tile_size = 128 - m_tile_size = 8 + l_tile_size : Nat = 32 + n_tile_size : Nat = 128 + m_tile_size : Nat = 8 yield_accum (AddMonoid Float) \result. tile(l, l_tile_size) \l_set. tile(n, n_tile_size) \n_set. @@ -2577,9 +2570,9 @@ def tiled_matmul( for_ l_offset:l_set. l_ix = inject(to=l, l_offset) for_ m_offset:m_set. - m_ix = inject m_offset + m_ix = inject(to=m, m_offset) for_ n_offset:n_set. - n_ix = inject n_offset + n_ix = inject(to=n, n_offset) result!l_ix!n_ix += x[l_ix][m_ix] * y[m_ix][n_ix] -- matmul. Better symbol to use? `@`? @@ -2592,8 +2585,9 @@ def (**)( def matmul_linearization( x: l=>m=>Float, y: m=>n=>Float - ) -> _ given (l|Ix, m|Ix, n|Ix) = - def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ = + ) -> (l=>n=>Float, (l=>m=>Float, m=>n=>Float)->l=>n=>Float) + given (l|Ix, m|Ix, n|Ix) = + def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> l=>n=>Float = x ** yt + xt ** y (x ** y, lin) @@ -2605,7 +2599,7 @@ def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) = transpose mat **. v def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = - fsum for p. + fsum for p:(n,m). (i,j) = p x[i] * mat[i,j] * y[j] diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index e2143fd9a..367959987 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -112,10 +112,10 @@ topDecl = dropSrc topDecl' where tyConParams' <- aExplicitParams tyConParams givens' <- aOptGivens givens constructors' <- forM constructors \(v, ps) -> do - ps' <- toNest <$> mapM tyOptBinder ps + ps' <- toNest <$> mapM (tyOptBinder Explicit) ps return (v, ps') return $ UDataDefDecl - (UDataDef name (catUOptAnnExplBinders givens' tyConParams') $ + (UDataDef name (givens' >>> tyConParams') $ map (\(name', cons) -> (name', UDataDefTrail cons)) constructors') (fromString name) (toNest $ map (fromString . fst) constructors') @@ -126,7 +126,7 @@ topDecl = dropSrc topDecl' where methods <- forM defs \(ann, d) -> do (methodName, lam) <- aDef d return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam) - return $ UStructDecl (fromString name) (UStructDef name (catUOptAnnExplBinders givens' params') fields' methods) + return $ UStructDecl (fromString name) (UStructDef name (givens' >>> params') fields' methods) topDecl' (CInterface name params methods) = do params' <- aExplicitParams params (methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do @@ -134,8 +134,6 @@ topDecl = dropSrc topDecl' where return (fromString methodName, ty') return $ UInterface params' methodTys (fromString name) (toNest methodNames) topDecl' (CInstanceDecl def) = aInstanceDef def - topDecl' (CEffectDecl _ _) = error "not implemented" - topDecl' (CHandlerDecl _ _ _ _ _ _) = error "not implemented" decl :: LetAnn -> CSDecl -> SyntaxM (UDecl VoidS VoidS) decl ann = propagateSrcB \case @@ -162,68 +160,96 @@ aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do case optParams of Just params -> do params' <- aExplicitParams params - return $ UInstance clName' (catUOptAnnExplBinders givens' params') args' methods' instName' ExplicitApp + return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS) aDef (CDef name params optRhs optGivens body) = do - explicitParams <- aExplicitParams params + explicitParams <- explicitBindersOptAnn params let rhsDefault = (ExplicitApp, Nothing, Nothing) (expl, effs, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, optEffs, resultTy) -> do effs <- fromMaybeM optEffs UPure aEffects resultTy' <- expr resultTy return (expl, Just effs, Just resultTy') implicitParams <- aOptGivens optGivens - let allParams = catUOptAnnExplBinders implicitParams explicitParams + let allParams = implicitParams >>> explicitParams body' <- block body return (name, ULamExpr allParams expl effs resultTy body') -catUOptAnnExplBinders :: UOptAnnExplBinders n l -> UOptAnnExplBinders l l' -> UOptAnnExplBinders n l' -catUOptAnnExplBinders (expls, bs) (expls', bs') = (expls <> expls', bs >>> bs') - stripParens :: Group -> Group stripParens (WithSrc _ (CParens [g])) = stripParens g stripParens g = g -aExplicitParams :: ExplicitParams -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) -aExplicitParams gs = generalBinders DataParam Explicit gs +-- === combinators for different sorts of binder lists === + +aOptGivens :: Maybe GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aOptGivens optGivens = fromMaybeM optGivens Empty aGivens + +binderList + :: [Group] -> (Group -> SyntaxM (Nest UAnnBinder VoidS VoidS)) + -> SyntaxM (Nest UAnnBinder VoidS VoidS) +binderList gs cont = concatNests <$> forM gs \case + WithSrc _ (CGivens gs') -> aGivens gs' + g -> cont g + +withTrailingConstraints + :: Group -> (Group -> SyntaxM (UAnnBinder VoidS VoidS)) + -> SyntaxM (Nest UAnnBinder VoidS VoidS) +withTrailingConstraints g cont = case g of + Binary Pipe lhs c -> do + Nest (UAnnBinder expl b ann cs) bs <- withTrailingConstraints lhs cont + (ctx, s) <- case b of + UBindSource ctx s -> return (ctx, s) + UIgnore -> throw SyntaxErr "Can't constrain anonymous binders" + UBind _ _ _ -> error "Shouldn't have internal names until renaming pass" + c' <- expr c + let v = WithSrcE ctx $ UVar (SourceName ctx s) + return $ UnaryNest (UAnnBinder expl b ann (cs ++ [c'])) + >>> bs + >>> UnaryNest (asConstraintBinder v c') + _ -> UnaryNest <$> cont g + where + asConstraintBinder :: UExpr VoidS -> UConstraint VoidS -> UAnnBinder VoidS VoidS + asConstraintBinder v c = do + let t = ns $ UApp c [v] [] + UAnnBinder (Inferred Nothing (Synth Full)) UIgnore (UAnn t) [] + +aGivens :: GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aGivens (implicits, optConstraints) = do + implicits' <- concatNests <$> forM implicits \b -> withTrailingConstraints b implicitArgBinder + constraints <- fromMaybeM optConstraints Empty (\gs -> toNest <$> mapM synthBinder gs) + return $ implicits' >>> constraints + +synthBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS) +synthBinder g = tyOptBinder (Inferred Nothing (Synth Full)) g -aOptGivens :: Maybe GivenClause -> SyntaxM (UOptAnnExplBinders VoidS VoidS) -aOptGivens optGivens = do - (expls, implicitParams) <- unzip <$> fromMaybeM optGivens [] aGivens - return (expls, toNest implicitParams) +concatNests :: [Nest b VoidS VoidS] -> Nest b VoidS VoidS +concatNests [] = Empty +concatNests (b:bs) = b >>> concatNests bs -aGivens :: GivenClause -> SyntaxM [(Explicitness, UOptAnnBinder VoidS VoidS)] -aGivens (implicits, optConstraints) = do - implicits' <- mapM (generalBinder DataParam (Inferred Nothing Unify)) implicits - constraints <- fromMaybeM optConstraints [] \gs -> do - mapM (generalBinder TypeParam (Inferred Nothing (Synth Full))) gs - return $ implicits' <> constraints - -generalBinders - :: ParamStyle -> Explicitness -> [Group] - -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) -generalBinders paramStyle expl params = do - (expls, bs) <- unzip . concat <$> forM params \case - WithSrc _ (CGivens gs) -> aGivens gs - p -> (:[]) <$> generalBinder paramStyle expl p - return (expls, toNest bs) - -generalBinder :: ParamStyle -> Explicitness -> Group - -> SyntaxM (Explicitness, UOptAnnBinder VoidS VoidS) -generalBinder paramStyle expl g = case expl of - Inferred _ (Synth _) -> (expl,) <$> tyOptBinder g - Inferred _ Unify -> do - b <- binderOptTy g - expl' <- return case b of - UAnnBinder (UBindSource _ s) _ _ -> Inferred (Just s) Unify - _ -> expl - return (expl', b) - Explicit -> (expl,) <$> case paramStyle of - TypeParam -> tyOptBinder g - DataParam -> binderOptTy g - - -- Binder pattern with an optional type annotation +implicitArgBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS) +implicitArgBinder g = do + UAnnBinder _ b ann cs <- binderOptTy (Inferred Nothing Unify) g + s <- case b of + UBindSource _ s -> return $ Just s + _ -> return Nothing + return $ UAnnBinder (Inferred s Unify) b ann cs + +aExplicitParams :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aExplicitParams bs = binderList bs \b -> withTrailingConstraints b \b' -> + binderOptTy Explicit b' + +aPiBinders :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aPiBinders bs = binderList bs \b -> + UnaryNest <$> tyOptBinder Explicit b + +explicitBindersOptAnn :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +explicitBindersOptAnn bs = binderList bs \b -> withTrailingConstraints b \b' -> + binderOptTy Explicit b' + +-- === + +-- Binder pattern with an optional type annotation patOptAnn :: Group -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) patOptAnn (Binary Colon lhs typeAnn) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) patOptAnn (WithSrc _ (CParens [g])) = patOptAnn g @@ -236,14 +262,14 @@ uBinder (WithSrc src b) = addSrcContext src $ case b of _ -> throw SyntaxErr "Binder must be an identifier or `_`" -- Type annotation with an optional binder pattern -tyOptPat :: Group -> SyntaxM (UOptAnnBinder VoidS VoidS) +tyOptPat :: Group -> SyntaxM (UAnnBinder VoidS VoidS) tyOptPat = \case -- Named type - Binary Colon lhs typeAnn -> UAnnBinder <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] + Binary Colon lhs typeAnn -> UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] -- Binder in grouping parens. WithSrc _ (CParens [g]) -> tyOptPat g -- Anonymous type - g -> UAnnBinder UIgnore <$> (UAnn <$> expr g) <*> pure [] + g -> UAnnBinder Explicit UIgnore <$> (UAnn <$> expr g) <*> pure [] -- Pattern of a case binder. This treats bare names specially, in -- that they become (nullary) constructors to match rather than names @@ -280,41 +306,33 @@ pat = propagateSrcB pat' where _ -> error "unexpected postfix group (should be ruled out at grouping stage)" pat' _ = throw SyntaxErr "Illegal pattern" -data ParamStyle - = TypeParam -- binder optional, used in pi types - | DataParam -- type optional , used in lambda - -tyOptBinder :: Group -> SyntaxM (UAnnBinder req VoidS VoidS) -tyOptBinder = \case +tyOptBinder :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) +tyOptBinder expl = \case Binary Pipe _ _ -> throw SyntaxErr "Unexpected constraint" Binary Colon name ty -> do b <- uBinder name ann <- UAnn <$> expr ty - return $ UAnnBinder b ann [] + return $ UAnnBinder expl b ann [] g -> do ty <- expr g - return $ UAnnBinder UIgnore (UAnn ty) [] - -binderOptTy :: Group -> SyntaxM (UOptAnnBinder VoidS VoidS) -binderOptTy g = do - (g', constraints) <- trailingConstraints g - case g' of - Binary Colon name ty -> do - b <- uBinder name - ann <- UAnn <$> expr ty - return $ UAnnBinder b ann constraints - _ -> do - b <- uBinder g' - return $ UAnnBinder b UNoAnn constraints - -trailingConstraints :: Group -> SyntaxM (Group, [UConstraint VoidS]) -trailingConstraints gTop = go [] gTop where - go :: [UConstraint VoidS] -> Group -> SyntaxM (Group, [UConstraint VoidS]) - go cs = \case - Binary Pipe lhs c -> do - c' <- expr c - go (c':cs) lhs - g -> return (g, cs) + return $ UAnnBinder expl UIgnore (UAnn ty) [] + +binderOptTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) +binderOptTy expl = \case + Binary Colon name ty -> do + b <- uBinder name + ann <- UAnn <$> expr ty + return $ UAnnBinder expl b ann [] + g -> do + b <- uBinder g + return $ UAnnBinder expl b UNoAnn [] + +binderReqTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) +binderReqTy expl (Binary Colon name ty) = do + b <- uBinder name + ann <- UAnn <$> expr ty + return $ UAnnBinder expl b ann [] +binderReqTy _ _ = throw SyntaxErr $ "Expected an annotated binder" argList :: [Group] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) argList gs = partitionEithers <$> mapM singleArg gs @@ -356,7 +374,7 @@ aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of (name, lam) <- aDef def return $ UMethodDef (fromString name) lam CLet (WithSrc _ (CIdentifier name)) rhs -> do - rhs' <- ULamExpr ([], Empty) ImplicitApp Nothing Nothing <$> block rhs + rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs return $ UMethodDef (fromString name) rhs' _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." @@ -377,10 +395,10 @@ blockDecls [WithSrc src d] = addSrcContext src case d of CExpr g -> (Empty,) <$> expr g _ -> throw SyntaxErr "Block must end in expression" blockDecls (WithSrc pos (CBind b rhs):ds) = do - (_, b') <- generalBinder DataParam Explicit b + b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs body <- block $ IndentedBlock ds - let lam = ULam $ ULamExpr ([Explicit], UnaryNest b') ExplicitApp Nothing Nothing body + let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing Nothing body return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam)) blockDecls (d:ds) = do d' <- decl PlainLet d @@ -411,7 +429,7 @@ expr = propagateSrcE expr' where expr' (CArrow lhs effs rhs) = do case lhs of WithSrc _ (CParens gs) -> do - bs <- generalBinders TypeParam Explicit gs + bs <- aPiBinders gs effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy @@ -451,7 +469,7 @@ expr = propagateSrcE expr' where Colon -> throw SyntaxErr "Colon separates binders from their type annotations, is not a standalone operator.\nIf you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" ImplicitArrow -> case lhs of WithSrc _ (CParens gs) -> do - bs <- generalBinders TypeParam Explicit gs + bs <- aPiBinders gs resultTy <- expr rhs return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy _ -> throw SyntaxErr "Argument types should be in parentheses" @@ -481,7 +499,7 @@ expr = propagateSrcE expr' where _ -> throw SyntaxErr $ "Prefix (" ++ name ++ ") not legal as a bare expression" where range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [ns UHole, lim] + range rangeName lim = explicitApp rangeName [lim] expr' (CPostfix name g) = case name of ".." -> range "RangeFrom" <$> expr g @@ -489,9 +507,9 @@ expr = propagateSrcE expr' where _ -> throw SyntaxErr $ "Postfix (" ++ name ++ ") not legal as a bare expression" where range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [ns UHole, lim] + range rangeName lim = explicitApp rangeName [lim] expr' (CLambda params body) = do - params' <- aExplicitParams $ map stripParens params + params' <- explicitBindersOptAnn $ map stripParens params body' <- block body return $ ULam $ ULamExpr params' ExplicitApp Nothing Nothing body' expr' (CFor kind indices body) = do @@ -501,7 +519,7 @@ expr = propagateSrcE expr' where KRof -> (Rev, False) KRof_ -> (Rev, True) -- TODO: Can we fetch the source position from the error context, to feed into `buildFor`? - e <- buildFor (0, 0) dir <$> mapM binderOptTy indices <*> block body + e <- buildFor (0, 0) dir <$> mapM (binderOptTy Explicit) indices <*> block body if trailingUnit then return $ UDo $ ns $ UBlock (UnaryNest (nsB $ UExprDecl e)) (ns unitExpr) else return $ dropSrcE e @@ -520,7 +538,7 @@ expr = propagateSrcE expr' where ty <- expr lhs case rhs of [b] -> do - b' <- binderOptTy b + b' <- binderReqTy Explicit b return $ UDepPairTy $ UDepPairType ImplicitDepPair b' ty _ -> error "n-ary dependent pairs not implemented" @@ -533,7 +551,7 @@ unitExpr = UPrim (UCon $ P.ProdCon) [] -- === Builders === -- TODO Does this generalize? Swap list for Nest? -buildFor :: SrcPos -> Direction -> [UOptAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS +buildFor :: SrcPos -> Direction -> [UAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS buildFor pos dir binders body = case binders of [] -> error "should have nonempty list of binder" [b] -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b body diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index 65491714e..a0b022fdf 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -48,8 +48,8 @@ newtype Polynomial (n::S) = -- This is the main entrypoint. Doing polynomial math sometimes lets -- us compute sums in closed form. This tries to compute -- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`. -sumUsingPolys :: (Builder SimpIR m, Fallible1 m, Emits n) - => Atom SimpIR n -> Abs (Binder SimpIR) (Block SimpIR) n -> m n (Atom SimpIR n) +sumUsingPolys :: Emits n + => Atom SimpIR n -> Abs (Binder SimpIR) (Block SimpIR) n -> BuilderM SimpIR n (Atom SimpIR n) sumUsingPolys lim (Abs i body) = do sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do blockAsPoly body' >>= \case @@ -192,7 +192,7 @@ blockAsPolyRec decls result = case decls of -- coefficients. This is why we have to find the least common multiples and do the -- accumulation over numbers multiplied by that LCM. We essentially do fixed point -- fractional math here. -emitPolynomial :: (Emits n, Builder SimpIR m) => Polynomial n -> m n (Atom SimpIR n) +emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (Atom SimpIR n) emitPolynomial (Polynomial p) = do let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p monoAtoms <- flip traverse (toList p) $ \(m, c) -> do @@ -206,7 +206,7 @@ emitPolynomial (Polynomial p) = do -- because it might be causing overflows due to all arithmetic being shifted. asAtom = IdxRepVal . fromInteger -emitMonomial :: (Emits n, Builder SimpIR m) => Monomial n -> m n (Atom SimpIR n) +emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (Atom SimpIR n) emitMonomial (Monomial m) = do varAtoms <- forM (toList m) \(v, e) -> case v of LeftE v' -> do @@ -217,7 +217,7 @@ emitMonomial (Monomial m) = do ipow atom e foldM imul (IdxRepVal 1) varAtoms -ipow :: (Emits n, Builder SimpIR m) => Atom SimpIR n -> Int -> m n (Atom SimpIR n) +ipow :: Emits n => Atom SimpIR n -> Int -> BuilderM SimpIR n (Atom SimpIR n) ipow x i = foldM imul (IdxRepVal 1) (replicate i x) -- === instances === diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index c539d01b4..9319f41f4 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -38,11 +38,11 @@ import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...)) -- === Ordinary (local) builder class === -class (EnvReader m, EnvExtender m, Fallible1 m, IRRep r) +class (EnvReader m, Fallible1 m, IRRep r) => Builder (r::IR) (m::MonadKind1) | m -> r where rawEmitDecl :: Emits n => NameHint -> LetAnn -> Expr r n -> m n (AtomVar r n) -class Builder r m => ScopableBuilder (r::IR) (m::MonadKind1) | m -> r where +class (EnvExtender m, Builder r m) => ScopableBuilder (r::IR) (m::MonadKind1) | m -> r where buildScopedAndThen :: SinkableE e => (forall l. (Emits l, DExt n l) => m l (e l)) @@ -252,9 +252,9 @@ instance ( IRRep r, RenameB frag, HoistableB frag, OutFrag frag {-# INLINE refreshAbs #-} instance (SinkableV v, HoistingTopBuilder f m) => HoistingTopBuilder f (SubstReaderT v m i) where - emitHoistedEnv ab = SubstReaderT $ lift $ emitHoistedEnv ab + emitHoistedEnv ab = liftSubstReaderT $ emitHoistedEnv ab {-# INLINE emitHoistedEnv #-} - canHoistToTop e = SubstReaderT $ lift $ canHoistToTop e + canHoistToTop e = liftSubstReaderT $ canHoistToTop e {-# INLINE canHoistToTop #-} -- === Top-level builder class === @@ -302,7 +302,7 @@ emitSynthCandidates sc = emitLocalModuleEnv $ mempty {envSynthCandidates = sc} addInstanceSynthCandidate :: TopBuilder m => ClassName n -> InstanceName n -> m n () addInstanceSynthCandidate className instanceName = - emitSynthCandidates $ SynthCandidates [] (M.singleton className [instanceName]) + emitSynthCandidates $ SynthCandidates (M.singleton className [instanceName]) updateTransposeRelation :: (Mut n, TopBuilder m) => TopFunName n -> TopFunName n -> m n () updateTransposeRelation f1 f2 = @@ -401,13 +401,13 @@ instance Fallible m => TopBuilder (TopBuilderT m) where {-# INLINE localTopBuilder #-} instance (SinkableV v, TopBuilder m) => TopBuilder (SubstReaderT v m i) where - emitBinding hint binding = SubstReaderT $ lift $ emitBinding hint binding + emitBinding hint binding = liftSubstReaderT $ emitBinding hint binding {-# INLINE emitBinding #-} - emitEnv ab = SubstReaderT $ lift $ emitEnv ab + emitEnv ab = liftSubstReaderT $ emitEnv ab {-# INLINE emitEnv #-} - emitNamelessEnv bs = SubstReaderT $ lift $ emitNamelessEnv bs + emitNamelessEnv bs = liftSubstReaderT $ emitNamelessEnv bs {-# INLINE emitNamelessEnv #-} - localTopBuilder cont = SubstReaderT $ ReaderT \env -> do + localTopBuilder cont = SubstReaderT \env -> do localTopBuilder do Distinct <- getDistinct runReaderT (runSubstReaderT' cont) (sink env) @@ -440,7 +440,7 @@ type BuilderEmissions r = RNest (Decl r) newtype BuilderT (r::IR) (m::MonadKind) (n::S) (a:: *) = BuilderT { runBuilderT' :: InplaceT Env (BuilderEmissions r) m n a } deriving ( Functor, Applicative, Monad, MonadTrans1, MonadFail, Fallible - , Catchable, CtxReader, ScopeReader, Alternative, Searcher + , Catchable, CtxReader, ScopeReader, Alternative , MonadWriter w, MonadReader r') type BuilderM (r::IR) = BuilderT r HardFailM @@ -514,14 +514,14 @@ instance (IRRep r, Fallible m) => EnvExtender (BuilderT r m) where {-# INLINE refreshAbs #-} instance (SinkableV v, ScopableBuilder r m) => ScopableBuilder r (SubstReaderT v m i) where - buildScopedAndThen cont1 cont2 = SubstReaderT $ ReaderT \env -> + buildScopedAndThen cont1 cont2 = SubstReaderT \env -> buildScopedAndThen (runReaderT (runSubstReaderT' cont1) (sink env)) (\d e -> runReaderT (runSubstReaderT' $ cont2 d e) (sink env)) {-# INLINE buildScopedAndThen #-} instance (SinkableV v, Builder r m) => Builder r (SubstReaderT v m i) where - rawEmitDecl hint ann expr = SubstReaderT $ lift $ emitDecl hint ann expr + rawEmitDecl hint ann expr = liftSubstReaderT $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} instance (SinkableE e, ScopableBuilder r m) => ScopableBuilder r (OutReaderT e m) where @@ -555,6 +555,10 @@ instance (SinkableE e, Builder r m) => Builder r (ReaderT1 e m) where ReaderT1 $ lift $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} +instance (DiffStateE s d, Builder r m) => Builder r (DiffStateT1 s d m) where + rawEmitDecl hint ann expr = lift11 $ rawEmitDecl hint ann expr + {-# INLINE rawEmitDecl #-} + instance (SinkableE e, HoistableState e, Builder r m) => Builder r (StateT1 e m) where rawEmitDecl hint ann expr = lift11 $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c4cc41bb1..99559bc17 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -20,7 +20,6 @@ module CheapReduction where import Control.Applicative -import Control.Monad.Trans import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.State.Strict import Control.Monad.Reader @@ -42,7 +41,6 @@ import Types.Core import Types.Imp import Types.Primitives import Util -import {-# SOURCE #-} Inference (trySynthTerm) -- Carry out the reductions we are willing to carry out during type -- inference. The goal is to support type aliases like `Int = Int32` @@ -102,9 +100,9 @@ class ( Alternative2 m, SubstReader AtomSubstVal m lookupCache :: AtomName r o -> m i o (Maybe (Maybe (Atom r o))) instance IRRep r => CheapReducer (CheapReducerM r) r where - updateCache v u = CheapReducerM $ SubstReaderT $ lift $ lift11 $ + updateCache v u = CheapReducerM $ liftSubstReaderT $ lift11 $ modify (MapE . M.insert v (toMaybeE u) . fromMapE) - lookupCache v = CheapReducerM $ SubstReaderT $ lift $ lift11 $ + lookupCache v = CheapReducerM $ liftSubstReaderT $ lift11 $ fmap fromMaybeE <$> gets (M.lookup v . fromMapE) liftCheapReducerM @@ -185,11 +183,6 @@ instance IRRep r => CheaplyReducibleE r (Atom r) (Atom r) where -- TODO: we don't collect the dict holes here, so there's a danger of -- dropping them if they turn out to be phantom. Lam _ -> substM a - DictHole ctx ty' access -> do - ty <- cheapReduceE ty' - runFallibleT1 (trySynthTerm ty access) >>= \case - Success d -> return d - Failure _ -> return $ DictHole ctx ty access -- We traverse the Atom constructors that might contain lambda expressions -- explicitly, to make sure that we can skip normalizing free vars inside those. Con con -> Con <$> traverseOp con cheapReduceE cheapReduceE (error "unexpected lambda") @@ -316,6 +309,9 @@ instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') cheapReduceE (LeftE e) = LeftE <$> cheapReduceE e cheapReduceE (RightE e) = RightE <$> cheapReduceE e +instance CheaplyReducibleE r e e' => CheaplyReducibleE r (ListE e) (ListE e') where + cheapReduceE (ListE xs) = ListE <$> mapM cheapReduceE xs + -- XXX: TODO: figure out exactly what our normalization invariants are. We -- shouldn't have to choose `normalizeProj` or `asNaryProj` on a -- case-by-case basis. This is here for now because it makes it easier to switch @@ -616,7 +612,6 @@ visitAtomPartial = \case Eff eff -> Eff <$> visitGeneric eff DictCon t d -> DictCon <$> visitType t <*> visitGeneric d NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x - DictHole ctx ty access -> DictHole ctx <$> visitGeneric ty <*> pure access TypeAsAtom t -> TypeAsAtom <$> visitGeneric t RepValAtom repVal -> RepValAtom <$> visitGeneric repVal diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index b8677c185..14573bf26 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -180,9 +180,6 @@ instance IRRep r => CheckableE r (Atom r) where con' <- typeCheckNewtypeCon con xTy return $ NewtypeCon con' x' SimpInCore x -> SimpInCore <$> checkE x - DictHole ctx ty access -> do - ty' <- ty |: TyKind - return $ DictHole ctx ty' access ProjectElt resultTy UnwrapNewtype x -> do resultTy' <- resultTy |: TyKind (x', NewtypeTyCon con) <- checkAndGetType x diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 06cbc702b..32eabb071 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -199,7 +199,6 @@ topDecl' = <|> interfaceDef <|> (CInstanceDecl <$> instanceDef True) <|> (CInstanceDecl <$> instanceDef False) - <|> effectDef proseBlock :: Parser SourceBlock' proseBlock = label "prose block" $ char '\'' >> fmap (Misc . ProseBlock . fst) (withSource consumeTillBreak) @@ -291,26 +290,6 @@ interfaceDef = do return (methodName, ty) return $ CInterface className params methodDecls -effectDef :: Parser CTopDecl' -effectDef = do - keyWord EffectKW - effName <- anyName - sigs <- opSigList - return $ CEffectDecl (fromString effName) sigs - -opSigList :: Parser [(SourceName, UResumePolicy, Group)] -opSigList = onePerLine do - policy <- resumePolicy - v <- anyName - void $ sym ":" - ty <- cGroup - return (fromString v, policy, ty) - -resumePolicy :: Parser UResumePolicy -resumePolicy = (keyWord JmpKW $> UNoResume) - <|> (keyWord DefKW $> ULinearResume) - <|> (keyWord CtlKW $> UAnyResume) - nameAndType :: Parser (SourceName, Group) nameAndType = do n <- anyName @@ -693,11 +672,11 @@ ops = , [symOpL "@"] , [symOpN "::"] , [symOpR "$"] - , [symOpL "|"] , [symOpN "+=", symOpN ":="] -- Associate right so the mistaken utterance foo : i:Fin 4 => (..i) -- groups as a bad pi type rather than a bad binder , [symOpR ":"] + , [symOpL "|"] , [symOpR ",>"] , [symOpR "&>"] , [withClausePostfix] diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 8bdac679e..6f4bd6b5a 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -79,7 +79,7 @@ type EnvExtender2 (m::MonadKind2) = forall (n::S). EnvExtender (m n) newtype EnvReaderT (m::MonadKind) (n::S) (a:: *) = EnvReaderT {runEnvReaderT' :: ReaderT (DistinctEvidence n, Env n) m a } deriving ( Functor, Applicative, Monad, MonadFail - , MonadWriter w, Fallible, Searcher, Alternative) + , MonadWriter w, Fallible, Alternative) type EnvReaderM = EnvReaderT Identity type FallibleEnvReaderM = EnvReaderT FallibleM @@ -132,6 +132,15 @@ instance MonadIO m => MonadIO (EnvReaderT m n) where deriving instance (Monad m, MonadState s m) => MonadState s (EnvReaderT m o) +instance (Monad m, CtxReader m) => CtxReader (EnvReaderT m o) where + getErrCtx = EnvReaderT $ lift getErrCtx + {-# INLINE getErrCtx #-} + +instance (Monad m, Catchable m) => Catchable (EnvReaderT m o) where + catchErr (EnvReaderT (ReaderT m)) f = EnvReaderT $ ReaderT \env -> + m env `catchErr` \err -> runReaderT (runEnvReaderT' $ f err) env + {-# INLINE catchErr #-} + -- === Instances for Name monads === instance (SinkableE e, EnvReader m) @@ -389,12 +398,6 @@ withFreshBinders (binding:rest) cont = do cont (Nest b bs) (sink (binderName b) : vs) -getInstanceDicts :: EnvReader m => ClassName n -> m n [InstanceName n] -getInstanceDicts name = do - env <- withEnv moduleEnv - return $ M.findWithDefault [] name $ instanceDicts $ envSynthCandidates env -{-# INLINE getInstanceDicts #-} - -- These `fromNary` functions traverse a chain of unary structures (LamExpr, -- TabLamExpr, CorePiType, respectively) up to the given maxDepth, and return the -- discovered binders packed as the nary structure (NaryLamExpr or PiType), diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 426250884..6af6141c8 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -4,22 +4,20 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Err (Err (..), Errs (..), ErrType (..), Except (..), +module Err (Err (..), ErrType (..), Except (..), ErrCtx (..), SrcTextCtx, Fallible (..), Catchable (..), catchErrExcept, FallibleM (..), HardFailM (..), CtxReader (..), - runFallibleM, runHardFail, throw, throwErr, + runFallibleM, runHardFail, throw, addContext, addSrcContext, addSrcTextContext, catchIOExcept, liftExcept, liftExceptAlt, assertEq, ignoreExcept, - pprint, docAsStr, getCurrentCallStack, printCurrentCallStack, - FallibleApplicativeWrapper, traverseMergingErrs, - SearcherM (..), Searcher (..), runSearcherM) where + pprint, docAsStr, getCurrentCallStack, printCurrentCallStack + ) where import Control.Exception hiding (throw) import Control.Applicative import Control.Monad -import Control.Monad.Trans.Maybe import Control.Monad.Identity import Control.Monad.Writer.Strict import Control.Monad.State.Strict @@ -39,7 +37,6 @@ import SourceInfo -- === core API === data Err = Err ErrType ErrCtx String deriving (Show, Eq) -newtype Errs = Errs [Err] deriving (Eq, Semigroup, Monoid) data ErrType = NoErr | ParseErr @@ -63,6 +60,7 @@ data ErrType = NoErr | EscapedNameErr | ModuleImportErr | MonadFailErr + | SearchFailure -- used as the identity for `Alternative` instances deriving (Show, Eq) type SrcTextCtx = Maybe (Int, Text) -- Int is the offset in the source file @@ -74,11 +72,11 @@ data ErrCtx = ErrCtx deriving (Show, Eq, Generic) class MonadFail m => Fallible m where - throwErrs :: Errs -> m a + throwErr :: Err -> m a addErrCtx :: ErrCtx -> m a -> m a class Fallible m => Catchable m where - catchErr :: m a -> (Errs -> m a) -> m a + catchErr :: m a -> (Err -> m a) -> m a catchErrExcept :: Catchable m => m a -> m (Except a) catchErrExcept m = catchErr (Success <$> m) (\e -> return $ Failure e) @@ -88,19 +86,16 @@ catchErrExcept m = catchErr (Success <$> m) (\e -> return $ Failure e) class Fallible m => CtxReader m where getErrCtx :: m ErrCtx --- We have this in its own class because StateT can't implement it --- (but FallibleM, Except and IO all can) -class Fallible m => FallibleApplicative m where - mergeErrs :: m a -> m b -> m (a, b) - newtype FallibleM a = FallibleM { fromFallibleM :: ReaderT ErrCtx Except a } deriving (Functor, Applicative, Monad) instance Fallible FallibleM where - throwErrs (Errs errs) = FallibleM $ ReaderT \ambientCtx -> - throwErrs $ Errs [Err errTy (ambientCtx <> ctx) s | Err errTy ctx s <- errs] - {-# INLINE throwErrs #-} + -- TODO: we end up adding the context multiple times when we do throw/catch. + -- We should fix it. + throwErr (Err errTy ctx s) = FallibleM $ ReaderT \ambientCtx -> + throwErr $ Err errTy (ambientCtx <> ctx) s + {-# INLINE throwErr #-} addErrCtx ctx (FallibleM m) = FallibleM $ local (<> ctx) m {-# INLINE addErrCtx #-} @@ -110,17 +105,27 @@ instance Catchable FallibleM where Failure errs -> runReaderT (fromFallibleM $ handler errs) ctx Success ans -> return ans -instance FallibleApplicative FallibleM where - mergeErrs (FallibleM (ReaderT f1)) (FallibleM (ReaderT f2)) = - FallibleM $ ReaderT \ctx -> mergeErrs (f1 ctx) (f2 ctx) - instance CtxReader FallibleM where getErrCtx = FallibleM ask {-# INLINE getErrCtx #-} +instance Alternative FallibleM where + empty = throw SearchFailure "" + {-# INLINE empty #-} + m1 <|> m2 = do + catchSearchFailure m1 >>= \case + Nothing -> m2 + Just x -> return x + {-# INLINE (<|>) #-} + +catchSearchFailure :: Catchable m => m a -> m (Maybe a) +catchSearchFailure m = (Just <$> m) `catchErr` \case + Err SearchFailure _ _ -> return Nothing + err -> throwErr err + instance Fallible IO where - throwErrs errs = throwIO errs - {-# INLINE throwErrs #-} + throwErr errs = throwIO errs + {-# INLINE throwErr #-} addErrCtx ctx m = do result <- catchIOExcept m liftExcept $ addErrCtx ctx result @@ -132,23 +137,17 @@ instance Catchable IO where Success result -> return result Failure errs -> handler errs -instance FallibleApplicative IO where - mergeErrs m1 m2 = do - result1 <- catchIOExcept m1 - result2 <- catchIOExcept m2 - liftExcept $ mergeErrs result1 result2 - runFallibleM :: FallibleM a -> Except a runFallibleM m = runReaderT (fromFallibleM m) mempty {-# INLINE runFallibleM #-} -- === Except type === --- Except is isomorphic to `Either Errs` but having a distinct type makes it +-- Except is isomorphic to `Either Err` but having a distinct type makes it -- easier to debug type errors. data Except a = - Failure Errs + Failure Err | Success a deriving (Show, Eq) @@ -169,23 +168,6 @@ instance Monad Except where Success x >>= f = f x {-# INLINE (>>=) #-} --- === FallibleApplicativeWrapper === - --- Wraps a Fallible monad, presenting an applicative interface that sequences --- actions using the error-concatenating `mergeErrs` instead of the default --- abort-on-failure sequencing. - -newtype FallibleApplicativeWrapper m a = - FallibleApplicativeWrapper { fromFallibleApplicativeWrapper :: m a } - deriving (Functor) - -instance FallibleApplicative m => Applicative (FallibleApplicativeWrapper m) where - pure x = FallibleApplicativeWrapper $ pure x - {-# INLINE pure #-} - liftA2 f (FallibleApplicativeWrapper m1) (FallibleApplicativeWrapper m2) = - FallibleApplicativeWrapper $ fmap (uncurry f) (mergeErrs m1 m2) - {-# INLINE liftA2 #-} - -- === HardFail === -- Implements Fallible by crashing. Used in type querying when we want to avoid @@ -222,24 +204,17 @@ instance MonadFail HardFailM where {-# INLINE fail #-} instance Fallible HardFailM where - throwErrs errs = error $ pprint errs - {-# INLINE throwErrs #-} + throwErr errs = error $ pprint errs + {-# INLINE throwErr #-} addErrCtx _ cont = cont {-# INLINE addErrCtx #-} -instance FallibleApplicative HardFailM where - mergeErrs cont1 cont2 = (,) <$> cont1 <*> cont2 - -- === convenience layer === throw :: Fallible m => ErrType -> String -> m a -throw errTy s = throwErrs $ Errs [addCompilerStackCtx $ Err errTy mempty s] +throw errTy s = throwErr $ addCompilerStackCtx $ Err errTy mempty s {-# INLINE throw #-} -throwErr :: Fallible m => Err -> m a -throwErr err = throwErrs $ Errs [addCompilerStackCtx err] -{-# INLINE throwErr #-} - addCompilerStackCtx :: Err -> Err addCompilerStackCtx (Err ty ctx msg) = Err ty ctx{stackCtx = compilerStack} msg where @@ -278,17 +253,17 @@ addSrcTextContext offset text m = catchIOExcept :: MonadIO m => IO a -> m (Except a) catchIOExcept m = liftIO $ (liftM Success m) `catches` - [ Handler \(e::Errs) -> return $ Failure e - , Handler \(e::IOError) -> return $ Failure $ Errs [Err DataIOErr mempty $ show e] + [ Handler \(e::Err) -> return $ Failure e + , Handler \(e::IOError) -> return $ Failure $ Err DataIOErr mempty $ show e -- Propagate asynchronous exceptions like ThreadKilled; they are -- part of normal operation (of the live evaluation modes), not -- compiler bugs. , Handler \(e::AsyncException) -> liftIO $ throwIO e - , Handler \(e::SomeException) -> return $ Failure $ Errs [Err CompilerErr mempty $ show e] + , Handler \(e::SomeException) -> return $ Failure $ Err CompilerErr mempty $ show e ] liftExcept :: Fallible m => Except a -> m a -liftExcept (Failure errs) = throwErrs errs +liftExcept (Failure errs) = throwErr errs liftExcept (Success ans) = return ans {-# INLINE liftExcept #-} @@ -310,83 +285,21 @@ assertEq x y s = if x == y then return () ++ pprint x ++ " != " ++ pprint y ++ "\n\n" ++ prettyCallStack callStack ++ "\n" --- === search monad === - -infix 0 -class (Monad m, Alternative m) => Searcher m where - -- Runs the second computation when the first yields an empty set of results. - -- This is just `<|>` for greedy searchers like `Maybe`, but in other cases, - -- like the list monad, it matters that the second computation isn't run if - -- the first succeeds. - () :: m a -> m a -> m a - --- Adds an extra error case to `FallibleM` so we can give it an Alternative --- instance with an identity element. -newtype SearcherM a = SearcherM { runSearcherM' :: MaybeT FallibleM a } - deriving (Functor, Applicative, Monad) - -runSearcherM :: SearcherM a -> Except (Maybe a) -runSearcherM m = runFallibleM $ runMaybeT (runSearcherM' m) -{-# INLINE runSearcherM #-} - -instance MonadFail SearcherM where - fail _ = SearcherM $ MaybeT $ return Nothing - {-# INLINE fail #-} - -instance Fallible SearcherM where - throwErrs e = SearcherM $ lift $ throwErrs e - {-# INLINE throwErrs #-} - addErrCtx ctx (SearcherM (MaybeT m)) = SearcherM $ MaybeT $ - addErrCtx ctx $ m - {-# INLINE addErrCtx #-} - -instance Alternative SearcherM where - empty = SearcherM $ MaybeT $ return Nothing - SearcherM (MaybeT m1) <|> SearcherM (MaybeT m2) = SearcherM $ MaybeT do - m1 >>= \case - Just ans -> return $ Just ans - Nothing -> m2 - -instance Catchable SearcherM where - SearcherM (MaybeT m) `catchErr` handler = SearcherM $ MaybeT $ - m `catchErr` \errs -> runMaybeT $ runSearcherM' $ handler errs - -instance Searcher SearcherM where - () = (<|>) - {-# INLINE () #-} - -instance CtxReader SearcherM where - getErrCtx = SearcherM $ lift getErrCtx - {-# INLINE getErrCtx #-} - -instance Searcher [] where - [] m = m - m _ = m - {-# INLINE () #-} - -instance (Monoid w, Searcher m) => Searcher (WriterT w m) where - WriterT m1 WriterT m2 = WriterT (m1 m2) - {-# INLINE () #-} - instance (Monoid w, Fallible m) => Fallible (WriterT w m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} addErrCtx ctx (WriterT m) = WriterT $ addErrCtx ctx m {-# INLINE addErrCtx #-} -instance Searcher m => Searcher (ReaderT r m) where - ReaderT f1 ReaderT f2 = ReaderT \r -> f1 r f2 r - {-# INLINE () #-} - instance Fallible [] where - throwErrs _ = [] - {-# INLINE throwErrs #-} + throwErr _ = [] + {-# INLINE throwErr #-} addErrCtx _ m = m {-# INLINE addErrCtx #-} instance Fallible Maybe where - throwErrs _ = Nothing - {-# INLINE throwErrs #-} + throwErr _ = Nothing + {-# INLINE throwErr #-} addErrCtx _ m = m {-# INLINE addErrCtx #-} @@ -404,11 +317,6 @@ layout :: LayoutOptions layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions where unbounded = unsafePerformIO $ (Just "1"==) <$> lookupEnv "DEX_PPRINT_UNBOUNDED" -traverseMergingErrs :: (Traversable f, FallibleApplicative m) - => (a -> m b) -> f a -> m (f b) -traverseMergingErrs f xs = - fromFallibleApplicativeWrapper $ traverse (\x -> FallibleApplicativeWrapper $ f x) xs - -- === instances === instance MonadFail FallibleM where @@ -416,29 +324,19 @@ instance MonadFail FallibleM where {-# INLINE fail #-} instance Fallible Except where - throwErrs errs = Failure errs - {-# INLINE throwErrs #-} + throwErr errs = Failure errs + {-# INLINE throwErr #-} addErrCtx _ (Success ans) = Success ans - addErrCtx ctx (Failure (Errs errs)) = - Failure $ Errs [Err errTy (ctx <> ctx') s | Err errTy ctx' s <- errs] + addErrCtx ctx (Failure (Err errTy ctx' s)) = + Failure $ Err errTy (ctx <> ctx') s {-# INLINE addErrCtx #-} -instance FallibleApplicative Except where - mergeErrs (Success x) (Success y) = Success (x, y) - mergeErrs x y = Failure (getErrs x <> getErrs y) - where getErrs :: Except a -> Errs - getErrs = \case Failure e -> e - Success _ -> mempty - instance MonadFail Except where - fail s = Failure $ Errs [Err CompilerErr mempty s] + fail s = Failure $ Err CompilerErr mempty s {-# INLINE fail #-} -instance Exception Errs - -instance Show Errs where - show errs = pprint errs +instance Exception Err instance Pretty Err where pretty (Err e ctx s) = pretty e <> pretty s <> prettyCtx @@ -496,10 +394,11 @@ instance Pretty ErrType where EscapedNameErr -> "Leaked local variables:" ModuleImportErr -> "Module import error: " MonadFailErr -> "MonadFail error (internal error)" + SearchFailure -> "Search error (internal error)" instance Fallible m => Fallible (ReaderT r m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} addErrCtx ctx (ReaderT f) = ReaderT \r -> addErrCtx ctx $ f r {-# INLINE addErrCtx #-} @@ -507,21 +406,13 @@ instance Catchable m => Catchable (ReaderT r m) where ReaderT f `catchErr` handler = ReaderT \r -> f r `catchErr` \e -> runReaderT (handler e) r -instance FallibleApplicative m => FallibleApplicative (ReaderT r m) where - mergeErrs (ReaderT f1) (ReaderT f2) = - ReaderT \r -> mergeErrs (f1 r) (f2 r) - instance CtxReader m => CtxReader (ReaderT r m) where getErrCtx = lift getErrCtx {-# INLINE getErrCtx #-} -instance Pretty Errs where - pretty (Errs [err]) = pretty err - pretty (Errs errs) = prettyLines errs - instance Fallible m => Fallible (StateT s m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} addErrCtx ctx (StateT f) = StateT \s -> addErrCtx ctx $ f s {-# INLINE addErrCtx #-} diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 56cfa95de..bd3c462a2 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -25,7 +25,6 @@ import Data.Maybe (fromJust, isJust) import Data.Text.Prettyprint.Doc import Control.Category import Control.Monad.Identity -import Control.Monad.Reader import Control.Monad.Writer.Strict import Control.Monad.State.Strict hiding (State) import qualified Control.Monad.State.Strict as MTL @@ -246,14 +245,14 @@ instance ImpBuilder ImpM where {-# INLINE extendAllocsToFree #-} instance ImpBuilder m => ImpBuilder (SubstReaderT AtomSubstVal m i) where - emitMultiReturnInstr instr = SubstReaderT $ lift $ emitMultiReturnInstr instr + emitMultiReturnInstr instr = liftSubstReaderT $ emitMultiReturnInstr instr {-# INLINE emitMultiReturnInstr #-} - emitDeclsImp ab = SubstReaderT $ lift $ emitDeclsImp ab + emitDeclsImp ab = liftSubstReaderT $ emitDeclsImp ab {-# INLINE emitDeclsImp #-} - buildScopedImp cont = SubstReaderT $ ReaderT \env -> + buildScopedImp cont = SubstReaderT \env -> buildScopedImp $ runSubstReaderT (sink env) $ cont {-# INLINE buildScopedImp #-} - extendAllocsToFree ptr = SubstReaderT $ lift $ extendAllocsToFree ptr + extendAllocsToFree ptr = liftSubstReaderT $ extendAllocsToFree ptr {-# INLINE extendAllocsToFree #-} instance ImpBuilder m => Imper (SubstReaderT AtomSubstVal m) diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index b747c2c2c..c5691cd2f 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -95,7 +95,7 @@ newtype CompileM i o a = , EnvReader, SubstReader OperandSubstVal ) instance MonadState CompileState (CompileM i o) where - state f = CompileM $ SubstReaderT $ lift $ EnvReaderT $ lift $ state f + state f = CompileM $ liftSubstReaderT $ EnvReaderT $ lift $ state f class MonadState CompileState m => LLVMBuilder (m::MonadKind) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index d4384ddb9..373844202 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -4,33 +4,28 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} module Inference - ( inferTopUDecl, checkTopUType, inferTopUExpr - , trySynthTerm, generalizeDict, asTopBlock - , synthTopE, UDeclInferenceResult (..), asFFIFunType) where + ( inferTopUDecl, checkTopUType, inferTopUExpr , generalizeDict, asTopBlock + , UDeclInferenceResult (..), asFFIFunType) where import Prelude hiding ((.), id) import Control.Category import Control.Applicative import Control.Monad import Control.Monad.State.Strict -import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.Reader import Data.Either (partitionEithers) import Data.Foldable (toList, asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) -import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat, group, line, nest) +import Data.Text.Prettyprint.Doc (Pretty (..)) import Data.Word import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M -import qualified Data.Set as S -import qualified Unsafe.Coerce as TrulyUnsafe import GHC.Generics (Generic (..)) import Builder @@ -49,17 +44,15 @@ import Types.Imp import Types.Primitives import Types.Source import Util hiding (group) -import PPrint (prettyBlock) -- === Top-level interface === checkTopUType :: (Fallible1 m, EnvReader m) => UType n -> m n (CType n) -checkTopUType ty = liftInfererM $ solveLocal $ withApplyDefaults $ checkUType ty +checkTopUType ty = liftInfererM $ checkUType ty {-# SCC checkTopUType #-} inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) -inferTopUExpr e = asTopBlock =<< liftInfererM do - solveLocal $ buildBlockInf $ withApplyDefaults $ inferSigma noHint e +inferTopUExpr e = asTopBlock =<< liftInfererM (buildScoped $ bottomUp e) {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = @@ -67,27 +60,22 @@ data UDeclInferenceResult e n = | UDeclResultBindName LetAnn (TopBlock CoreIR n) (Abs (UBinder (AtomNameC CoreIR)) e n) | UDeclResultBindPattern NameHint (TopBlock CoreIR n) (ReconAbs CoreIR e n) -inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, HoistableE e, RenameE e) +inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, HasNamesE e) => UTopDecl n l -> e l -> m n (UDeclInferenceResult e n) inferTopUDecl (UStructDecl tc def) result = do tc' <- emitBinding (getNameHint tc) $ TyConBinding Nothing (DotMethods mempty) - def' <- liftInfererM $ solveLocal do - extendRenamer (tc@>sink tc') $ inferStructDef def - def'' <- synthTyConDef def' - updateTopEnv $ UpdateTyConDef tc' def'' - UStructDef _ (_, paramBs) _ methods <- return def + def' <- liftInfererM $ extendRenamer (tc@>tc') $ inferStructDef def + updateTopEnv $ UpdateTyConDef tc' def' + UStructDef _ paramBs _ methods <- return def forM_ methods \(letAnn, methodName, methodDef) -> do - method <- liftInfererM $ solveLocal $ - extendRenamer (tc@>sink tc') $ - inferDotMethod (sink tc') (Abs paramBs methodDef) - methodSynth <- synthTopE (Lam method) - method' <- emitTopLet (getNameHint methodName) letAnn (Atom methodSynth) + method <- liftInfererM $ extendRenamer (tc@>tc') $ + inferDotMethod tc' (Abs paramBs methodDef) + method' <- emitTopLet (getNameHint methodName) letAnn (Atom $ Lam method) updateTopEnv $ UpdateFieldDef tc' methodName (atomVarName method') UDeclResultDone <$> applyRename (tc @> tc') result inferTopUDecl (UDataDefDecl def tc dcs) result = do - tcDef <- liftInfererM $ solveLocal $ inferTyConDef def - tcDef'@(TyConDef _ _ _ (ADTCons dataCons)) <- synthTyConDef tcDef - tc' <- emitBinding (getNameHint tcDef') $ TyConBinding (Just tcDef') (DotMethods mempty) + tcDef@(TyConDef _ _ _ (ADTCons dataCons)) <- liftInfererM $ inferTyConDef def + tc' <- emitBinding (getNameHint tcDef) $ TyConBinding (Just tcDef) (DotMethods mempty) dcs' <- forM (enumerate dataCons) \(i, dcDef) -> emitBinding (getNameHint dcDef) $ DataConBinding tc' i let subst = tc @> tc' <.> dcs @@> dcs' @@ -95,33 +83,26 @@ inferTopUDecl (UDataDefDecl def tc dcs) result = do inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do let sn = getSourceName className let methodSourceNames = nestToList getSourceName methodNames - classDef <- liftInfererM $ solveLocal $ inferClassDef sn methodSourceNames paramBs methodTys + classDef <- liftInfererM $ inferClassDef sn methodSourceNames paramBs methodTys className' <- emitBinding (getNameHint sn) $ ClassBinding classDef - methodNames' <- - forM (enumerate methodSourceNames) \(i, prettyName) -> do - emitBinding (getNameHint prettyName) $ MethodBinding className' i + methodNames' <- forM (enumerate methodSourceNames) \(i, prettyName) -> do + emitBinding (getNameHint prettyName) $ MethodBinding className' i let subst = className @> className' <.> methodNames @@> methodNames' UDeclResultDone <$> applyRename subst result -inferTopUDecl (UInstance className instanceBs params methods maybeName expl) result = do +inferTopUDecl (UInstance className bs params methods maybeName expl) result = do let (InternalName _ _ className') = className - ab <- liftInfererM $ solveLocal do - withRoleUBinders instanceBs do - ClassDef _ _ _ roleExpls paramBinders _ _ <- lookupClassDef (sink className') - let expls = snd <$> roleExpls - params' <- checkInstanceParams expls paramBinders params - className'' <- sinkM className' - body <- checkInstanceBody className'' params' methods - return (ListE params' `PairE` body) - Abs bs' (ListE params' `PairE` body) <- return ab - let (roleExpls, bs'') = unzipAttrs bs' - let def = InstanceDef className' roleExpls bs'' params' body + def <- liftInfererM $ withRoleUBinders bs \(ZipB roleExpls bs') -> do + ClassDef _ _ _ _ paramBinders _ _ <- lookupClassDef (sink className') + params' <- checkInstanceParams paramBinders params + body <- checkInstanceBody (sink className') params' methods + return $ InstanceDef className' roleExpls bs' params' body UDeclResultDone <$> case maybeName of RightB UnitB -> do - void $ synthInstanceDefAndAddSynthCandidate def + instanceName <- emitInstanceDef def + addInstanceSynthCandidate className' instanceName return result JustB instanceName' -> do - def' <- synthInstanceDef def - instanceName <- emitInstanceDef def' + instanceName <- emitInstanceDef def lam <- instanceFun instanceName expl instanceAtomName <- emitTopLet (getNameHint instanceName') PlainLet $ Atom lam applyRename (instanceName' @> atomVarName instanceAtomName) result @@ -131,21 +112,18 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d UExprDecl _ -> error "Shouldn't have this at the top level (should have become a command instead)" ULet letAnn p tyAnn rhs -> case p of WithSrcB _ (UPatBinder b) -> do - block <- liftInfererM $ solveLocal $ buildBlockInf do - checkMaybeAnnExpr (getNameHint b) tyAnn rhs <* applyDefaults + block <- liftInfererM $ buildScoped do + checkMaybeAnnExpr tyAnn rhs topBlock <- asTopBlock block return $ UDeclResultBindName letAnn topBlock (Abs b result) _ -> do - PairE block recon <- liftInfererM $ solveLocal $ buildBlockInfWithRecon do - val <- checkMaybeAnnExpr (getNameHint p) tyAnn rhs + PairE block recon <- liftInfererM $ buildBlockInfWithRecon do + val <- checkMaybeAnnExpr tyAnn rhs v <- emitHinted (getNameHint p) $ Atom val bindLetPat p v do - applyDefaults renameM result topBlock <- asTopBlock block return $ UDeclResultBindPattern (getNameHint p) topBlock recon -inferTopUDecl (UEffectDecl _ _ _) _ = error "not implemented" -inferTopUDecl (UHandlerDecl _ _ _ _ _ _ _) _ = error "not implemented" {-# SCC inferTopUDecl #-} asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n) @@ -161,869 +139,404 @@ getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM d let dTy = DictTy $ DictType classSourceName className' params' return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy --- === Inferer interface === - -buildAbsInfWithExpl - :: (HasNamesE e, SubstE AtomSubstVal e) - => EmitsInf n - => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> InfererM i l (e l)) - -> InfererM i n (Abs (WithExpl CBinder) e n) -buildAbsInfWithExpl hint expl ty cont = do - Abs b e <- buildAbsInf hint expl ty cont - return $ Abs (WithAttrB expl b) e - -buildNaryAbsInfWithExpl - :: (HasNamesE e, SubstE AtomSubstVal e) - => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> InfererM i l (e l)) - -> InfererM i n (Abs (Nest (WithExpl CBinder)) e n) -buildNaryAbsInfWithExpl expls bs cont = do - Abs bs' e <- buildNaryAbsInf expls bs cont - return $ Abs (zipAttrs expls bs') e - -buildNaryAbsInf - :: (HasNamesE e, SubstE AtomSubstVal e) - => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> InfererM i l (e l)) - -> InfererM i n (Abs (Nest CBinder) e n) -buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = - prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do - bs' <- applyRename (b@>atomVarName v) (Abs bs UnitE) - buildNaryAbsInf expls bs' \vs -> cont (sink v:vs) -buildNaryAbsInf _ _ _ = error "zip error" - -buildDeclsInf - :: (HasNamesE e, SubstE AtomSubstVal e) - => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (Abs (Nest CDecl) e n) -buildDeclsInf cont = buildDeclsInfUnzonked $ cont >>= zonk - -applyDefaults :: EmitsInf o => InfererM i o () -applyDefaults = do - defaults <- getDefaults - applyDefault (intDefaults defaults) (BaseTy $ Scalar Int32Type) - applyDefault (natDefaults defaults) NatTy - where - applyDefault ds ty = - forM_ (nameSetToList ds) \v -> do - v' <- toAtomVar v - tryConstrainEq (Var v') (Type ty) - -withApplyDefaults :: EmitsInf o => InfererM i o a -> InfererM i o a -withApplyDefaults cont = cont <* applyDefaults -{-# INLINE withApplyDefaults #-} - --- === Concrete Inferer monad === - -data InfOutMap (n::S) = - InfOutMap - (Env n) - (SolverSubst n) - (Defaults n) - -- the subset of the names in the bindings whose definitions may contain - -- inference vars (this is so we can avoid zonking everything in scope when - -- we zonk bindings) - (UnsolvedEnv n) - -- allowed effects - (EffectRow CoreIR n) - -data DefaultType = IntDefault | NatDefault - -data Defaults (n::S) = Defaults - { intDefaults :: NameSet n -- Set of names that should be defaulted to Int32 - , natDefaults :: NameSet n } -- Set of names that should be defaulted to Nat32 - -instance Semigroup (Defaults n) where - Defaults d1 d2 <> Defaults d1' d2' = Defaults (d1 <> d1') (d2 <> d2') - -instance Monoid (Defaults n) where - mempty = Defaults mempty mempty - -instance SinkableE Defaults where - sinkingProofE _ _ = todoSinkableProof -instance HoistableE Defaults where - freeVarsE (Defaults d1 d2) = d1 <> d2 -instance RenameE Defaults where - renameE env (Defaults d1 d2) = Defaults (substDefaultSet d1) (substDefaultSet d2) - where - substDefaultSet d = freeVarsE $ renameE env $ ListE $ nameSetToList @(AtomNameC CoreIR) d +-- === Inferer monad === -instance Pretty (Defaults n) where - pretty (Defaults ints nats) = - attach "Names defaulting to Int32" (nameSetToList @(AtomNameC CoreIR) ints) - <+> attach "Names defaulting to Nat32" (nameSetToList @(AtomNameC CoreIR) nats) - where - attach _ [] = mempty - attach s l = s <+> pretty l +newtype SolverSubst n = SolverSubst { fromSolverSubst :: M.Map (CAtomName n) (CAtom n) } -zonkDefaults :: SolverSubst n -> Defaults n -> Defaults n -zonkDefaults s (Defaults d1 d2) = - Defaults (zonkDefaultSet d1) (zonkDefaultSet d2) - where - zonkDefaultSet d = flip foldMap (nameSetToList @(AtomNameC CoreIR) d) \v -> - case lookupSolverSubst s v of - Rename v' -> freeVarsE v' - SubstVal (Var v') -> freeVarsE v' - _ -> mempty - -data InfOutFrag (n::S) (l::S) = InfOutFrag (InfEmissions n l) (Defaults l) (Constraints l) - -instance Pretty (InfOutFrag n l) where - pretty (InfOutFrag emissions defaults solverSubst) = - vcat [ "Pending emissions:" <+> pretty (unRNest emissions) - , "Defaults:" <+> pretty defaults - , "Solver substitution:" <+> pretty solverSubst - ] - -type InfEmission = EitherE (DeclBinding CoreIR) SolverBinding -type InfEmissions = RNest (BinderP (AtomNameC CoreIR) InfEmission) - -instance GenericB InfOutFrag where - type RepB InfOutFrag = PairB InfEmissions (LiftB (PairE Defaults Constraints)) - fromB (InfOutFrag emissions defaults solverSubst) = - PairB emissions (LiftB (PairE defaults solverSubst)) - toB (PairB emissions (LiftB (PairE defaults solverSubst))) = - InfOutFrag emissions defaults solverSubst - -instance ProvesExt InfOutFrag -instance RenameB InfOutFrag -instance BindsNames InfOutFrag -instance SinkableB InfOutFrag -instance HoistableB InfOutFrag - -instance OutFrag InfOutFrag where - emptyOutFrag = InfOutFrag REmpty mempty mempty - catOutFrags (InfOutFrag em ds ss) (InfOutFrag em' ds' ss') = - withExtEvidence em' $ - InfOutFrag (em >>> em') (sink ds <> ds') (sink ss <> ss') - -instance HasScope InfOutMap where - toScope (InfOutMap bindings _ _ _ _) = toScope bindings - -instance OutMap InfOutMap where - emptyOutMap = InfOutMap emptyOutMap emptySolverSubst mempty mempty Pure - -instance ExtOutMap InfOutMap EnvFrag where - extendOutMap (InfOutMap bindings ss dd oldUn effs) frag = - withExtEvidence frag do - let newUn = UnsolvedEnv $ getAtomNames frag - let newEnv = bindings `extendOutMap` frag - -- As an optimization, only do the zonking for the new stuff. - let (zonkedUn, zonkedEnv) = zonkUnsolvedEnv (sink ss) newUn newEnv - InfOutMap zonkedEnv (sink ss) (sink dd) (sink oldUn <> zonkedUn) (sink effs) - -newtype UnsolvedEnv (n::S) = - UnsolvedEnv { fromUnsolvedEnv :: S.Set (CAtomName n) } - deriving (Semigroup, Monoid) - -instance SinkableE UnsolvedEnv where - sinkingProofE = todoSinkableProof +emptySolverSubst :: SolverSubst n +emptySolverSubst = SolverSubst mempty -getAtomNames :: Distinct l => EnvFrag n l -> S.Set (CAtomName l) -getAtomNames frag = S.fromList $ nameSetToList $ toNameSet $ toScopeFrag frag - --- TODO: zonk the allowed effects and synth candidates in the bindings too --- TODO: the reason we need this is that `getType` uses the bindings to obtain --- type information, and we need this information when we emit decls. For --- example, if we emit `f x` and we don't know that `f` has a type of the form --- `a -> b` then `getType` will crash. But we control the inference-specific --- implementation of `emitDecl`, so maybe we could instead do something like --- emit a fresh inference variable in the case thea `getType` fails. --- XXX: It might be tempting to add a check for empty solver substs here, --- but please don't do that! We use this function to filter overestimates of --- UnsolvedEnv, and for performance reasons we should do that even when the --- SolverSubst is empty. -zonkUnsolvedEnv :: Distinct n => SolverSubst n -> UnsolvedEnv n -> Env n - -> (UnsolvedEnv n, Env n) -zonkUnsolvedEnv ss unsolved env = - flip runState env $ execWriterT do - forM_ (S.toList $ fromUnsolvedEnv unsolved) \v -> do - flip lookupEnvPure v . topEnv <$> get >>= \case - AtomNameBinding rhs -> do - let rhs' = zonkAtomBindingWithOutMap (InfOutMap env ss mempty mempty Pure) rhs - modify \e -> e {topEnv = updateEnv v (AtomNameBinding rhs') (topEnv e)} - let rhsHasInfVars = runEnvReaderM env $ hasInferenceVars rhs' - when rhsHasInfVars $ tell $ UnsolvedEnv $ S.singleton v - --- TODO: we need this shim because top level emissions can't implement `SubstE --- AtomSubstVal` so GHC doesn't know how to zonk them. If we split up top-level --- emissions from local ones in the name color system then we won't have this --- problem. -zonkAtomBindingWithOutMap - :: Distinct n => InfOutMap n -> AtomBinding CoreIR n -> AtomBinding CoreIR n -zonkAtomBindingWithOutMap outMap = \case - LetBound e -> LetBound $ zonkWithOutMap outMap e - MiscBound e -> MiscBound $ zonkWithOutMap outMap e - SolverBound e -> SolverBound $ zonkWithOutMap outMap e - NoinlineFun t e -> NoinlineFun (zonkWithOutMap outMap t) (zonkWithOutMap outMap e) - FFIFunBound x y -> FFIFunBound (zonkWithOutMap outMap x) (zonkWithOutMap outMap y) - --- TODO: Wouldn't it be faster to carry the set of inference-emitted names in the out map? -hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool -hasInferenceVars e = liftEnvReaderM $ anyInferenceVars $ freeAtomVarsList e -{-# INLINE hasInferenceVars #-} +data InfState (n::S) = InfState + { givens :: Givens n + , infEffects :: EffectRow CoreIR n } -anyInferenceVars :: [CAtomName n] -> EnvReaderM n Bool -anyInferenceVars = \case - [] -> return False - (v:vs) -> isInferenceVar v >>= \case - True -> return True - False -> anyInferenceVars vs +newtype InfererM (i::S) (o::S) (a:: *) = InfererM + { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR FallibleM)) i o a } + deriving (Functor, Applicative, Monad, MonadFail, Alternative, Builder CoreIR, + EnvExtender, ScopableBuilder CoreIR, + ScopeReader, EnvReader, Fallible, Catchable, CtxReader, SubstReader Name) + +type InfererCPSB b i o a = (forall o'. DExt o o' => b o o' -> InfererM i o' a) -> InfererM i o a +type InfererCPSB2 b i i' o a = (forall o'. DExt o o' => b o o' -> InfererM i' o' a) -> InfererM i o a + +liftInfererM :: (EnvReader m, Fallible1 m) => InfererM n n a -> m n a +liftInfererM cont = do + ansM <- liftBuilderT $ runReaderT1 emptyInfState $ runSubstReaderT idSubst $ runInfererM' cont + liftExcept $ runFallibleM ansM + where + emptyInfState :: InfState n + emptyInfState = InfState (Givens HM.empty) Pure +{-# INLINE liftInfererM #-} -isInferenceVar :: EnvReader m => CAtomName n -> m n Bool -isInferenceVar v = lookupEnv v >>= \case - AtomNameBinding (SolverBound _) -> return True - _ -> return False +-- === Solver monad === -instance ExtOutMap InfOutMap InfOutFrag where - extendOutMap m (InfOutFrag em ds' cs) = do - let InfOutMap env ss ds us effs = m `extendOutMap` toEnvFrag em - let ds'' = sink ds <> ds' - let (env', us', ss') = extendOutMapWithConstraints env us ss cs - InfOutMap env' ss' ds'' us' effs - -extendOutMapWithConstraints - :: Distinct n => Env n -> UnsolvedEnv n -> SolverSubst n -> Constraints n - -> (Env n, UnsolvedEnv n, SolverSubst n) -extendOutMapWithConstraints env us ss (Constraints allCs) = case tryUnsnoc allCs of - Nothing -> (env, us, ss) - Just (cs, (v, x)) -> do - let (env', us', SolverSubst ss') = extendOutMapWithConstraints env us ss (Constraints cs) - let s = M.singleton v x - let (us'', env'') = zonkUnsolvedEnv (SolverSubst s) us' env' - let ss'' = fmap (applySolverSubstE env'' (SolverSubst s)) ss' - let ss''' = SolverSubst $ ss'' <> s - (env'', us'', ss''') +type Solution = PairE CAtomName CAtom +newtype SolverDiff (n::S) = SolverDiff (RListE Solution n) + deriving (MonoidE, SinkableE, HoistableE, RenameE) +type SolverM i o a = DiffStateT1 SolverSubst SolverDiff (InfererM i) o a -newtype InfererM (i::S) (o::S) (a:: *) = InfererM - { runInfererM' :: SubstReaderT Name (InplaceT InfOutMap InfOutFrag SearcherM) i o a } - deriving (Functor, Applicative, Monad, MonadFail, Alternative, Searcher, - ScopeReader, Fallible, Catchable, CtxReader, SubstReader Name) +type Zonkable e = (HasNamesE e, SubstE AtomSubstVal e) -liftInfererMSubst :: (Fallible2 m, SubstReader Name m, EnvReader2 m) - => InfererM i o a -> m i o a -liftInfererMSubst cont = do - env <- unsafeGetEnv - subst <- getSubst - Distinct <- getDistinct - (InfOutFrag REmpty _ _, result) <- - liftExcept $ liftM fromJust $ runSearcherM $ runInplaceT (initInfOutMap env) $ - runSubstReaderT subst $ runInfererM' $ cont - return result - -liftInfererM :: (EnvReader m, Fallible1 m) - => InfererM n n a -> m n a -liftInfererM cont = runSubstReaderT idSubst $ liftInfererMSubst $ cont -{-# INLINE liftInfererM #-} +liftSolverM :: SolverM i o a -> InfererM i o a +liftSolverM cont = fst <$> runDiffStateT1 emptySolverSubst cont -runLocalInfererM - :: SinkableE e - => (forall l. (EmitsInf l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (Abs InfOutFrag e n) -runLocalInfererM cont = InfererM $ SubstReaderT $ ReaderT \env -> do - locallyMutableInplaceT (do - Distinct <- getDistinct - EmitsInf <- fabricateEmitsInfEvidenceM - runSubstReaderT (sink env) $ runInfererM' cont) - (\d e -> return $ Abs d e) -{-# INLINE runLocalInfererM #-} - -initInfOutMap :: Env n -> InfOutMap n -initInfOutMap bindings = - InfOutMap bindings emptySolverSubst mempty (UnsolvedEnv mempty) Pure - -newtype InfDeclEmission (n::S) (l::S) = InfDeclEmission (BinderP (AtomNameC CoreIR) InfEmission n l) -instance ExtOutMap InfOutMap InfDeclEmission where - extendOutMap env (InfDeclEmission d) = env `extendOutMap` toEnvFrag d - {-# INLINE extendOutMap #-} -instance ExtOutFrag InfOutFrag InfDeclEmission where - extendOutFrag (InfOutFrag ems ds ss) (InfDeclEmission em) = - withSubscopeDistinct em $ InfOutFrag (RNest ems em) (sink ds) (sink ss) - {-# INLINE extendOutFrag #-} - -emitInfererM :: Mut o => NameHint -> InfEmission o -> InfererM i o (CAtomVar o) -emitInfererM hint emission = do - v <- InfererM $ SubstReaderT $ lift $ freshExtendSubInplaceT hint \b -> - (InfDeclEmission (b :> emission), binderName b) - return $ AtomVar v $ getType emission -{-# INLINE emitInfererM #-} - -extendSolverSubst :: CAtomName n -> CAtom n -> InfererM i n () -extendSolverSubst v ty = do - InfererM $ SubstReaderT $ lift $ - void $ extendTrivialInplaceT $ - InfOutFrag REmpty mempty (singleConstraint v ty) -{-# INLINE extendSolverSubst #-} - -zonk :: (SubstE AtomSubstVal e, SinkableE e) => e n -> InfererM i n (e n) -zonk e = InfererM $ SubstReaderT $ lift do - Distinct <- getDistinct - solverOutMap <- getOutMapInplaceT - return $ zonkWithOutMap solverOutMap e +zonk :: Zonkable e => e n -> SolverM i n (e n) +zonk e = do + s <- getDiffState + applySolverSubst s e {-# INLINE zonk #-} -emitSolver :: EmitsInf n => SolverBinding n -> InfererM i n (CAtomVar n) -emitSolver binding = emitInfererM (getNameHint @String "?") $ RightE binding -{-# INLINE emitSolver #-} - -solveLocal :: HasNamesE e - => (forall l. (EmitsInf l, Ext n l, Distinct l) => InfererM i l (e l)) - -> InfererM i n (e n) -solveLocal cont = do - Abs (InfOutFrag unsolvedInfVars _ _) result <- dceInfFrag =<< runLocalInfererM cont - case unRNest unsolvedInfVars of - Empty -> return result - Nest (b:>RightE (InfVarBound ty (ctx, desc))) _ -> addSrcContext ctx $ - throw TypeErr $ formatAmbiguousVarErr (binderName b) ty desc - _ -> error "shouldn't be possible" +applySolverSubst :: (EnvReader m, Zonkable e) => SolverSubst n -> e n -> m n (e n) +applySolverSubst subst e = do + Distinct <- getDistinct + env <- unsafeGetEnv + return $ fmapNames env (lookupSolverSubst subst) e +{-# INLINE applySolverSubst #-} formatAmbiguousVarErr :: CAtomName n -> CType n' -> InfVarDesc -> String formatAmbiguousVarErr infVar ty = \case AnnotationInfVar v -> "Couldn't infer type of unannotated binder " <> v ImplicitArgInfVar (f, argName) -> - "Couldn't infer implicit argument " <> argName <> " of " <> f + "Couldn't infer implicit argument `" <> argName <> "` of " <> f TypeInstantiationInfVar t -> "Couldn't infer instantiation of type " <> t MiscInfVar -> "Ambiguous type variable: " ++ pprint infVar ++ ": " ++ pprint ty --- XXX: we should almost always used the zonking `buildDeclsInf` , --- except where it's not possible because the result isn't atom-substitutable, --- such as the source map at the top level. -buildDeclsInfUnzonked - :: (SinkableE e, HoistableE e, RenameE e) - => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (Abs (Nest CDecl) e n) -buildDeclsInfUnzonked cont = do - InfererM $ SubstReaderT $ ReaderT \env -> do - Abs frag result <- locallyMutableInplaceT (do - Emits <- fabricateEmitsEvidenceM - EmitsInf <- fabricateEmitsInfEvidenceM - runSubstReaderT (sink env) $ runInfererM' cont) - (\d e -> return $ Abs d e) - extendInplaceT =<< hoistThroughDecls frag result - -buildAbsInf - :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) - => EmitsInf n - => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> InfererM i l (e l)) - -> InfererM i n (Abs CBinder e n) -buildAbsInf hint expl ty cont = do - ab <- InfererM $ SubstReaderT $ ReaderT \env -> do - extendInplaceT =<< withFreshBinder hint ty \bWithTy@(b:>_) -> do - ab <- locallyMutableInplaceT (do - v <- sinkM $ binderVar bWithTy - extendInplaceTLocal (extendSynthCandidatesInf expl $ atomVarName v) do - EmitsInf <- fabricateEmitsInfEvidenceM - -- zonking is needed so that dceInfFrag works properly - runSubstReaderT (sink env) (runInfererM' $ cont v >>= zonk)) - (\d e -> return $ Abs d e) - ab' <- dceInfFrag ab - refreshAbs ab' \infFrag result -> do - case exchangeBs $ PairB b infFrag of - HoistSuccess (PairB infFrag' b') -> do - return $ withSubscopeDistinct b' $ - Abs infFrag' $ Abs b' result - HoistFailure vs -> do - throw EscapedNameErr $ (pprint vs) - ++ "\nFailed to exchange binders in buildAbsInf" - ++ "\n" ++ pprint infFrag - Abs b e <- return ab - ty' <- zonk ty - return $ Abs (b:>ty') e - -dceInfFrag - :: (EnvReader m, EnvExtender m, Fallible1 m, RenameE e, HoistableE e) - => Abs InfOutFrag e n -> m n (Abs InfOutFrag e n) -dceInfFrag ab@(Abs frag@(InfOutFrag bs _ _) e) = - case bs of - REmpty -> return ab - _ -> hoistThroughDecls frag e >>= \case - Abs frag' (Abs Empty e') -> return $ Abs frag' e' - _ -> error "Shouldn't have any decls without `Emits` constraint" - -addDefault :: CAtomName o -> DefaultType ->InfererM i o () -addDefault v defaultType = - InfererM $ SubstReaderT $ lift $ - extendTrivialInplaceT $ InfOutFrag REmpty defaults mempty - where - defaults = case defaultType of - IntDefault -> Defaults (freeVarsE v) mempty - NatDefault -> Defaults mempty (freeVarsE v) - -getDefaults :: InfererM i o (Defaults o) -getDefaults = InfererM $ SubstReaderT $ lift do - InfOutMap _ _ defaults _ _ <- getOutMapInplaceT - return defaults - -instance Builder CoreIR (InfererM i) where - rawEmitDecl hint ann expr = do - -- This zonking, and the zonking of the bindings elsewhere, is only to - -- prevent `getType` from failing. But maybe we should just catch the - -- failure if it occurs and generate a fresh inference name for the type in - -- that case? - expr' <- zonk expr - emitInfererM hint $ LeftE $ DeclBinding ann expr' - {-# INLINE rawEmitDecl #-} - -getAllowedEffects :: InfererM i n (EffectRow CoreIR n) -getAllowedEffects = do - InfOutMap _ _ _ _ effs <- InfererM $ SubstReaderT $ lift $ getOutMapInplaceT - return effs - -withoutEffects :: InfererM i o a -> InfererM i o a -withoutEffects cont = withAllowedEffects Pure cont +withFreshBinderInf :: NameHint -> Explicitness -> CType o -> InfererCPSB CBinder i o a +withFreshBinderInf hint expl ty cont = + withFreshBinder hint ty \b -> do + givens <- case expl of + Inferred _ (Synth _) -> return [Var $ binderVar b] + _ -> return [] + extendGivens givens $ cont b +{-# INLINE withFreshBinderInf #-} + +withFreshBindersInf + :: (SinkableE e, RenameE e) + => [Explicitness] -> Abs (Nest CBinder) e o + -> (forall o'. DExt o o' => Nest CBinder o o' -> e o' -> InfererM i o' a) + -> InfererM i o a +withFreshBindersInf explsTop (Abs bsTop e) contTop = + runSubstReaderT idSubst $ go explsTop bsTop \bs' -> do + e' <- renameM e + liftSubstReaderT $ contTop bs' e' + where + go :: [Explicitness] -> Nest CBinder ii ii' + -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name (InfererM i) ii' o' a) + -> SubstReaderT Name (InfererM i) ii o a + go [] Empty cont = withDistinct $ cont Empty + go (expl:expls) (Nest b bs) cont = do + ty <- renameM $ binderType b + SubstReaderT \s -> withFreshBinderInf (getNameHint b) expl ty \b' -> do + runSubstReaderT (sink s) $ extendSubst (b@>binderName b') do + go expls bs \bs' -> cont (Nest b' bs') + go _ _ _ = error "zip error" +{-# INLINE withFreshBindersInf #-} + +withInferenceVar + :: (Zonkable e, Emits o, ToBinding binding (AtomNameC CoreIR)) => NameHint -> binding o + -> (forall o'. (Emits o', DExt o o') => CAtomName o' -> SolverM i o' (e o', CAtom o')) + -> SolverM i o (e o) +withInferenceVar hint binding cont = diffStateT1 \s -> do + declsAndAns <- withFreshBinder hint binding \(b:>_) -> do + hardHoist b <$> buildScoped do + v <- sinkM $ binderName b + s' <- sinkM s + (PairE ans soln, diff) <- runDiffStateT1 s' do + toPairE <$> cont v + let subst = SolverSubst $ M.singleton v soln + ans' <- applySolverSubst subst ans + diff' <- applySolutionToDiff subst v diff + return $ PairE ans' diff' + fromPairE <$> emitDecls declsAndAns + where + applySolutionToDiff :: SolverSubst n -> CAtomName n -> SolverDiff n -> InfererM i n (SolverDiff n) + applySolutionToDiff subst vSoln (SolverDiff (RListE (ReversedList cs))) = do + SolverDiff . RListE . ReversedList <$> forMFilter cs \(PairE v x) -> + case v == vSoln of + True -> return Nothing + False -> Just . PairE v <$> applySolverSubst subst x +{-# INLINE withInferenceVar #-} + +withFreshUnificationVar + :: (Zonkable e, Emits o) => InfVarDesc -> Kind CoreIR o + -> (forall o'. (Emits o', DExt o o') => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshUnificationVar desc k cont = do + -- TODO: we shouldn't need the context stuff on `InfVarBound` anymore + ctx <- srcPosCtx <$> getErrCtx + withInferenceVar "_unif_" (InfVarBound k (ctx, desc)) \v -> do + ans <- toAtomVar v >>= cont + soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case + Just soln -> return soln + Nothing -> throw TypeErr $ formatAmbiguousVarErr v k desc + return (ans, soln) +{-# INLINE withFreshUnificationVar #-} + +withFreshUnificationVarNoEmits + :: (Zonkable e) => InfVarDesc -> Kind CoreIR o + -> (forall o'. (DExt o o') => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshUnificationVarNoEmits desc k cont = diffStateT1 \s -> do + Abs Empty resultAndDiff <- buildScoped do + liftM toPairE $ runDiffStateT1 (sink s) $ + withFreshUnificationVar desc (sink k) cont + return $ fromPairE resultAndDiff + +withFreshDictVar + :: (Zonkable e, Emits o) => CType o + -- This tells us how to synthesize the dict. The supplied CType won't contain inference vars. + -> (forall o'. ( DExt o o') => CType o' -> SolverM i o' (CAtom o')) + -> (forall o'. (Emits o', DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshDictVar dictTy synthIt cont = hasInferenceVars dictTy >>= \case + False -> withDistinct $ synthIt dictTy >>= cont + True -> withInferenceVar "_dict_" (DictBound dictTy) \v -> do + ans <- cont =<< (Var <$> toAtomVar v) + dictTy' <- zonk $ sink dictTy + dict <- synthIt dictTy' + return (ans, dict) +{-# INLINE withFreshDictVar #-} + +withFreshDictVarNoEmits + :: (Zonkable e) => CType o + -- This tells us how to synthesize the dict. The supplied CType won't contain inference vars. + -> (forall o'. (DExt o o') => CType o' -> SolverM i o' (CAtom o')) + -> (forall o'. (DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshDictVarNoEmits dictTy synthIt cont = diffStateT1 \s -> do + Abs Empty resultAndDiff <- buildScoped do + liftM toPairE $ runDiffStateT1 (sink s) $ + withFreshDictVar (sink dictTy) synthIt cont + return $ fromPairE resultAndDiff +{-# INLINE withFreshDictVarNoEmits #-} + +withDict + :: (Zonkable e, Emits o) => Kind CoreIR o + -> (forall o'. (Emits o', DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withDict dictTy cont = withFreshDictVar dictTy + (\dictTy' -> lift11 $ trySynthTerm dictTy' Full) + cont +{-# INLINE withDict#-} + +addConstraint :: CAtomName o -> CAtom o -> SolverM i o () +addConstraint v ty = updateDiffStateM (SolverDiff $ RListE $ toSnocList [PairE v ty]) +{-# INLINE addConstraint #-} + +getInfState :: InfererM i o (InfState o) +getInfState = InfererM $ liftSubstReaderT ask +{-# INLINE getInfState #-} + +withInfState :: (InfState o -> InfState o) -> InfererM i o a -> InfererM i o a +withInfState f cont = InfererM $ local f (runInfererM' cont) +{-# INLINE withInfState #-} withAllowedEffects :: EffectRow CoreIR o -> InfererM i o a -> InfererM i o a -withAllowedEffects effs cont = do - InfererM $ SubstReaderT $ ReaderT \env -> do - extendInplaceTLocal (\(InfOutMap x y z w _) -> InfOutMap x y z w effs) do - runSubstReaderT env $ runInfererM' do - cont - -type InferenceNameBinders = Nest (BinderP (AtomNameC CoreIR) SolverBinding) - --- When we finish building a block of decls we need to hoist the local solver --- information into the outer scope. If the local solver state mentions local --- variables which are about to go out of scope then we emit a "escaped scope" --- error. To avoid false positives, we clean up as much dead (i.e. solved) --- solver state as possible. -hoistThroughDecls - :: ( RenameE e, HoistableE e, Fallible1 m, ScopeReader m, EnvExtender m) - => InfOutFrag n l - -> e l - -> m n (Abs InfOutFrag (Abs (Nest CDecl) e) n) -hoistThroughDecls outFrag result = do - env <- unsafeGetEnv - refreshAbs (Abs outFrag result) \outFrag' result' -> do - liftExcept $ hoistThroughDecls' env outFrag' result' -{-# INLINE hoistThroughDecls #-} - -hoistThroughDecls' - :: (HoistableE e, Distinct l) - => Env n - -> InfOutFrag n l - -> e l - -> Except (Abs InfOutFrag (Abs (Nest CDecl) e) n) -hoistThroughDecls' env (InfOutFrag emissions defaults constraints) result = do - withSubscopeDistinct emissions do - let subst = constraintsToSubst (env `extendOutMap` toEnvFrag emissions) constraints - HoistedSolverState infVars defaults' subst' decls result' <- - hoistInfStateRec env emissions emptyInferenceNameBindersFV - (zonkDefaults subst defaults) (UnhoistedSolverSubst emptyOutFrag subst) Empty result - let constraints' = substToConstraints subst' - let hoistedInfFrag = InfOutFrag (infNamesToEmissions infVars) defaults' constraints' - return $ Abs hoistedInfFrag $ Abs decls result' - -constraintsToSubst :: Distinct n => Env n -> Constraints n -> SolverSubst n -constraintsToSubst env (Constraints csTop) = case tryUnsnoc csTop of - Nothing -> emptySolverSubst - Just (cs, (v, x)) -> do - let SolverSubst m = constraintsToSubst env (Constraints cs) - let s = M.singleton v x - SolverSubst $ fmap (applySolverSubstE env (SolverSubst s)) m <> s - -substToConstraints :: SolverSubst n -> Constraints n -substToConstraints (SolverSubst m) = Constraints $ toSnocList $ M.toList m - -data HoistedSolverState e n where - HoistedSolverState - :: InferenceNameBinders n l1 - -> Defaults l1 - -> SolverSubst l1 - -> Nest CDecl l1 l2 - -> e l2 - -> HoistedSolverState e n - --- XXX: Be careful how you construct DelayedSolveNests! When the substitution is --- applied, the pieces are concatenated through regular map concatenation, not --- through recursive substitutions as in catSolverSubsts! This is safe to do when --- the individual SolverSubsts come from a projection of a larger SolverSubst, --- which is how we use them in `hoistInfStateRec`. -type DelayedSolveNest (b::B) (n::S) (l::S) = Nest (EitherB b (LiftB SolverSubst)) n l - -resolveDelayedSolve :: Distinct l => Env n -> SolverSubst n -> DelayedSolveNest CDecl n l -> Nest CDecl n l -resolveDelayedSolve env subst = \case - Empty -> Empty - Nest (RightB (LiftB sfrag)) rest -> resolveDelayedSolve env (subst `unsafeCatSolverSubst` sfrag) rest - Nest (LeftB (Let b rhs) ) rest -> - withSubscopeDistinct rest $ withSubscopeDistinct b $ - Nest (Let b (applySolverSubstE env subst rhs)) $ - resolveDelayedSolve (env `extendOutMap` toEnvFrag (b:>rhs)) (sink subst) rest - where - unsafeCatSolverSubst :: SolverSubst n -> SolverSubst n -> SolverSubst n - unsafeCatSolverSubst (SolverSubst a) (SolverSubst b) = SolverSubst $ a <> b - -data InferenceNameBindersFV (n::S) (l::S) = InferenceNameBindersFV (NameSet n) (InferenceNameBinders n l) -instance BindsNames InferenceNameBindersFV where - toScopeFrag = toScopeFrag . dropInferenceNameBindersFV -instance BindsEnv InferenceNameBindersFV where - toEnvFrag = toEnvFrag . dropInferenceNameBindersFV -instance ProvesExt InferenceNameBindersFV where - toExtEvidence = toExtEvidence . dropInferenceNameBindersFV -instance HoistableB InferenceNameBindersFV where - freeVarsB (InferenceNameBindersFV fvs _) = fvs - -emptyInferenceNameBindersFV :: InferenceNameBindersFV n n -emptyInferenceNameBindersFV = InferenceNameBindersFV mempty Empty - -dropInferenceNameBindersFV :: InferenceNameBindersFV n l -> InferenceNameBinders n l -dropInferenceNameBindersFV (InferenceNameBindersFV _ bs) = bs - -prependNameBinder - :: BinderP (AtomNameC CoreIR) SolverBinding n q - -> InferenceNameBindersFV q l -> InferenceNameBindersFV n l -prependNameBinder b (InferenceNameBindersFV fvs bs) = - InferenceNameBindersFV (freeVarsB b <> hoistFilterNameSet b fvs) (Nest b bs) - --- XXX: Stashing Distinct here is a little naughty, since that's generally not allowed. --- Here it should be ok, because it's only used in hoistInfStateRec, which doesn't emit. -data UnhoistedSolverSubst (n::S) where - UnhoistedSolverSubst :: Distinct l => ScopeFrag n l -> SolverSubst l -> UnhoistedSolverSubst n - -delayedHoistSolverSubst :: BindsNames b => b n l -> UnhoistedSolverSubst l -> UnhoistedSolverSubst n -delayedHoistSolverSubst b (UnhoistedSolverSubst frag s) = UnhoistedSolverSubst (toScopeFrag b >>> frag) s - -hoistSolverSubst :: UnhoistedSolverSubst n -> HoistExcept (SolverSubst n) -hoistSolverSubst (UnhoistedSolverSubst frag s) = hoist frag s - --- TODO: Instead of delaying the solve, compute the most-nested scope once --- and then use it for all _eager_ substitutions while hoisting! Using a super-scope --- for substitution shouldn't be a problem! -hoistInfStateRec - :: forall n l l1 l2 e. (Distinct n, Distinct l2, HoistableE e) - => Env n -> InfEmissions n l - -> InferenceNameBindersFV l l1 -> Defaults l1 -> UnhoistedSolverSubst l1 - -> DelayedSolveNest CDecl l1 l2 -> e l2 - -> Except (HoistedSolverState e n) -hoistInfStateRec env emissions !infVars defaults !subst decls e = case emissions of - REmpty -> do - subst' <- liftHoistExcept' "Failed to hoist solver substitution in hoistInfStateRec" - $ hoistSolverSubst subst - let decls' = withSubscopeDistinct decls $ - resolveDelayedSolve (env `extendOutMap` toEnvFrag infVars) subst' decls - return $ HoistedSolverState (dropInferenceNameBindersFV infVars) defaults subst' decls' e - RNest rest (b :> infEmission) -> do - withSubscopeDistinct decls do - case infEmission of - RightE binding@(InfVarBound _ _) -> do - UnhoistedSolverSubst frag (SolverSubst substMap) <- return subst - let vHoist :: CAtomName l1 = withSubscopeDistinct infVars $ sink $ binderName b -- binder name at l1 - let vUnhoist = withExtEvidence frag $ sink vHoist -- binder name below frag - case M.lookup vUnhoist substMap of - -- Unsolved inference variables are just gathered as they are. - Nothing -> - hoistInfStateRec env rest (prependNameBinder (b:>binding) infVars) - defaults subst decls e - -- If a variable is solved, we eliminate it. - Just bSolutionUnhoisted -> do - bSolution <- - liftHoistExcept' "Failed to eliminate solved variable in hoistInfStateRec " - $ hoist frag bSolutionUnhoisted - case exchangeBs $ PairB b infVars of - -- This used to be accepted by the code at some point (and handled the same way - -- as the Nothing) branch above, but I don't understand why. We don't even seem - -- to be exercising it anyway, so throw a not implemented error for now. - HoistFailure _ -> throw NotImplementedErr "Unzonked unsolved variables" - HoistSuccess (PairB infVars' b') -> do - let defaults' = hoistDefaults b' defaults - let bZonkedDecls = Nest (RightB (LiftB $ SolverSubst $ M.singleton vHoist bSolution)) decls -#ifdef DEX_DEBUG - -- Hoist the subst eagerly, unlike the unsafe implementation. - hoistedSubst@(SolverSubst hoistMap) <- liftHoistExcept $ hoistSolverSubst subst - let subst' = withSubscopeDistinct b' $ UnhoistedSolverSubst (toScopeFrag b') $ - SolverSubst $ M.delete vHoist hoistMap - -- Zonk the decls with `v @> bSolution` to make sure hoisting will succeed. - -- This is quadratic, which is why we don't do this in the fast implementation! - let allEmissions = RNest rest (b :> infEmission) - let declsScope = withSubscopeDistinct infVars $ - (env `extendOutMap` toEnvFrag allEmissions) `extendOutMap` toEnvFrag infVars - let resolvedDecls = resolveDelayedSolve declsScope hoistedSubst bZonkedDecls - PairB resolvedDecls' b'' <- liftHoistExcept $ exchangeBs $ PairB b' resolvedDecls - let decls' = fmapNest LeftB resolvedDecls' - -- NB: We assume that e is hoistable above e! This has to be taken - -- care of by zonking the result before this function is entered. - e' <- liftHoistExcept $ hoist b'' e - withSubscopeDistinct b'' $ - hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#else - -- SolverSubst should be recursively zonked, so any v that's a member - -- should never appear in an rhs. Hence, deleting the entry corresponding to - -- v should hoist the substitution above b'. - let subst' = unsafeCoerceE $ UnhoistedSolverSubst frag $ SolverSubst $ M.delete vUnhoist substMap - -- Applying the substitution `v @> bSolution` would eliminate `b` from decls, so this - -- is equivalent to hoisting above b'. This is of course not reflected in the type - -- system, which is why we use unsafe coercions. - let decls' = unsafeCoerceB bZonkedDecls - -- This is much more sketchy, but it reflects the e-hoistability assumption - -- that our safe implementation makes as well. Except here it's obviously unchecked. - let e' :: e UnsafeS = unsafeCoerceE e - Distinct <- return $ fabricateDistinctEvidence @UnsafeS - hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#endif - RightE (SkolemBound _) -> do -#ifdef DEX_DEBUG - PairB infVars' b' <- liftHoistExcept' "Skolem leak?" $ exchangeBs $ PairB b infVars - defaults' <- liftHoistExcept' "Skolem leak?" $ hoist b' defaults - let subst' = delayedHoistSolverSubst b' subst - PairB decls' b'' <- liftHoistExcept' "Skolem leak?" $ exchangeBs $ PairB b' decls - e' <- liftHoistExcept' "Skolem leak?" $ hoist b'' e - withSubscopeDistinct b'' $ hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#else - -- Skolem vars are only instantiated in unification, and we're very careful to - -- never let them leak into the types of inference vars emitted while unifying - -- and into the solver subst. - Distinct <- return $ fabricateDistinctEvidence @UnsafeS - hoistInfStateRec @n @UnsafeS @UnsafeS @UnsafeS - env - (unsafeCoerceB rest) (unsafeCoerceB infVars) - (unsafeCoerceE defaults) (unsafeCoerceE subst) - (unsafeCoerceB decls) (unsafeCoerceE e) -#endif - LeftE emission -> do - -- Move the binder below all unsolved inference vars. Failure to do so is - -- an inference error --- a variable cannot be solved once we exit the env - -- of all variables it mentions in its type. - -- TODO: Shouldn't this be an ambiguous type error? - PairB infVars' (b':>emission') <- - liftHoistExcept' "Failed to move binder below unsovled inference vars" - $ exchangeBs (PairB (b:>emission) infVars) - -- Now, those are real leakage errors. We never want to leak this var through a solution! - -- But since we delay hoisting, they will only be raised later. - let subst' = delayedHoistSolverSubst b' subst - let defaults' = hoistDefaults b' defaults - let decls' = Nest (LeftB (Let b' emission')) decls - hoistInfStateRec env rest infVars' defaults' subst' decls' e - -hoistDefaults :: BindsNames b => b n l -> Defaults l -> Defaults n -hoistDefaults b (Defaults d1 d2) = Defaults (hoistFilterNameSet b d1) - (hoistFilterNameSet b d2) - -infNamesToEmissions :: InferenceNameBinders n l -> InfEmissions n l -infNamesToEmissions = go REmpty - where - go :: InfEmissions n q -> InferenceNameBinders q l -> InfEmissions n l - go acc = \case - Empty -> acc - Nest (b:>binding) rest -> go (RNest acc (b:>RightE binding)) rest - -instance EnvReader (InfererM i) where - unsafeGetEnv = do - InfOutMap bindings _ _ _ _ <- InfererM $ SubstReaderT $ lift $ getOutMapInplaceT - return bindings - {-# INLINE unsafeGetEnv #-} - -instance EnvExtender (InfererM i) where - refreshAbs ab cont = InfererM $ SubstReaderT $ ReaderT \env -> do - refreshAbs ab \b e -> runSubstReaderT (sink env) $ runInfererM' $ cont b e - {-# INLINE refreshAbs #-} - --- === helpers for extending synthesis candidates === - --- TODO: we should pull synth candidates out of the Env and then we can treat it --- like an ordinary reader without all this ceremony. - -extendSynthCandidatesInf :: Explicitness -> CAtomName n -> InfOutMap n -> InfOutMap n -extendSynthCandidatesInf c v (InfOutMap env x y z w) = - InfOutMap (extendSynthCandidates c v env) x y z w -{-# INLINE extendSynthCandidatesInf #-} - -extendSynthCandidates :: Explicitness -> CAtomName n -> Env n -> Env n -extendSynthCandidates (Inferred _ (Synth _)) v (Env topEnv (ModuleEnv a b scs)) = - Env topEnv (ModuleEnv a b scs') - where scs' = scs <> SynthCandidates [v] mempty -extendSynthCandidates _ _ env = env -{-# INLINE extendSynthCandidates #-} - -extendSynthCandidatess :: Distinct n => [Explicitness] -> Nest CBinder n' n -> Env n -> Env n -extendSynthCandidatess (expl:expls) (Nest b bs) env = - extendSynthCandidatess expls bs env' - where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env -extendSynthCandidatess [] Empty env = env -extendSynthCandidatess _ _ _ = error "zip error" -{-# INLINE extendSynthCandidatess #-} +withAllowedEffects effs cont = withInfState (\(InfState g _) -> InfState g effs) cont +{-# INLINE withAllowedEffects #-} -- === actual inference pass === -data RequiredTy (e::E) (n::S) = - Check (e n) +data RequiredTy (n::S) = + Check (CType n) | Infer deriving Show -checkSigma :: EmitsBoth o - => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of - Pi piTy@(CorePiType _ expls _ _) -> do - if all (== Explicit) expls - then fallback - else case expr of - WithSrcE src (ULam lam) -> addSrcContext src $ Lam <$> checkULam lam piTy - _ -> Lam <$> buildLamInf piTy \args resultTy -> do - explicits <- return $ catMaybes $ args <&> \case - (Explicit, arg) -> Just arg - _ -> Nothing - expr' <- inferWithoutInstantiation expr >>= zonk - dropSubst $ checkOrInferApp expr' explicits [] (Check resultTy) - DepPairTy depPairTy -> case depPairTy of - DepPairType ImplicitDepPair (_ :> lhsTy) _ -> do - -- TODO: check for the case that we're given some of the implicit dependent pair args explicitly - lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy - -- TODO: make an InfVarDesc case for dep pair instantiation - rhsTy <- instantiate depPairTy [lhsVal] - rhsVal <- checkSigma noHint expr rhsTy - return $ DepPair lhsVal rhsVal depPairTy - _ -> fallback - _ -> fallback - where fallback = checkOrInferRho hint expr (Check sTy) - -inferSigma :: EmitsBoth o => NameHint -> UExpr i -> InfererM i o (CAtom o) -inferSigma hint (WithSrcE pos expr) = case expr of - ULam lam -> addSrcContext pos $ Lam <$> inferULam lam - _ -> inferRho hint (WithSrcE pos expr) - -checkRho :: EmitsBoth o => - NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkRho hint expr ty = checkOrInferRho hint expr (Check ty) -{-# INLINE checkRho #-} - -inferRho :: EmitsBoth o => - NameHint -> UExpr i -> InfererM i o (CAtom o) -inferRho hint expr = checkOrInferRho hint expr Infer -{-# INLINE inferRho #-} - -getImplicitArg :: EmitsInf o => InferenceArgDesc -> InferenceMechanism -> CType o -> InfererM i o (CAtom o) -getImplicitArg desc inf argTy = case inf of - Unify -> Var <$> freshInferenceName (ImplicitArgInfVar desc) argTy - Synth reqMethodAccess -> do - ctx <- srcPosCtx <$> getErrCtx - return $ DictHole (AlwaysEqual ctx) argTy reqMethodAccess - -withBlockDecls - :: EmitsBoth o - => UBlock i -> (forall i'. UExpr i' -> InfererM i' o a) -> InfererM i o a -withBlockDecls (WithSrcE src (UBlock declsTop result)) contTop = - addSrcContext src $ go declsTop $ contTop result where - go :: EmitsBoth o => Nest UDecl i i' -> InfererM i' o a -> InfererM i o a - go decls cont = case decls of - Empty -> cont - Nest d ds -> withUDecl d $ go ds $ cont - -withUDecl - :: EmitsBoth o - => UDecl i i' - -> InfererM i' o a - -> InfererM i o a -withUDecl (WithSrcB src d) cont = addSrcContext src case d of - UPass -> cont - UExprDecl e -> inferSigma noHint e >> cont - ULet letAnn p ann rhs -> do - val <- checkMaybeAnnExpr (getNameHint p) ann rhs - var <- emitDecl (getNameHint p) letAnn $ Atom val - bindLetPat p var cont - --- "rho" means the required type here should not be (at the top level) an implicit pi type or --- an implicit dependent pair type. We don't want to unify those directly. --- The name hint names the object being computed -checkOrInferRho - :: forall i o. EmitsBoth o - => NameHint -> UExpr i -> RequiredTy CType o -> InfererM i o (CAtom o) -checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do - addSrcContext pos $ confuseGHC >>= \_ -> case expr of - UVar _ -> inferAndInstantiate - ULit l -> matchRequirement $ Con $ Lit l - ULam lamExpr -> do +data PartialPiType (n::S) where + PartialPiType + :: AppExplicitness -> [Explicitness] + -> Nest CBinder n l + -> EffectRow CoreIR l + -> RequiredTy l + -> PartialPiType n + +data PartialType (n::S) = + PartialType (PartialPiType n) + | FullType (CType n) + +checkOrInfer :: Emits o => RequiredTy o -> UExpr i -> InfererM i o (CAtom o) +checkOrInfer reqTy expr = case reqTy of + Infer -> bottomUp expr + Check t -> topDown t expr + +topDown :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) +topDown ty uexpr = topDownPartial (typeAsPartialType ty) uexpr + +topDownPartial :: Emits o => PartialType o -> UExpr i -> InfererM i o (CAtom o) +topDownPartial partialTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos $ + case partialTy of + PartialType partialPiTy -> case expr of + ULam lam -> Lam <$> checkULamPartial partialPiTy lam + _ -> Lam <$> etaExpandPartialPi partialPiTy \resultTy explicitArgs -> do + expr' <- bottomUpExplicit exprWithSrc + dropSubst $ checkOrInferApp expr' explicitArgs [] resultTy + FullType ty -> topDownExplicit ty exprWithSrc + +-- Creates a lambda for all args and returns (via CPA) the explicit args +etaExpandPartialPi + :: PartialPiType o + -> (forall o'. (Emits o', DExt o o') => RequiredTy o' -> [CAtom o'] -> InfererM i o' (CAtom o')) + -> InfererM i o (CoreLamExpr o) +etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do + withFreshBindersInf expls (Abs bs (PairE effs reqTy)) \bs' (PairE effs' reqTy') -> do + let args = zip expls (Var <$> bindersVars bs') + explicits <- return $ catMaybes $ args <&> \case + (Explicit, arg) -> Just arg + _ -> Nothing + withAllowedEffects effs' do + body <- buildScoped $ cont (sink reqTy') (sink <$> explicits) + resultTy <- blockTy body + let piTy = CorePiType appExpl expls bs' (EffTy effs' resultTy) + return $ CoreLamExpr piTy $ LamExpr bs' body + +-- Doesn't introduce implicit pi binders or dependent pairs +topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) +topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case expr of + ULam lamExpr -> case reqTy of + Pi piTy -> Lam <$> checkULam lamExpr piTy + _ -> throw TypeErr $ "Unexpected lambda. Expected: " ++ pprint reqTy + UFor dir uFor -> case reqTy of + TabPi tabPiTy -> do + lam@(UnaryLamExpr b' _) <- checkUForExpr uFor tabPiTy + ixTy <- asIxType $ binderType b' + emitHof $ For dir ixTy lam + _ -> throw TypeErr $ "Unexpected `for` expression. Expected: " ++ pprint reqTy + UApp f posArgs namedArgs -> do + f' <- bottomUpExplicit f + checkOrInferApp f' posArgs namedArgs (Check reqTy) + UDepPair lhs rhs -> case reqTy of + DepPairTy ty@(DepPairType _ (_ :> lhsTy) _) -> do + lhs' <- checkSigmaDependent lhs (FullType lhsTy) + rhsTy <- instantiate ty [lhs'] + rhs' <- topDown rhsTy rhs + return $ DepPair lhs' rhs' ty + _ -> throw TypeErr $ "Unexpected dependent pair. Expected: " ++ pprint reqTy + UCase scrut alts -> do + scrut' <- bottomUp scrut + let scrutTy = getType scrut' + alts' <- mapM (checkCaseAlt (Check reqTy) scrutTy) alts + buildSortedCase scrut' alts' reqTy + UDo block -> withBlockDecls block \result -> topDownExplicit (sink reqTy) result + UTabCon xs -> do case reqTy of - Check (Pi piTy) -> Lam <$> checkULam lamExpr piTy - Check _ -> Lam <$> inferULam lamExpr >>= matchRequirement - Infer -> Lam <$> inferULam lamExpr + TabPi tabPiTy -> checkTabCon tabPiTy xs + _ -> throw TypeErr $ "Unexpected table constructor. Expected: " ++ pprint reqTy + UNatLit x -> do + let litVal = Con $ Lit $ Word64Lit $ fromIntegral x + applyFromLiteralMethod reqTy "from_unsigned_integer" litVal + UIntLit x -> do + let litVal = Con $ Lit $ Int64Lit $ fromIntegral x + applyFromLiteralMethod reqTy "from_integer" litVal + UPrim UTuple xs -> case reqTy of + TyKind -> Type . ProdTy <$> mapM checkUType xs + ProdTy reqTys -> do + when (length reqTys /= length xs) $ throw TypeErr "Tuple length mismatch" + ProdVal <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x + _ -> throw TypeErr $ "Unexpected tuple. Expected: " ++ pprint reqTy + UFieldAccess _ _ -> infer + UVar _ -> infer + UTypeAnn _ _ -> infer + UTabApp _ _ -> infer + UFloatLit _ -> infer + UPrim _ _ -> infer + ULit _ -> infer + UPi _ -> infer + UTabPi _ -> infer + UDepPairTy _ -> infer + UHole -> throw TypeErr "Can't infer value of hole" + where + infer :: InfererM i o (CAtom o) + infer = do + sigmaAtom <- maybeInterpretPunsAsTyCons (Check reqTy) =<< bottomUpExplicit exprWithSrc + instantiateSigma (Check reqTy) sigmaAtom + +bottomUp :: Emits o => UExpr i -> InfererM i o (CAtom o) +bottomUp expr = bottomUpExplicit expr >>= instantiateSigma Infer + +-- Doesn't instantiate implicit args +bottomUpExplicit :: Emits o => UExpr i -> InfererM i o (SigmaAtom o) +bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of + UVar ~(InternalName _ sn v) -> do + v' <- renameM v + ty <- getUVarType v' + return $ SigmaUVar sn ty v' + ULit l -> return $ SigmaAtom Nothing $ Con $ Lit l + UFieldAccess x (WithSrc pos' field) -> addSrcContext pos' do + x' <- bottomUp x + ty <- return $ getType x' + fields <- getFieldDefs ty + case M.lookup field fields of + Just def -> case def of + FieldProj i -> SigmaAtom Nothing <$> projectField i x' + FieldDotMethod method (TyConParams _ params) -> do + method' <- toAtomVar method + resultTy <- partialAppType (getType method') (params ++ [x']) + return $ SigmaPartialApp resultTy (Var method') (params ++ [x']) + Nothing -> throw TypeErr $ + "Can't resolve field " ++ pprint field ++ " of type " ++ pprint ty ++ + "\nKnown fields are: " ++ pprint (M.keys fields) + ULam lamExpr -> SigmaAtom Nothing <$> Lam <$> inferULam lamExpr UFor dir uFor -> do - lam@(UnaryLamExpr b' _) <- case reqTy of - Check (TabPi tabPiTy) -> do checkUForExpr uFor tabPiTy - Check _ -> inferUForExpr uFor - Infer -> inferUForExpr uFor + lam@(UnaryLamExpr b' _) <- inferUForExpr uFor ixTy <- asIxType $ binderType b' - result <- emitHof $ For dir ixTy lam - matchRequirement result + liftM (SigmaAtom Nothing) $ emitHof $ For dir ixTy lam UApp f posArgs namedArgs -> do - f' <- inferWithoutInstantiation f >>= zonk - checkOrInferApp f' posArgs namedArgs reqTy + f' <- bottomUpExplicit f + SigmaAtom Nothing <$> checkOrInferApp f' posArgs namedArgs Infer UTabApp tab args -> do - tab' <- inferRho noHint tab >>= zonk - inferTabApp (srcPos tab) tab' args >>= matchRequirement + tab' <- bottomUp tab + SigmaAtom Nothing <$> inferTabApp (srcPos tab) tab' args UPi (UPiExpr bs appExpl effs ty) -> do -- TODO: check explicitness constraints - ab <- withUBinders bs \_ -> EffTy <$> checkUEffRow effs <*> checkUType ty - Abs bs' effTy' <- return ab - let (expls, bs'') = unzipAttrs bs' - matchRequirement $ Type $ Pi $ CorePiType appExpl expls bs'' effTy' - UTabPi (UTabPiExpr (UAnnBinder b ann cs) ty) -> do - unless (null cs) $ throw TypeErr "`=>` shouldn't have constraints" - ann' <- asIxType =<< checkAnn (getSourceName b) ann - piTy <- case b of - UIgnore -> - buildTabPiInf noHint ann' \_ -> checkUType ty - _ -> buildTabPiInf (getNameHint b) ann' \v -> extendRenamer (b@>atomVarName v) do - let msg = "Can't reduce type expression: " ++ docAsStr (pretty ty) - Type rhs <- withReducibleEmissions msg $ Type <$> checkUType ty - return rhs - matchRequirement $ Type $ TabPi piTy - UDepPairTy (UDepPairType expl (UAnnBinder b ann cs) rhs) -> do - unless (null cs) $ throw TypeErr "Dependent pair binders shouldn't have constraints" - ann' <- checkAnn (getSourceName b) ann - matchRequirement =<< liftM (Type . DepPairTy) do - buildDepPairTyInf (getNameHint b) expl ann' \v -> extendRenamer (b@>atomVarName v) do - let msg = "Can't reduce type expression: " ++ docAsStr (pretty rhs) - withReducibleEmissions msg $ checkUType rhs - UDepPair lhs rhs -> do - case reqTy of - Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do - lhs' <- checkSigmaDependent noHint lhs lhsTy - rhsTy <- instantiate ty [lhs'] - rhs' <- checkSigma noHint rhs rhsTy - return $ DepPair lhs' rhs' ty - _ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate it" - UCase scrut alts -> do - scrut' <- inferRho noHint scrut - scrutTy <- return $ getType scrut' - reqTy' <- case reqTy of - Infer -> freshType - Check req -> return req - alts' <- mapM (checkCaseAlt reqTy' scrutTy) alts - scrut'' <- zonk scrut' - buildSortedCase scrut'' alts' reqTy' - UDo block -> withBlockDecls block \result -> checkOrInferRho hint result reqTy - UTabCon xs -> inferTabCon hint xs reqTy >>= matchRequirement - UHole -> case reqTy of - Infer -> throw MiscErr "Can't infer type of hole" - Check ty -> freshAtom ty + withUBinders bs \(ZipB expls bs') -> do + effTy' <- EffTy <$> checkUEffRow effs <*> checkUType ty + return $ SigmaAtom Nothing $ Type $ + Pi $ CorePiType appExpl expls bs' effTy' + UTabPi (UTabPiExpr b ty) -> do + Abs b' ty' <- withUBinder b \(WithAttrB _ b') -> + liftM (Abs b') $ checkUType ty + d <- getIxDict $ binderType b' + let piTy = TabPiType d b' ty' + return $ SigmaAtom Nothing $ Type $ TabPi piTy + UDepPairTy (UDepPairType expl b rhs) -> do + withUBinder b \(WithAttrB _ b') -> do + rhs' <- checkUType rhs + return $ SigmaAtom Nothing $ Type $ DepPairTy $ DepPairType expl b' rhs' + UDepPair _ _ -> throw TypeErr $ + "Can't infer the type of a dependent pair; please annotate its type" + UCase scrut (alt:alts) -> do + scrut' <- bottomUp scrut + let scrutTy = getType scrut' + alt'@(IndexedAlt _ altAbs) <- checkCaseAlt Infer scrutTy alt + Abs b ty <- liftEnvReaderM $ refreshAbs altAbs \b body -> do + ty <- blockTy body + return $ Abs b ty + resultTy <- liftHoistExcept $ hoist b ty + alts' <- mapM (checkCaseAlt (Check resultTy) scrutTy) alts + SigmaAtom Nothing <$> buildSortedCase scrut' (alt':alts') resultTy + UCase _ [] -> throw TypeErr "Can't infer empty case expressions" + UDo block -> withBlockDecls block \result -> bottomUpExplicit result + UTabCon xs -> liftM (SigmaAtom Nothing) $ inferTabCon xs UTypeAnn val ty -> do - ty' <- zonk =<< checkUType ty - val' <- checkSigma hint val ty' - matchRequirement val' - UPrim UTuple xs -> case reqTy of - Check TyKind -> Type . ProdTy <$> mapM checkUType xs - _ -> do - xs' <- mapM (inferRho noHint) xs - matchRequirement $ ProdVal xs' + ty' <- checkUType ty + liftM (SigmaAtom Nothing) $ topDown ty' val + UPrim UTuple xs -> do + xs' <- forM xs \x -> bottomUp x + return $ SigmaAtom Nothing $ ProdVal xs' UPrim UMonoLiteral [WithSrcE _ l] -> case l of - UIntLit x -> matchRequirement $ Con $ Lit $ Int32Lit $ fromIntegral x - UNatLit x -> matchRequirement $ Con $ Lit $ Word32Lit $ fromIntegral x + UIntLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Int32Lit $ fromIntegral x + UNatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Word32Lit $ fromIntegral x _ -> throw MiscErr "argument to %monoLit must be a literal" UPrim UExplicitApply (f:xs) -> do - f' <- inferWithoutInstantiation f - xs' <- mapM (inferRho noHint) xs - applySigmaAtom f' xs' >>= matchRequirement + f' <- bottomUpExplicit f + xs' <- mapM bottomUp xs + SigmaAtom Nothing <$> applySigmaAtom f' xs' UPrim UProjNewtype [x] -> do - x' <- inferRho hint x >>= emitHinted hint . Atom - unwrapNewtype $ Var x' + x' <- bottomUp x >>= unwrapNewtype + return $ SigmaAtom Nothing x' UPrim prim xs -> do xs' <- forM xs \x -> do inferPrimArg x >>= \case @@ -1031,44 +544,88 @@ checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do LetBound (DeclBinding _ (Atom e)) -> return e _ -> return $ Var v x' -> return x' - matchRequirement =<< matchPrimApp prim xs' - UFieldAccess _ _ -> inferAndInstantiate - UNatLit x -> do - let defaultVal = Con $ Lit $ Word32Lit $ fromIntegral x - let litVal = Con $ Lit $ Word64Lit $ fromIntegral x - matchRequirement =<< applyFromLiteralMethod "from_unsigned_integer" defaultVal NatDefault litVal - UIntLit x -> do - let defaultVal = Con $ Lit $ Int32Lit $ fromIntegral x - let litVal = Con $ Lit $ Int64Lit $ fromIntegral x - matchRequirement =<< applyFromLiteralMethod "from_integer" defaultVal IntDefault litVal - UFloatLit x -> matchRequirement $ Con $ Lit $ Float32Lit $ realToFrac x - -- TODO: Make sure that this conversion is not lossy! + liftM (SigmaAtom Nothing) $ matchPrimApp prim xs' + UNatLit _ -> throw TypeErr $ "Can't infer type of literal. Try an explicit annotation" + UIntLit _ -> throw TypeErr $ "Can't infer type of literal. Try an explicit annotation" + UFloatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Float32Lit $ realToFrac x + UHole -> throw TypeErr "Can't infer value of hole" + +expectEq :: (PrettyE e, AlphaEqE e) => e o -> e o -> InfererM i o () +expectEq reqTy actualTy = alphaEq reqTy actualTy >>= \case + True -> return () + False -> throw TypeErr $ "Expected: " ++ pprint reqTy ++ + "\nActual: " ++ pprint actualTy +{-# INLINE expectEq #-} + +matchReq :: Ext o o' => RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') +matchReq (Check reqTy) x = do + reqTy' <- sinkM reqTy + return x <* expectEq reqTy' (getType x) +matchReq Infer x = return x +{-# INLINE matchReq #-} + +instantiateSigma :: Emits o => RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) +instantiateSigma reqTy sigmaAtom = case sigmaAtom of + SigmaUVar _ _ _ -> case getType sigmaAtom of + Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do + bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do + case reqTy of + Infer -> return [] + Check reqTy' -> return [TypeConstraint (sink reqTy') resultTy'] + args <- inferMixedArgs @UExpr fDesc expls bsConstrained ([], []) + applySigmaAtom sigmaAtom args + _ -> fallback + _ -> fallback where - matchRequirement :: CAtom o -> InfererM i o (CAtom o) - matchRequirement x = return x <* - case reqTy of - Infer -> return () - Check req -> do - ty <- return $ getType x - constrainTypesEq req ty - {-# INLINE matchRequirement #-} - - inferAndInstantiate :: InfererM i o (CAtom o) - inferAndInstantiate = do - sigmaAtom <- maybeInterpretPunsAsTyCons reqTy =<< inferWithoutInstantiation uExprWithSrc - instantiateSigma sigmaAtom >>= matchRequirement - {-# INLINE inferAndInstantiate #-} - -applyFromLiteralMethod :: EmitsBoth n => SourceName -> CAtom n -> DefaultType -> CAtom n -> InfererM i n (CAtom n) -applyFromLiteralMethod methodName defaultVal defaultTy litVal = do + fallback = forceSigmaAtom sigmaAtom >>= matchReq reqTy + fDesc = getSourceName sigmaAtom + +forceSigmaAtom :: Emits o => SigmaAtom o -> InfererM i o (CAtom o) +forceSigmaAtom sigmaAtom = case sigmaAtom of + SigmaAtom _ x -> return x + SigmaUVar _ _ v -> case v of + UAtomVar v' -> Var <$> toAtomVar v' + _ -> applySigmaAtom sigmaAtom [] + SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? + +withBlockDecls + :: (Emits o, Zonkable e) + => UBlock i + -> (forall i' o'. (Emits o', DExt o o') => UExpr i' -> InfererM i' o' (e o')) + -> InfererM i o (e o) +withBlockDecls (WithSrcE src (UBlock declsTop result)) contTop = + addSrcContext src $ go declsTop $ contTop result where + go :: (Emits o, Zonkable e) + => Nest UDecl i i' + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) + go decls cont = case decls of + Empty -> withDistinct cont + Nest d ds -> withUDecl d $ go ds $ cont + +withUDecl + :: (Emits o, Zonkable e) + => UDecl i i' + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) +withUDecl (WithSrcB src d) cont = addSrcContext src case d of + UPass -> withDistinct cont + UExprDecl e -> withDistinct $ bottomUp e >> cont + ULet letAnn p ann rhs -> do + val <- checkMaybeAnnExpr ann rhs + var <- emitDecl (getNameHint p) letAnn $ Atom val + bindLetPat p var cont + +applyFromLiteralMethod + :: Emits n => CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) +applyFromLiteralMethod resultTy methodName litVal = lookupSourceMap methodName >>= \case - Nothing -> return defaultVal + Nothing -> error $ "prelude function not found: " ++ methodName Just ~(UMethodVar methodName') -> do MethodBinding className _ <- lookupEnv methodName' - resultTyVar <- freshInferenceName MiscInfVar TyKind - dictTy <- DictTy <$> dictType className [Var resultTyVar] - addDefault (atomVarName resultTyVar) defaultTy - emitExpr =<< mkApplyMethod (DictHole (AlwaysEqual emptySrcPosCtx) dictTy Full) 0 [litVal] + dictTy <- DictTy <$> dictType className [Type resultTy] + d <- trySynthTerm dictTy Full + emitExpr =<< mkApplyMethod d 0 [litVal] -- atom that requires instantiation to become a rho type data SigmaAtom n = @@ -1093,51 +650,6 @@ instance HasSourceName (SigmaAtom n) where SigmaUVar sn _ _ -> sn SigmaPartialApp _ _ _ -> "" -instance SinkableE SigmaAtom where - sinkingProofE = error "it's fine, trust me" - -instance SubstE AtomSubstVal SigmaAtom where - substE env (SigmaAtom sn x) = SigmaAtom sn $ substE env x - substE env (SigmaUVar sn ty uvar) = case uvar of - UAtomVar v -> substE env $ SigmaAtom (Just sn) $ Var (AtomVar v ty) - UTyConVar v -> SigmaUVar sn ty' $ UTyConVar $ substE env v - UDataConVar v -> SigmaUVar sn ty' $ UDataConVar $ substE env v - UPunVar v -> SigmaUVar sn ty' $ UPunVar $ substE env v - UClassVar v -> SigmaUVar sn ty' $ UClassVar $ substE env v - UMethodVar v -> SigmaUVar sn ty' $ UMethodVar $ substE env v - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" - where ty' = substE env ty - substE env (SigmaPartialApp ty f xs) = - SigmaPartialApp (substE env ty) (substE env f) (map (substE env) xs) - --- XXX: this must handle the complement of the cases that `checkOrInferRho` --- handles directly or else we'll just keep bouncing between the two. -inferWithoutInstantiation - :: forall i o. EmitsBoth o - => UExpr i -> InfererM i o (SigmaAtom o) -inferWithoutInstantiation (WithSrcE pos expr) = - addSrcContext pos $ confuseGHC >>= \_ -> case expr of - UVar ~(InternalName _ sn v) -> do - v' <- renameM v - ty <- getUVarType v' - return $ SigmaUVar sn ty v' - UFieldAccess x (WithSrc pos' field) -> addSrcContext pos' do - x' <- inferRho noHint x >>= zonk - ty <- return $ getType x' - fields <- getFieldDefs ty - case M.lookup field fields of - Just def -> case def of - FieldProj i -> SigmaAtom Nothing <$> projectField i x' - FieldDotMethod method (TyConParams _ params) -> do - method' <- toAtomVar method - resultTy <- partialAppType (getType method') (params ++ [x']) - return $ SigmaPartialApp resultTy (Var method') (params ++ [x']) - Nothing -> throw TypeErr $ - "Can't resolve field " ++ pprint field ++ " of type " ++ pprint ty ++ - "\nKnown fields are: " ++ pprint (M.keys fields) - _ -> SigmaAtom Nothing <$> inferRho noHint (WithSrcE pos expr) - data FieldDef (n::S) = FieldProj Int | FieldDotMethod (CAtomName n) (TyConParams n) @@ -1169,32 +681,6 @@ getFieldDefs ty = case ty of where noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s -instantiateSigma :: forall i o. EmitsBoth o => SigmaAtom o -> InfererM i o (CAtom o) -instantiateSigma sigmaAtom = case getType sigmaAtom of - Pi piTy@(CorePiType ExplicitApp _ _ _) -> do - Lam <$> etaExpandExplicits fDesc piTy \args -> - applySigmaAtom (sink sigmaAtom) args - Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do - args <- inferMixedArgs @UExpr fDesc expls (Abs bs resultTy) [] [] - applySigmaAtom sigmaAtom args - DepPairTy (DepPairType ImplicitDepPair _ _) -> - -- TODO: we should probably call instantiateSigma again here in case - -- we have nested dependent pairs. Also, it looks like this doesn't - -- get called after function application. We probably want to fix that. - fallback >>= getSnd - _ -> fallback - where - fallback = case sigmaAtom of - SigmaAtom _ x -> return x - SigmaUVar _ _ v -> case v of - UAtomVar v' -> do - v'' <- toAtomVar v' - return $ Var v'' - _ -> applySigmaAtom sigmaAtom [] - SigmaPartialApp _ _ _ -> error "shouldn't hit this case because we should have a pi type here" - fDesc :: SourceName - fDesc = getSourceName sigmaAtom - projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of ProdTy _ -> projectTuple i x @@ -1206,129 +692,65 @@ projectField i x = case getType x of _ -> bad where bad = error $ "bad projection: " ++ pprint (i, x) --- creates a lambda term with just the explicit binders, but provides --- args corresponding to all the binders (explicit and implicit) -etaExpandExplicits - :: EmitsInf o => SourceName -> CorePiType o - -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) - -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName (CorePiType _ explsTop bsTop (EffTy effs _)) contTop = do - Abs bs body <- go explsTop bsTop \xs -> do - effs' <- applySubst (bsTop@@>(SubstVal<$>xs)) effs - withAllowedEffects effs' do - body <- buildBlockInf $ contTop $ sinkList xs - return $ PairE effs' body - let (expls, bs') = unzipAttrs bs - coreLamExpr ExplicitApp expls $ Abs bs' body - where - go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => [Explicitness] -> Nest CBinder o any - -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest (b:>ty) rest) cont = case expl of - Explicit -> do - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ty \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go expls rest' \args -> cont (sink (Var v) : args) - Inferred argSourceName infMech -> do - arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech ty - Abs rest' UnitE <- applySubst (b@>SubstVal arg) $ Abs rest UnitE - go expls rest' \args -> cont (sink arg : args) - go _ _ _ = error "zip error" - -buildLamInf - :: EmitsInf o => CorePiType o - -> (forall o' . (EmitsBoth o', DExt o o') - => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) - -> InfererM i o (CoreLamExpr o) -buildLamInf (CorePiType appExpl explsTop bsTop effTy) contTop = do - ab <- go explsTop bsTop \xs -> do - let (expls, xs') = unzip xs - EffTy effs' resultTy' <- applySubst (bsTop@@>(SubstVal<$>xs')) effTy - withAllowedEffects effs' do - body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') - return $ PairE effs' body - coreLamExpr appExpl explsTop ab - where - go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> Nest CBinder o any - -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest CBinder) e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest b rest) cont = do - prependAbs <$> buildAbsInf (getNameHint b) expl (binderType b) \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go expls rest' \args -> cont $ (expl, sink $ Var v) : args - go _ _ _ = error "zip error" - -class ExplicitArg (e::E) where - checkExplicitArg :: EmitsBoth o => IsDependent -> e i -> CType o -> InfererM i o (CAtom o) - inferExplicitArg :: EmitsBoth o => e i -> InfererM i o (CAtom o) +class PrettyE e => ExplicitArg (e::E) where + checkExplicitArg :: Emits o => IsDependent -> e i -> PartialType o -> InfererM i o (CAtom o) + inferExplicitArg :: Emits o => e i -> InfererM i o (CAtom o) + isHole :: e n -> Bool instance ExplicitArg UExpr where - checkExplicitArg isDependent arg argTy = + checkExplicitArg isDependent arg argTy = do if isDependent - then checkSigmaDependent noHint arg argTy - else checkSigma noHint arg argTy + then checkSigmaDependent arg argTy -- do we actually need this? + else topDownPartial argTy arg - inferExplicitArg arg = inferRho noHint arg + inferExplicitArg arg = bottomUp arg + isHole = \case + WithSrcE _ UHole -> True + _ -> False instance ExplicitArg CAtom where checkExplicitArg _ arg argTy = do arg' <- renameM arg - constrainTypesEq argTy $ getType arg' + case argTy of + FullType argTy' -> expectEq argTy' (getType arg') + PartialType _ -> return () -- TODO? return arg' inferExplicitArg arg = renameM arg + isHole _ = False checkOrInferApp - :: forall i o arg - . (EmitsBoth o, ExplicitArg arg) + :: forall i o arg . (Emits o, ExplicitArg arg) => SigmaAtom o -> [arg i] -> [(SourceName, arg i)] - -> RequiredTy CType o - -> InfererM i o (CAtom o) + -> RequiredTy o -> InfererM i o (CAtom o) checkOrInferApp f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of - Pi (CorePiType appExpl expls bs effTy) -> case appExpl of + Pi piTy@(CorePiType appExpl expls _ _) -> case appExpl of ExplicitApp -> do - checkArity expls posArgs - args' <- inferMixedArgs fDesc expls (Abs bs effTy) posArgs namedArgs - applySigmaAtom f args' >>= matchRequirement - ImplicitApp -> do - -- TODO: should this already have been done by the time we get `f`? - implicitArgs <- inferMixedArgs @UExpr fDesc expls (Abs bs effTy) [] [] - f'' <- SigmaAtom (Just fDesc) <$> applySigmaAtom f implicitArgs - checkOrInferApp f'' posArgs namedArgs Infer >>= matchRequirement - -- TODO: special-case error for when `fTy` can't possibly be a function - fTy -> do - when (not $ null namedArgs) do - throw TypeErr "Can't infer function types with named arguments" - args' <- mapM inferExplicitArg posArgs - argTys <- return $ map getType args' - resultTy <- getResultTy - let expected = nonDepPiType argTys Pure resultTy - constrainTypesEq (Pi expected) fTy - f'' <- zonk f - applySigmaAtom f'' args' + checkExplicitArity expls posArgs + bsConstrained <- buildAppConstraints reqTy piTy + args <- inferMixedArgs fDesc expls bsConstrained (posArgs, namedArgs) + applySigmaAtom f args + ImplicitApp -> error "should already have handled this case" + ty -> throw TypeErr $ "Expected a function type. Got: " ++ pprint ty where fDesc :: SourceName fDesc = getSourceName f' - getResultTy :: InfererM i o (CType o) - getResultTy = case reqTy of - Infer -> freshType - Check req -> return req - - matchRequirement :: CAtom o -> InfererM i o (CAtom o) - matchRequirement x = return x <* - case reqTy of - Infer -> return () - Check req -> do - ty <- return $ getType x - constrainTypesEq req ty - -maybeInterpretPunsAsTyCons :: RequiredTy CType n -> SigmaAtom n -> InfererM i n (SigmaAtom n) +buildAppConstraints :: RequiredTy n -> CorePiType n -> InfererM i n (Abs (Nest CBinder) Constraints n) +buildAppConstraints reqTy (CorePiType _ _ bs effTy) = do + effsAllowed <- infEffects <$> getInfState + buildConstraints (Abs bs effTy) \_ (EffTy effs resultTy) -> do + resultTyConstraints <- return case reqTy of + Infer -> [] + Check reqTy' -> [TypeConstraint (sink reqTy') resultTy] + EffectRow _ t <- return effs + effConstraints <- case t of + NoTail -> return [] + EffectRowTail _ -> return [EffectConstraint (sink effsAllowed) effs] + return $ resultTyConstraints ++ effConstraints + +maybeInterpretPunsAsTyCons :: RequiredTy n -> SigmaAtom n -> InfererM i n (SigmaAtom n) maybeInterpretPunsAsTyCons (Check TyKind) (SigmaUVar sn _ (UPunVar v)) = do let v' = UTyConVar v ty <- getUVarType v' @@ -1337,7 +759,7 @@ maybeInterpretPunsAsTyCons _ x = return x type IsDependent = Bool -applySigmaAtom :: EmitsBoth o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) +applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) applySigmaAtom (SigmaAtom _ f) args = emitExprWithEffects =<< mkApp f args applySigmaAtom (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do @@ -1366,8 +788,6 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args emitExprWithEffects =<< mkApplyMethod dictArg methodIdx args' - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" applySigmaAtom (SigmaPartialApp _ f prevArgs) args = emitExprWithEffects =<< mkApp f (prevArgs ++ args) @@ -1409,58 +829,130 @@ applyDataCon tc conIx topArgs = do where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty -emitExprWithEffects :: EmitsBoth o => CExpr o -> InfererM i o (CAtom o) +emitExprWithEffects :: Emits o => CExpr o -> InfererM i o (CAtom o) emitExprWithEffects expr = do addEffects $ getEffects expr emitExpr expr -checkArity :: [Explicitness] -> [a] -> InfererM i o () -checkArity expls args = do +checkExplicitArity :: [Explicitness] -> [a] -> InfererM i o () +checkExplicitArity expls args = do let arity = length [() | Explicit <- expls] let numArgs = length args when (numArgs /= arity) do throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ pprint arity ++ " but got " ++ pprint numArgs +type MixedArgs arg = ([arg], [(SourceName, arg)]) -- positional args, named args +data Constraint (n::S) = + TypeConstraint (CType n) (CType n) + -- permitted effects (no inference vars), proposed effects + | EffectConstraint (EffectRow CoreIR n) (EffectRow CoreIR n) +type Constraints = ListE Constraint + +buildConstraints + :: RenameE e + => Abs (Nest CBinder) e o + -> (forall o'. DExt o o' => [CAtom o'] -> e o' -> EnvReaderM o' [Constraint o']) + -> InfererM i o (Abs (Nest CBinder) Constraints o) +buildConstraints ab cont = liftEnvReaderM do + refreshAbs ab \bs e -> do + cs <- cont (Var <$> bindersVars bs) e + return $ Abs bs $ ListE cs + -- TODO: check that there are no extra named args provided inferMixedArgs - :: forall arg i o e - . (ExplicitArg arg, EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => SourceName -> [Explicitness] - -> Abs (Nest CBinder) e o -> [arg i] -> [(SourceName, arg i)] + :: forall arg i o . (Emits o, ExplicitArg arg) + => SourceName + -> [Explicitness] -> Abs (Nest CBinder) Constraints o + -> MixedArgs (arg i) -> InfererM i o [CAtom o] -inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do - checkNamedArgValidity explsTop (map fst namedArgs) - liftM fst $ runStreamReaderT1 posArgs $ go explsTop bsAbs +inferMixedArgs fSourceName explsTop bsAbs argsTop@(_, namedArgsTop) = do + checkNamedArgValidity explsTop (map fst namedArgsTop) + liftSolverM $ fromListE <$> go explsTop bsAbs argsTop where - go :: (EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => [Explicitness] -> Abs (Nest CBinder) e o - -> StreamReaderT1 (arg i) (InfererM i) o [CAtom o] - go [] (Abs Empty _) = return [] - go (expl:expls) (Abs (Nest b bs) result) = do - let rest = Abs bs result - let isDependent = binderName b `isFreeIn` rest - arg <- inferMixedArg isDependent (binderType b) expl - arg' <- lift11 $ zonk arg - rest' <- applySubst (b @> SubstVal arg') rest - (arg:) <$> go expls rest' - go _ _ = error "zip error" + go :: Emits oo + => [Explicitness] -> Abs (Nest CBinder) Constraints oo -> MixedArgs (arg i) + -> SolverM i oo (ListE CAtom oo) + go expls (Abs bs cs) args = do + cs' <- eagerlyApplyConstraints bs cs + case (expls, bs) of + ([], Empty) -> return mempty + (expl:explsRest, Nest b bsRest) -> do + let isDependent = binderName b `isFreeIn` Abs bsRest cs' + inferMixedArg isDependent (binderType b) expl args \arg restArgs -> do + bs' <- applySubst (b @> SubstVal arg) (Abs bsRest cs') + (ListE [arg] <>) <$> go explsRest bs' restArgs + (_, _) -> error "zip error" + + eagerlyApplyConstraints + :: Nest CBinder oo oo' -> Constraints oo' + -> SolverM i oo (Constraints oo') + eagerlyApplyConstraints Empty (ListE cs) = mapM_ applyConstraint cs >> return (ListE []) + eagerlyApplyConstraints bs (ListE cs) = ListE <$> forMFilter cs \c -> do + case hoist bs c of + HoistSuccess c' -> case c' of + TypeConstraint _ _ -> applyConstraint c' >> return Nothing + EffectConstraint _ (EffectRow specificEffs _) -> + hasInferenceVars specificEffs >>= \case + False -> applyConstraint c' >> return Nothing + -- we delay applying the constraint in this case because we might + -- learn more about the specific effects after we've seen more + -- arguments (like a `Ref h a` that tells us about the `h`) + True -> return $ Just c + HoistFailure _ -> return $ Just c + + inferMixedArg + :: (Emits oo, Zonkable e) => IsDependent -> CType oo -> Explicitness -> MixedArgs (arg i) + -> (forall o'. (Emits o', DExt oo o') => CAtom o' -> MixedArgs (arg i) -> SolverM i o' (e o')) + -> SolverM i oo (e oo) + inferMixedArg isDependent argTy' expl args cont = do + argTy <- zonk argTy' + case expl of + Explicit -> do + -- this should succeed because we've already done the arity check + (arg:argsRest, namedArgs) <- return args + if isHole arg + then do + let desc = (fSourceName, "_") + withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> + cont (Var v) (argsRest, namedArgs) + else do + arg' <- checkOrInferExplicitArg isDependent arg argTy + withDistinct $ cont arg' (argsRest, namedArgs) + Inferred argName infMech -> do + let desc = (fSourceName, fromMaybe "_" argName) + case lookupNamedArg args argName of + Just arg -> do + arg' <- checkOrInferExplicitArg isDependent arg argTy + withDistinct $ cont arg' args + Nothing -> case infMech of + Unify -> withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (Var v) args + Synth _ -> withDict argTy \d -> cont d args + + checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) + checkOrInferExplicitArg isDependent arg argTy = do + arg' <- lift11 $ withoutInfVarsPartial argTy >>= \case + Just partialTy -> checkExplicitArg isDependent arg partialTy + Nothing -> inferExplicitArg arg + constrainTypesEq argTy (getType arg') + return arg' - inferMixedArg :: EmitsBoth o => IsDependent -> CType o -> Explicitness - -> StreamReaderT1 (arg i) (InfererM i) o (CAtom o) - inferMixedArg isDependent argTy = \case - Explicit -> do - -- this should succeed because we've already done the arity check - Just arg <- readStream - lift11 $ checkExplicitArg isDependent arg argTy - Inferred argName infMech -> lift11 do - case lookupNamedArg argName of - Nothing -> getImplicitArg (fSourceName, fromMaybe "_" argName) infMech argTy - Just arg -> checkExplicitArg isDependent arg argTy - - lookupNamedArg :: Maybe SourceName -> Maybe (arg i) - lookupNamedArg Nothing = Nothing - lookupNamedArg (Just v) = lookup v namedArgs + lookupNamedArg :: MixedArgs x -> Maybe SourceName -> Maybe x + lookupNamedArg _ Nothing = Nothing + lookupNamedArg (_, namedArgs) (Just v) = lookup v namedArgs + + withoutInfVarsPartial :: CType n -> InfererM i n (Maybe (PartialType n)) + withoutInfVarsPartial = \case + Pi piTy -> + withoutInfVars piTy >>= \case + Just piTy' -> return $ Just $ PartialType $ piAsPartialPi piTy' + Nothing -> withoutInfVars $ PartialType $ piAsPartialPiDropResultTy piTy + ty -> liftM (FullType <$>) $ withoutInfVars ty + + withoutInfVars :: HoistableE e => e n -> InfererM i n (Maybe (e n)) + withoutInfVars x = hasInferenceVars x >>= \case + True -> return Nothing + False -> return $ Just x checkNamedArgValidity :: Fallible m => [Explicitness] -> [SourceName] -> m () checkNamedArgValidity expls offeredNames = do @@ -1476,9 +968,9 @@ checkNamedArgValidity expls offeredNames = do throw TypeErr $ "Unrecognized named arguments: " ++ pprint unrecognizedNames ++ "\nShould be one of: " ++ pprint acceptedNames -inferPrimArg :: EmitsBoth o => UExpr i -> InfererM i o (CAtom o) +inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do - xBlock <- buildBlockInf $ inferRho noHint x + xBlock <- buildScoped $ bottomUp x EffTy _ ty <- blockEffTy xBlock case ty of TyKind -> cheapReduce xBlock >>= \case @@ -1549,46 +1041,41 @@ pattern ExplicitCoreLam bs body <- Lam (CoreLamExpr _ (LamExpr bs body)) -- === n-ary applications === -inferTabApp :: EmitsBoth o => SrcPosCtx -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) +inferTabApp :: Emits o => SrcPosCtx -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) inferTabApp tabCtx tab args = addSrcContext tabCtx do tabTy <- return $ getType tab args' <- inferNaryTabAppArgs tabTy args - tab' <- zonk tab - emitExpr =<< mkTabApp tab' args' + emitExpr =<< mkTabApp tab args' -inferNaryTabAppArgs - :: EmitsBoth o - => CType o -> [UExpr i] -> InfererM i o [CAtom o] +inferNaryTabAppArgs :: Emits o => CType o -> [UExpr i] -> InfererM i o [CAtom o] inferNaryTabAppArgs _ [] = return [] -inferNaryTabAppArgs tabTy (arg:rest) = do - TabPiType _ b resultTy <- fromTabPiType True tabTy - let ixTy = binderType b - let isDependent = binderName b `isFreeIn` resultTy - arg' <- if isDependent - then checkSigmaDependent (getNameHint b) arg ixTy - else checkSigma (getNameHint b) arg ixTy - arg'' <- zonk arg' - resultTy' <- applySubst (b @> SubstVal arg'') resultTy - rest' <- inferNaryTabAppArgs resultTy' rest - return $ arg'':rest' - -checkSigmaDependent :: EmitsBoth o - => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkSigmaDependent hint e@(WithSrcE ctx _) ty = addSrcContext ctx $ - withReducibleEmissions depFunErrMsg $ checkSigma hint e (sink ty) +inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of + TabPi (TabPiType _ b resultTy) -> do + let ixTy = binderType b + let isDependent = binderName b `isFreeIn` resultTy + arg' <- if isDependent + then checkSigmaDependent arg (FullType ixTy) + else topDown ixTy arg + resultTy' <- applySubst (b @> SubstVal arg') resultTy + rest' <- inferNaryTabAppArgs resultTy' rest + return $ arg':rest' + _ -> throw TypeErr $ "Expected a table type but got: " ++ pprint tabTy + +checkSigmaDependent :: Emits o => UExpr i -> PartialType o -> InfererM i o (CAtom o) +checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ + withReducibleEmissions depFunErrMsg $ topDownPartial (sink ty) e where depFunErrMsg = "Dependent functions can only be applied to fully evaluated expressions. " ++ "Bind the argument to a name before you apply the function." withReducibleEmissions - :: ( EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e - , HoistableE e, CheaplyReducibleE CoreIR e e) + :: ( Zonkable e, CheaplyReducibleE CoreIR e e, SubstE AtomSubstVal e) => String - -> (forall o' . (EmitsBoth o', DExt o o') => InfererM i o' (e o')) + -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) -> InfererM i o (e o) withReducibleEmissions msg cont = do - Abs decls result <- buildDeclsInf cont + Abs decls result <- buildScoped cont cheapReduceWithDecls decls result >>= \case Just t -> return t _ -> throw TypeErr msg @@ -1650,83 +1137,63 @@ buildSortedCase scrut alts resultTy = do instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) instanceFun instanceName appExpl = do InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName - ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do + liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) - return $ Abs bs' (PairE Pure (WithoutDecls result)) - Lam <$> coreLamExpr appExpl (snd<$>expls) ab - -checkMaybeAnnExpr :: EmitsBoth o - => NameHint -> Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) -checkMaybeAnnExpr hint ty expr = confuseGHC >>= \_ -> case ty of - Nothing -> inferSigma hint expr - Just ty' -> checkSigma hint expr =<< zonk =<< checkUType ty' - -inferRole :: CType o -> Explicitness -> InfererM i o ParamRole -inferRole ty = \case - Inferred _ (Synth _) -> return DictParam - _ -> do - zonk ty >>= \case - TyKind -> return TypeParam - ty' -> isData ty' >>= \case - True -> return DataParam - -- TODO(dougalm): the `False` branch should throw an error but that's - -- currently too conservative. e.g. `data RangeFrom q:Type i:q = ...` - -- fails because `q` isn't data. We should be able to fix it once we - -- have a `Data a` class (see issue #680). - False -> return DataParam -{-# INLINE inferRole #-} - -inferTyConDef :: EmitsInf o => UDataDef i -> InfererM i o (TyConDef o) + let effTy = EffTy Pure (getType result) + let body = WithoutDecls result + let piTy = CorePiType appExpl (snd<$>expls) bs' effTy + return $ Lam $ CoreLamExpr piTy (LamExpr bs' body) + +checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) +checkMaybeAnnExpr ty expr = confuseGHC >>= \_ -> case ty of + Nothing -> bottomUp expr + Just ty' -> do + ty'' <- checkUType ty' + topDown ty'' expr + +inferTyConDef :: UDataDef i -> InfererM i o (TyConDef o) inferTyConDef (UDataDef tyConName paramBs dataCons) = do - Abs paramBs' dataCons' <- - withRoleUBinders paramBs do - ADTCons <$> mapM inferDataCon dataCons - let (roleExpls, paramBs'') = unzipAttrs paramBs' - return (TyConDef tyConName roleExpls paramBs'' dataCons') + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + dataCons' <- ADTCons <$> mapM inferDataCon dataCons + return (TyConDef tyConName roleExpls paramBs' dataCons') -inferStructDef :: EmitsInf o => UStructDef i -> InfererM i o (TyConDef o) +inferStructDef :: UStructDef i -> InfererM i o (TyConDef o) inferStructDef (UStructDef tyConName paramBs fields _) = do - let (fieldNames, fieldTys) = unzip fields - Abs paramBs' dataConDefs <- withRoleUBinders paramBs do + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + let (fieldNames, fieldTys) = unzip fields tys <- mapM checkUType fieldTys - return $ StructFields $ zip fieldNames tys - let (roleExpls, paramBs'') = unzipAttrs paramBs' - return $ TyConDef tyConName roleExpls paramBs'' dataConDefs + let dataConDefs = StructFields $ zip fieldNames tys + return $ TyConDef tyConName roleExpls paramBs' dataConDefs inferDotMethod - :: EmitsInf o => TyConName o - -> Abs (Nest UOptAnnBinder) (Abs UAtomBinder ULamExpr) i + :: TyConName o + -> Abs (Nest UAnnBinder) (Abs UAtomBinder ULamExpr) i -> InfererM i o (CoreLamExpr o) inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do TyConDef sn roleExpls paramBs _ <- lookupTyCon tc let expls = snd <$> roleExpls - ab <- buildNaryAbsInfWithExpl expls (Abs paramBs UnitE) \paramVs -> do - let paramVs' = catMaybes $ zip expls paramVs <&> \(expl, v) -> case expl of - Inferred _ (Synth _) -> Nothing - _ -> Just v - extendRenamer (uparamBs @@> (atomVarName <$> paramVs')) do + withFreshBindersInf expls (Abs paramBs UnitE) \paramBs' UnitE -> do + let paramVs = bindersVars paramBs' + extendRenamer (uparamBs @@> (atomVarName <$> paramVs)) do let selfTy = NewtypeTyCon $ UserADTType sn (sink tc) (TyConParams expls (Var <$> paramVs)) - buildAbsInfWithExpl "self" Explicit selfTy \vSelf -> - extendRenamer (selfB @> atomVarName vSelf) $ inferULam lam - Abs paramBs'' (Abs selfB' lam') <- return ab - return $ prependCoreLamExpr (paramBs'' >>> UnaryNest selfB') lam' - -prependCoreLamExpr :: Nest (WithExpl CBinder) n l -> CoreLamExpr l -> CoreLamExpr n -prependCoreLamExpr bs e = case e of - CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do - let (expls, bs') = unzipAttrs bs - let piType = CorePiType appExpl (expls <> piExpls) (bs' >>> piBs) effTy - let lamExpr = LamExpr (fmapNest withoutAttr bs >>> lamBs) body - CoreLamExpr piType lamExpr - -inferDataCon :: EmitsInf o => (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) + withFreshBinderInf "self" Explicit selfTy \selfB' -> do + lam' <- extendRenamer (selfB @> binderName selfB') $ inferULam lam + return $ prependCoreLamExpr (expls ++ [Explicit]) (paramBs' >>> UnaryNest selfB') lam' + + where + prependCoreLamExpr :: [Explicitness] -> Nest CBinder n l -> CoreLamExpr l -> CoreLamExpr n + prependCoreLamExpr expls bs e = case e of + CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do + let piType = CorePiType appExpl (expls <> piExpls) (bs >>> piBs) effTy + let lamExpr = LamExpr (bs >>> lamBs) body + CoreLamExpr piType lamExpr + +inferDataCon :: (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) inferDataCon (sourceName, UDataDefTrail argBs) = do - let expls = nestToList (const Explicit) argBs - Abs argBs' UnitE <- withUBinders (expls, argBs) \_ -> return UnitE - let argBs'' = Abs (fmapNest withoutAttr argBs') UnitE - let (repTy, projIdxs) = dataConRepTy argBs'' - return $ DataConDef sourceName argBs'' repTy projIdxs + withUBinders argBs \(ZipB _ argBs') -> do + let (repTy, projIdxs) = dataConRepTy $ EmptyAbs argBs' + return $ DataConDef sourceName (EmptyAbs argBs') repTy projIdxs dataConRepTy :: EmptyAbs (Nest CBinder) n -> (CType n, [[Projection]]) dataConRepTy (Abs topBs UnitE) = case topBs of @@ -1754,188 +1221,182 @@ dataConRepTy (Abs topBs UnitE) = case topBs of depTy = DepPairTy $ DepPairType ExplicitDepPair b tailTy inferClassDef - :: EmitsInf o - => SourceName -> [SourceName] - -> UOptAnnExplBinders i i' - -> [UType i'] + :: SourceName -> [SourceName] -> Nest UAnnBinder i i' -> [UType i'] -> InfererM i o (ClassDef o) -inferClassDef className methodNames paramBs@(expls, paramBs') methods = do - let paramBsWithAttrBs = zipWithNest paramBs' expls \b expl -> WithAttrB expl b - let paramNames = catMaybes $ nestToList - (\(WithAttrB expl (UAnnBinder b _ _)) -> case expl of - Inferred _ (Synth _) -> Nothing - _ -> Just $ Just $ getSourceName b) paramBsWithAttrBs - ab <- withRoleUBinders paramBs do - ListE <$> forM methods \m -> do - checkUType m >>= \case - Pi t -> return t - t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) - Abs (PairB bs scs) (ListE mtys) <- identifySuperclasses ab - let (roleExpls, bs') = unzipAttrs bs - return $ ClassDef className methodNames paramNames roleExpls bs' scs mtys - -identifySuperclasses - :: RenameE e => Abs (Nest (WithRoleExpl CBinder)) e n - -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinder)) (Nest CBinder)) e n) -identifySuperclasses ab = do - refreshAbs ab \bs e -> do - bs' <- partitionBinders bs \b@(WithAttrB (_, expl) b') -> case expl of - Explicit -> return $ LeftB b - Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" - Inferred _ (Synth _) -> return $ RightB b' - return $ Abs bs' e - -withUBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) - => UAnnExplBinders req i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -withUBinders bs cont = case bs of - ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do - ann' <- checkAnn (getSourceName b) ann - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ann' \v -> - concatAbs <$> withConstraintBinders cs v do - extendSubst (b@>sink (atomVarName v)) $ withUBinders (expls, rest) \vs -> - cont (sink v : vs) - _ -> error "zip error" - -withConstraintBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) - => [UConstraint i] - -> CAtomVar o - -> (forall o'. (EmitsInf o', DExt o o') => InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -withConstraintBinders [] _ cont = getDistinct >>= \Distinct -> Abs Empty <$> cont -withConstraintBinders (c:cs) v cont = do - Type dictTy <- withReducibleEmissions "Can't reduce interface constraint" do - c' <- inferWithoutInstantiation c >>= zonk - dropSubst $ checkOrInferApp c' [Var $ sink v] [] (Check TyKind) - prependAbs <$> buildAbsInfWithExpl "d" (Inferred Nothing (Synth Full)) dictTy \_ -> - withConstraintBinders cs (sink v) cont - -withRoleUBinders - :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) - => UAnnExplBinders req i i' - -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) -withRoleUBinders roleBs cont = case roleBs of - ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont - (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do - ann' <- checkAnn (getSourceName b) ann - Abs b' (Abs bs' e) <- buildAbsInf (getNameHint b) expl ann' \v -> do - Abs ds (Abs bs' e) <- withConstraintBinders cs v $ - extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders (expls, rest) cont - let ds' = fmapNest (\(WithAttrB expl' b') -> WithAttrB (DictParam, expl') b') ds - return $ Abs (ds' >>> bs') e - role <- inferRole (binderType b') expl - return $ Abs (Nest (WithAttrB (role,expl) b') bs') e - _ -> error "zip error" - -inferULam :: EmitsInf o => ULamExpr i -> InfererM i o (CoreLamExpr o) -inferULam (ULamExpr bs appExpl effs resultTy body) = do - ab <- withUBinders bs \_ -> do - effs' <- fromMaybe Pure <$> mapM checkUEffRow effs - resultTy' <- mapM checkUType resultTy - body' <- buildBlockInf $ withAllowedEffects (sink effs') do - case resultTy' of - Nothing -> withBlockDecls body \result -> inferSigma noHint result - Just resultTy'' -> - withBlockDecls body \result -> - checkSigma noHint result (sink resultTy'') - return (PairE effs' body') - Abs bs' (PairE effs' body') <- return ab +inferClassDef className methodNames paramBs methodTys = do + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + let paramNames = catMaybes $ nestToListFlip paramBs \(UAnnBinder expl b _ _) -> + case expl of Inferred _ (Synth _) -> Nothing + _ -> Just $ Just $ getSourceName b + methodTys' <- forM methodTys \m -> do + checkUType m >>= \case + Pi t -> return t + t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) + PairB paramBs'' superclassBs <- partitionBinders (zipAttrs roleExpls paramBs') $ + \b@(WithAttrB (_, expl) b') -> case expl of + Explicit -> return $ LeftB b + Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" + Inferred _ (Synth _) -> return $ RightB b' + let (roleExpls', paramBs''') = unzipAttrs paramBs'' + return $ ClassDef className methodNames paramNames roleExpls' paramBs''' superclassBs methodTys' + +withUBinder :: UAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a +withUBinder (UAnnBinder expl b ann cs) cont = do + ty <- inferAnn ann cs + withFreshBinderInf (getNameHint b) expl ty \b' -> + extendSubst (b@>binderName b') $ cont (WithAttrB expl b') + +withUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithExpl CBinder)) i i' o a +withUBinders bs cont = do + Abs bs' UnitE <- inferUBinders bs \_ -> return UnitE let (expls, bs'') = unzipAttrs bs' - case appExpl of - ImplicitApp -> checkImplicitLamRestrictions bs'' effs' - ExplicitApp -> return () - coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' - -checkImplicitLamRestrictions :: Nest CBinder o o' -> EffectRow CoreIR o' -> InfererM i o () -checkImplicitLamRestrictions _ _ = return () -- TODO - -checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) -checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = do - unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - let iTy = binderAnn bPi - case ann of - UNoAnn -> return () - UAnn forAnn -> checkUType forAnn >>= constrainTypesEq iTy - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> do - extendRenamer (bFor@>atomVarName i) do - TabPiType _ bPi' resultTy <- sinkM tabPi - resultTy' <- applyRename (bPi'@>atomVarName i) resultTy - buildBlockInf do - withBlockDecls body \result -> - checkSigma noHint result $ sink resultTy' - return $ LamExpr (UnaryNest b) body' - -inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) -inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do - unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - iTy <- checkAnn (getSourceName bFor) ann - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> - extendRenamer (bFor@>atomVarName i) $ buildBlockInf $ - withBlockDecls body \result -> - checkOrInferRho noHint result Infer - return $ LamExpr (UnaryNest b) body' + withFreshBindersInf expls (Abs bs'' UnitE) \bs''' UnitE -> do + extendSubst (bs@@> (atomVarName <$> bindersVars bs''')) $ + cont $ zipAttrs expls bs''' -checkULam :: EmitsInf o => ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) -checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) - (CorePiType piAppExpl expls piBs effTy) = do - checkArity expls (nestToList (const ()) lamBs) +inferUBinders + :: Zonkable e => Nest UAnnBinder i i' + -> (forall o'. DExt o o' => [CAtomName o'] -> InfererM i' o' (e o')) + -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) +inferUBinders Empty cont = withDistinct $ Abs Empty <$> cont [] +inferUBinders (Nest (UAnnBinder expl b ann cs) bs) cont = do + -- TODO: factor out the common part of each case (requires an annotated + -- `where` clause because of the rank-2 type) + ty <- inferAnn ann cs + withFreshBinderInf (getNameHint b) expl ty \b' -> do + extendSubst (b@>binderName b') do + Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs) + return $ Abs (Nest (WithAttrB expl b') bs') e + +withRoleUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithRoleExpl CBinder)) i i' o a +withRoleUBinders bs cont = do + withUBinders bs \(ZipB expls bs') -> do + let tys = getType <$> bindersVars bs' + roleExpls <- forM (zip tys expls) \(ty, expl) -> do + role <- inferRole ty expl + return (role, expl) + cont (zipAttrs roleExpls bs') + where + inferRole :: CType o -> Explicitness -> InfererM i o ParamRole + inferRole ty = \case + Inferred _ (Synth _) -> return DictParam + _ -> case ty of + TyKind -> return TypeParam + _ -> isData ty >>= \case + True -> return DataParam + -- TODO(dougalm): the `False` branch should throw an error but that's + -- currently too conservative. e.g. `data RangeFrom q:Type i:q = ...` + -- fails because `q` isn't data. We should be able to fix it once we + -- have a `Data a` class (see issue #680). + False -> return DataParam + {-# INLINE inferRole #-} + +inferAnn :: UAnn i -> [UConstraint i] -> InfererM i o (CType o) +inferAnn ann cs = case ann of + UAnn ty -> checkUType ty + UNoAnn -> case cs of + WithSrcE _ (UVar ~(InternalName _ _ v)):_ -> do + renameM v >>= getUVarType >>= \case + Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _) -> return ty + ty -> throw TypeErr $ "Constraint should be a unary function. Got: " ++ pprint ty + _ -> throw TypeErr "Type annotation or constraint required" + +checkULamPartial :: PartialPiType o -> ULamExpr i -> InfererM i o (CoreLamExpr o) +checkULamPartial partialPiTy lamExpr = do + PartialPiType piAppExpl expls piBs piEffs piReqTy <- return partialPiTy + ULamExpr lamBs lamAppExpl lamEffs lamResultTy body <- return lamExpr + checkExplicitArity expls (nestToList (const ()) lamBs) when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl - Abs explBs body' <- checkLamBinders expls piBs lamBs \vs -> do - EffTy piEffs' piResultTy' <- applyRename (piBs@@>map atomVarName vs) effTy - case lamResultTy of - Nothing -> return () - Just t -> checkUType t >>= constrainTypesEq piResultTy' + checkLamBinders expls piBs lamBs \lamBs' -> do + PairE piEffs' piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) (PairE piEffs piReqTy) + resultTy <- case (lamResultTy, piReqTy') of + (Nothing, Infer ) -> return Infer + (Just t , Infer ) -> Check <$> checkUType t + (Nothing, Check t) -> Check <$> return t + (Just t , Check t') -> checkUType t >>= expectEq t' >> return (Check t') forM_ lamEffs \lamEffs' -> do lamEffs'' <- checkUEffRow lamEffs' - constrainEq (Eff piEffs') (Eff lamEffs'') - withAllowedEffects piEffs' do - body' <- buildBlockInf do - piResultTy'' <- sinkM piResultTy' - withBlockDecls body \result -> - checkSigma noHint result piResultTy'' - return $ PairE piEffs' body' - let (expls', bs') = unzipAttrs explBs - coreLamExpr piAppExpl expls' $ Abs bs' body' - -checkLamBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e) - => [Explicitness] -> Nest CBinder o any - -> Nest UOptAnnBinder i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -checkLamBinders [] Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do - prependAbs <$> case piExpl of - Inferred _ _ -> - buildAbsInfWithExpl (getNameHint piB) piExpl piAnn \v -> do - Abs piBs' UnitE <- applyRename (piB@>atomVarName v) $ Abs piBs UnitE - checkLamBinders piExpls piBs' lamBs \vs -> - cont (sink v:vs) - Explicit -> case lamBs of - Nest (UAnnBinder lamB ann cs) lamBsRest -> do - case ann of - UAnn lamAnn -> checkUType lamAnn >>= constrainTypesEq piAnn - UNoAnn -> return () - buildAbsInfWithExpl (getNameHint lamB) Explicit piAnn \v -> do - concatAbs <$> withConstraintBinders cs v do - Abs piBs' UnitE <- applyRename (piB@>sink (atomVarName v)) $ Abs piBs UnitE - extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piExpls piBs' lamBsRest \vs -> - cont (sink v:vs) - Empty -> error "zip error" -checkLamBinders _ _ _ _ = error "zip error" - -checkInstanceParams :: EmitsInf o => [Explicitness] -> Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] -checkInstanceParams expls bsTop paramsTop = do - checkArity expls paramsTop - go bsTop paramsTop + expectEq (Eff piEffs') (Eff lamEffs'') + body' <- withAllowedEffects piEffs' do + buildScoped $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result + resultTy' <- case resultTy of + Infer -> blockTy body' + Check t -> return t + let piTy = CorePiType piAppExpl expls lamBs' (EffTy piEffs' resultTy') + return $ CoreLamExpr piTy (LamExpr lamBs' body') + where + checkLamBinders + :: [Explicitness] -> Nest CBinder o any -> Nest UAnnBinder i i' + -> InfererCPSB2 (Nest CBinder) i i' o a + checkLamBinders [] Empty Empty cont = withDistinct $ cont Empty + checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do + case piExpl of + Inferred _ _ -> do + withFreshBinderInf (getNameHint piB) piExpl piAnn \b -> do + Abs piBs' UnitE <- applyRename (piB@>binderName b) (EmptyAbs piBs) + checkLamBinders piExpls piBs' lamBs \bs -> cont (Nest b bs) + Explicit -> case lamBs of + Nest (UAnnBinder _ lamB lamAnn _) lamBsRest -> do + case lamAnn of + UAnn lamAnn' -> checkUType lamAnn' >>= expectEq piAnn + UNoAnn -> return () + withFreshBinderInf (getNameHint lamB) Explicit piAnn \b -> do + Abs piBs' UnitE <- applyRename (piB@>binderName b) (EmptyAbs piBs) + extendRenamer (lamB@>sink (binderName b)) $ + checkLamBinders piExpls piBs' lamBsRest \bs -> cont (Nest b bs) + Empty -> error "zip error" + checkLamBinders _ _ _ _ = error "zip error" + +inferUForExpr :: Emits o => UForExpr i -> InfererM i o (LamExpr CoreIR o) +inferUForExpr (UForExpr b body) = do + withUBinder b \(WithAttrB _ b') -> do + body' <- buildScoped $ withBlockDecls body \result -> bottomUp result + return $ LamExpr (UnaryNest b') body' + +checkUForExpr :: Emits o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) +checkUForExpr (UForExpr bFor body) (TabPiType _ bPi resultTy) = do + let uLamExpr = ULamExpr (UnaryNest bFor) ExplicitApp Nothing Nothing body + effsAllowed <- infEffects <$> getInfState + partialPi <- liftEnvReaderM $ refreshAbs (Abs bPi resultTy) \bPi' resultTy' -> do + return $ PartialPiType ExplicitApp [Explicit] (UnaryNest bPi') (sink effsAllowed) (Check resultTy') + CoreLamExpr _ lamExpr <- checkULamPartial partialPi uLamExpr + return lamExpr + +inferULam :: ULamExpr i -> InfererM i o (CoreLamExpr o) +inferULam (ULamExpr bs appExpl effs resultTy body) = do + Abs (ZipB expls bs') (PairE effTy body') <- inferUBinders bs \_ -> do + effs' <- fromMaybe Pure <$> mapM checkUEffRow effs + resultTy' <- mapM checkUType resultTy + body' <- buildScoped $ withAllowedEffects (sink effs') do + withBlockDecls body \result -> + case resultTy' of + Nothing -> bottomUp result + Just resultTy'' -> topDown (sink resultTy'') result + resultTy'' <- blockTy body' + let effTy = EffTy effs' resultTy'' + return $ PairE effTy body' + return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body') + +checkULam :: ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) +checkULam ulam piTy = checkULamPartial (piAsPartialPi piTy) ulam + +piAsPartialPi :: CorePiType n -> PartialPiType n +piAsPartialPi (CorePiType appExpl expls bs (EffTy effs ty)) = + PartialPiType appExpl expls bs effs (Check ty) + +typeAsPartialType :: CType n -> PartialType n +typeAsPartialType (Pi piTy) = PartialType $ piAsPartialPi piTy +typeAsPartialType ty = FullType ty + +piAsPartialPiDropResultTy :: CorePiType n -> PartialPiType n +piAsPartialPiDropResultTy (CorePiType appExpl expls bs (EffTy effs _)) = + PartialPiType appExpl expls bs effs Infer + +checkInstanceParams :: Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] +checkInstanceParams bsTop paramsTop = go bsTop paramsTop where - go :: EmitsInf o => Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] + go :: Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] go Empty [] = return [] go (Nest (b:>ty) bs) (x:xs) = do x' <- checkUParam ty x @@ -1944,7 +1405,7 @@ checkInstanceParams expls bsTop paramsTop = do go _ _ = error "zip error" checkInstanceBody - :: EmitsInf o => ClassName o -> [CAtom o] + :: ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className @@ -1966,8 +1427,7 @@ superclassDictTys (Nest b bs) = do Abs bs' UnitE <- liftHoistExcept $ hoist b $ Abs bs UnitE (binderType b:) <$> superclassDictTys bs' -checkMethodDef :: EmitsInf o - => ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) +checkMethodDef :: ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv @@ -1976,40 +1436,32 @@ checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint classSourceName (i,) <$> Lam <$> checkULam rhs (methodTys !! i) -checkUEffRow :: EmitsInf o => UEffectRow i -> InfererM i o (EffectRow CoreIR o) +checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) checkUEffRow (UEffectRow effs t) = do effs' <- liftM eSetFromList $ mapM checkUEff $ toList effs t' <- case t of Nothing -> return NoTail Just (~(SIInternalName _ v _ _)) -> do v' <- toAtomVar =<< renameM v - constrainVarTy v' EffKind + expectEq EffKind (getType v') return $ EffectRowTail v' return $ EffectRow effs' t' -checkUEff :: EmitsInf o => UEffect i -> InfererM i o (Effect CoreIR o) +checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) checkUEff eff = case eff of URWSEffect rws (~(SIInternalName _ region _ _)) -> do region' <- renameM region >>= toAtomVar - constrainVarTy region' (TC HeapType) + expectEq (TC HeapType) (getType region') return $ RWSEffect rws (Var region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect -constrainVarTy :: EmitsInf o => CAtomVar o -> CType o -> InfererM i o () -constrainVarTy v tyReq = do - varTy <- return $ getType $ Var v - constrainTypesEq tyReq varTy - type CaseAltIndex = Int -checkCaseAlt :: EmitsBoth o - => CType o -> CType o -> UAlt i -> InfererM i o (IndexedAlt o) +checkCaseAlt :: Emits o => RequiredTy o -> CType o -> UAlt i -> InfererM i o (IndexedAlt o) checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do alt <- checkCasePat pat scrutineeTy do - reqTy' <- sinkM reqTy - withBlockDecls body \result -> - checkOrInferRho noHint result (Check reqTy') + withBlockDecls body \result -> checkOrInfer (sink reqTy) result idx <- getCaseAltIndex pat return $ IndexedAlt idx alt @@ -2020,286 +1472,158 @@ getCaseAltIndex (WithSrcB _ pat) = case pat of return con _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" -checkCasePat :: EmitsBoth o - => UPat i i' - -> CType o - -> (forall o'. (EmitsBoth o', Ext o o') => InfererM i' o' (CAtom o')) - -> InfererM i o (Alt CoreIR o) +checkCasePat + :: Emits o + => UPat i i' -> CType o + -> (forall o'. (Emits o', Ext o o') => InfererM i' o' (CAtom o')) + -> InfererM i o (Alt CoreIR o) checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon - TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName + tyConDef <- lookupTyCon dataDefName + params <- inferParams scrutineeTy dataDefName + ADTCons cons <- instantiateTyConDef tyConDef params DataConDef _ _ repTy idxs <- return $ cons !! con when (length idxs /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) - (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) - constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params - buildAltInf repTy' \arg -> do - args <- forM idxs \projs -> do - ans <- normalizeNaryProj (init projs) (Var arg) - emit $ Atom ans - bindLetPats ps args $ cont + withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do + buildScoped do + args <- forM idxs \projs -> do + ans <- normalizeNaryProj (init projs) (sink $ Var $ binderVar b) + emit $ Atom ans + bindLetPats ps args $ cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" -inferParams :: (EmitsBoth o, HasNamesE e, SubstE AtomSubstVal e) - => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) -inferParams sourceName roleExpls (Abs paramBs bodyTop) = do - let expls = snd <$> roleExpls - (params, e') <- go expls (Abs paramBs bodyTop) - return (TyConParams expls params, e') - where - go :: (EmitsBoth o, HasNamesE e, SubstE AtomSubstVal e) - => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) - go [] (Abs Empty body) = return ([], body) - go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do - x <- case expl of - Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) ty - Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech ty - rest <- applySubst (b@>SubstVal x) $ Abs bs body - (params, body') <- go expls rest - return (x:params, body') - go _ _ = error "zip error" - -bindLetPats :: EmitsBoth o - => Nest UPat i i' -> [CAtomVar o] -> InfererM i' o a -> InfererM i o a -bindLetPats Empty [] cont = cont -bindLetPats (Nest p ps) (x:xs) cont = bindLetPat p x $ bindLetPats ps xs cont +inferParams :: Emits o => CType o -> TyConName o -> InfererM i o (TyConParams o) +inferParams ty dataDefName = do + TyConDef sourceName roleExpls paramBs _ <- lookupTyCon dataDefName + let paramExpls = snd <$> roleExpls + let inferenceExpls = paramExpls <&> \case + Explicit -> Inferred Nothing Unify + expl -> expl + paramBsAbs <- buildConstraints (Abs paramBs UnitE) \params _ -> do + let ty' = TypeCon sourceName (sink dataDefName) $ TyConParams paramExpls params + return [TypeConstraint (sink ty) ty'] + args <- inferMixedArgs sourceName inferenceExpls paramBsAbs emptyMixedArgs + return $ TyConParams paramExpls args + +bindLetPats + :: (Emits o, HasNamesE e) + => Nest UPat i i' -> [CAtomVar o] + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) +bindLetPats Empty [] cont = getDistinct >>= \Distinct -> cont +bindLetPats (Nest p ps) (x:xs) cont = bindLetPat p x $ bindLetPats ps (sink <$> xs) cont bindLetPats _ _ _ = error "mismatched number of args" -bindLetPat :: EmitsBoth o => UPat i i' -> CAtomVar o -> InfererM i' o a -> InfererM i o a +bindLetPat + :: (Emits o, HasNamesE e) + => UPat i i' -> CAtomVar o + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of - UPatBinder b -> extendSubst (b @> atomVarName v) cont + UPatBinder b -> getDistinct >>= \Distinct -> extendSubst (b @> atomVarName v) cont UPatProd ps -> do let n = nestLength ps - ty <- return $ getType v - _ <- fromProdType n ty - x <- zonk $ Var v + case getType v of + ProdTy ts | length ts == n -> return () + ty -> throw TypeErr $ "Expected a product type but got: " ++ pprint ty xs <- forM (iota n) \i -> do - normalizeProj (ProjectProduct i) x >>= emit . Atom + normalizeProj (ProjectProduct i) (Var v) >>= emit . Atom bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do - let x = Var v - ty <- return $ getType x - _ <- fromDepPairType ty - x' <- zonk x -- ensure it has a dependent pair type before unpacking - x1 <- getFst x' >>= zonk >>= emit . Atom + case getType v of + DepPairTy _ -> return () + ty -> throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty + x1 <- getFst (Var v) >>= emit . Atom bindLetPat p1 x1 do - x2 <- getSnd x' >>= zonk >>= emit . Atom + x2 <- getSnd (sink $ Var v) >>= emit . Atom bindLetPat p2 x2 do cont UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, _) <- lookupDataCon =<< renameM conName - TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName + TyConDef _ _ _ cons <- lookupTyCon dataDefName case cons of ADTCons [DataConDef _ _ _ idxss] -> do when (length idxss /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxss) ++ " got " ++ show (nestLength ps) - (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) - constrainVarTy v $ TypeCon sourceName dataDefName params - x <- cheapNormalize =<< zonk (Var v) + void $ inferParams (getType $ Var v) dataDefName + x <- cheapNormalize $ Var v xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom bindLetPats ps xs cont _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" UPatTable ps -> do - elemTy <- freshType let n = fromIntegral (nestLength ps) :: Word32 - let iTy = FinConst n - idxTy <- asIxType iTy - ty <- return $ getType $ Var v - constrainTypesEq ty (idxTy ==> elemTy) - v' <- zonk $ Var v + cheapNormalize (getType v) >>= \case + TabPi (TabPiType _ (_:>FinConst n') _) | n == n' -> return () + ty -> throw TypeErr $ "Expected a Fin " ++ show n ++ " table type but got: " ++ pprint ty xs <- forM [0 .. n - 1] \i -> do - emit =<< mkTabApp v' [NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)] + emit =<< mkTabApp (Var v) [NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)] bindLetPats ps xs cont -checkAnn :: EmitsInf o => SourceName -> UAnn req i -> InfererM i o (CType o) -checkAnn binderSourceName ann = case ann of - UAnn ty -> checkUType ty - UNoAnn -> do - let desc = AnnotationInfVar binderSourceName - TyVar <$> freshInferenceName desc TyKind - -checkUType :: EmitsInf o => UType i -> InfererM i o (CType o) +checkUType :: UType i -> InfererM i o (CType o) checkUType t = do Type t' <- checkUParam TyKind t return t' -checkUParam :: EmitsInf o => Kind CoreIR o -> UType i -> InfererM i o (CAtom o) +checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) checkUParam k uty@(WithSrcE pos _) = addSrcContext pos $ - withReducibleEmissions msg $ withoutEffects $ checkRho noHint uty (sink k) + withReducibleEmissions msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty where msg = "Can't reduce type expression: " ++ pprint uty -inferTabCon :: forall i o. EmitsBoth o - => NameHint -> [UExpr i] -> RequiredTy CType o -> InfererM i o (CAtom o) -inferTabCon hint xs reqTy = do +inferTabCon :: forall i o. Emits o => [UExpr i] -> InfererM i o (CAtom o) +inferTabCon xs = do let n = fromIntegral (length xs) :: Word32 let finTy = FinConst n - ctx <- srcPosCtx <$> getErrCtx - let dataDictHole dTy = Just $ WhenIRE $ DictHole (AlwaysEqual ctx) dTy Full - case reqTy of - Infer -> do - elemTy <- case xs of - [] -> freshType - (x:_) -> getType <$> inferRho noHint x - ixTy <- asIxType finTy - let tabTy = ixTy ==> elemTy - xs' <- forM xs \x -> checkRho noHint x elemTy - dTy <- DictTy <$> dataDictType elemTy - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - Check tabTy -> do - TabPiType _ b elemTy <- fromTabPiType True tabTy - constrainTypesEq (binderType b) finTy - xs' <- forM (enumerate xs) \(i, x) -> do - let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o - elemTy' <- applySubst (b@>SubstVal i') elemTy - checkRho noHint x elemTy' - dTy <- case hoist b elemTy of - HoistSuccess elemTy' -> DictTy <$> dataDictType elemTy' - HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do - withFreshBinder noHint finTy \b' -> do - elemTy' <- applyRename (b@>binderName b') elemTy - dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - --- Bool flag is just to tweak the reported error message -fromTabPiType :: EmitsBoth o => Bool -> CType o -> InfererM i o (TabPiType CoreIR o) -fromTabPiType _ (TabPi piTy) = return piTy -fromTabPiType expectPi ty = do - a <- freshType - b <- freshType - a' <- asIxType a - let piTy = nonDepTabPiType a' b - if expectPi then constrainTypesEq (TabPi piTy) ty - else constrainTypesEq ty (TabPi piTy) - return piTy - -fromProdType :: EmitsBoth o => Int -> CType o -> InfererM i o [CType o] -fromProdType n (ProdTy ts) | length ts == n = return ts -fromProdType n ty = do - ts <- mapM (const $ freshType) (replicate n ()) - constrainTypesEq (ProdTy ts) ty - return ts - -fromDepPairType :: EmitsBoth o => CType o -> InfererM i o (DepPairType CoreIR o) -fromDepPairType (DepPairTy t) = return t -fromDepPairType ty = throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty - -addEffects :: EmitsBoth o => EffectRow CoreIR o -> InfererM i o () + elemTy <- case xs of + [] -> throw TypeErr "Can't infer type of empty table" + x:_ -> getType <$> bottomUp x + ixTy <- asIxType finTy + let tabTy = ixTy ==> elemTy + xs' <- forM xs \x -> topDown elemTy x + dTy <- DictTy <$> dataDictType elemTy + dataDict <- trySynthTerm dTy Full + emitExpr $ TabCon (Just $ WhenIRE dataDict) tabTy xs' + +checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> [UExpr i] -> InfererM i o (CAtom o) +checkTabCon tabTy@(TabPiType _ b elemTy) xs = do + let n = fromIntegral (length xs) :: Word32 + let finTy = FinConst n + expectEq (binderType b) finTy + xs' <- forM (enumerate xs) \(i, x) -> do + let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o + elemTy' <- applySubst (b@>SubstVal i') elemTy + topDown elemTy' x + dTy <- case hoist b elemTy of + HoistSuccess elemTy' -> DictTy <$> dataDictType elemTy' + HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do + withFreshBinder noHint finTy \b' -> do + elemTy' <- applyRename (b@>binderName b') elemTy + dTy <- DictTy <$> dataDictType elemTy' + return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) + dataDict <- trySynthTerm dTy Full + emitExpr $ TabCon (Just $ WhenIRE dataDict) (TabPi tabTy) xs' + +addEffects :: EffectRow CoreIR o -> InfererM i o () +addEffects Pure = return () addEffects eff = do - allowed <- checkAllowedUnconditionally eff - unless allowed $ do - effsAllowed <- getAllowedEffects - eff' <- openEffectRow eff - constrainEq (Eff effsAllowed) (Eff eff') - -checkAllowedUnconditionally :: EffectRow CoreIR o -> InfererM i o Bool -checkAllowedUnconditionally Pure = return True -checkAllowedUnconditionally eff = do - eff' <- zonk eff - effAllowed <- getAllowedEffects >>= zonk - return $ case checkExtends effAllowed eff' of - Failure _ -> False - Success () -> True - -openEffectRow :: EmitsBoth o => EffectRow CoreIR o -> InfererM i o (EffectRow CoreIR o) -openEffectRow (EffectRow effs NoTail) = extendEffRow effs <$> freshEff -openEffectRow effRow = return effRow + effsAllowed <- infEffects <$> getInfState + case checkExtends effsAllowed eff of + Success () -> return () + Failure _ -> expectEq (Eff effsAllowed) (Eff eff) + +getIxDict :: CType o -> InfererM i o (IxDict CoreIR o) +getIxDict t = do + dictTy <- DictTy <$> ixDictType t + IxDictAtom <$> trySynthTerm dictTy Full asIxType :: CType o -> InfererM i o (IxType CoreIR o) -asIxType ty = do - dictTy <- DictTy <$> ixDictType ty - ctx <- srcPosCtx <$> getErrCtx - return $ IxType ty $ IxDictAtom $ DictHole (AlwaysEqual ctx) dictTy Full -{-# SCC asIxType #-} +asIxType ty = IxType ty <$> getIxDict ty -- === Solver === -newtype SolverSubst n = SolverSubst (M.Map (CAtomName n) (CAtom n)) - -instance Pretty (SolverSubst n) where - pretty (SolverSubst m) = pretty $ M.toList m - -type SolverOutMap = InfOutMap - -data SolverOutFrag (n::S) (l::S) = - SolverOutFrag (SolverEmissions n l) (Constraints l) -newtype Constraints n = Constraints (SnocList (CAtomName n, CAtom n)) - deriving (Monoid, Semigroup) -type SolverEmissions = RNest (BinderP (AtomNameC CoreIR) SolverBinding) - -instance GenericE Constraints where - type RepE Constraints = ListE (CAtomName `PairE` CAtom) - fromE (Constraints xs) = ListE [PairE x y | (x,y) <- toList xs] - {-# INLINE fromE #-} - toE (ListE xs) = Constraints $ toSnocList $ [(x,y) | PairE x y <- xs] - {-# INLINE toE #-} - -instance SinkableE Constraints -instance RenameE Constraints -instance HoistableE Constraints -instance Pretty (Constraints n) where - pretty (Constraints xs) = pretty $ unsnoc xs - -instance GenericB SolverOutFrag where - type RepB SolverOutFrag = PairB SolverEmissions (LiftB Constraints) - fromB (SolverOutFrag em subst) = PairB em (LiftB subst) - toB (PairB em (LiftB subst)) = SolverOutFrag em subst - -instance ProvesExt SolverOutFrag -instance RenameB SolverOutFrag -instance BindsNames SolverOutFrag -instance SinkableB SolverOutFrag - -instance OutFrag SolverOutFrag where - emptyOutFrag = SolverOutFrag REmpty mempty - catOutFrags (SolverOutFrag em ss) (SolverOutFrag em' ss') = - withExtEvidence em' $ - SolverOutFrag (em >>> em') (sink ss <> ss') - -instance ExtOutMap InfOutMap SolverOutFrag where - extendOutMap infOutMap outFrag = - extendOutMap infOutMap $ liftSolverOutFrag outFrag - -type SolverM = InfererM VoidS - -liftSolverM :: EnvReader m => SolverM n a -> m n (Except a) -liftSolverM cont = do - env <- unsafeGetEnv - Distinct <- getDistinct - return do - maybeResult <- runSearcherM $ runInplaceT (initInfOutMap env) $ - runSubstReaderT (newSubst absurdNameFunction) $ runInfererM' cont - case maybeResult of - Nothing -> throw TypeErr "No solution" - Just (_, result) -> return result -{-# INLINE liftSolverM #-} - -newtype SolverEmission (n::S) (l::S) = SolverEmission (BinderP (AtomNameC CoreIR) SolverBinding n l) -instance ExtOutMap SolverOutMap SolverEmission where - extendOutMap env (SolverEmission e) = env `extendOutMap` toEnvFrag e -instance ExtOutFrag SolverOutFrag SolverEmission where - extendOutFrag (SolverOutFrag es substs) (SolverEmission e) = - withSubscopeDistinct e $ SolverOutFrag (RNest es e) (sink substs) - -freshInferenceName :: EmitsInf n => InfVarDesc -> Kind CoreIR n -> InfererM i n (CAtomVar n) -freshInferenceName desc k = do - ctx <- srcPosCtx <$> getErrCtx - emitSolver $ InfVarBound k (ctx, desc) -{-# INLINE freshInferenceName #-} - -freshSkolemName :: EmitsInf n => Kind CoreIR n -> InfererM i n (CAtomVar n) -freshSkolemName k = emitSolver $ SkolemBound k -{-# INLINE freshSkolemName #-} - -emptySolverSubst :: SolverSubst n -emptySolverSubst = SolverSubst mempty - -singleConstraint :: CAtomName n -> CAtom n -> Constraints n -singleConstraint v ty = Constraints $ toSnocList [(v, ty)] - -- TODO: put this pattern and friends in the Name library? Don't really want to -- have to think about `eqNameColorRep` just to implement a partial map. lookupSolverSubst :: forall c n. Color c => SolverSubst n -> Name c n -> AtomSubstVal c n @@ -2308,48 +1632,32 @@ lookupSolverSubst (SolverSubst m) name = Nothing -> Rename name Just (ColorsEqual :: ColorsEqual c (AtomNameC CoreIR))-> case M.lookup name m of Nothing -> Rename name - Just ty -> SubstVal ty - -applySolverSubstE :: (SubstE AtomSubstVal e, Distinct n) - => Env n -> SolverSubst n -> e n -> e n -applySolverSubstE env solverSubst@(SolverSubst m) e = - if M.null m then e else fmapNames env (lookupSolverSubst solverSubst) e - -zonkWithOutMap :: (SubstE AtomSubstVal e, Distinct n) - => InfOutMap n -> e n -> e n -zonkWithOutMap (InfOutMap bindings solverSubst _ _ _) e = - applySolverSubstE bindings solverSubst e - -liftSolverOutFrag :: Distinct l => SolverOutFrag n l -> InfOutFrag n l -liftSolverOutFrag (SolverOutFrag emissions subst) = - InfOutFrag (liftSolverEmissions emissions) mempty subst - -liftSolverEmissions :: Distinct l => SolverEmissions n l -> InfEmissions n l -liftSolverEmissions emissions = - fmapRNest (\(b:>emission) -> (b:>RightE emission)) emissions - -fmapRNest :: (forall ii ii'. b ii ii' -> b' ii ii') - -> RNest b i i' - -> RNest b' i i' -fmapRNest _ REmpty = REmpty -fmapRNest f (RNest rest b) = RNest (fmapRNest f rest) (f b) - -instance GenericE SolverSubst where - -- XXX: this is a bit sketchy because it's not actually bijective... - type RepE SolverSubst = ListE (PairE CAtomName CAtom) - fromE (SolverSubst m) = ListE $ map (uncurry PairE) $ M.toList m - {-# INLINE fromE #-} - toE (ListE pairs) = SolverSubst $ M.fromList $ map fromPairE pairs - {-# INLINE toE #-} - -instance SinkableE SolverSubst where -instance RenameE SolverSubst where -instance HoistableE SolverSubst - -constrainTypesEq :: EmitsInf o => CType o -> CType o -> InfererM i o () + Just sol -> SubstVal sol + +applyConstraint :: Constraint o -> SolverM i o () +applyConstraint = \case + TypeConstraint t1 t2 -> constrainTypesEq t1 t2 + EffectConstraint r1 r2' -> do + -- r1 shouldn't have inference variables. And we can't infer anything about + -- any inference variables in r2's explicit effects because we don't know + -- how they line up with r1's. So this is just about figuring out r2's tail. + r2 <- zonk r2' + let msg = "Allowed effects: " ++ pprint r1 ++ + "\nRequested effects: " ++ pprint r2 + case checkExtends r1 r2 of + Success () -> return () + Failure _ -> addContext msg $ searchFailureAsTypeErr do + EffectRow effs1 t1 <- return r1 + EffectRow effs2 (EffectRowTail v2) <- return r2 + guard =<< isUnificationName (atomVarName v2) + guard $ null (eSetToList $ effs2 `eSetDifference` effs1) + let extras1 = effs1 `eSetDifference` effs2 + extendSolution (atomVarName v2) (Eff $ EffectRow extras1 t1) + +constrainTypesEq :: CType o -> CType o -> SolverM i o () constrainTypesEq t1 t2 = constrainEq (Type t1) (Type t2) -- TODO: use a type class instead? -constrainEq :: EmitsInf o => CAtom o -> CAtom o -> InfererM i o () +constrainEq :: CAtom o -> CAtom o -> SolverM i o () constrainEq t1 t2 = do t1' <- zonk t1 t2' <- zonk t2 @@ -2361,25 +1669,23 @@ constrainEq t1 t2 = do ++ (case infVars of Empty -> "" _ -> "\n(Solving for: " ++ pprint (nestToList pprint infVars) ++ ")") - void $ addContext msg $ withSubst (newSubst absurdNameFunction) $ unify t1' t2' + void $ addContext msg $ unify t1' t2' -class (AlphaEqE e, SinkableE e, SubstE AtomSubstVal e) => Unifiable (e::E) where - unifyZonked :: EmitsInf n => e n -> e n -> SolverM n () +class (AlphaEqE e, Zonkable e) => Unifiable (e::E) where + unifyZonked :: e n -> e n -> SolverM i n () -tryConstrainEq :: EmitsInf o => CAtom o -> CAtom o -> InfererM i o () -tryConstrainEq t1 t2 = do - constrainEq t1 t2 `catchErr` \errs -> case errs of - Errs [Err TypeErr _ _] -> return () - _ -> throwErrs errs - -unify :: (EmitsInf n, Unifiable e) => e n -> e n -> SolverM n () +unify :: Unifiable e => e n -> e n -> SolverM i n () unify e1 e2 = do e1' <- zonk e1 e2' <- zonk e2 - (unifyZonked e1' e2' throw TypeErr "") + searchFailureAsTypeErr $ unifyZonked e1' e2' {-# INLINE unify #-} {-# SCC unify #-} +searchFailureAsTypeErr :: SolverM i n a -> SolverM i n a +searchFailureAsTypeErr cont = cont <|> throw TypeErr "" +{-# INLINE searchFailureAsTypeErr #-} + instance Unifiable (Atom CoreIR) where unifyZonked e1 e2 = confuseGHC >>= \_ -> case sameConstructor e1 e2 of False -> case (e1, e2) of @@ -2437,7 +1743,7 @@ instance Unifiable (EffectRow CoreIR) where <|> unifyZip x1 x2 where - unifyDirect :: EmitsInf n => EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM n () + unifyDirect :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () unifyDirect r@(EffectRow effs' mv') (EffectRow effs (EffectRowTail v)) | null (eSetToList effs) = case mv' of EffectRowTail v' | v == v' -> guard $ null $ eSetToList effs' @@ -2445,17 +1751,27 @@ instance Unifiable (EffectRow CoreIR) where unifyDirect _ _ = empty {-# INLINE unifyDirect #-} - unifyZip :: EmitsInf n => EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM n () + unifyZip :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () unifyZip r1 r2 = case (r1, r2) of (EffectRow effs1 t1, EffectRow effs2 t2) | not (eSetNull effs1 || eSetNull effs2) -> do let extras1 = effs1 `eSetDifference` effs2 let extras2 = effs2 `eSetDifference` effs1 - newRow <- freshEff - unify (EffectRow mempty t1) (extendEffRow extras2 newRow) - unify (extendEffRow extras1 newRow) (EffectRow mempty t2) + void $ withFreshEff \newRow -> do + unify (EffectRow mempty (sink t1)) (extendEffRow (sink extras2) newRow) + unify (extendEffRow (sink extras1) newRow) (EffectRow mempty (sink t2)) + return UnitE _ -> unifyEq r1 r2 -unifyEq :: AlphaEqE e => e n -> e n -> SolverM n () +withFreshEff + :: Zonkable e + => (forall o'. DExt o o' => EffectRow CoreIR o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshEff cont = + withFreshUnificationVarNoEmits MiscInfVar EffKind \v -> do + cont $ EffectRow mempty $ EffectRowTail v +{-# INLINE withFreshEff #-} + +unifyEq :: AlphaEqE e => e n -> e n -> SolverM i n () unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} @@ -2466,33 +1782,48 @@ instance Unifiable CorePiType where unless (expls1 == expls2) empty go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where - go :: EmitsInf n - => Abs (Nest CBinder) (EffTy CoreIR) n + go :: Abs (Nest CBinder) (EffTy CoreIR) n -> Abs (Nest CBinder) (EffTy CoreIR) n - -> SolverM n () + -> SolverM i n () go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 go (Abs (Nest (b1:>t1) bs1) rest1) (Abs (Nest (b2:>t2) bs2) rest2) = do unify t1 t2 - v <- freshSkolemName t1 - ab1 <- zonk =<< applySubst (b1@>SubstVal (Var v)) (Abs bs1 rest1) - ab2 <- zonk =<< applySubst (b2@>SubstVal (Var v)) (Abs bs2 rest2) - go ab1 ab2 + void $ withFreshSkolemName t1 \v -> do + ab1 <- zonk =<< applyRename (b1@>atomVarName v) (Abs bs1 rest1) + ab2 <- zonk =<< applyRename (b2@>atomVarName v) (Abs bs2 rest2) + go ab1 ab2 + return UnitE go _ _ = empty -unifyTabPiType :: EmitsInf n => TabPiType CoreIR n -> TabPiType CoreIR n -> SolverM n () +unifyTabPiType :: TabPiType CoreIR n -> TabPiType CoreIR n -> SolverM i n () unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do let ann1 = binderType b1 let ann2 = binderType b2 unify ann1 ann2 - v <- freshSkolemName ann1 - ty1' <- applySubst (b1@>SubstVal (Var v)) ty1 - ty2' <- applySubst (b2@>SubstVal (Var v)) ty2 - unify ty1' ty2' - -extendSolution :: CAtomName n -> CAtom n -> SolverM n () + void $ withFreshSkolemName ann1 \v -> do + ty1' <- applyRename (b1@>atomVarName v) ty1 + ty2' <- applyRename (b2@>atomVarName v) ty2 + unify ty1' ty2' + return UnitE + +withFreshSkolemName + :: Zonkable e => Kind CoreIR o + -> (forall o'. DExt o o' => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshSkolemName ty cont = diffStateT1 \s -> do + withFreshBinder "skol" (SkolemBound ty) \b -> do + (ans, diff) <- runDiffStateT1 (sink s) do + v <- toAtomVar $ binderName b + ans <- cont v >>= zonk + liftHoistExcept $ hoist b ans + diff' <- liftHoistExcept $ hoist b diff + return (ans, diff') +{-# INLINE withFreshSkolemName #-} + +extendSolution :: CAtomName n -> CAtom n -> SolverM i n () extendSolution v t = - isInferenceName v >>= \case + isUnificationName v >>= \case True -> do when (v `isFreeIn` t) $ throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) -- When we unify under a pi binder we replace its occurrences with a @@ -2501,14 +1832,20 @@ extendSolution v t = -- of worms. forM_ (freeAtomVarsList t) \fv -> whenM (isSkolemName fv) $ throw TypeErr $ "Can't unify with skolem vars" - extendSolverSubst v t + addConstraint v t False -> empty -isInferenceName :: EnvReader m => CAtomName n -> m n Bool -isInferenceName v = lookupEnv v >>= \case +isUnificationName :: EnvReader m => CAtomName n -> m n Bool +isUnificationName v = lookupEnv v >>= \case AtomNameBinding (SolverBound (InfVarBound _ _)) -> return True _ -> return False -{-# INLINE isInferenceName #-} +{-# INLINE isUnificationName #-} + +isSolverName :: EnvReader m => CAtomName n -> m n Bool +isSolverName v = lookupEnv v >>= \case + AtomNameBinding (SolverBound _) -> return True + _ -> return False + isSkolemName :: EnvReader m => CAtomName n -> m n Bool isSkolemName v = lookupEnv v >>= \case @@ -2516,22 +1853,10 @@ isSkolemName v = lookupEnv v >>= \case _ -> return False {-# INLINE isSkolemName #-} -freshType :: EmitsInf n => InfererM i n (CType n) -freshType = TyVar <$> freshInferenceName MiscInfVar TyKind -{-# INLINE freshType #-} - -freshAtom :: EmitsInf n => Type CoreIR n -> InfererM i n (CAtom n) -freshAtom t = Var <$> freshInferenceName MiscInfVar t -{-# INLINE freshAtom #-} - -freshEff :: EmitsInf n => InfererM i n (EffectRow CoreIR n) -freshEff = EffectRow mempty . EffectRowTail <$> freshInferenceName MiscInfVar EffKind -{-# INLINE freshEff #-} - renameForPrinting :: (EnvReader m, HasNamesE e) => e n -> m n (Abs (Nest (AtomNameBinder CoreIR)) e n) renameForPrinting e = do - infVars <- filterM isInferenceVar $ freeAtomVarsList e + infVars <- filterM isSolverName $ freeAtomVarsList e let ab = abstractFreeVarsNoAnn infVars e let hints = take (length infVars) $ map getNameHint $ map (:[]) ['a'..'z'] ++ map show [(0::Int)..] @@ -2545,17 +1870,6 @@ renameForPrinting e = do -- === dictionary synthesis === -synthTopE :: (EnvReader m, Fallible1 m, DictSynthTraversable e) => e n -> m n (e n) -synthTopE block = do - (liftExcept =<<) $ liftDictSynthTraverserM $ dsTraverse block -{-# SCC synthTopE #-} - -synthTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> m n (TyConDef n) -synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynthTraverserM do - dsTraverseExplBinders (snd <$> roleExpls) bs \bs' -> - TyConDef sn roleExpls bs' <$> dsTraverse body -{-# SCC synthTyConDef #-} - -- Given a simplified dict (an Atom of type `DictTy _` in the -- post-simplification IR), and a requested, more general, dict type, generalize -- the dict to match the more general type. This is only possible because we @@ -2563,133 +1877,90 @@ synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynth -- valid to implement `generalizeDict` by re-synthesizing the whole dictionary, -- but we know that the derivation tree has to be the same, so we take the -- shortcut of just generalizing the data parameters. -generalizeDict :: (EnvReader m) => CType n -> Dict n -> m n (Dict n) +generalizeDict :: EnvReader m => CType n -> Dict n -> m n (Dict n) generalizeDict ty dict = do - result <- liftSolverM $ solveLocal $ generalizeDictAndUnify (sink ty) (sink dict) + result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict case result of Failure e -> error $ "Failed to generalize " ++ pprint dict ++ " to " ++ pprint ty ++ " because " ++ pprint e Success ans -> return ans -generalizeDictAndUnify :: EmitsInf n => CType n -> Dict n -> SolverM n (Dict n) -generalizeDictAndUnify ty dict = do - dict' <- generalizeDictRec dict - dictTy <- return $ getType dict' - unify ty dictTy - zonk dict' - -generalizeDictRec :: EmitsInf n => Dict n -> SolverM n (Dict n) -generalizeDictRec dict = do +generalizeDictRec :: CType n -> Dict n -> InfererM i n (Dict n) +generalizeDictRec targetTy dict = do -- TODO: we should be able to avoid the normalization here . We only need it -- because we sometimes end up with superclass projections. But they shouldn't -- really be allowed to occur in the post-simplification IR. DictCon _ dict' <- cheapNormalize dict - mkDictAtom =<< case dict' of + case dict' of InstanceDict instanceName args -> do InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName - args' <- generalizeInstanceArgs roleExpls bs args - return $ InstanceDict instanceName args' - IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy + liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do + d <- mkDictAtom $ InstanceDict (sink instanceName) args' + constrainEq (sink $ Type targetTy) (Type $ getType d) + return d + IxFin _ -> case targetTy of + DictTy (DictType "Ix" _ [Type (NewtypeTyCon (Fin n))]) -> mkDictAtom $ IxFin n + _ -> error $ "not an Ix(Fin _) dict: " ++ pprint targetTy InstantiatedGiven _ _ -> notSimplifiedDict SuperclassProj _ _ -> notSimplifiedDict - DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty + DataData _ -> case targetTy of + DictTy (DictType "Data" _ [Type t]) -> mkDictAtom $ DataData t + _ -> error "not a data dict" where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict -generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> Nest CBinder n l -> [CAtom n] -> SolverM n [CAtom n] -generalizeInstanceArgs [] Empty [] = return [] -generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do - arg' <- case role of - -- XXX: for `TypeParam` we can just emit a fresh inference name rather than - -- traversing the whole type like we do in `Generalize.hs`. The reason is - -- that it's valid to implement `generalizeDict` by synthesizing an entirely - -- fresh dictionary, and if we were to do that, we would infer this type - -- parameter exactly as we do here, using inference. - TypeParam -> Var <$> freshInferenceName MiscInfVar TyKind - DictParam -> generalizeDictAndUnify ty arg - DataParam -> Var <$> freshInferenceName MiscInfVar ty - Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) - args' <- generalizeInstanceArgs expls bs' args - return $ arg':args' -generalizeInstanceArgs _ _ _ = error "zip error" - -synthInstanceDefAndAddSynthCandidate - :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceName n) -synthInstanceDefAndAddSynthCandidate def@(InstanceDef className expls bs params (InstanceBody superclasses _)) = do - let emptyDef = InstanceDef className expls bs params $ InstanceBody superclasses [] - instanceName <- emitInstanceDef emptyDef - addInstanceSynthCandidate className instanceName - synthInstanceDefRec instanceName def - return instanceName +generalizeInstanceArgs + :: Zonkable e => [RoleExpl] -> Nest CBinder o any -> [CAtom o] + -> (forall o'. DExt o o' => [CAtom o'] -> SolverM i o' (e o')) + -> SolverM i o (e o) +generalizeInstanceArgs [] Empty [] cont = withDistinct $ cont [] +generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) cont = do + generalizeInstanceArg role ty arg \arg' -> do + Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) + generalizeInstanceArgs expls bs' (sink <$> args) \args' -> + cont $ sink arg' : args' +generalizeInstanceArgs _ _ _ _ = error "zip error" + +generalizeInstanceArg + :: Zonkable e => ParamRole -> CType o -> CAtom o + -> (forall o'. DExt o o' => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +generalizeInstanceArg role ty arg cont = case role of + -- XXX: for `TypeParam` we can just emit a fresh inference name rather than + -- traversing the whole type like we do in `Generalize.hs`. The reason is + -- that it's valid to implement `generalizeDict` by synthesizing an entirely + -- fresh dictionary, and if we were to do that, we would infer this type + -- parameter exactly as we do here, using inference. + TypeParam -> withFreshUnificationVarNoEmits MiscInfVar TyKind \v -> cont $ Var v + DictParam -> withFreshDictVarNoEmits ty (\ty' -> lift11 $ generalizeDictRec ty' (sink arg)) cont + DataParam -> withFreshUnificationVarNoEmits MiscInfVar ty \v -> cont $ Var v emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do ty <- getInstanceType instanceDef emitBinding (getNameHint className) $ InstanceBinding instanceDef ty -type InstanceDefAbsBodyT = - ((ListE CAtom) `PairE` (ListE CAtom) `PairE` (ListE CAtom) `PairE` (ListE CAtom)) - -pattern InstanceDefAbsBody :: [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] - -> InstanceDefAbsBodyT n -pattern InstanceDefAbsBody params superclasses doneMethods todoMethods = - ListE params `PairE` (ListE superclasses) `PairE` (ListE doneMethods) `PairE` (ListE todoMethods) - -type InstanceDefAbsT n = ([RoleExpl], Abs (Nest CBinder) InstanceDefAbsBodyT n) - -pattern InstanceDefAbs :: [RoleExpl] -> Nest CBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] - -> InstanceDefAbsT h -pattern InstanceDefAbs expls bs params superclasses doneMethods todoMethods = - (expls, Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods)) - -synthInstanceDefRec - :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceName n -> InstanceDef n -> m n () -synthInstanceDefRec instanceName def = do - InstanceDef className roleExplsTop bs params (InstanceBody superclasses methods) <- return def - let ab = InstanceDefAbs roleExplsTop bs params superclasses [] methods - recur ab className instanceName - where - recur :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) - => InstanceDefAbsT n -> ClassName n -> InstanceName n -> m n () - recur (InstanceDefAbs _ _ _ _ _ []) _ _ = return () - recur (roleExpls, ab) cname iname = do - (def', ab') <- liftExceptEnvReaderM $ refreshAbs ab - \bs' (InstanceDefAbsBody ps scs doneMethods (m:ms)) -> do - EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess (snd<$>roleExpls) bs' env - flip runReaderT (Distinct, env') $ runEnvReaderT' do - m' <- synthTopE m - let doneMethods' = doneMethods ++ [m'] - let ab' = InstanceDefAbs roleExpls bs' ps scs doneMethods' ms - let def' = InstanceDef cname roleExpls bs' ps $ InstanceBody scs doneMethods' - return (def', ab') - updateTopEnv $ UpdateInstanceDef iname def' - recur ab' cname iname - -synthInstanceDef - :: (EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceDef n) -synthInstanceDef (InstanceDef className expls bs params body) = do - liftExceptEnvReaderM $ refreshAbs (Abs bs (ListE params `PairE` body)) - \bs' (ListE params' `PairE` InstanceBody superclasses methods) -> do - EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess (snd<$>expls) bs' env - flip runReaderT (Distinct, env') $ runEnvReaderT' do - methods' <- mapM synthTopE methods - return $ InstanceDef className expls bs' params' $ InstanceBody superclasses methods' - -- main entrypoint to dictionary synthesizer -trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (SynthAtom n) +trySynthTerm :: CType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) trySynthTerm ty reqMethodAccess = do hasInferenceVars ty >>= \case - True -> throw TypeErr "Can't synthesize a dictionary for a type with inference vars" - False -> do + True -> throw TypeErr $ "Can't synthesize a dictionary for a type with inference vars: " ++ pprint ty + False -> withVoidSubst do synthTy <- liftExcept $ typeAsSynthType ty - solutions <- liftSyntherM $ synthTerm synthTy reqMethodAccess - case solutions of - [] -> throw TypeErr $ "Couldn't synthesize a class dictionary for: " ++ pprint ty - [d] -> cheapNormalize d -- normalize to reduce code size - _ -> throw TypeErr $ "Multiple candidate class dictionaries for: " ++ pprint ty + synthTerm synthTy reqMethodAccess + <|> throw TypeErr ("Couldn't synthesize a class dictionary for: " ++ pprint ty) {-# SCC trySynthTerm #-} +hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool +hasInferenceVars e = liftEnvReaderM $ anyInferenceVars $ freeAtomVarsList e +{-# INLINE hasInferenceVars #-} + +anyInferenceVars :: [CAtomName n] -> EnvReaderM n Bool +anyInferenceVars = \case + [] -> return False + (v:vs) -> isSolverName v >>= \case + True -> return True + False -> anyInferenceVars vs + type SynthAtom = CAtom type SynthPiType n = ([Explicitness], Abs (Nest CBinder) DictType n) data SynthType n = @@ -2699,38 +1970,13 @@ data SynthType n = data Givens n = Givens { fromGivens :: HM.HashMap (EKey SynthType n) (SynthAtom n) } -class (Alternative1 m, Searcher1 m, EnvReader m, EnvExtender m) - => Synther m where - getGivens :: m n (Givens n) - withGivens :: Givens n -> m n a -> m n a - -newtype SyntherM (n::S) (a:: *) = SyntherM - { runSyntherM' :: OutReaderT Givens (EnvReaderT []) n a } - deriving ( Functor, Applicative, Monad, EnvReader, EnvExtender - , ScopeReader, MonadFail - , Alternative, Searcher, OutReader Givens) - -instance Synther SyntherM where - getGivens = askOutReader - {-# INLINE getGivens #-} - withGivens givens cont = localOutReader givens cont - {-# INLINE withGivens #-} - -liftSyntherM :: EnvReader m => SyntherM n a -> m n [a] -liftSyntherM cont = - liftEnvReaderT do - initGivens <- givensFromEnv - runOutReaderT initGivens $ runSyntherM' cont -{-# INLINE liftSyntherM #-} - -givensFromEnv :: EnvReader m => m n (Givens n) -givensFromEnv = do - env <- withEnv moduleEnv - givens <- mapM toAtomVar $ lambdaDicts $ envSynthCandidates env - getSuperclassClosure (Givens HM.empty) (Var <$> givens) -{-# SCC givensFromEnv #-} +getGivens :: InfererM i o (Givens o) +getGivens = givens <$> getInfState + +withGivens :: Givens o -> InfererM i o a -> InfererM i o a +withGivens givens cont = withInfState (\s -> s { givens = givens }) cont -extendGivens :: Synther m => [SynthAtom n] -> m n a -> m n a +extendGivens :: [SynthAtom o] -> InfererM i o a -> InfererM i o a extendGivens newGivens cont = do prevGivens <- getGivens finalGivens <- getSuperclassClosure prevGivens newGivens @@ -2745,7 +1991,7 @@ typeAsSynthType :: CType n -> Except (SynthType n) typeAsSynthType = \case DictTy dictTy -> return $ SynthDictType dictTy Pi (CorePiType ImplicitApp expls bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (expls, Abs bs d) - ty -> Failure $ Errs [Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty] + ty -> Failure $ Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty {-# SCC typeAsSynthType #-} getSuperclassClosure :: EnvReader m => Givens n -> [SynthAtom n] -> m n (Givens n) @@ -2755,8 +2001,7 @@ getSuperclassClosure givens newGivens = do return $ getSuperclassClosurePure env givens newGivens {-# INLINE getSuperclassClosure #-} -getSuperclassClosurePure - :: Distinct n => Env n -> Givens n -> [SynthAtom n] -> Givens n +getSuperclassClosurePure :: Distinct n => Env n -> Givens n -> [SynthAtom n] -> Givens n getSuperclassClosurePure env givens newGivens = snd $ runState (runEnvReaderT env (mapM_ visitGiven newGivens)) givens where @@ -2788,20 +2033,22 @@ getSuperclassClosurePure env givens newGivens = forM (enumerate superclasses) \(i, ty) -> do return $ DictCon ty $ SuperclassProj synthExpr i -synthTerm :: SynthType n -> RequiredMethodAccess -> SyntherM n (SynthAtom n) +synthTerm :: SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of SynthPiType (expls, ab) -> do - ab' <- withGivenBinders expls ab \bs targetTy' -> do - Abs bs <$> synthTerm (SynthDictType targetTy') reqMethodAccess - Abs bs synthExpr <- return ab' - liftM Lam $ coreLamExpr ImplicitApp expls $ Abs bs $ PairE Pure (WithoutDecls synthExpr) + ab' <- withFreshBindersInf expls ab \bs' targetTy' -> do + Abs bs' <$> synthTerm (SynthDictType targetTy') reqMethodAccess + Abs bs' synthExpr <- return ab' + let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) + let lamExpr = LamExpr bs' (WithoutDecls synthExpr) + return $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n DictType "Data" _ [Type t] -> do - void (synthDictForData dictTy synthDictFromGiven dictTy) + void (synthDictForData dictTy <|> synthDictFromGiven dictTy) return $ DictCon (DictTy dictTy) $ DataData t _ -> do - dict <- synthDictFromInstance dictTy synthDictFromGiven dictTy + dict <- synthDictFromInstance dictTy <|> synthDictFromGiven dictTy case dict of DictCon _ (InstanceDict instanceName _) -> do isReqMethodAccessAllowed <- reqMethodAccess `isMethodAccessAllowedBy` instanceName @@ -2811,40 +2058,6 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of _ -> return dict {-# SCC synthTerm #-} -coreLamExpr :: EnvReader m => AppExplicitness - -> [Explicitness] -> Abs (Nest CBinder) (PairE (EffectRow CoreIR) CBlock) n - -> m n (CoreLamExpr n) -coreLamExpr appExpl expls ab = liftEnvReaderM do - refreshAbs ab \bs' (PairE effs' body') -> do - EffTy _ resultTy <- blockEffTy body' - return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') - -withGivenBinders - :: HasNamesE e => [Explicitness] -> Abs (Nest CBinder) e n - -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) - -> SyntherM n a -withGivenBinders explsTop (Abs bsTop e) contTop = - runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do - e' <- renameM e - liftSubstReaderT $ contTop bsTop' e' - where - go :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name SyntherM i' o' a) - -> SubstReaderT Name SyntherM i o a - go expls bs cont = case (expls, bs) of - ([], Empty) -> getDistinct >>= \Distinct -> cont Empty - (expl:explsRest, Nest b rest) -> do - argTy <- renameM $ binderType b - withFreshBinder (getNameHint b) argTy \b' -> do - givens <- case expl of - Inferred _ (Synth _) -> return [Var $ binderVar b'] - _ -> return [] - s <- getSubst - liftSubstReaderT $ extendGivens givens $ - runSubstReaderT (s <>> b@>binderName b') $ - go explsRest rest \rest' -> cont (Nest b' rest') - _ -> error "zip error" - isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool isMethodAccessAllowedBy access instanceName = do InstanceDef className _ _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName @@ -2855,7 +2068,7 @@ isMethodAccessAllowedBy access instanceName = do Full -> return $ numClassMethods == numInstanceMethods Partial numReqMethods -> return $ numReqMethods <= numInstanceMethods -synthDictFromGiven :: DictType n -> SyntherM n (SynthAtom n) +synthDictFromGiven :: DictType n -> InfererM i n (SynthAtom n) synthDictFromGiven targetTy = do givens <- ((HM.elems . fromGivens) <$> getGivens) asum $ givens <&> \given -> do @@ -2867,41 +2080,37 @@ synthDictFromGiven targetTy = do args <- instantiateSynthArgs targetTy givenPiTy return $ DictCon (DictTy targetTy) $ InstantiatedGiven given args -synthDictFromInstance :: DictType n -> SyntherM n (SynthAtom n) +synthDictFromInstance :: DictType n -> InfererM i n (SynthAtom n) synthDictFromInstance targetTy@(DictType _ targetClass _) = do instances <- getInstanceDicts targetClass - asum $ instances <&> \candidate -> do + asum $ instances <&> \candidate -> typeErrAsSearchFailure do CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) return $ DictCon (DictTy targetTy) $ InstanceDict candidate args -instantiateSynthArgs :: DictType n -> SynthPiType n -> SyntherM n [CAtom n] -instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do - ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do - args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) - zonk $ ListE args - forM args \case - DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req - arg -> return arg - where - go :: EmitsInf o - => DictType o -> [Explicitness] -> Abs (Nest CBinder) DictType i - -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] - go target allExpls (Abs bs proposed) = case (allExpls, bs) of - ([], Empty) -> do - proposed' <- substM proposed - liftSubstReaderT $ unify target proposed' - return [] - (expl:expls, Nest b rest) -> do - argTy <- substM $ binderType b - arg <- liftSubstReaderT case expl of - Explicit -> error "instances shouldn't have explicit args" - Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy - Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target expls (Abs rest proposed) - _ -> error "zip error" - -synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) +getInstanceDicts :: EnvReader m => ClassName n -> m n [InstanceName n] +getInstanceDicts name = do + env <- withEnv moduleEnv + return $ M.findWithDefault [] name $ instanceDicts $ envSynthCandidates env +{-# INLINE getInstanceDicts #-} + +instantiateSynthArgs :: DictType n -> SynthPiType n -> InfererM i n [CAtom n] +instantiateSynthArgs target (expls, synthPiTy) = do + liftM fromListE $ withReducibleEmissions "dict args" do + bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do + return [TypeConstraint (DictTy $ sink target) (DictTy resultTy)] + ListE <$> inferMixedArgs "dict" expls bsConstrained emptyMixedArgs + +emptyMixedArgs :: MixedArgs (CAtom n) +emptyMixedArgs = ([], []) + +typeErrAsSearchFailure :: InfererM i n a -> InfererM i n a +typeErrAsSearchFailure cont = cont `catchErr` \err@(Err errTy _ _) -> do + case errTy of + TypeErr -> empty + _ -> throwErr err + +synthDictForData :: forall i n. DictType n -> InfererM i n (SynthAtom n) synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of -- TODO Deduplicate vs CheckType.checkDataLike -- The "Var" case is different @@ -2922,10 +2131,12 @@ synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of _ -> notData where recur ty' = synthDictForData $ DictType "Data" dName [Type ty'] - recurBinder :: (RenameB b, BindsEnv b) => Abs b CType n -> SyntherM n (SynthAtom n) - recurBinder bAbs = refreshAbs bAbs \b' ty'' -> do - ans <- synthDictForData $ DictType "Data" (sink dName) [Type ty''] - return $ ignoreHoistFailure $ hoist b' ans + recurBinder :: Abs CBinder CType n -> InfererM i n (SynthAtom n) + recurBinder (Abs b body) = + withFreshBinderInf noHint Explicit (binderType b) \b' -> do + body' <- applyRename (b@>binderName b') body + ans <- synthDictForData $ DictType "Data" (sink dName) [Type body'] + return $ ignoreHoistFailure $ hoist b' ans notData = empty success = return $ DictCon (DictTy dictTy) $ DataData ty synthDictForData dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy @@ -2939,183 +2150,22 @@ instance GenericE Givens where instance SinkableE Givens where --- === Dictionary synthesis traversal === - -liftDictSynthTraverserM - :: EnvReader m - => DictSynthTraverserM n n a - -> m n (Except a) -liftDictSynthTraverserM m = do - (ans, LiftE errs) <- liftM runHardFail $ liftBuilderT $ - runStateT1 (runSubstReaderT idSubst $ runDictSynthTraverserM m) (LiftE $ Errs []) - return $ case errs of - Errs [] -> Success ans - _ -> Failure errs - -newtype DictSynthTraverserM i o a = - DictSynthTraverserM - { runDictSynthTraverserM :: - SubstReaderT Name (StateT1 (LiftE Errs) (BuilderM CoreIR)) i o a} - deriving (MonadFail, Fallible, Functor, Applicative, Monad, ScopeReader, - EnvReader, EnvExtender, Builder CoreIR, SubstReader Name, - ScopableBuilder CoreIR, MonadState (LiftE Errs o)) - -instance NonAtomRenamer (DictSynthTraverserM i o) i o where renameN = renameM -instance Visitor (DictSynthTraverserM i o) CoreIR i o where - visitType = dsTraverse - visitAtom = dsTraverse - visitPi = visitPiDefault - visitLam = visitLamNoEmits -instance ExprVisitorNoEmits (DictSynthTraverserM i o) CoreIR i o where - visitExprNoEmits = visitGeneric - -class DictSynthTraversable (e::E) where - dsTraverse :: e i -> DictSynthTraverserM i o (e o) - -instance DictSynthTraversable (TopLam CoreIR) where - dsTraverse (TopLam d ty lam) = TopLam d <$> visitPiDefault ty <*> visitLamNoEmits lam - -instance DictSynthTraversable CAtom where - dsTraverse atom = case atom of - DictHole (AlwaysEqual ctx) ty access -> do - ty' <- cheapNormalize =<< dsTraverse ty - ans <- liftEnvReaderT $ addSrcContext ctx $ trySynthTerm ty' access - case ans of - Failure errs -> put (LiftE errs) >> renameM atom - Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do - Pi piTy' <- dsTraverse $ Pi piTy - lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do - visitDeclsNoEmits decls \decls' -> do - LamExpr bsLam' <$> Abs decls' <$> dsTraverse result - return $ Lam $ CoreLamExpr piTy' lam' - Var _ -> renameM atom - SimpInCore _ -> renameM atom - ProjectElt _ _ _ -> renameM atom - _ -> visitAtomPartial atom - -instance DictSynthTraversable CType where - dsTraverse ty = case ty of - Pi (CorePiType appExpl expls bs (EffTy effs resultTy)) -> Pi <$> - dsTraverseExplBinders expls bs \bs' -> do - CorePiType appExpl expls bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) - TyVar _ -> renameM ty - ProjectEltTy _ _ _ -> renameM ty - _ -> visitTypePartial ty - -instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric - -dsTraverseExplBinders - :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> DictSynthTraverserM i' o' a) - -> DictSynthTraverserM i o a -dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (expl:expls) (Nest b bs) cont = do - ty <- dsTraverse $ binderType b - withFreshBinder (getNameHint b) ty \b' -> do - let v = binderName b' - extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders expls bs \bs' -> cont $ Nest b' bs' -dsTraverseExplBinders _ _ _ = error "zip error" - -extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a -extendSynthCandidatesDict c v cont = DictSynthTraverserM do - SubstReaderT $ ReaderT \env -> StateT1 \s -> BuilderT do - extendInplaceTLocal (extendSynthCandidates c v) $ runBuilderT' $ - runStateT1 (runSubstReaderT env $ runDictSynthTraverserM $ cont) s -{-# INLINE extendSynthCandidatesDict #-} - -- === Inference-specific builder patterns === --- The higher-order functions in Builder, like `buildLam` can't be easily used --- in inference because they don't allow for the emission of inference --- variables, which must be handled each time we leave a scope. In an earlier --- version we tried to put this logic in the implementation of InfererM's --- instance of Builder, but it forced us to overfit the Builder API to satisfy --- the needs of inference, like adding `SubstE AtomSubstVal e` constraints in --- various places. - type WithExpl = WithAttrB Explicitness type WithRoleExpl = WithAttrB RoleExpl -buildBlockInf - :: EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) - -> InfererM i n (CBlock n) -buildBlockInf cont = do - Abs decls (PairE result ty) <- buildDeclsInf do - ans <- cont - ty <- cheapNormalize $ getType ans - return $ PairE ans ty - let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line - <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line - void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty - return $ Abs decls result -{-# INLINE buildBlockInf #-} - buildBlockInfWithRecon - :: (EmitsInf n, HasNamesE e) - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) + :: HasNamesE e + => (forall l. (Emits l, DExt n l) => InfererM i l (e l)) -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do - ab <- buildDeclsInfUnzonked cont - (block, recon) <- refreshAbs ab \decls result -> do + ab <- buildScoped cont + liftEnvReaderM $ liftM toPairE $ refreshAbs ab \decls result -> do (newResult, recon) <- telescopicCapture decls result return (Abs decls newResult, recon) - return $ PairE block recon {-# INLINE buildBlockInfWithRecon #-} -buildTabPiInf - :: EmitsInf n - => NameHint -> IxType CoreIR n - -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) - -> InfererM i n (TabPiType CoreIR n) -buildTabPiInf hint (IxType t d) body = do - Abs b resultTy <- buildAbsInf hint Explicit t \v -> withoutEffects $ body v - return $ TabPiType d b resultTy - -buildDepPairTyInf - :: EmitsInf n - => NameHint -> DepPairExplicitness -> CType n - -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) - -> InfererM i n (DepPairType CoreIR n) -buildDepPairTyInf hint expl ty body = do - Abs b resultTy <- buildAbsInf hint Explicit ty body - return $ DepPairType expl b resultTy - -buildAltInf - :: EmitsInf n - => CType n - -> (forall l. (EmitsBoth l, Ext n l) => CAtomVar l -> InfererM i l (CAtom l)) - -> InfererM i n (Alt CoreIR n) -buildAltInf ty body = do - buildAbsInf noHint Explicit ty \v -> - buildBlockInf do - Distinct <- getDistinct - body $ sink v - --- === EmitsInf predicate === - -type EmitsBoth n = (EmitsInf n, Emits n) - -class Mut n => EmitsInf (n::S) -data EmitsInfEvidence (n::S) where - EmitsInf :: EmitsInf n => EmitsInfEvidence n -instance EmitsInf UnsafeS - -fabricateEmitsInfEvidence :: forall n. EmitsInfEvidence n -fabricateEmitsInfEvidence = withFabricatedEmitsInf @n EmitsInf - -fabricateEmitsInfEvidenceM :: forall m n. Monad1 m => m n (EmitsInfEvidence n) -fabricateEmitsInfEvidenceM = return fabricateEmitsInfEvidence - -withFabricatedEmitsInf :: forall n a. (EmitsInf n => a) -> a -withFabricatedEmitsInf cont = fromWrapWithEmitsInf - ( TrulyUnsafe.unsafeCoerce ( WrapWithEmitsInf cont :: WrapWithEmitsInf n a - ) :: WrapWithEmitsInf UnsafeS a) -newtype WrapWithEmitsInf n r = - WrapWithEmitsInf { fromWrapWithEmitsInf :: EmitsInf n => r } - -- === IFunType === asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n)) @@ -3155,6 +2205,49 @@ checkScalarOrPairType ty = throw TypeErr $ pprint ty -- === instances === +instance DiffStateE SolverSubst SolverDiff where + updateDiffStateE :: forall n. Distinct n => Env n -> SolverSubst n -> SolverDiff n -> SolverSubst n + updateDiffStateE _ initState (SolverDiff (RListE diffs)) = foldl update initState (unsnoc diffs) + where + update :: Distinct n => SolverSubst n -> Solution n -> SolverSubst n + update (SolverSubst subst) (PairE v x) = SolverSubst $ M.insert v x subst + +instance SinkableE InfState where sinkingProofE _ = todoSinkableProof + +instance GenericE SigmaAtom where + type RepE SigmaAtom = EitherE3 (LiftE (Maybe SourceName) `PairE` CAtom) + (LiftE SourceName `PairE` CType `PairE` UVar) + (CType `PairE` CAtom `PairE` ListE CAtom) + fromE = \case + SigmaAtom x y -> Case0 $ LiftE x `PairE` y + SigmaUVar x y z -> Case1 $ LiftE x `PairE` y `PairE` z + SigmaPartialApp x y z -> Case2 $ x `PairE` y `PairE` ListE z + {-# INLINE fromE #-} + + toE = \case + Case0 (LiftE x `PairE` y) -> SigmaAtom x y + Case1 (LiftE x `PairE` y `PairE` z) -> SigmaUVar x y z + Case2 (x `PairE` y `PairE` ListE z) -> SigmaPartialApp x y z + _ -> error "impossible" + {-# INLINE toE #-} + +instance RenameE SigmaAtom +instance HoistableE SigmaAtom +instance SinkableE SigmaAtom + +instance SubstE AtomSubstVal SigmaAtom where + substE env (SigmaAtom sn x) = SigmaAtom sn $ substE env x + substE env (SigmaUVar sn ty uvar) = case uvar of + UAtomVar v -> substE env $ SigmaAtom (Just sn) $ Var (AtomVar v ty) + UTyConVar v -> SigmaUVar sn ty' $ UTyConVar $ substE env v + UDataConVar v -> SigmaUVar sn ty' $ UDataConVar $ substE env v + UPunVar v -> SigmaUVar sn ty' $ UPunVar $ substE env v + UClassVar v -> SigmaUVar sn ty' $ UClassVar $ substE env v + UMethodVar v -> SigmaUVar sn ty' $ UMethodVar $ substE env v + where ty' = substE env ty + substE env (SigmaPartialApp ty f xs) = + SigmaPartialApp (substE env ty) (substE env f) (map (substE env) xs) + instance PrettyE e => Pretty (UDeclInferenceResult e l) where pretty = \case UDeclResultDone e -> pretty e @@ -3172,34 +2265,6 @@ instance (RenameE e, CheckableE CoreIR e) => CheckableE CoreIR (UDeclInferenceRe UDeclResultBindPattern hint block recon -> UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon -instance HasType CoreIR InfEmission where - getType = \case - LeftE (DeclBinding _ e) -> getType e - RightE b -> case b of - InfVarBound t _ -> t - SkolemBound t -> t - -instance (Monad m, ExtOutMap InfOutMap decls, OutFrag decls) - => EnvReader (InplaceT InfOutMap decls m) where - unsafeGetEnv = do - InfOutMap env _ _ _ _ <- getOutMapInplaceT - return env - -instance (Monad m, ExtOutMap InfOutMap decls, OutFrag decls) - => EnvExtender (InplaceT InfOutMap decls m) where - refreshAbs ab cont = UnsafeMakeInplaceT \env decls -> - refreshAbsPure (toScope env) ab \_ b e -> do - let subenv = extendOutMap env $ toEnvFrag b - (ans, d, _) <- unsafeRunInplaceT (cont b e) subenv emptyOutFrag - case fabricateDistinctEvidence @UnsafeS of - Distinct -> do - let env' = extendOutMap (unsafeCoerceE env) d - return (ans, catOutFrags decls d, env') - {-# INLINE refreshAbs #-} - -instance BindsEnv InfOutFrag where - toEnvFrag (InfOutFrag frag _ _) = toEnvFrag frag - instance GenericE SynthType where type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType)) fromE (SynthDictType d) = Case0 d @@ -3215,6 +2280,74 @@ instance HoistableE SynthType instance RenameE SynthType instance SubstE AtomSubstVal SynthType +instance GenericE Constraint where + type RepE Constraint = EitherE + (PairE CType CType) + (PairE (EffectRow CoreIR) (EffectRow CoreIR)) + fromE (TypeConstraint t1 t2) = LeftE (PairE t1 t2) + fromE (EffectConstraint e1 e2) = RightE (PairE e1 e2) + {-# INLINE fromE #-} + toE (LeftE (PairE t1 t2)) = TypeConstraint t1 t2 + toE (RightE (PairE e1 e2)) = EffectConstraint e1 e2 + {-# INLINE toE #-} + +instance SinkableE Constraint +instance HoistableE Constraint +instance (SubstE AtomSubstVal) Constraint + +instance GenericE RequiredTy where + type RepE RequiredTy = EitherE CType UnitE + fromE (Check ty) = LeftE ty + fromE Infer = RightE UnitE + {-# INLINE fromE #-} + toE (LeftE ty) = Check ty + toE (RightE UnitE) = Infer + {-# INLINE toE #-} + +instance SinkableE RequiredTy +instance HoistableE RequiredTy +instance AlphaEqE RequiredTy +instance RenameE RequiredTy + +instance GenericE PartialType where + type RepE PartialType = EitherE PartialPiType CType + fromE (PartialType ty) = LeftE ty + fromE (FullType ty) = RightE ty + {-# INLINE fromE #-} + toE (LeftE ty) = PartialType ty + toE (RightE ty) = FullType ty + {-# INLINE toE #-} + +instance SinkableE PartialType +instance HoistableE PartialType +instance AlphaEqE PartialType +instance RenameE PartialType + +instance GenericE SolverSubst where + -- XXX: this is a bit sketchy because it's not actually bijective... + type RepE SolverSubst = ListE (PairE CAtomName CAtom) + fromE (SolverSubst m) = ListE $ map (uncurry PairE) $ M.toList m + {-# INLINE fromE #-} + toE (ListE pairs) = SolverSubst $ M.fromList $ map fromPairE pairs + {-# INLINE toE #-} + +instance SinkableE SolverSubst where +instance RenameE SolverSubst where +instance HoistableE SolverSubst + +instance GenericE PartialPiType where + type RepE PartialPiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) + (EffectRow CoreIR `PairE` RequiredTy) + fromE (PartialPiType ex exs b eff ty) = LiftE (ex, exs) `PairE` Abs b (PairE eff ty) + {-# INLINE fromE #-} + toE (LiftE (ex, exs) `PairE` Abs b (PairE eff ty)) = PartialPiType ex exs b eff ty + {-# INLINE toE #-} + +instance SinkableE PartialPiType +instance HoistableE PartialPiType +instance AlphaEqE PartialPiType +instance RenameE PartialPiType + -- See Note [Confuse GHC] from Simplify.hs confuseGHC :: EnvReader m => m n (DistinctEvidence n) confuseGHC = getDistinct diff --git a/src/lib/Inference.hs-boot b/src/lib/Inference.hs-boot deleted file mode 100644 index a8f219389..000000000 --- a/src/lib/Inference.hs-boot +++ /dev/null @@ -1,14 +0,0 @@ --- Copyright 2021 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Inference (trySynthTerm) where - -import Core -import Name -import Types.Core -import Types.Primitives (RequiredMethodAccess) - -trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (CAtom n) diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index 56fb1cdba..d6c6f8a9d 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -6,16 +6,7 @@ {-# LANGUAGE UndecidableInstances #-} -module MTL1 ( - MonadTrans11 (..), HoistableState (..), - WriterT1, pattern WriterT1, runWriterT1, runWriterT1From, - StateT1, pattern StateT1, runStateT1, evalStateT1, MonadState1, - MaybeT1 (..), runMaybeT1, ReaderT1 (..), runReaderT1, - ScopedT1, pattern ScopedT1, runScopedT1, - FallibleT1, runFallibleT1, - runStreamWriterT1, StreamWriter (..), StreamWriterT1 (..), - runStreamReaderT1, StreamReader (..), StreamReaderT1 (..), - ) where +module MTL1 where import Control.Monad.Reader import Control.Monad.Writer.Class @@ -27,6 +18,7 @@ import Data.Foldable (toList) import Name import Err +import Types.Core (Env) import Core (EnvReader (..), EnvExtender (..)) import Util (SnocList (..), snoc, emptySnocList) @@ -117,6 +109,14 @@ deriving instance MonadWriter s (m n) => MonadWriter s (ReaderT1 r m n) deriving instance MonadState s (m n) => MonadState s (ReaderT1 r m n) +instance (Monad1 m, Alternative1 m) => Alternative ((ReaderT1 r m) n) where + empty = lift11 empty + {-# INLINE empty #-} + ReaderT1 (ReaderT m1) <|> ReaderT1 (ReaderT m2) = + ReaderT1 $ ReaderT \r -> m1 r <|> m2 r + {-# INLINE (<|>) #-} + + instance (SinkableE r, EnvReader m) => EnvReader (ReaderT1 r m) where unsafeGetEnv = lift11 unsafeGetEnv {-# INLINE unsafeGetEnv #-} @@ -136,7 +136,7 @@ instance (SinkableE r, EnvExtender m) => EnvExtender (ReaderT1 r m) where refreshAbs ab \b e -> runReaderT1 (sink r) $ cont b e instance (Monad1 m, Fallible (m n)) => Fallible (ReaderT1 r m n) where - throwErrs = lift11 . throwErrs + throwErr = lift11 . throwErr addErrCtx ctx (ReaderT1 m) = ReaderT1 $ addErrCtx ctx m {-# INLINE addErrCtx #-} @@ -193,7 +193,7 @@ instance (SinkableE s, ScopeReader m) => ScopeReader (StateT1 s m) where {-# INLINE getDistinct #-} instance (Monad1 m, Fallible (m n)) => Fallible (StateT1 s m n) where - throwErrs = lift11 . throwErrs + throwErr = lift11 . throwErr addErrCtx ctx (WrapStateT1 m) = WrapStateT1 $ addErrCtx ctx m {-# INLINE addErrCtx #-} @@ -204,6 +204,12 @@ instance (Monad1 m, CtxReader (m n)) => CtxReader (StateT1 s m n) where getErrCtx = lift11 getErrCtx {-# INLINE getErrCtx #-} +instance (Monad1 m, Alternative1 m) => Alternative ((StateT1 s m) n) where + empty = lift11 empty + {-# INLINE empty #-} + StateT1 m1 <|> StateT1 m2 = StateT1 \s -> m1 s <|> m2 s + {-# INLINE (<|>) #-} + class HoistableState (s::E) where hoistState :: BindsNames b => s n -> b n l -> s l -> s n @@ -279,7 +285,7 @@ instance Monad (m n) => MonadFail (MaybeT1 m n) where {-# INLINE fail #-} instance Monad (m n) => Fallible (MaybeT1 m n) where - throwErrs _ = empty + throwErr _ = empty addErrCtx _ cont = cont {-# INLINE addErrCtx #-} @@ -300,7 +306,7 @@ instance EnvExtender m => EnvExtender (MaybeT1 m) where -------------------- FallibleT1 -------------------- newtype FallibleT1 (m::MonadKind1) (n::S) a = - FallibleT1 { fromFallibleT :: ReaderT ErrCtx (MTE.ExceptT Errs (m n)) a } + FallibleT1 { fromFallibleT :: ReaderT ErrCtx (MTE.ExceptT Err (m n)) a } deriving (Functor, Applicative, Monad) runFallibleT1 :: Monad1 m => FallibleT1 m n a -> m n (Except a) @@ -315,8 +321,8 @@ instance Monad1 m => MonadFail (FallibleT1 m n) where {-# INLINE fail #-} instance Monad1 m => Fallible (FallibleT1 m n) where - throwErrs (Errs errs) = FallibleT1 $ ReaderT \ambientCtx -> - MTE.throwE $ Errs [Err errTy (ambientCtx <> ctx) s | Err errTy ctx s <- errs] + throwErr (Err errTy ctx s) = FallibleT1 $ ReaderT \ambientCtx -> + MTE.throwE $ Err errTy (ambientCtx <> ctx) s addErrCtx ctx (FallibleT1 m) = FallibleT1 $ local (<> ctx) m {-# INLINE addErrCtx #-} @@ -370,3 +376,90 @@ runStreamReaderT1 rs m = do (ans, LiftE rsRemaining) <- runStateT1 (runStreamReaderT1' m) (LiftE rs) return (ans, rsRemaining) {-# INLINE runStreamReaderT1 #-} + +-------------------- DiffState -------------------- + +class MonoidE (d::E) where + emptyE :: d n + catE :: d n -> d n -> d n + +class MonoidE d => DiffStateE (s::E) (d::E) where + updateDiffStateE :: Distinct n => Env n -> s n -> d n -> s n + +newtype DiffStateT1 (s::E) (d::E) (m::MonadKind1) (n::S) (a:: *) = + DiffStateT1' { runDiffStateT1'' :: StateT (s n, d n) (m n) a } + deriving ( Functor, Applicative, Monad, MonadFail, MonadIO + , Fallible, Catchable, CtxReader) + +pattern DiffStateT1 :: ((s n, d n) -> m n (a, (s n, d n))) -> DiffStateT1 s d m n a +pattern DiffStateT1 cont = DiffStateT1' (StateT cont) + +diffStateT1 + :: (EnvReader m, DiffStateE s d, MonoidE d) + => (s n -> m n (a, d n)) -> DiffStateT1 s d m n a +diffStateT1 cont = DiffStateT1 \(s, d) -> do + (ans, d') <- cont s + env <- unsafeGetEnv + Distinct <- getDistinct + return (ans, (updateDiffStateE env s d', catE d d')) +{-# INLINE diffStateT1 #-} + +runDiffStateT1 + :: (EnvReader m, DiffStateE s d, MonoidE d) + => s n -> DiffStateT1 s d m n a -> m n (a, d n) +runDiffStateT1 s (DiffStateT1' (StateT cont)) = do + (ans, (_, d)) <- cont (s, emptyE) + return (ans, d) +{-# INLINE runDiffStateT1 #-} + +class (Monad1 m, MonoidE d) + => MonadDiffState1 (m::MonadKind1) (s::E) (d::E) | m -> s, m -> d where + withDiffState :: s n -> m n a -> m n (a, d n) + updateDiffStateM :: d n -> m n () + getDiffState :: m n (s n) + +instance (EnvReader m, DiffStateE s d, MonoidE d) => MonadDiffState1 (DiffStateT1 s d m) s d where + getDiffState = DiffStateT1' $ fst <$> get + {-# INLINE getDiffState #-} + + withDiffState s cont = DiffStateT1' do + (sOld, dOld) <- get + put (s, emptyE) + ans <- runDiffStateT1'' cont + (_, dLocal) <- get + put (sOld, dOld) + return (ans, dLocal) + {-# INLINE withDiffState #-} + + updateDiffStateM d = DiffStateT1' do + (s, d') <- get + env <- lift unsafeGetEnv + Distinct <- lift getDistinct + put (updateDiffStateE env s d, catE d d') + {-# INLINE updateDiffStateM #-} + +instance MonoidE (ListE e) where + emptyE = mempty + catE = (<>) + +instance MonoidE (RListE e) where + emptyE = mempty + catE = (<>) + +instance (Monad1 m, Alternative1 m, MonoidE d) => Alternative ((DiffStateT1 s d m) n) where + empty = DiffStateT1' $ StateT \_ -> empty + {-# INLINE empty #-} + DiffStateT1' (StateT m1) <|> DiffStateT1' (StateT m2) = DiffStateT1' $ StateT \s -> + m1 s <|> m2 s + {-# INLINE (<|>) #-} + +instance (ScopeReader m, MonoidE d) => ScopeReader (DiffStateT1 s d m) where + unsafeGetScope = lift11 unsafeGetScope + getDistinct = lift11 getDistinct + +instance (EnvReader m, MonoidE d) => EnvReader (DiffStateT1 s d m) where + unsafeGetEnv = lift11 unsafeGetEnv + +instance MonadTrans11 (DiffStateT1 s d) where + lift11 m = DiffStateT1' $ lift m + {-# INLINE lift11 #-} diff --git a/src/lib/Name.hs b/src/lib/Name.hs index cdd5aa3e5..bf12eec1f 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -43,7 +43,7 @@ import qualified Unsafe.Coerce as TrulyUnsafe import RawName ( RawNameMap, RawName, NameHint, HasNameHint (..) , freshRawName, rawNameFromHint, rawNames, noHint) import qualified RawName as R -import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..) ) +import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..), unsnoc ) import Err import IRVariants @@ -247,6 +247,17 @@ instance Color c => RenameB (NameBinder c) where _ -> sink env <>> b @> (fromName $ binderName b') cont (scope', env') b' +-- === E-kinded functor === + +class FunctorE (f::E -> E) where + fmapE :: (forall l. e l -> e' l) -> f e n -> f e' n + +instance FunctorE ListE where + fmapE f (ListE xs) = ListE (fmap f xs) + +instance FunctorE (Abs b) where + fmapE f (Abs b e) = Abs b (f e) + -- === monadic type classes for reading and extending envs and scopes === data WithScope (e::E) (n::S) where @@ -262,6 +273,10 @@ class Monad1 m => ScopeReader (m::MonadKind1) where unsafeGetScope :: m n (Scope n) getDistinct :: m n (DistinctEvidence n) +withDistinct :: ScopeReader m => (Distinct n => m n a) -> m n a +withDistinct cont = getDistinct >>= \Distinct -> cont +{-# INLINE withDistinct #-} + class ScopeReader m => ScopeExtender (m::MonadKind1) where -- We normally use the EnvReader version, `refreshAbs`, but sometime we're -- working with raw binders that don't have env information associated with @@ -470,6 +485,9 @@ forgetEitherE (RightE x) = x newtype ListE (e::E) (n::S) = ListE { fromListE :: [e n] } deriving (Show, Eq, Generic) +newtype RListE (e::E) (n::S) = RListE { fromRListE :: (SnocList (e n)) } + deriving (Show, Eq, Generic) + newtype MapE (k::E) (v::E) (n::S) = MapE { fromMapE :: M.Map (k n) (v n) } deriving (Semigroup, Monoid) @@ -525,6 +543,9 @@ data WithAttrB (a:: *) (b::B) (n::S) (l::S) = WithAttrB {getAttr :: a , withoutAttr :: b n l } deriving (Show, Generic) +pattern ZipB :: [a] -> Nest b n l -> Nest (WithAttrB a b) n l +pattern ZipB attrs bs <- (unzipAttrs -> (attrs, bs)) + unzipAttrs :: Nest (WithAttrB a b) n l -> ([a], Nest b n l) unzipAttrs Empty = ([], Empty) unzipAttrs (Nest (WithAttrB a b) rest) = (a:as, Nest b bs) @@ -860,9 +881,6 @@ type MonadIO2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadIO (m n l) type Catchable1 (m :: MonadKind1) = forall (n::S) . Catchable (m n ) type Catchable2 (m :: MonadKind2) = forall (n::S) (l::S) . Catchable (m n l) -type Searcher1 (m :: MonadKind1) = forall (n::S) . Searcher (m n ) -type Searcher2 (m :: MonadKind2) = forall (n::S) (l::S) . Searcher (m n l) - type CtxReader1 (m :: MonadKind1) = forall (n::S) . CtxReader (m n ) type CtxReader2 (m :: MonadKind2) = forall (n::S) (l::S) . CtxReader (m n l) @@ -1316,12 +1334,6 @@ instance (Monad1 m, Alternative (m n)) => Alternative (OutReaderT e m n) where f1 env <|> f2 env {-# INLINE (<|>) #-} -instance Searcher1 m => Searcher (OutReaderT e m n) where - OutReaderT (ReaderT f1) OutReaderT (ReaderT f2) = - OutReaderT $ ReaderT \env -> - f1 env f2 env - {-# INLINE () #-} - instance MonadWriter w (m n) => MonadWriter w (OutReaderT e m n) where tell w = OutReaderT $ lift $ tell w {-# INLINE tell #-} @@ -1549,7 +1561,7 @@ instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, Fallible m) => Fallible (InplaceT bindings decls m n) where - throwErrs errs = UnsafeMakeInplaceT \_ _ -> throwErrs errs + throwErr errs = UnsafeMakeInplaceT \_ _ -> throwErr errs addErrCtx ctx cont = UnsafeMakeInplaceT \env decls -> addErrCtx ctx $ unsafeRunInplaceT cont env decls {-# INLINE addErrCtx #-} @@ -1567,13 +1579,6 @@ instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m f1 env decls <|> f2 env decls {-# INLINE (<|>) #-} -instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, - Monad m, Alternative m, Searcher m) - => Searcher (InplaceT bindings decls m n) where - UnsafeMakeInplaceT f1 UnsafeMakeInplaceT f2 = UnsafeMakeInplaceT \env decls -> - f1 env decls f2 env decls - {-# INLINE () #-} - instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Catchable m) => Catchable (InplaceT bindings decls m n) where @@ -2004,23 +2009,40 @@ instance (SinkableE k, SinkableE v, OrdE k) => SinkableE (MapE k v) where itemsE = ListE $ toPairE <$> M.toList m newItems = fromPairE <$> (fromListE $ sinkingProofE fresh itemsE) -instance SinkableE e => SinkableE (ListE e) where - sinkingProofE fresh (ListE xs) = ListE $ map (sinkingProofE fresh) xs - instance SinkableE e => SinkableE (NonEmptyListE e) where sinkingProofE fresh (NonEmptyListE xs) = NonEmptyListE $ fmap (sinkingProofE fresh) xs +instance SinkableE e => SinkableE (ListE e) where + sinkingProofE fresh (ListE xs) = ListE $ map (sinkingProofE fresh) xs + instance AlphaEqE e => AlphaEqE (ListE e) where alphaEqE (ListE xs) (ListE ys) | length xs == length ys = mapM_ (uncurry alphaEqE) (zip xs ys) | otherwise = zipErr instance Monoid (ListE e n) where - mempty = ListE [] + mempty = ListE mempty instance Semigroup (ListE e n) where ListE xs <> ListE ys = ListE $ xs <> ys +instance SinkableE e => SinkableE (RListE e) where + sinkingProofE fresh (RListE xs) = RListE $ fmap (sinkingProofE fresh) xs + +instance RenameE e => RenameE (RListE e) where + renameE env (RListE xs) = RListE $ fmap (renameE env) xs + +instance AlphaEqE e => AlphaEqE (RListE e) where + alphaEqE (RListE xs) (RListE ys) + | length xs == length ys = mapM_ (uncurry alphaEqE) (zip (fromReversedList xs) (fromReversedList ys)) + | otherwise = zipErr + +instance Monoid (RListE e n) where + mempty = RListE mempty + +instance Semigroup (RListE e n) where + RListE xs <> RListE ys = RListE $ xs <> ys + instance (EqE k, HashableE k) => GenericE (HashMapE k v) where type RepE (HashMapE k v) = ListE (PairE k v) fromE (HashMapE m) = ListE $ map (uncurry PairE) $ HM.toList m @@ -2149,6 +2171,9 @@ instance (PrettyE e1, PrettyE e2) => Pretty (EitherE e1 e2 n) where instance PrettyE e => Pretty (ListE e n) where pretty (ListE e) = pretty e +instance PrettyE e => Pretty (RListE e n) where + pretty (RListE e) = pretty $ unsnoc e + instance ( Generic (b UnsafeS UnsafeS) , Generic (body UnsafeS) ) => Generic (Abs b body n) where @@ -2744,6 +2769,10 @@ ignoreHoistFailure :: HasCallStack => HoistExcept a -> a ignoreHoistFailure (HoistSuccess x) = x ignoreHoistFailure (HoistFailure _) = error "hoist failure" +-- TODO: make this a no-op in the non-debug build +hardHoist :: (HasCallStack, BindsNames b, HoistableE e) => b n l -> e l -> e n +hardHoist b e = ignoreHoistFailure $ hoist b e + hoist :: (BindsNames b, HoistableE e) => b n l -> e l -> HoistExcept (e n) hoist b e = case R.disjoint fvs frag of @@ -2888,6 +2917,9 @@ instance HoistableB UnitB where instance HoistableE e => HoistableE (ListE e) where freeVarsE (ListE xs) = foldMap freeVarsE xs +instance HoistableE e => HoistableE (RListE e) where + freeVarsE (RListE xs) = foldMap freeVarsE xs + -- === environments === -- The `Subst` type is purely an optimization. We could do everything using diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 23bc7ea60..e223e2155 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -233,9 +233,6 @@ instance IRRep r => PrettyPrec (DepPairType r n) where prettyPrec (DepPairType _ b rhs) = atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [p b, p rhs] -instance Pretty (EffectOpType n) where - pretty (EffectOpType pol ty) = "[" <+> p pol <+> ":" <+> p ty <+> "]" - instance Pretty (CoreLamExpr n) where pretty (CoreLamExpr _ lam) = p lam @@ -254,7 +251,6 @@ instance IRRep r => PrettyPrec (Atom r n) where ProjectElt _ idxs v -> atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v NewtypeCon con x -> prettyPrecNewtype con x SimpInCore x -> prettyPrec x - DictHole _ e _ -> atPrec LowestPrec $ "synthesize" <+> pApp e TypeAsAtom ty -> prettyPrec ty instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec @@ -399,6 +395,7 @@ instance Pretty IxMethod where instance Pretty (SolverBinding n) where pretty (InfVarBound ty _) = "Inference variable of type:" <+> p ty pretty (SkolemBound ty ) = "Skolem variable of type:" <+> p ty + pretty (DictBound ty ) = "Dictionary variable of type:" <+> p ty instance Pretty (Binding c n) where pretty b = case b of @@ -623,10 +620,10 @@ instance Pretty (UAlt n) where pretty (UAlt pat body) = p pat <+> "->" <+> p body instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm (_, bs) dataCons) bTyCon bDataCons) = + pretty (UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons) = "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm (_, bs) fields defs)) = + pretty (UStructDecl bTyCon (UStructDef nm bs fields defs)) = "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 (prettyLines fields <> prettyLines defs) pretty (UInterface params methodTys interfaceName methodNames) = @@ -641,14 +638,6 @@ instance Pretty (UTopDecl n l) where pretty (UInstance className bs params methods (LeftB v) _) = "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params <> prettyLines methods - pretty (UEffectDecl opTys effName opNames) = - "effect" <+> p effName <> hardline <> foldMap (<>hardline) ops - where ops = [ p pol <+> p b <> ":" <> p (unsafeCoerceE ty) - | (b, UEffectOpType pol ty) <- zip (toList $ fromNest opNames) opTys] - pretty (UHandlerDecl effName bodyTyArg tyArgs retEff retTy opDefs name) = - "handler" <+> p name <+> "of" <+> p effName <+> p bodyTyArg <+> p tyArgs - <+> ":" <+> p retEff <+> p retTy <> hardline - <> foldMap ((<>hardline) . p) opDefs pretty (ULocalDecl decl) = p decl instance Pretty (UDecl' n l) where @@ -657,15 +646,6 @@ instance Pretty (UDecl' n l) where pretty (UExprDecl expr) = p expr pretty UPass = "pass" -instance Pretty (UEffectOpDef n) where - pretty (UEffectOpDef rp n body) = p rp <+> p n <+> "=" <+> p body - pretty (UReturnOpDef body) = "return =" <+> p body - -instance Pretty UResumePolicy where - pretty UNoResume = "jmp" - pretty ULinearResume = "def" - pretty UAnyResume = "ctl" - instance Pretty (UEffectRow n) where pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (p <$> toList x) pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (p <$> toList x)) <+> "|" <+> p y <> "}" @@ -676,15 +656,10 @@ prettyBinderNest bs = nest 6 $ line' <> (sep $ map p $ fromNest bs) instance Pretty (UDataDefTrail n) where pretty (UDataDefTrail bs) = p $ fromNest bs -instance Pretty (UAnnBinder req n l) where - pretty (UAnnBinder b ty cs) = p b <> ":" <> p ty <> printConstraints cs - -printConstraints :: Pretty a => [a] -> Doc ann -printConstraints = \case - [] -> mempty - c:cs -> "|" <> pretty c <> printConstraints cs +instance Pretty (UAnnBinder n l) where + pretty (UAnnBinder _ b ty _) = p b <> ":" <> p ty -instance Pretty (UAnn req n) where +instance Pretty (UAnn n) where pretty (UAnn ty) = ":" <> p ty pretty UNoAnn = mempty @@ -716,9 +691,7 @@ instance Pretty (Cache n) where pretty (Cache _ _ _ _ _ _) = "" -- TODO instance Pretty (SynthCandidates n) where - pretty scs = - "lambda dicts:" <+> p (lambdaDicts scs) <> hardline - <> "instance dicts:" <+> p (M.toList $ instanceDicts scs) + pretty scs = "instance dicts:" <+> p (M.toList $ instanceDicts scs) instance Pretty (LoadedModules n) where pretty _ = "" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 50a976816..031d688b8 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -205,8 +205,6 @@ getUVarType = \case ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind UMethodVar v -> getMethodNameType v - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case @@ -427,4 +425,4 @@ checkExtends allowed (EffectRow effs effTail) = do forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ "\nAllowed: " ++ pprint allowed - +{-# INLINE checkExtends #-} diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 9be267241..258b5f5c0 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -30,6 +30,7 @@ instance IRRep r => HasType r (AtomBinding r) where MiscBound ty -> ty SolverBound (InfVarBound ty _) -> ty SolverBound (SkolemBound ty) -> ty + SolverBound (DictBound ty) -> ty NoinlineFun ty _ -> ty TopDataBound (RepVal ty _) -> ty FFIFunBound piTy _ -> Pi piTy @@ -80,7 +81,6 @@ instance IRRep r => HasType r (Atom r) where RepValAtom (RepVal ty _) -> ty ProjectElt t _ _ -> t SimpInCore x -> getType x - DictHole _ ty _ -> ty TypeAsAtom ty -> getType ty instance IRRep r => HasType r (Type r) where diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 86e395e80..c3d46e4a8 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -98,7 +98,6 @@ tryAsDataAtom atom = do Lam _ -> notData DictCon _ _ -> notData Eff _ -> notData - DictHole _ _ _ -> notData TypeAsAtom _ -> notData where notData = error $ "Not runtime-representable data: " ++ pprint atom @@ -629,7 +628,6 @@ simplifyAtom atom = confuseGHC >>= \_ -> case atom of Eff eff -> Eff <$> substM eff PtrVar t v -> PtrVar t <$> substM v DictCon t d -> (DictCon <$> substM t <*> substM d) >>= cheapNormalize - DictHole _ _ _ -> error "shouldn't have dict holes past inference" NewtypeCon _ _ -> substM atom ProjectElt _ i x -> normalizeProj i =<< simplifyAtom x SimpInCore _ -> substM atom diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 3ee3b13b1..f9c4abcd4 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -148,13 +148,6 @@ instance SourceRenamableE (SourceNameOr (Name ClassNameC)) where _ -> throw TypeErr $ "Not a class name: " ++ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" -instance SourceRenamableE (SourceNameOr (Name EffectNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UEffectVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not an effect name: " ++ pprint sourceName - sourceRenameE _ = error "Shouldn't be source-renaming internal names" - instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrInternalName c) where sourceRenameE (SourceOrInternalName x) = SourceOrInternalName <$> sourceRenameE x @@ -164,25 +157,24 @@ instance (SourceRenamableE e, SourceRenamableB b) => SourceRenamableE (Abs b e) instance SourceRenamableB (UBinder (AtomNameC CoreIR)) where sourceRenameB b cont = sourceRenameUBinder UAtomVar b cont -instance SourceRenamableE (UAnn req) where +instance SourceRenamableE UAnn where sourceRenameE UNoAnn = return UNoAnn sourceRenameE (UAnn ann) = UAnn <$> sourceRenameE ann -instance SourceRenamableB (UAnnBinder req) where - sourceRenameB (UAnnBinder b ann cs) cont = do +instance SourceRenamableB UAnnBinder where + sourceRenameB (UAnnBinder expl b ann cs) cont = do ann' <- sourceRenameE ann - cs' <- mapM sourceRenameE cs - sourceRenameB b \b' -> - cont $ UAnnBinder b' ann' cs' + cs' <- mapM sourceRenameE cs + sourceRenameB b \b' -> cont $ UAnnBinder expl b' ann' cs' instance SourceRenamableE UExpr' where sourceRenameE expr = setMayShadow True case expr of UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam - UPi (UPiExpr (attrs, pats) appExpl eff body) -> + UPi (UPiExpr pats appExpl eff body) -> sourceRenameB pats \pats' -> - UPi <$> (UPiExpr (attrs, pats') <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) + UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) UApp f xs ys -> UApp <$> sourceRenameE f <*> forM xs sourceRenameE <*> forM ys (\(name, y) -> (name,) <$> sourceRenameE y) @@ -245,26 +237,20 @@ instance SourceRenamableB UTopDecl where sourceRenameUBinder UPunVar tyConName \tyConName' -> do structDef' <- sourceRenameE structDef cont $ UStructDecl tyConName' structDef' - UInterface (attrs, paramBs) methodTys className methodNames -> do + UInterface paramBs methodTys className methodNames -> do Abs paramBs' (ListE methodTys') <- sourceRenameB paramBs \paramBs' -> do methodTys' <- mapM sourceRenameE methodTys return $ Abs paramBs' $ ListE methodTys' sourceRenameUBinder UClassVar className \className' -> sourceRenameUBinderNest UMethodVar methodNames \methodNames' -> - cont $ UInterface (attrs, paramBs') methodTys' className' methodNames' - UInstance className (roleExpls, conditions) params methodDefs instanceName expl -> do + cont $ UInterface paramBs' methodTys' className' methodNames' + UInstance className conditions params methodDefs instanceName expl -> do className' <- sourceRenameE className Abs conditions' (PairE (ListE params') (ListE methodDefs')) <- sourceRenameE $ Abs conditions (PairE (ListE params) $ ListE methodDefs) sourceRenameB instanceName \instanceName' -> - cont $ UInstance className' (roleExpls, conditions') params' methodDefs' instanceName' expl - UEffectDecl opTypes effName opNames -> do - opTypes' <- mapM (\(UEffectOpType p ty) -> (UEffectOpType p) <$> sourceRenameE ty) opTypes - sourceRenameUBinder UEffectVar effName \effName' -> - sourceRenameUBinderNest UEffectOpVar opNames \opNames' -> - cont $ UEffectDecl opTypes' effName' opNames' - UHandlerDecl _ _ _ _ _ _ _ -> error "not implemented" + cont $ UInstance className' conditions' params' methodDefs' instanceName' expl instance SourceRenamableB UDecl' where sourceRenameB decl cont = case decl of @@ -277,8 +263,8 @@ instance SourceRenamableB UDecl' where UPass -> cont UPass instance SourceRenamableE ULamExpr where - sourceRenameE (ULamExpr (expls, args) expl effs resultTy body) = - sourceRenameB args \args' -> ULamExpr (expls, args') + sourceRenameE (ULamExpr args expl effs resultTy body) = + sourceRenameB args \args' -> ULamExpr args' <$> pure expl <*> mapM sourceRenameE effs <*> mapM sourceRenameE resultTy @@ -336,15 +322,15 @@ sourceRenameUBinder asUVar ubinder cont = case ubinder of UIgnore -> cont UIgnore instance SourceRenamableE UDataDef where - sourceRenameE (UDataDef tyConName (expls, paramBs) dataCons) = do + sourceRenameE (UDataDef tyConName paramBs dataCons) = do sourceRenameB paramBs \paramBs' -> do dataCons' <- forM dataCons \(dataConName, argBs) -> do argBs' <- sourceRenameE argBs return (dataConName, argBs') - return $ UDataDef tyConName (expls, paramBs') dataCons' + return $ UDataDef tyConName paramBs' dataCons' instance SourceRenamableE UStructDef where - sourceRenameE (UStructDef tyConName (expls, paramBs) fields methods) = do + sourceRenameE (UStructDef tyConName paramBs fields methods) = do sourceRenameB paramBs \paramBs' -> do fields' <- forM fields \(fieldName, ty) -> do ty' <- sourceRenameE ty @@ -352,7 +338,7 @@ instance SourceRenamableE UStructDef where methods' <- forM methods \(ann, methodName, lam) -> do lam' <- sourceRenameE lam return (ann, methodName, lam') - return $ UStructDef tyConName (expls, paramBs') fields' methods' + return $ UStructDef tyConName paramBs' fields' methods' instance SourceRenamableE UDataDefTrail where sourceRenameE (UDataDefTrail args) = sourceRenameB args \args' -> @@ -377,14 +363,6 @@ instance SourceRenamableE UMethodDef' where UMethodVar v' -> UMethodDef (InternalName pos v v') <$> sourceRenameE expr _ -> throw TypeErr $ "not a method name: " ++ pprint v -instance SourceRenamableE UEffectOpDef where - sourceRenameE (UReturnOpDef expr) = do - UReturnOpDef <$> sourceRenameE expr - sourceRenameE (UEffectOpDef rp ~(SourceName pos v) expr) = do - lookupSourceName v >>= \case - UEffectOpVar v' -> UEffectOpDef rp (InternalName pos v v') <$> sourceRenameE expr - _ -> throw TypeErr $ "not an effect operation name: " ++ pprint v - instance SourceRenamableB b => SourceRenamableB (Nest b) where sourceRenameB (Nest b bs) cont = sourceRenameB b \b' -> @@ -489,9 +467,6 @@ instance HasSourceNames UTopDecl where UInterface _ _ ~(UBindSource _ className) methodNames -> do S.singleton className <> sourceNames methodNames UInstance _ _ _ _ instanceName _ -> sourceNames instanceName - UEffectDecl _ ~(UBindSource _ effName) opNames -> do - S.singleton effName <> sourceNames opNames - UHandlerDecl _ _ _ _ _ _ handlerName -> sourceNames handlerName instance HasSourceNames UDecl' where sourceNames = \case diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index bf5036346..41bdf78e6 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -16,6 +16,7 @@ import Control.Monad.Reader import Control.Monad.State.Strict import Name +import MTL1 import IRVariants import Types.Core import Core @@ -35,6 +36,10 @@ dropSubst :: (SubstReader v m, FromName v) => m o o a -> m i o a dropSubst cont = withSubst idSubst cont {-# INLINE dropSubst #-} +withVoidSubst :: (SubstReader v m, FromName v) => m VoidS o a -> m i o a +withVoidSubst cont = withSubst (newSubst absurdNameFunction) cont +{-# INLINE withVoidSubst #-} + extendSubst :: SubstReader v m => SubstFrag v i i' o -> m i' o a -> m i o a extendSubst frag cont = do env <- (<>>frag) <$> getSubst @@ -280,37 +285,43 @@ asAtomSubstValSubst subst = newSubst \v -> toSubstVal (subst ! v) -- === SubstReaderT transformer === newtype SubstReaderT (v::V) (m::MonadKind1) (i::S) (o::S) (a:: *) = - SubstReaderT { runSubstReaderT' :: ReaderT (Subst v i o) (m o) a } + SubstReaderT' { runSubstReaderT' :: ReaderT (Subst v i o) (m o) a } + +pattern SubstReaderT :: (Subst v i o -> m o a) -> SubstReaderT v m i o a +pattern SubstReaderT f = SubstReaderT' (ReaderT f) + +runSubstReaderT :: Subst v i o -> SubstReaderT v m i o a -> m o a +runSubstReaderT env m = runReaderT (runSubstReaderT' m) env +{-# INLINE runSubstReaderT #-} instance (forall n. Functor (m n)) => Functor (SubstReaderT v m i o) where - fmap f (SubstReaderT m) = SubstReaderT $ fmap f m + fmap f (SubstReaderT' m) = SubstReaderT' $ fmap f m {-# INLINE fmap #-} instance Monad1 m => Applicative (SubstReaderT v m i o) where - pure = SubstReaderT . pure + pure = SubstReaderT' . pure {-# INLINE pure #-} - liftA2 f (SubstReaderT x) (SubstReaderT y) = SubstReaderT $ liftA2 f x y + liftA2 f (SubstReaderT' x) (SubstReaderT' y) = SubstReaderT' $ liftA2 f x y {-# INLINE liftA2 #-} - (SubstReaderT f) <*> (SubstReaderT x) = SubstReaderT $ f <*> x + (SubstReaderT' f) <*> (SubstReaderT' x) = SubstReaderT' $ f <*> x {-# INLINE (<*>) #-} instance (forall n. Monad (m n)) => Monad (SubstReaderT v m i o) where - return = SubstReaderT . return + return = SubstReaderT' . return {-# INLINE return #-} - (SubstReaderT m) >>= f = SubstReaderT (m >>= (runSubstReaderT' . f)) + (SubstReaderT' m) >>= f = SubstReaderT' (m >>= (runSubstReaderT' . f)) {-# INLINE (>>=) #-} deriving instance (Monad1 m, MonadFail1 m) => MonadFail (SubstReaderT v m i o) deriving instance (Monad1 m, Alternative1 m) => Alternative (SubstReaderT v m i o) deriving instance Fallible1 m => Fallible (SubstReaderT v m i o) -deriving instance Searcher1 m => Searcher (SubstReaderT v m i o) deriving instance Catchable1 m => Catchable (SubstReaderT v m i o) deriving instance CtxReader1 m => CtxReader (SubstReaderT v m i o) type ScopedSubstReader (v::V) = SubstReaderT v (ScopeReaderT Identity) :: MonadKind2 liftSubstReaderT :: Monad1 m => m o a -> SubstReaderT v m i o a -liftSubstReaderT m = SubstReaderT $ lift m +liftSubstReaderT m = SubstReaderT' $ lift m {-# INLINE liftSubstReaderT #-} runScopedSubstReader :: Distinct o => Scope o -> Subst v i o @@ -319,39 +330,43 @@ runScopedSubstReader scope env m = runIdentity $ runScopeReaderT scope $ runSubstReaderT env m {-# INLINE runScopedSubstReader #-} -runSubstReaderT :: Subst v i o -> SubstReaderT v m i o a -> m o a -runSubstReaderT env m = runReaderT (runSubstReaderT' m) env -{-# INLINE runSubstReaderT #-} - withSubstReaderT :: FromName v => SubstReaderT v m n n a -> m n a withSubstReaderT = runSubstReaderT idSubst {-# INLINE withSubstReaderT #-} instance (SinkableV v, Monad1 m) => SubstReader v (SubstReaderT v m) where - getSubst = SubstReaderT ask + getSubst = SubstReaderT' ask {-# INLINE getSubst #-} - withSubst env (SubstReaderT cont) = SubstReaderT $ withReaderT (const env) cont + withSubst env (SubstReaderT' cont) = SubstReaderT' $ withReaderT (const env) cont {-# INLINE withSubst #-} instance (SinkableV v, ScopeReader m) => ScopeReader (SubstReaderT v m i) where - unsafeGetScope = SubstReaderT $ lift unsafeGetScope + unsafeGetScope = liftSubstReaderT unsafeGetScope {-# INLINE unsafeGetScope #-} - getDistinct = SubstReaderT $ lift getDistinct + getDistinct = liftSubstReaderT getDistinct {-# INLINE getDistinct #-} instance (SinkableV v, EnvReader m) => EnvReader (SubstReaderT v m i) where - unsafeGetEnv = SubstReaderT $ lift unsafeGetEnv + unsafeGetEnv = liftSubstReaderT unsafeGetEnv {-# INLINE unsafeGetEnv #-} instance (SinkableV v, ScopeReader m, EnvExtender m) => EnvExtender (SubstReaderT v m i) where - refreshAbs ab cont = SubstReaderT $ ReaderT \subst -> + refreshAbs ab cont = SubstReaderT \subst -> refreshAbs ab \b e -> do subst' <- sinkM subst - let SubstReaderT (ReaderT cont') = cont b e + let SubstReaderT cont' = cont b e cont' subst' {-# INLINE refreshAbs #-} +instance MonadDiffState1 m s d => MonadDiffState1 (SubstReaderT v m i) s d where + withDiffState s m = + SubstReaderT \subst -> do + withDiffState s $ runSubstReaderT subst m + + updateDiffStateM d = liftSubstReaderT $ updateDiffStateM d + getDiffState = liftSubstReaderT getDiffState + type SubstEnvReaderM v = SubstReaderT v EnvReaderM :: MonadKind2 liftSubstEnvReaderM @@ -363,25 +378,24 @@ liftSubstEnvReaderM cont = liftEnvReaderM $ runSubstReaderT idSubst $ cont instance (SinkableV v, ScopeReader m, ScopeExtender m) => ScopeExtender (SubstReaderT v m i) where - refreshAbsScope ab cont = SubstReaderT $ ReaderT \env -> + refreshAbsScope ab cont = SubstReaderT \env -> refreshAbsScope ab \b e -> do - let SubstReaderT (ReaderT cont') = cont b e + let SubstReaderT cont' = cont b e env' <- sinkM env cont' env' instance (SinkableV v, MonadIO1 m) => MonadIO (SubstReaderT v m i o) where - liftIO m = SubstReaderT $ lift $ liftIO m + liftIO m = liftSubstReaderT $ liftIO m {-# INLINE liftIO #-} instance (Monad1 m, MonadState (s o) (m o)) => MonadState (s o) (SubstReaderT v m i o) where - state = SubstReaderT . lift . state + state = liftSubstReaderT . state {-# INLINE state #-} instance (Monad1 m, MonadReader (r o) (m o)) => MonadReader (r o) (SubstReaderT v m i o) where - ask = SubstReaderT $ ReaderT $ const ask + ask = SubstReaderT $ const ask {-# INLINE ask #-} - local r (SubstReaderT (ReaderT f)) = SubstReaderT $ ReaderT $ \env -> - local r $ f env + local r (SubstReaderT' (ReaderT f)) = SubstReaderT \env -> local r $ f env {-# INLINE local #-} -- === instances === @@ -467,6 +481,9 @@ instance FromName v => SubstE v (LiftE a) where instance SubstE v e => SubstE v (ListE e) where substE env (ListE xs) = ListE $ map (substE env) xs +instance SubstE v e => SubstE v (RListE e) where + substE env (RListE xs) = RListE $ fmap (substE env) xs + instance SubstE v e => SubstE v (NonEmptyListE e) where substE env (NonEmptyListE xs) = NonEmptyListE $ fmap (substE env) xs diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index f2379f259..f2f206038 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -363,8 +363,6 @@ runEnvQuery query = do UDataConVar v' -> pprint <$> lookupEnv v' UClassVar v' -> pprint <$> lookupEnv v' UMethodVar v' -> pprint <$> lookupEnv v' - UEffectVar v' -> pprint <$> lookupEnv v' - UEffectOpVar v' -> pprint <$> lookupEnv v' UPunVar v' -> do val <- lookupEnv v' return $ pprint val ++ "\n(type constructor and data constructor share the same name)" @@ -536,12 +534,7 @@ whenOpt x act = getConfig <&> optLevel >>= \case evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n) evalBlock typed = do - -- Be careful when adding new compilation passes here. If you do, be sure to - -- also check compileTopLevelFun, below, and Export.prepareFunctionForExport. - -- In most cases it should be easiest to add new passes to simpOptimizations or - -- loweredOptimizations, below, because those are reused in all three places. - synthed <- checkPass SynthPass $ synthTopE typed - SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock synthed + SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock typed opt <- simpOptimizations simp simpResult <- case opt of TopLam _ _ (LamExpr Empty (WithoutDecls result)) -> return result diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 067af737e..94d2248f1 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -58,8 +58,6 @@ data Atom (r::IR) (n::S) where Eff :: EffectRow CoreIR n -> Atom CoreIR n DictCon :: Type CoreIR n -> DictExpr n -> Atom CoreIR n NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Atom CoreIR n - DictHole :: AlwaysEqual SrcPosCtx -> Type CoreIR n -> RequiredMethodAccess - -> Atom CoreIR n TypeAsAtom :: Type CoreIR n -> Atom CoreIR n -- === Shims between IRs === SimpInCore :: SimpInCore n -> Atom CoreIR n @@ -623,10 +621,8 @@ data TopEnvUpdate n = | UpdateFieldDef (TyConName n) SourceName (CAtomName n) -- TODO: we could add a lot more structure for querying by dict type, caching, etc. --- TODO: make these `Name n` instead of `Atom n` so they're usable as cache keys. data SynthCandidates n = SynthCandidates - { lambdaDicts :: [AtomName CoreIR n] - , instanceDicts :: M.Map (ClassName n) [InstanceName n] } + { instanceDicts :: M.Map (ClassName n) [InstanceName n] } deriving (Show, Generic) emptyImportStatus :: ImportStatus n @@ -811,6 +807,7 @@ data InfVarDesc = data SolverBinding (n::S) = InfVarBound (CType n) InfVarCtx | SkolemBound (CType n) + | DictBound (CType n) deriving (Show, Generic) -- Context for why we created an inference variable. @@ -1458,12 +1455,9 @@ instance IRRep r => GenericE (Atom r) where {- ProjectElt -} (Type r `PairE` LiftE Projection `PairE` Atom r) {- Lam -} (WhenCore r CoreLamExpr) {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r) - ) (EitherE4 + ) (EitherE3 {- DictCon -} (WhenCore r (CType `PairE` DictExpr)) {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) - {- DictHole -} (WhenCore r (LiftE (AlwaysEqual SrcPosCtx) `PairE` - (Type CoreIR) `PairE` - (LiftE RequiredMethodAccess))) {- Con -} (Con r) ) (EitherE5 {- Eff -} ( WhenCore r (EffectRow r)) @@ -1480,8 +1474,7 @@ instance IRRep r => GenericE (Atom r) where DepPair l r ty -> Case0 (Case3 $ l `PairE` r `PairE` ty) DictCon t d -> Case1 $ Case0 $ WhenIRE $ t `PairE` d NewtypeCon c x -> Case1 $ Case1 $ WhenIRE (c `PairE` x) - DictHole s t access -> Case1 $ Case2 $ WhenIRE (LiftE s `PairE` t `PairE` LiftE access) - Con con -> Case1 $ Case3 con + Con con -> Case1 $ Case2 con Eff effs -> Case2 $ Case0 $ WhenIRE effs PtrVar t v -> Case2 $ Case1 $ LiftE t `PairE` v RepValAtom rv -> Case2 $ Case2 $ WhenIRE $ rv @@ -1499,8 +1492,7 @@ instance IRRep r => GenericE (Atom r) where Case1 val -> case val of Case0 (WhenIRE (t `PairE` d)) -> DictCon t d Case1 (WhenIRE (c `PairE` x)) -> NewtypeCon c x - Case2 (WhenIRE (LiftE s `PairE` t `PairE` LiftE access)) -> DictHole s t access - Case3 con -> Con con + Case2 con -> Con con _ -> error "impossible" Case2 val -> case val of Case0 (WhenIRE effs) -> Eff effs @@ -2119,12 +2111,11 @@ deriving instance IRRep r => Show (DepPairType r n) deriving via WrapE (DepPairType r) n instance IRRep r => Generic (DepPairType r n) instance GenericE SynthCandidates where - type RepE SynthCandidates = - ListE (AtomName CoreIR) `PairE` ListE (PairE ClassName (ListE InstanceName)) - fromE (SynthCandidates xs ys) = ListE xs `PairE` ListE ys' + type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) + fromE (SynthCandidates ys) = ListE ys' where ys' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList ys) {-# INLINE fromE #-} - toE (ListE xs `PairE` ListE ys) = SynthCandidates xs ys' + toE (ListE ys) = SynthCandidates ys' where ys' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) ys {-# INLINE toE #-} @@ -2259,17 +2250,20 @@ instance AlphaEqE LinearizationSpec instance AlphaHashableE LinearizationSpec instance GenericE SolverBinding where - type RepE SolverBinding = EitherE2 + type RepE SolverBinding = EitherE3 (PairE CType (LiftE InfVarCtx)) CType + CType fromE = \case InfVarBound ty ctx -> Case0 (PairE ty (LiftE ctx)) SkolemBound ty -> Case1 ty + DictBound ty -> Case2 ty {-# INLINE fromE #-} toE = \case Case0 (PairE ty (LiftE ct)) -> InfVarBound ty ct Case1 ty -> SkolemBound ty + Case2 ty -> DictBound ty _ -> error "impossible" {-# INLINE toE #-} @@ -2455,11 +2449,10 @@ instance IRRep r => BindsOneName (Decl r) (AtomNameC r) where {-# INLINE binderName #-} instance Semigroup (SynthCandidates n) where - SynthCandidates xs ys <> SynthCandidates xs' ys' = - SynthCandidates (xs<>xs') (M.unionWith (<>) ys ys') + SynthCandidates xs <> SynthCandidates xs' = SynthCandidates (M.unionWith (<>) xs xs') instance Monoid (SynthCandidates n) where - mempty = SynthCandidates mempty mempty + mempty = SynthCandidates mempty instance GenericB EnvFrag where type RepB EnvFrag = RecSubstFrag Binding diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 21a6974ee..fdb2a3eba 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -95,13 +95,6 @@ data CTopDecl' SourceName -- Interface name ExplicitParams [(SourceName, Group)] -- Method declarations - | CEffectDecl SourceName [(SourceName, UResumePolicy, Group)] - | CHandlerDecl SourceName -- Handler name - SourceName -- Effect name - SourceName -- Body type parameter - Group -- Handler arguments - Group -- Handler type annotation - [(SourceName, Maybe UResumePolicy, CSBlock)] -- Handler methods -- header, givens (may be empty), methods, optional name. The header should contain -- the prerequisites, class name, and class arguments. | CInstanceDecl CInstanceDef @@ -216,9 +209,7 @@ data UVar (n::S) = | UTyConVar (Name TyConNameC n) | UDataConVar (Name DataConNameC n) | UClassVar (Name ClassNameC n) - | UEffectVar (Name EffectNameC n) | UMethodVar (Name MethodNameC n) - | UEffectOpVar (Name EffectOpNameC n) | UPunVar (Name TyConNameC n) -- for names also used as data constructors deriving (Eq, Ord, Show, Generic) @@ -274,12 +265,9 @@ data FieldName' = | FieldNum Int deriving (Show, Eq, Ord) -type UAnnExplBinders req n l = ([Explicitness], Nest (UAnnBinder req) n l) -type UOptAnnExplBinders n l = UAnnExplBinders AnnOptional n l - data ULamExpr (n::S) where ULamExpr - :: UOptAnnExplBinders n l -- args + :: Nest UAnnBinder n l -- args -> AppExplicitness -> Maybe (UEffectRow l) -- optional effect -> Maybe (UType l) -- optional result type @@ -287,33 +275,33 @@ data ULamExpr (n::S) where -> ULamExpr n data UPiExpr (n::S) where - UPiExpr :: UOptAnnExplBinders n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n + UPiExpr :: Nest UAnnBinder n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n data UTabPiExpr (n::S) where - UTabPiExpr :: UOptAnnBinder n l -> UType l -> UTabPiExpr n + UTabPiExpr :: UAnnBinder n l -> UType l -> UTabPiExpr n data UDepPairType (n::S) where - UDepPairType :: DepPairExplicitness -> UOptAnnBinder n l -> UType l -> UDepPairType n + UDepPairType :: DepPairExplicitness -> UAnnBinder n l -> UType l -> UDepPairType n -type UConDef (n::S) (l::S) = (SourceName, Nest UReqAnnBinder n l) +type UConDef (n::S) (l::S) = (SourceName, Nest UAnnBinder n l) data UDataDef (n::S) where UDataDef :: SourceName -- source name for pretty printing - -> UOptAnnExplBinders n l + -> Nest UAnnBinder n l -> [(SourceName, UDataDefTrail l)] -- data constructor types -> UDataDef n data UStructDef (n::S) where UStructDef :: SourceName -- source name for pretty printing - -> UOptAnnExplBinders n l + -> Nest UAnnBinder n l -> [(SourceName, UType l)] -- named payloads -> [(LetAnn, SourceName, Abs UAtomBinder ULamExpr l)] -- named methods (initial binder is for `self`) -> UStructDef n data UDataDefTrail (l::S) where - UDataDefTrail :: Nest UReqAnnBinder l l' -> UDataDefTrail l + UDataDefTrail :: Nest UAnnBinder l l' -> UDataDefTrail l data UTopDecl (n::S) (l::S) where ULocalDecl :: UDecl n l -> UTopDecl n l @@ -327,42 +315,24 @@ data UTopDecl (n::S) (l::S) where -> UStructDef l -- actual definition -> UTopDecl n l UInterface - :: UOptAnnExplBinders n p -- parameter binders + :: Nest UAnnBinder n p -- parameter binders -> [UType p] -- method types -> UBinder ClassNameC n l' -- class name -> Nest (UBinder MethodNameC) l' l -- method names -> UTopDecl n l UInstance :: SourceNameOr (Name ClassNameC) n -- class name - -> UOptAnnExplBinders n l' + -> Nest UAnnBinder n l' -> [UExpr l'] -- class parameters -> [UMethodDef l'] -- method definitions -- Maybe we should make a separate color (namespace) for instance names? -> MaybeB UAtomBinder n l -- optional instance name -> AppExplicitness -- explicitness (only relevant for named instances) -> UTopDecl n l - UEffectDecl - :: [UEffectOpType n] -- operation types - -> UBinder EffectNameC n l' -- effect name - -> Nest (UBinder EffectOpNameC) l' l -- operation names - -> UTopDecl n l - UHandlerDecl - :: SourceNameOr (Name EffectNameC) n -- effect name - -> UAtomBinder n b -- body type argument - -> UOptAnnExplBinders b l' -- type args - -> UEffectRow l' -- returning effect - -> UType l' -- returning type - -> [UEffectOpDef l'] -- operation definitions - -> UBinder HandlerNameC n l -- handler name - -> UTopDecl n l type UType = UExpr type UConstraint = UExpr -data UEffectOpType (n::S) where - UEffectOpType :: UResumePolicy -> UType s -> UEffectOpType s - deriving (Show, Generic) - data UResumePolicy = UNoResume | ULinearResume @@ -373,32 +343,18 @@ instance Hashable UResumePolicy instance Store UResumePolicy data UForExpr (n::S) where - UForExpr :: UOptAnnBinder n l -> UBlock l -> UForExpr n + UForExpr :: UAnnBinder n l -> UBlock l -> UForExpr n type UMethodDef = WithSrcE UMethodDef' data UMethodDef' (n::S) = UMethodDef (SourceNameOr (Name MethodNameC) n) (ULamExpr n) deriving (Show, Generic) -data UEffectOpDef (n::S) = - UEffectOpDef UResumePolicy (SourceNameOr (Name EffectOpNameC) n) (UExpr n) - | UReturnOpDef (UExpr n) - deriving (Show, Generic) +data UAnn (n::S) = UAnn (UType n) | UNoAnn deriving Show -data AnnRequirement = AnnRequired | AnnOptional - -data UAnn (annReq::AnnRequirement) (n::S) where - UAnn :: UType n -> UAnn annReq n - UNoAnn :: UAnn AnnOptional n -deriving instance Show (UAnn annReq n) - - -data UAnnBinder (annReq::AnnRequirement) (n::S) (l::S) = - UAnnBinder (UAtomBinder n l) (UAnn annReq n) [UConstraint n] +data UAnnBinder (n::S) (l::S) = + UAnnBinder Explicitness (UAtomBinder n l) (UAnn n) [UConstraint n] deriving (Show, Generic) -type UReqAnnBinder = UAnnBinder AnnRequired :: B -type UOptAnnBinder = UAnnBinder AnnOptional :: B - data UAlt (n::S) where UAlt :: UPat n l -> UBlock l -> UAlt n @@ -419,8 +375,8 @@ pattern UPatIgnore = UPatBinder UIgnore class HasSourceName a where getSourceName :: a -> SourceName -instance HasSourceName (UAnnBinder req n l) where - getSourceName (UAnnBinder b _ _) = getSourceName b +instance HasSourceName (UAnnBinder n l) where + getSourceName (UAnnBinder _ b _ _) = getSourceName b instance HasSourceName (UBinder c n l) where getSourceName = \case @@ -448,6 +404,12 @@ instance HasSrcPos (WithSrcE (a::E) (n::S)) where instance HasSrcPos (WithSrcB (b::B) (n::S) (n::S)) where srcPos (WithSrcB pos _) = pos +instance HasSrcPos (UBinder c n l) where + srcPos = \case + UBindSource ctx _ -> ctx + UIgnore -> SrcPosCtx Nothing Nothing + UBind ctx _ _ -> ctx + -- === SourceMap === data SourceNameDef n = @@ -529,7 +491,7 @@ data PrintBackend = data OutFormat = Printed (Maybe PrintBackend) | RenderHtml deriving (Show, Eq, Generic) -data PassName = Parse | RenamePass | TypePass | SynthPass | SimpPass | ImpPass | JitPass +data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | LowerOptPass | LowerPass | ResultPass | JaxprAndHLO | EarlyOptPass | OptPass | VectPass | OccAnalysisPass | InlinePass @@ -537,8 +499,7 @@ data PassName = Parse | RenamePass | TypePass | SynthPass | SimpPass | ImpPass | instance Show PassName where show p = case p of - Parse -> "parse" ; RenamePass -> "rename"; - TypePass -> "typed" ; SynthPass -> "synth" + Parse -> "parse" ; RenamePass -> "rename"; TypePass -> "typed" SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm" LLVMOpt -> "llvmopt" ; AsmPass -> "asm" JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result" @@ -627,19 +588,16 @@ instance Pretty (SourceMap n) where fold [pretty v <+> "@>" <+> pretty x <> hardline | (v, x) <- M.toList m ] instance GenericE UVar where - type RepE UVar = EitherE8 (Name (AtomNameC CoreIR)) (Name TyConNameC) + type RepE UVar = EitherE6 (Name (AtomNameC CoreIR)) (Name TyConNameC) (Name DataConNameC) (Name ClassNameC) - (Name MethodNameC) (Name EffectNameC) - (Name EffectOpNameC) (Name TyConNameC) + (Name MethodNameC) (Name TyConNameC) fromE name = case name of UAtomVar v -> Case0 v UTyConVar v -> Case1 v UDataConVar v -> Case2 v UClassVar v -> Case3 v UMethodVar v -> Case4 v - UEffectVar v -> Case5 v - UEffectOpVar v -> Case6 v - UPunVar v -> Case7 v + UPunVar v -> Case5 v {-# INLINE fromE #-} toE name = case name of @@ -648,9 +606,8 @@ instance GenericE UVar where Case2 v -> UDataConVar v Case3 v -> UClassVar v Case4 v -> UMethodVar v - Case5 v -> UEffectVar v - Case6 v -> UEffectOpVar v - Case7 v -> UPunVar v + Case5 v -> UPunVar v + _ -> error "impossible" {-# INLINE toE #-} instance Pretty (UVar n) where @@ -660,8 +617,6 @@ instance Pretty (UVar n) where UDataConVar v -> "Data constructor name: " <> pretty v UClassVar v -> "Class name: " <> pretty v UMethodVar v -> "Method name: " <> pretty v - UEffectVar v -> "Effect name: " <> pretty v - UEffectOpVar v -> "Effect operation name: " <> pretty v UPunVar v -> "Shared type constructor / data constructor name: " <> pretty v -- TODO: name subst instances for the rest of UExpr @@ -710,12 +665,12 @@ instance Color c => RenameB (UBinder c) where UIgnore -> cont env UIgnore UBind ctx sn b -> renameB env b \env' b' -> cont env' $ UBind ctx sn b' -instance ProvesExt (UAnnBinder req) where -instance BindsNames (UAnnBinder req) where - toScopeFrag (UAnnBinder b _ _) = toScopeFrag b +instance ProvesExt UAnnBinder where +instance BindsNames UAnnBinder where + toScopeFrag (UAnnBinder _ b _ _) = toScopeFrag b -instance BindsAtMostOneName (UAnnBinder req) (AtomNameC CoreIR) where - UAnnBinder b _ _ @> x = b @> x +instance BindsAtMostOneName UAnnBinder (AtomNameC CoreIR) where + UAnnBinder _ b _ _ @> x = b @> x instance GenericE (WithSrcE e) where type RepE (WithSrcE e) = PairE (LiftE SrcPosCtx) e @@ -770,8 +725,8 @@ instance IsString (UBinder s VoidS VoidS) where instance IsString (UPat' VoidS VoidS) where fromString = UPatBinder . fromString -instance IsString (UOptAnnBinder VoidS VoidS) where - fromString s = UAnnBinder (fromString s) UNoAnn [] +instance IsString (UAnnBinder VoidS VoidS) where + fromString s = UAnnBinder Explicit (fromString s) UNoAnn [] instance IsString (UExpr' VoidS) where fromString = UVar . fromString diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 7f9a89859..853c384e5 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -306,7 +306,7 @@ getAlternative xs = asum $ map pure xs {-# INLINE getAlternative #-} newtype SnocList a = ReversedList { fromReversedList :: [a] } - deriving Functor -- XXX: NOT deriving order-sensitive things like Monoid, Applicative etc + deriving (Show, Eq, Ord, Generic, Functor) -- XXX: NOT deriving order-sensitive things like Monoid, Applicative etc instance Semigroup (SnocList a) where (ReversedList x) <> (ReversedList y) = ReversedList $ y ++ x @@ -320,6 +320,10 @@ instance Foldable SnocList where foldMap f (ReversedList xs) = foldMap f (reverse xs) {-# INLINE foldMap #-} +instance Traversable SnocList where + traverse f (ReversedList xs) = ReversedList . reverse <$> traverse f (reverse xs) + {-# INLINE traverse #-} + snoc :: SnocList a -> a -> SnocList a snoc (ReversedList xs) x = ReversedList (x:xs) {-# INLINE snoc #-} diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 88a6ef48e..bf585a089 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -85,13 +85,13 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM SubstReaderT Name (ReaderT1 CommuteMap (ReaderT1 (LiftE Word32) - (StateT1 (LiftE Errs) (BuilderT SimpIR FallibleM)))) i o a } + (StateT1 (LiftE [Err]) (BuilderT SimpIR FallibleM)))) i o a } deriving ( Functor, Applicative, Monad, MonadFail, MonadReader (CommuteMap o) - , MonadState (LiftE Errs o), Fallible, ScopeReader, EnvReader + , MonadState (LiftE [Err] o), Fallible, ScopeReader, EnvReader , EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable , SubstReader Name) -vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) +vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, [Err]) vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do case popNest bsDestB of Just (PairB bs b) -> @@ -102,7 +102,7 @@ vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do {-# SCC vectorizeLoops #-} liftTopVectorizeM :: (EnvReader m) - => Word32 -> TopVectorizeM i i a -> m i (a, Errs) + => Word32 -> TopVectorizeM i i a -> m i (a, [Err]) liftTopVectorizeM vectorByteWidth action = do fallible <- liftBuilderT $ flip runStateT1 mempty $ runReaderT1 (LiftE vectorByteWidth) $ @@ -111,7 +111,7 @@ liftTopVectorizeM vectorByteWidth action = do case runFallibleM fallible of -- The failure case should not occur: vectorization errors should have been -- caught inside `vectorizeLoopsDecls` (and should have been added to the - -- `Errs` state of the `StateT` instance that is run with `runStateT` above). + -- `Err` state of the `StateT` instance that is run with `runStateT` above). Failure errs -> error $ pprint errs Success (a, (LiftE errs)) -> return $ (a, errs) @@ -123,12 +123,11 @@ addVectErrCtx name payload m = throwVectErr :: Fallible m => String -> m a throwVectErr msg = throwErr (Err MiscErr mempty msg) -prependCtxToErrs :: ErrCtx -> Errs -> Errs -prependCtxToErrs ctx (Errs errs) = - Errs $ map (\(Err ty ctx' msg) -> Err ty (ctx <> ctx') msg) errs +prependCtxToErr :: ErrCtx -> Err -> Err +prependCtxToErr ctx (Err ty ctx' msg) = Err ty (ctx <> ctx') msg askVectorByteWidth :: TopVectorizeM i o Word32 -askVectorByteWidth = TopVectorizeM $ SubstReaderT $ lift $ lift11 (fromLiftE <$> ask) +askVectorByteWidth = TopVectorizeM $ liftSubstReaderT $ lift11 (fromLiftE <$> ask) extendCommuteMap :: AtomName SimpIR o -> MonoidCommutes -> TopVectorizeM i o a -> TopVectorizeM i o a extendCommuteMap name commutativity = local $ insertNameMapE name $ LiftE commutativity @@ -186,11 +185,11 @@ vectorizeLoopsExpr expr = do seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body' return $ PrimOp $ DAMOp seqOp) else renameM expr) - `catchErr` \errs -> do + `catchErr` \err -> do let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr ctx = mempty { messageCtx = [msg] } - errs' = prependCtxToErrs ctx errs - modify (<> LiftE errs') + err' = prependCtxToErr ctx err + modify (\(LiftE errs) -> LiftE (err':errs)) recurSeq expr _ -> recurSeq expr PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do @@ -352,12 +351,12 @@ liftVectorizeM loopWidth action = do let fallible = flip runReaderT loopWidth act case runFallibleM fallible of Success a -> return a - Failure errs -> throwErrs errs -- re-raise inside ambient monad + Failure errs -> throwErr errs -- re-raise inside ambient monad where vSubst subst val = VRename $ subst ! val getLoopWidth :: VectorizeM i o Word32 -getLoopWidth = VectorizeM $ SubstReaderT $ ReaderT $ const $ ask +getLoopWidth = VectorizeM $ SubstReaderT $ const $ ask -- TODO When needed, can code a variant of this that also returns the Stability -- of the value returned by the LamExpr. From f46a6bb34bdc5186696144e9b72a04d6676c7ed1 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 17 Oct 2023 00:18:58 -0400 Subject: [PATCH 03/41] Move `Var` and `ProjectElt` into a separate data type, `Stuck`. This lets us add more cases to the list of "expression-like things that can appear in types" and it enforces normalization syntactically. I'm also hoping we can handle the stuck case generically in many places so that it's not burdensome to add another case to it. --- src/lib/Builder.hs | 127 ++++--- src/lib/CheapReduction.hs | 636 ++++++++++++------------------------ src/lib/CheckType.hs | 124 +++---- src/lib/Imp.hs | 26 +- src/lib/Inference.hs | 193 ++++++----- src/lib/Inline.hs | 12 +- src/lib/Linearize.hs | 22 +- src/lib/Lower.hs | 28 +- src/lib/OccAnalysis.hs | 12 +- src/lib/Optimize.hs | 7 +- src/lib/PPrint.hs | 48 +-- src/lib/QueryType.hs | 23 +- src/lib/QueryTypePure.hs | 26 +- src/lib/RuntimePrint.hs | 9 +- src/lib/Simplify.hs | 121 +++++-- src/lib/Subst.hs | 18 + src/lib/TopLevel.hs | 7 +- src/lib/Transpose.hs | 52 +-- src/lib/Types/Core.hs | 178 ++++++---- src/lib/Types/Primitives.hs | 2 + src/lib/Vectorize.hs | 25 +- 21 files changed, 822 insertions(+), 874 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 9319f41f4..ee41f1b4a 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -70,6 +70,10 @@ emit :: (Builder r m, Emits n) => Expr r n -> m n (AtomVar r n) emit expr = emitDecl noHint PlainLet expr {-# INLINE emit #-} +emitInline :: (Builder r m, Emits n) => Atom r n -> m n (AtomVar r n) +emitInline atom = emitDecl noHint InlineLet $ Atom atom +{-# INLINE emitInline #-} + emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n) emitHinted hint expr = emitDecl hint PlainLet expr {-# INLINE emitHinted #-} @@ -502,6 +506,8 @@ instance (IRRep r, Fallible m) => Builder r (BuilderT r m) where ty <- return $ getType expr v <- BuilderT $ freshExtendSubInplaceT hint \b -> (BuilderDeclEmission $ Let b $ DeclBinding ann expr, binderName b) + -- -- Debugging snippet + -- traceM $ pprint v ++ " = " ++ pprint expr return $ AtomVar v ty {-# INLINE rawEmitDecl #-} @@ -726,7 +732,7 @@ injectAltResult :: EnvReader m => [SType n] -> Int -> Alt SimpIR n -> m n (Alt S injectAltResult sumTys con (Abs b body) = liftBuilder do buildAlt (binderType b) \v -> do originalResult <- emitBlock =<< applySubst (b@>SubstVal (Var v)) body - (dataResult, nonDataResult) <- fromPair originalResult + (dataResult, nonDataResult) <- fromPairReduced originalResult return $ PairVal dataResult $ Con $ SumCon (sinkList sumTys) con nonDataResult -- TODO: consider a version with nonempty list of alternatives where we figure @@ -1010,22 +1016,16 @@ ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y ieq x y = emitOp $ BinOp (ICmp Equal) x y -fromPair :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n (Atom r n, Atom r n) +fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) fromPair pair = do getUnpacked pair >>= \case [x, y] -> return (x, y) _ -> error "expected a pair" -getFst :: Builder r m => Atom r n -> m n (Atom r n) -getFst p = fst <$> fromPair p - -getSnd :: Builder r m => Atom r n -> m n (Atom r n) -getSnd p = snd <$> fromPair p - -- the rightmost index is applied first -getNaryProjRef :: (Builder r m, Emits n) => [Projection] -> Atom r n -> m n (Atom r n) -getNaryProjRef [] ref = return ref -getNaryProjRef (i:is) ref = getProjRef i =<< getNaryProjRef is ref +applyProjectionsRef :: (Builder r m, Emits n) => [Projection] -> Atom r n -> m n (Atom r n) +applyProjectionsRef [] ref = return ref +applyProjectionsRef (i:is) ref = getProjRef i =<< applyProjectionsRef is ref getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n) getProjRef i r = emitOp =<< mkProjRef r i @@ -1033,51 +1033,53 @@ getProjRef i r = emitOp =<< mkProjRef r i -- XXX: getUnpacked must reduce its argument to enforce the invariant that -- ProjectElt atoms are always fully reduced (to avoid type errors between two -- equivalent types spelled differently). -getUnpacked :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n [Atom r n] -getUnpacked atom = do - atom' <- cheapNormalize atom - ty <- return $ getType atom' - positions <- case ty of - ProdTy tys -> return $ void tys - DepPairTy _ -> return [(), ()] - _ -> error $ "not a product type: " ++ pprint ty - forM (enumerate positions) \(i, _) -> - normalizeProj (ProjectProduct i) atom' +getUnpacked :: (Builder r m, Emits n) => Atom r n -> m n [Atom r n] +getUnpacked atom = forM (productIdxs atom) \i -> proj i atom {-# SCC getUnpacked #-} -getProj :: (Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) -getProj i atom = do - atom' <- cheapNormalize atom - normalizeProj (ProjectProduct i) atom' - -emitUnpacked :: (Builder r m, Emits n) => Atom r n -> m n [AtomVar r n] -emitUnpacked tup = do - xs <- getUnpacked tup - forM xs \x -> emit $ Atom x +productIdxs :: IRRep r => Atom r n -> [Int] +productIdxs atom = + let positions = case getType atom of + ProdTy tys -> void tys + DepPairTy _ -> [(), ()] + ty -> error $ "not a product type: " ++ pprint ty + in fst <$> enumerate positions -unwrapNewtype :: EnvReader m => CAtom n -> m n (CAtom n) +unwrapNewtype :: (Emits n, Builder CoreIR m) => CAtom n -> m n (CAtom n) unwrapNewtype (NewtypeCon _ x) = return x unwrapNewtype x = case getType x of NewtypeTyCon con -> do (_, ty) <- unwrapNewtypeType con - return $ ProjectElt ty UnwrapNewtype x + emitExpr $ Unwrap ty x _ -> error "not a newtype" {-# INLINE unwrapNewtype #-} -projectTuple :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) -projectTuple i x = normalizeProj (ProjectProduct i) x +proj ::(Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) +proj i = \case + ProdVal xs -> return $ xs !! i + DepPair l _ _ | i == 0 -> return l + DepPair _ r _ | i == 1 -> return r + x -> do + ty <- projType i x + emitExpr $ Project ty i x -projectStruct :: EnvReader m => Int -> CAtom n -> m n (CAtom n) +getFst :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) +getFst = proj 0 + +getSnd :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) +getSnd = proj 1 + +projectStruct :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n) projectStruct i x = do projs <- getStructProjections i (getType x) - normalizeNaryProj projs x + applyProjections projs x {-# INLINE projectStruct #-} projectStructRef :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n) projectStructRef i x = do RefTy _ valTy <- return $ getType x projs <- getStructProjections i valTy - getNaryProjRef projs x + applyProjectionsRef projs x {-# INLINE projectStructRef #-} getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection] @@ -1089,6 +1091,24 @@ getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do _ -> [ProjectProduct i, UnwrapNewtype] getStructProjections _ _ = error "not a struct" +-- the rightmost index is applied first +applyProjections :: (Builder CoreIR m, Emits n) => [Projection] -> CAtom n -> m n (CAtom n) +applyProjections [] x = return x +applyProjections (p:ps) x = do + x' <- applyProjections ps x + case p of + ProjectProduct i -> proj i x' + UnwrapNewtype -> unwrapNewtype x' + +-- the rightmost index is applied first +applyProjectionsReduced :: EnvReader m => [Projection] -> CAtom n -> m n (CAtom n) +applyProjectionsReduced [] x = return x +applyProjectionsReduced (p:ps) x = do + x' <- applyProjectionsReduced ps x + case p of + ProjectProduct i -> reduceProj i x' + UnwrapNewtype -> reduceUnwrap x' + mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n) mkApp f xs = do et <- appEffTy (getType f) xs @@ -1109,10 +1129,23 @@ mkApplyMethod d i xs = do resultTy <- typeOfApplyMethod d i xs return $ ApplyMethod resultTy d i xs -mkDictAtom :: EnvReader m => DictExpr n -> m n (CAtom n) -mkDictAtom d = do - ty <- typeOfDictExpr d - return $ DictCon ty d +mkIxFin :: (EnvReader m, Fallible1 m) => CAtom n -> m n (DictCon n) +mkIxFin n = do + dictTy <- liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n + return $ IxFin dictTy n + +mkDataData :: (EnvReader m, Fallible1 m) => CType n -> m n (DictCon n) +mkDataData dataTy = do + dictTy <- DictTy <$> dataDictType dataTy + return $ DataData dictTy dataTy + +mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (DictCon n) +mkInstanceDict instanceName args = do + instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName + sourceName <- getSourceName <$> lookupClassDef className + PairE (ListE params) _ <- instantiate instanceDef args + let ty = DictTy $ DictType sourceName className params + return $ InstanceDict ty instanceName args mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) mkCase scrut resultTy alts = liftEnvReaderM do @@ -1456,14 +1489,20 @@ unpackTelescope (ReconBinders tyTop _) xTop = go tyTop xTop where go :: (Fallible1 m, EnvReader m, IRRep r) => TelescopeType c e l-> Atom r n -> m n [Atom r n] go ty x = case ty of - ProdTelescope _ -> getUnpacked x + ProdTelescope _ -> getUnpackedReduced x DepTelescope ty1 (Abs _ ty2) -> do - (x1, xPair) <- fromPair x - (xDep, x2) <- fromPair xPair + (x1, xPair) <- fromPairReduced x + (xDep, x2) <- fromPairReduced xPair xs1 <- go ty1 x1 xs2 <- go ty2 x2 return $ xs1 ++ (xDep : xs2) +fromPairReduced :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n (Atom r n, Atom r n) +fromPairReduced pair = (,) <$> reduceProj 0 pair <*> reduceProj 1 pair + +getUnpackedReduced :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n [Atom r n] +getUnpackedReduced xs = forM (productIdxs xs) \i -> reduceProj i xs + -- sorts name-annotation pairs so that earlier names may be occur free in later -- annotations but not vice versa. type AnnVar c e n = (Name c n, e n) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 99559bc17..0f850166f 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -8,32 +8,26 @@ {-# OPTIONS_GHC -Wno-orphans #-} module CheapReduction - ( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize - , normalizeProj, asNaryProj, normalizeNaryProj - , depPairLeftTy, instantiateTyConDef - , dataDefRep, unwrapNewtypeType, repValAtom - , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType - , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) + ( reduceWithDecls, reduceExpr, reduceBlock + , instantiateTyConDef, dataDefRep, unwrapNewtypeType + , NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated - , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst) + , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst + , repValAtom, projType, reduceUnwrap, reduceProj, reduceSuperclassProj + , reduceInstantiateGiven, typeOfApp) where import Control.Applicative import Control.Monad.Writer.Strict hiding (Alt) -import Control.Monad.State.Strict import Control.Monad.Reader -import Data.Foldable (toList) -import Data.Functor.Identity import Data.Functor ((<&>)) -import qualified Data.List.NonEmpty as NE -import qualified Data.Map.Strict as M +import Data.Maybe (fromJust) import Subst import Core import Err import IRVariants -import MTL1 import Name import PPrint () import QueryTypePure @@ -52,315 +46,153 @@ import Util -- === api === -type NiceE r e = (HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e, IRRep r) - -cheapReduce :: forall r e' e m n - . (EnvReader m, CheaplyReducibleE r e e', NiceE r e, NiceE r e') - => e n -> m n (Maybe (e' n)) -cheapReduce e = liftCheapReducerM idSubst $ cheapReduceE e -{-# INLINE cheapReduce #-} -{-# SCC cheapReduce #-} - -cheapReduceWithDecls - :: forall r e' e m n l - . ( CheaplyReducibleE r e e', NiceE r e', NiceE r e, EnvReader m ) - => Nest (Decl r) n l -> e l -> m n (Maybe (e' n)) -cheapReduceWithDecls decls result = do - Abs decls' result' <- sinkM $ Abs decls result - liftCheapReducerM idSubst $ - cheapReduceWithDeclsB decls' $ - cheapReduceE result' -{-# INLINE cheapReduceWithDecls #-} -{-# SCC cheapReduceWithDecls #-} - -cheapNormalize :: (EnvReader m, CheaplyReducibleE r e e, NiceE r e) => e n -> m n (e n) -cheapNormalize a = cheapReduce a >>= \case - Just ans -> return ans - _ -> error "couldn't normalize expression" -{-# INLINE cheapNormalize #-} +reduceWithDecls + :: (IRRep r, HasNamesE e, SubstE AtomSubstVal e, EnvReader m) + => WithDecls r e n -> m n (Maybe (e n)) +reduceWithDecls (Abs decls e) = + liftReducerM $ reduceWithDeclsM decls $ substM e + +reduceExpr :: (IRRep r, EnvReader m) => Expr r n -> m n (Maybe (Atom r n)) +reduceExpr e = liftReducerM $ reduceExprM e +{-# INLINE reduceExpr #-} + +reduceBlock :: (IRRep r, EnvReader m) => Block r n -> m n (Maybe (Atom r n)) +reduceBlock e = liftReducerM $ reduceBlockM e +{-# INLINE reduceBlock #-} + +reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) +reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x +{-# INLINE reduceProj #-} + +reduceUnwrap :: (IRRep r, EnvReader m) => Atom r n -> m n (Atom r n) +reduceUnwrap x = liftM fromJust $ liftReducerM $ reduceUnwrapM x +{-# INLINE reduceUnwrap #-} + +reduceSuperclassProj :: EnvReader m => Int -> CAtom n -> m n (CAtom n) +reduceSuperclassProj i x = liftM fromJust $ liftReducerM $ reduceSuperclassProjM i x +{-# INLINE reduceSuperclassProj #-} + +reduceInstantiateGiven :: EnvReader m => CAtom n -> [CAtom n] -> m n (CAtom n) +reduceInstantiateGiven f xs = liftM fromJust $ liftReducerM $ reduceInstantiateGivenM f xs +{-# INLINE reduceInstantiateGiven #-} -- === internal === -newtype CheapReducerM (r::IR) (i :: S) (o :: S) (a :: *) = - CheapReducerM - (SubstReaderT AtomSubstVal - (MaybeT1 - (ScopedT1 (MapE (AtomName r) (MaybeE (Atom r))) - (EnvReaderT Identity))) i o a) - deriving (Functor, Applicative, Monad, Alternative) - -deriving instance IRRep r => ScopeReader (CheapReducerM r i) -deriving instance IRRep r => EnvReader (CheapReducerM r i) -deriving instance IRRep r => EnvExtender (CheapReducerM r i) -deriving instance IRRep r => SubstReader AtomSubstVal (CheapReducerM r) - -class ( Alternative2 m, SubstReader AtomSubstVal m - , EnvReader2 m, EnvExtender2 m) => CheapReducer m r | m -> r where - updateCache :: AtomName r o -> Maybe (Atom r o) -> m i o () - lookupCache :: AtomName r o -> m i o (Maybe (Maybe (Atom r o))) - -instance IRRep r => CheapReducer (CheapReducerM r) r where - updateCache v u = CheapReducerM $ liftSubstReaderT $ lift11 $ - modify (MapE . M.insert v (toMaybeE u) . fromMapE) - lookupCache v = CheapReducerM $ liftSubstReaderT $ lift11 $ - fmap fromMaybeE <$> gets (M.lookup v . fromMapE) - -liftCheapReducerM - :: (EnvReader m, IRRep r) - => Subst AtomSubstVal i o -> CheapReducerM r i o a - -> m o (Maybe a) -liftCheapReducerM subst (CheapReducerM m) = do - liftM runIdentity $ liftEnvReaderT $ runScopedT1 - (runMaybeT1 $ runSubstReaderT subst m) mempty -{-# INLINE liftCheapReducerM #-} - -cheapReduceWithDeclsB - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (e o) -cheapReduceWithDeclsB decls cont = do - Abs irreducibleDecls result <- cheapReduceWithDeclsRec decls cont - case hoist irreducibleDecls result of - HoistSuccess result' -> return result' - HoistFailure _ -> empty - -cheapReduceWithDeclsRec - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (Abs (Nest (Decl r)) e o) -cheapReduceWithDeclsRec decls cont = case decls of - Empty -> Abs Empty <$> cont - Nest (Let b binding@(DeclBinding _ expr)) rest -> do - optional (cheapReduceE expr) >>= \case - Nothing -> do - binding' <- substM binding - withFreshBinder (getNameHint b) binding' \(b':>_) -> do - updateCache (binderName b') Nothing - extendSubst (b@>Rename (binderName b')) do - Abs decls' result <- cheapReduceWithDeclsRec rest cont - return $ Abs (Nest (Let b' binding') decls') result - Just x -> - extendSubst (b@>SubstVal x) $ - cheapReduceWithDeclsRec rest cont - -cheapReduceName :: forall c r i o . (IRRep r, Color c) => Name c o -> CheapReducerM r i o (AtomSubstVal c o) -cheapReduceName v = - case eqColorRep @c @(AtomNameC r) of - Just ColorsEqual -> - lookupEnv v >>= \case - AtomNameBinding binding -> cheapReduceAtomBinding v binding - Nothing -> stuck - where stuck = return $ Rename v - -cheapReduceAtomBinding - :: forall r i o. IRRep r - => AtomName r o -> AtomBinding r o -> CheapReducerM r i o (AtomSubstVal (AtomNameC r) o) -cheapReduceAtomBinding v = \case - LetBound (DeclBinding _ e) -> do - cachedVal <- lookupCache v >>= \case - Nothing -> do - result <- optional (dropSubst $ cheapReduceE e) - updateCache v result - return result - Just result -> return result - case cachedVal of - Nothing -> stuck - Just ans -> return $ SubstVal ans - _ -> stuck - where stuck = return $ Rename v - -class CheaplyReducibleE (r::IR) (e::E) (e'::E) | e -> e', e -> r where - cheapReduceE :: e i -> CheapReducerM r i o (e' o) - -instance IRRep r => CheaplyReducibleE r (Atom r) (Atom r) where - cheapReduceE :: forall i o. Atom r i -> CheapReducerM r i o (Atom r o) - cheapReduceE a = confuseGHC >>= \_ -> case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - Lam _ -> substM a - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - Con con -> Con <$> traverseOp con cheapReduceE cheapReduceE (error "unexpected lambda") - DictCon t d -> do - t' <- cheapReduceE t - cheapReduceDictExpr t' d - SimpInCore (LiftSimp t x) -> do - t' <- cheapReduceE t - x' <- substM x - liftSimpAtom t' x' - -- These two are a special-case hack. TODO(dougalm): write a traversal over - -- the NewtypeTyCon (or types generally) - NewtypeCon NatCon n -> NewtypeCon NatCon <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where - cheapReduceE :: forall i o. Type r i -> CheapReducerM r i o (Type r o) - cheapReduceE a = case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - TabPi (TabPiType d (b:>t) resultTy) -> do - t' <- cheapReduceE t - d' <- cheapReduceE d - withFreshBinder (getNameHint b) t' \b' -> do - resultTy' <- extendSubst (b@>Rename (binderName b')) $ cheapReduceE resultTy - return $ TabPi $ TabPiType d' b' resultTy' - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - NewtypeTyCon (Fin n) -> NewtypeTyCon . Fin <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -cheapReduceDictExpr :: CType o -> DictExpr i -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceDictExpr resultTy d = case d of - SuperclassProj child superclassIx -> do - cheapReduceE child >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName - let InstanceBody superclasses _ = body - instantiate (Abs bs (superclasses !! superclassIx)) args' - child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx - InstantiatedGiven f xs -> - reduceApp <|> justSubst - where reduceApp = do - f' <- cheapReduceE f - xs' <- mapM cheapReduceE (toList xs) - cheapReduceApp f' xs' - InstanceDict _ _ -> justSubst - IxFin _ -> justSubst - DataData ty -> DictCon resultTy . DataData <$> cheapReduceE ty - where justSubst = DictCon resultTy <$> substM d - -instance CheaplyReducibleE CoreIR TyConParams TyConParams where - cheapReduceE (TyConParams infs ps) = - TyConParams infs <$> mapM cheapReduceE ps - -instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where - cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result - -instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where - cheapReduceE expr = confuseGHC >>= \_ -> case expr of - Atom atom -> cheapReduceE atom - App _ f' xs' -> do - xs <- mapM cheapReduceE xs' - f <- cheapReduceE f' - cheapReduceApp f xs - -- TODO: Make sure that this wraps correctly - -- TODO: Other casts? - PrimOp (MiscOp (CastOp ty' val')) -> do - ty <- cheapReduceE ty' - case ty of - BaseTy (Scalar Word32Type) -> do - val <- cheapReduceE val' - case val of - Con (Lit (Word64Lit v)) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v - _ -> empty - _ -> empty - ApplyMethod _ dict i explicitArgs -> do - explicitArgs' <- mapM cheapReduceE explicitArgs - cheapReduceE dict >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - def <- lookupInstanceDef instanceName - withInstantiated def args' \(PairE _ (InstanceBody _ methods)) -> do - method' <- cheapReduceE $ methods !! i - cheapReduceApp method' explicitArgs' - _ -> empty +type ReducerM = SubstReaderT AtomSubstVal (EnvReaderT FallibleM) + +liftReducerM :: EnvReader m => ReducerM n n a -> m n (Maybe a) +liftReducerM cont = do + liftM (ignoreExcept . runFallibleM) $ liftEnvReaderT $ runSubstReaderT idSubst do + (Just <$> cont) <|> return Nothing + +reduceWithDeclsM :: IRRep r => Nest (Decl r) i i' -> ReducerM i' o a -> ReducerM i o a +reduceWithDeclsM Empty cont = cont +reduceWithDeclsM (Nest (Let b (DeclBinding _ expr)) rest) cont = do + x <- reduceExprM expr + extendSubst (b@>SubstVal x) $ reduceWithDeclsM rest cont + +reduceBlockM :: IRRep r => Block r i -> ReducerM i o (Atom r o) +reduceBlockM (Abs decls x) = reduceWithDeclsM decls $ substM x + +reduceExprM :: IRRep r => Expr r i -> ReducerM i o (Atom r o) +reduceExprM = \case + Atom x -> substM x + App _ f xs -> mapM substM xs >>= reduceApp f + Unwrap _ x -> substM x >>= reduceUnwrapM + Project _ i x -> substM x >>= reduceProjM i + ApplyMethod _ dict i explicitArgs -> do + explicitArgs' <- mapM substM explicitArgs + dict' <- substM dict + case dict' of + DictCon (InstanceDict _ instanceName args) -> dropSubst do + def <- lookupInstanceDef instanceName + withInstantiated def args \(PairE _ (InstanceBody _ methods)) -> do + reduceApp (methods !! i) explicitArgs' + _ -> empty + PrimOp (MiscOp (CastOp ty' val')) -> do + ty <- substM ty' + val <- substM val' + case (ty, val) of + (BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v + _ -> empty + TopApp _ _ _ -> empty + TabApp _ _ _ -> empty + Case _ _ _ -> empty + TabCon _ _ _ -> empty + PrimOp _ -> empty + +reduceApp :: CAtom i -> [CAtom o] -> ReducerM i o (CAtom o) +reduceApp f xs = do + f' <- substM f -- TODO: avoid double-subst + case f' of + Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceBlockM body + -- TODO: check ultrapure + Var v -> lookupAtomName (atomVarName v) >>= \case + LetBound (DeclBinding _ (Atom f'')) -> dropSubst $ reduceApp f'' xs + _ -> empty _ -> empty -cheapReduceApp :: CAtom o -> [CAtom o] -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceApp f xs = case f of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> cheapReduceE body - _ -> empty - -instance IRRep r => CheaplyReducibleE r (IxType r) (IxType r) where - cheapReduceE (IxType t d) = IxType <$> cheapReduceE t <*> cheapReduceE d - -instance IRRep r => CheaplyReducibleE r (IxDict r) (IxDict r) where - cheapReduceE = \case - IxDictAtom x -> IxDictAtom <$> cheapReduceE x - IxDictRawFin n -> IxDictRawFin <$> cheapReduceE n - IxDictSpecialized t d xs -> - IxDictSpecialized <$> cheapReduceE t <*> substM d <*> mapM cheapReduceE xs - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (PairE e1 e2) (PairE e1' e2') where - cheapReduceE (PairE e1 e2) = PairE <$> cheapReduceE e1 <*> cheapReduceE e2 - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (EitherE e1 e2) (EitherE e1' e2') where - cheapReduceE (LeftE e) = LeftE <$> cheapReduceE e - cheapReduceE (RightE e) = RightE <$> cheapReduceE e - -instance CheaplyReducibleE r e e' => CheaplyReducibleE r (ListE e) (ListE e') where - cheapReduceE (ListE xs) = ListE <$> mapM cheapReduceE xs - --- XXX: TODO: figure out exactly what our normalization invariants are. We --- shouldn't have to choose `normalizeProj` or `asNaryProj` on a --- case-by-case basis. This is here for now because it makes it easier to switch --- to the new version of `ProjectElt`. -asNaryProj :: IRRep r => Projection -> Atom r n -> (NE.NonEmpty Projection, AtomVar r n) -asNaryProj p (Var v) = (p NE.:| [], v) -asNaryProj p1 (ProjectElt _ p2 x) = do - let (p2' NE.:| ps, v) = asNaryProj p2 x - (p1 NE.:| (p2':ps), v) -asNaryProj p x = error $ "Can't normalize projection: " ++ pprint p ++ " " ++ pprint x - --- assumes the atom is already normalized -normalizeNaryProj :: IRRep r => EnvReader m => [Projection] -> Atom r n -> m n (Atom r n) -normalizeNaryProj [] x = return x -normalizeNaryProj (i:is) x = normalizeProj i =<< normalizeNaryProj is x - --- assumes the atom itself is already normalized -normalizeProj :: IRRep r => EnvReader m => Projection -> Atom r n -> m n (Atom r n) -normalizeProj UnwrapNewtype atom = case atom of - NewtypeCon _ x -> return x - SimpInCore (LiftSimp (NewtypeTyCon t) x) -> do - t' <- snd <$> unwrapNewtypeType t - return $ SimpInCore $ LiftSimp t' x - x -> case getType x of - NewtypeTyCon t -> do - t' <- snd <$> unwrapNewtypeType t - return $ ProjectElt t' UnwrapNewtype x - _ -> error "expected a newtype" -normalizeProj (ProjectProduct i) atom = case atom of - Con (ProdCon xs) -> return $ xs !! i +reduceProjM :: IRRep r => Int -> Atom r o -> ReducerM i o (Atom r o) +reduceProjM i x = case x of + ProdVal xs -> return $ xs !! i DepPair l _ _ | i == 0 -> return l DepPair _ r _ | i == 1 -> return r - SimpInCore (LiftSimp _ x) -> do - x' <- normalizeProj (ProjectProduct i) x - resultTy <- getResultTy - return $ SimpInCore $ LiftSimp resultTy x' + SimpInCore (LiftSimp _ simpAtom) -> do + simpAtom' <- dropSubst $ reduceProjM i simpAtom + resultTy <- getResultType + return $ SimpInCore $ LiftSimp resultTy simpAtom' RepValAtom (RepVal _ tree) -> case tree of Branch trees -> do - resultTy <- getResultTy + resultTy <- getResultType repValAtom $ RepVal resultTy (trees!!i) Leaf _ -> error "unexpected leaf" - _ -> do - resultTy <- getResultTy - return $ ProjectElt resultTy (ProjectProduct i) atom - where - getResultTy = projType i (getType atom) atom -{-# INLINE normalizeProj #-} - --- === lifting imp to simp and simp to core === + Stuck e -> do + resultTy <- getResultType + return $ Stuck $ StuckProject resultTy i e + _ -> empty + where getResultType = projType i x + +reduceSuperclassProjM :: Int -> CAtom o -> ReducerM i o (CAtom o) +reduceSuperclassProjM superclassIx dict = case dict of + DictCon (InstanceDict _ instanceName args) -> dropSubst do + args' <- mapM substM args + InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName + let InstanceBody superclasses _ = body + instantiate (Abs bs (superclasses !! superclassIx)) args' + Stuck child' -> do + resultTy <- superclassProjType superclassIx (getType dict) + return $ Stuck $ SuperclassProj resultTy superclassIx child' + _ -> error "invalid superclass projection" + +reduceInstantiateGivenM :: CAtom o -> [CAtom o] -> ReducerM i o (CAtom o) +reduceInstantiateGivenM f xs = case f of + Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceBlockM body + Stuck f' -> do + resultTy <- typeOfApp (getType f) xs + return $ Stuck $ InstantiatedGiven resultTy f' xs + _ -> error "bad instantiation" + +projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n) +projType i x = case getType x of + ProdTy xs -> return $ xs !! i + DepPairTy t | i == 0 -> return $ depPairLeftTy t + DepPairTy t | i == 1 -> do + liftReducerM (reduceProjM 0 x) >>= \case + Just xFst -> instantiate t [xFst] + Nothing -> err + _ -> err + where err = error $ "Can't project type: " ++ pprint (getType x) + +superclassProjType :: EnvReader m => Int -> CType n -> m n (CType n) +superclassProjType i (DictTy (DictType _ className params)) = do + ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params +superclassProjType _ _ = error "bad superclass projection" + +typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) +typeOfApp (Pi piTy) xs = withSubstReaderT $ + withInstantiated piTy xs \(EffTy _ ty) -> substM ty +typeOfApp _ _ = error "expected a pi type" repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) repValAtom (RepVal ty tree) = case ty of @@ -377,55 +209,23 @@ repValAtom (RepVal ty tree) = case ty of malformed = error "malformed repval" {-# INLINE repValAtom #-} -liftSimpType :: EnvReader m => SType n -> m n (CType n) -liftSimpType = \case - BaseTy t -> return $ BaseTy t - ProdTy ts -> ProdTy <$> mapM rec ts - SumTy ts -> SumTy <$> mapM rec ts - t -> error $ "not implemented: " ++ pprint t - where rec = liftSimpType -{-# INLINE liftSimpType #-} - -liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) -liftSimpAtom ty simpAtom = case simpAtom of - Var _ -> justLift - ProjectElt _ _ _ -> justLift - RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? - _ -> do - (cons , ty') <- unwrapLeadingNewtypesType ty - atom <- case (ty', simpAtom) of - (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v - (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs - (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do - x1' <- rec t1 x1 - t2' <- applySubst (b@>SubstVal x1') t2 - x2' <- rec t2' x2 - return $ DepPair x1' x2' dpt - _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' - return $ wrapNewtypesData cons atom - where - rec = liftSimpAtom - justLift = return $ SimpInCore $ LiftSimp ty simpAtom -{-# INLINE liftSimpAtom #-} - -liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) -liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f -liftSimpFun _ _ = error "not a pi type" - --- See Note [Confuse GHC] from Simplify.hs -confuseGHC :: IRRep r => CheapReducerM r i n (DistinctEvidence n) -confuseGHC = getDistinct -{-# INLINE confuseGHC #-} - --- TODO: These used to live in QueryType. Think about a better way to organize --- them. Maybe a common set of low-level type-querying utils that both --- CheapReduction and QueryType import? - depPairLeftTy :: DepPairType r n -> Type r n depPairLeftTy (DepPairType _ (_:>ty) _) = ty {-# INLINE depPairLeftTy #-} +reduceUnwrapM :: IRRep r => Atom r o -> ReducerM i o (Atom r o) +reduceUnwrapM = \case + NewtypeCon _ x -> return x + SimpInCore (LiftSimp (NewtypeTyCon t) x) -> do + t' <- snd <$> unwrapNewtypeType t + return $ SimpInCore $ LiftSimp t' x + Stuck e -> case getType e of + NewtypeTyCon t -> do + t' <- snd <$> unwrapNewtypeType t + return $ Stuck $ StuckUnwrap t' e + _ -> error "expected a newtype" + _ -> empty + unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) unwrapNewtypeType = \case Nat -> return (NatCon, IdxRepTy) @@ -437,27 +237,6 @@ unwrapNewtypeType = \case ty -> error $ "Shouldn't be projecting: " ++ pprint ty {-# INLINE unwrapNewtypeType #-} -projType :: (IRRep r, EnvReader m) => Int -> Type r n -> Atom r n -> m n (Type r n) -projType i ty x = case ty of - ProdTy xs -> return $ xs !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x - instantiate t [xFst] - _ -> error $ "Can't project type: " ++ pprint ty - -unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) -unwrapLeadingNewtypesType = \case - NewtypeTyCon tyCon -> do - (dataCon, ty) <- unwrapNewtypeType tyCon - (dataCons, ty') <- unwrapLeadingNewtypesType ty - return (dataCon:dataCons, ty') - ty -> return ([], ty) - -wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n -wrapNewtypesData [] x = x -wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x - instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs @@ -508,15 +287,6 @@ dataDefRep (StructFields fields) = case map snd fields of [ty] -> ty tys -> ProdTy tys -makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) -makeStructRepVal tyConName args = do - TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName - case fields of - [_] -> case args of - [arg] -> return arg - _ -> error "wrong number of args" - _ -> return $ ProdVal args - -- === traversable terms === class Monad m => NonAtomRenamer m i o | m -> i, m -> o where @@ -558,17 +328,15 @@ visitAtomDefault :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) => Atom r i -> m i o (Atom r o) visitAtomDefault atom = case atom of - Var _ -> atomSubstM atom + Stuck _ -> atomSubstM atom SimpInCore _ -> atomSubstM atom - ProjectElt t i x -> ProjectElt <$> visitType t <*> pure i <*> visitGeneric x _ -> visitAtomPartial atom visitTypeDefault :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) => Type r i -> m i o (Type r o) -visitTypeDefault = \case - TyVar v -> atomSubstM $ TyVar v - ProjectEltTy t i x -> ProjectEltTy <$> visitType t <*> pure i <*> visitGeneric x +visitTypeDefault ty = case ty of + StuckTy _ -> atomSubstM ty x -> visitTypePartial x visitPiDefault @@ -592,15 +360,13 @@ visitBinders (Nest (b:>ty) bs) cont = do visitBinders bs \bs' -> cont $ Nest b' bs' --- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These --- should be handled explicitly beforehand. TODO: split out these cases under a --- separate constructor, perhaps even a `hole` paremeter to `Atom` or part of --- `IR`. +-- XXX: This doesn't handle the `Stuck` or `SimpInCore` cases. These should be +-- handled explicitly beforehand. TODO: split out these cases under a separate +-- constructor, perhaps even a `hole` paremeter to `Atom` or part of `IR`. visitAtomPartial :: (IRRep r, Visitor m r i o) => Atom r i -> m (Atom r o) visitAtomPartial = \case - Var _ -> error "Not handled generically" + Stuck _ -> error "Not handled generically" SimpInCore _ -> error "Not handled generically" - ProjectElt _ _ _ -> error "Not handled generically" Con con -> Con <$> visitGeneric con PtrVar t v -> PtrVar t <$> renameN v DepPair x y t -> do @@ -610,17 +376,16 @@ visitAtomPartial = \case return $ DepPair x' y' t' Lam lam -> Lam <$> visitGeneric lam Eff eff -> Eff <$> visitGeneric eff - DictCon t d -> DictCon <$> visitType t <*> visitGeneric d + DictCon d -> DictCon <$> visitGeneric d NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x TypeAsAtom t -> TypeAsAtom <$> visitGeneric t RepValAtom repVal -> RepValAtom <$> visitGeneric repVal --- XXX: This doesn't handle the `TyVar` or `ProjectEltTy` cases. These should be --- handled explicitly beforehand. +-- XXX: This doesn't handle the `Stuck` case. It should be handled explicitly +-- beforehand. visitTypePartial :: (IRRep r, Visitor m r i o) => Type r i -> m (Type r o) visitTypePartial = \case - TyVar _ -> error "Not handled generically" - ProjectEltTy _ _ _ -> error "Not handled generically" + StuckTy _ -> error "Not handled generically" NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t Pi t -> Pi <$> visitGeneric t TabPi t -> TabPi <$> visitGeneric t @@ -646,6 +411,8 @@ instance IRRep r => VisitGeneric (Expr r) r where PrimOp op -> PrimOp <$> visitGeneric op App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs + Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x + Unwrap t x -> Unwrap <$> visitGeneric t <*> visitGeneric x instance IRRep r => VisitGeneric (PrimOp r) r where visitGeneric = \case @@ -703,13 +470,11 @@ instance IRRep r => VisitGeneric (EffectRow r) r where _ -> error "Not a valid effect substitution" return $ extendEffRow effs' tailEffRow -instance VisitGeneric DictExpr CoreIR where +instance VisitGeneric DictCon CoreIR where visitGeneric = \case - InstantiatedGiven x xs -> InstantiatedGiven <$> visitGeneric x <*> mapM visitGeneric xs - SuperclassProj x i -> SuperclassProj <$> visitGeneric x <*> pure i - InstanceDict v xs -> InstanceDict <$> renameN v <*> mapM visitGeneric xs - IxFin x -> IxFin <$> visitGeneric x - DataData t -> DataData <$> visitGeneric t + InstanceDict t v xs -> InstanceDict <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs + IxFin t x -> IxFin <$> visitGeneric t <*> visitGeneric x + DataData t dataTy -> DataData <$> visitGeneric t <*> visitGeneric dataTy instance VisitGeneric NewtypeCon CoreIR where visitGeneric = \case @@ -797,11 +562,6 @@ instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm -- === SubstE/SubstB instances === -- These live here, as orphan instances, because we normalize as we substitute. -toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) -toAtomVar v = do - ty <- getType <$> lookupAtomName v - return $ AtomVar v ty - bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n] bindersToVars bs = do withExtEvidence bs do @@ -833,29 +593,42 @@ instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where instance SubstV (SubstVal Atom) (SubstVal Atom) where instance IRRep r => SubstE AtomSubstVal (Atom r) where - substE es@(env, subst) = \case - Var (AtomVar v ty) -> case subst!v of - Rename v' -> Var $ AtomVar v' (substE es ty) - SubstVal x -> x - SimpInCore x -> SimpInCore (substE es x) - ProjectElt _ i x -> do - let x' = substE es x - runEnvReaderM env $ normalizeProj i x' + substE es = \case + Stuck e -> substStuck es e + SimpInCore x -> SimpInCore (substE es x) atom -> runReader (runSubstVisitor $ visitAtomPartial atom) es instance IRRep r => SubstE AtomSubstVal (Type r) where - substE es@(env, subst) = \case - TyVar (AtomVar v ty) -> case subst ! v of - Rename v' -> TyVar $ AtomVar v' (substE es ty) - SubstVal (Type t) -> t - SubstVal atom -> error $ "bad substitution: " ++ pprint v ++ " -> " ++ pprint atom - ProjectEltTy _ i x -> do - let x' = substE es x - case runEnvReaderM env $ normalizeProj i x' of - Type t -> t - _ -> error "bad substitution" + substE es = \case + StuckTy e -> case substStuck es e of + Type t -> t + _ -> error "bad substitution" ty -> runReader (runSubstVisitor $ visitTypePartial ty) es +substStuck :: (IRRep r, Distinct o) => (Env o, Subst AtomSubstVal i o) -> Stuck r i -> Atom r o +substStuck (env, subst) stuck = + ignoreExcept $ runFallibleM $ runEnvReaderT env $ runSubstReaderT subst $ reduceStuck stuck + +reduceStuck :: (IRRep r, Distinct o) => Stuck r i -> ReducerM i o (Atom r o) +reduceStuck = \case + StuckVar (AtomVar v ty) -> do + lookupSubstM v >>= \case + Rename v' -> Var . AtomVar v' <$> substM ty + SubstVal x -> return x + StuckProject _ i x -> do + x' <- reduceStuck x + dropSubst $ reduceProjM i x' + StuckUnwrap _ x -> do + x' <- reduceStuck x + dropSubst $ reduceUnwrapM x' + InstantiatedGiven _ f xs -> do + xs' <- mapM substM xs + f' <- reduceStuck f + reduceInstantiateGivenM f' xs' + SuperclassProj _ superclassIx child -> do + child' <- reduceStuck child + reduceSuperclassProjM superclassIx child' + instance SubstE AtomSubstVal SimpInCore instance IRRep r => SubstE AtomSubstVal (EffectRow r) where @@ -904,7 +677,6 @@ instance IRRep r => SubstE AtomSubstVal (Expr r) instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) instance SubstE AtomSubstVal InstanceBody instance SubstE AtomSubstVal DictType -instance SubstE AtomSubstVal DictExpr instance IRRep r => SubstE AtomSubstVal (LamExpr r) instance SubstE AtomSubstVal CorePiType instance SubstE AtomSubstVal CoreLamExpr diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 14573bf26..24d96fb53 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -91,11 +91,9 @@ checkTypesEq :: IRRep r => Type r o -> Type r o -> TyperM r i o () checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case True -> return () False -> {-# SCC typeNormalization #-} do - reqTy' <- cheapNormalize reqTy - ty' <- cheapNormalize ty - alphaEq reqTy' ty' >>= \case + alphaEq reqTy ty >>= \case True -> return () - False -> throw TypeErr $ pprint reqTy' ++ " != " ++ pprint ty' + False -> throw TypeErr $ pprint reqTy ++ " != " ++ pprint ty {-# INLINE checkTypesEq #-} class SinkableE e => CheckableE (r::IR) (e::E) | e -> r where @@ -156,12 +154,7 @@ instance IRRep r => CheckableE r (AtomName r) where instance IRRep r => CheckableE r (Atom r) where checkE = \case - Var name -> do - name' <- checkE name - case getType name' of - RawRefTy _ -> affineUsed $ atomVarName name' - _ -> return () - return $ Var name' + Stuck e -> Stuck <$> checkE e Lam lam -> Lam <$> checkE lam DepPair l r ty -> do l' <- checkE l @@ -173,31 +166,13 @@ instance IRRep r => CheckableE r (Atom r) where Eff eff -> Eff <$> checkE eff PtrVar t v -> PtrVar t <$> renameM v -- TODO: check against cached type - DictCon ty dictExpr -> DictCon <$> checkE ty <*> checkE dictExpr + DictCon con -> DictCon <$> checkE con RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check NewtypeCon con x -> do (x', xTy) <- checkAndGetType x con' <- typeCheckNewtypeCon con xTy return $ NewtypeCon con' x' SimpInCore x -> SimpInCore <$> checkE x - ProjectElt resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - (x', NewtypeTyCon con) <- checkAndGetType x - resultTy'' <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' UnwrapNewtype x' - ProjectElt resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', xTy) <- checkAndGetType x - resultTy'' <- case xTy of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - checkInstantiation t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' (ProjectProduct i) x' TypeAsAtom ty -> TypeAsAtom <$> checkE ty instance IRRep r => CheckableE r (AtomVar r) where @@ -221,26 +196,7 @@ instance IRRep r => CheckableE r (Type r) where params' <- mapM checkE params void $ checkInstantiation (Abs paramBs UnitE) params' return $ DictTy (DictType sn className' params') - TyVar v -> TyVar <$> checkE v - ProjectEltTy resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - x' <- checkE x - NewtypeTyCon con <- return $ getType x' - ty <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' ty - return $ ProjectEltTy resultTy' UnwrapNewtype x' - ProjectEltTy resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', ty) <- checkAndGetType x - resultTy'' <- case ty of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - instantiate t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint ty - checkTypesEq resultTy' resultTy'' - return $ ProjectEltTy resultTy' (ProjectProduct i) x' + StuckTy e -> StuckTy <$> checkE e instance CheckableE CoreIR SimpInCore where checkE x = renameM x -- TODO: check @@ -316,26 +272,70 @@ instance IRRep r => CheckableWithEffects r (Expr r) where -- each index from the ix dict. HoistFailure _ -> forM xs checkE return $ TabCon maybeD' ty' xs' + Project resultTy i x -> do + resultTy' <- resultTy |: TyKind + (x', xTy) <- checkAndGetType x + resultTy'' <- case xTy of + ProdTy tys -> return $ tys !! i + DepPairTy t | i == 0 -> return $ depPairLeftTy t + DepPairTy t | i == 1 -> do + xFst <- reduceProj 0 x' + checkInstantiation t [xFst] + _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy + checkTypesEq resultTy' resultTy'' + return $ Project resultTy' i x' + Unwrap resultTy x -> do + resultTy' <- resultTy |: TyKind + (x', NewtypeTyCon con) <- checkAndGetType x + resultTy'' <- snd <$> unwrapNewtypeType con + checkTypesEq resultTy' resultTy'' + return $ Unwrap resultTy' x' instance CheckableE CoreIR TyConParams where checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params -instance CheckableE CoreIR DictExpr where +instance IRRep r => CheckableE r (Stuck r) where checkE = \case - InstanceDict instanceName args -> do + StuckVar name -> do + name' <- checkE name + case getType name' of + RawRefTy _ -> affineUsed $ atomVarName name' + _ -> return () + return $ StuckVar name' + StuckUnwrap resultTy x -> do + Unwrap resultTy' (Stuck x') <- checkWithEffects Pure $ Unwrap resultTy (Stuck x) + return $ StuckUnwrap resultTy' x' + StuckProject resultTy i x -> do + Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x) + return $ StuckProject resultTy' i' x' + InstantiatedGiven resultTy given args -> do + resultTy' <- resultTy |: TyKind + (given', Pi piTy) <- checkAndGetType given + args' <- mapM checkE args + EffTy Pure ty <- checkInstantiation piTy args' + checkTypesEq resultTy' ty + return $ InstantiatedGiven resultTy' given' args' + SuperclassProj t i d -> SuperclassProj <$> checkE t <*> pure i <*> checkE d -- TODO: check index in range + +depPairLeftTy :: DepPairType r n -> Type r n +depPairLeftTy (DepPairType _ (_:>ty) _) = ty +{-# INLINE depPairLeftTy #-} + +instance CheckableE CoreIR DictCon where + checkE = \case + InstanceDict ty instanceName args -> do + ty' <- ty |: TyKind instanceName' <- renameM instanceName args' <- mapM checkE args instanceDef <- lookupInstanceDef instanceName' void $ checkInstantiation instanceDef args' - return $ InstanceDict instanceName' args' - InstantiatedGiven given args -> do - (given', Pi piTy) <- checkAndGetType given - args' <- mapM checkE args - EffTy Pure _ <- checkInstantiation piTy args' - return $ InstantiatedGiven given' args' - SuperclassProj d i -> SuperclassProj <$> checkE d <*> pure i -- TODO: check index in range - IxFin n -> IxFin <$> n |: NatTy - DataData ty -> DataData <$> ty |: TyKind + return $ InstanceDict ty' instanceName' args' + IxFin ty n -> do + ty' <- ty |: TyKind + IxFin ty' <$> n |: NatTy + DataData ty dataTy -> do + ty' <- ty |: TyKind + DataData ty' <$> dataTy |: TyKind instance IRRep r => CheckableE r (DepPairType r) where checkE (DepPairType expl b ty) = do @@ -532,13 +532,13 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where x' <- checkE x y' <- y |: getType x' return $ Select p' x' y' - CastOp t@(TyVar _) e -> CastOp <$> (t|:TyKind) <*> renameM e + CastOp t@(StuckTy (StuckVar _)) e -> CastOp <$> (t|:TyKind) <*> renameM e CastOp destTy e -> do e' <- checkE e destTy' <- destTy |: TyKind checkValidCast (getType e') destTy' return $ CastOp destTy' e' - BitcastOp t@(TyVar _) e -> BitcastOp <$> (t|:TyKind) <*> renameM e + BitcastOp t@(StuckTy (StuckVar _)) e -> BitcastOp <$> (t|:TyKind) <*> renameM e BitcastOp destTy e -> do destTy' <- destTy |: TyKind e' <- checkE e diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index bd3c462a2..7bf22d6ef 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -9,8 +9,7 @@ {-# OPTIONS_GHC -Wno-orphans #-} module Imp - ( toImpFunction - , impFunType, getIType, abstractLinktimeObjects + ( toImpFunction, repValAtom, impFunType, getIType, abstractLinktimeObjects , repValFromFlatList, addImpTracing -- These are just for the benefit of serialization/printing. otherwise we wouldn't need them , BufferType (..), IdxNest, IndexStructure, IExprInterpretation (..), typeToTree @@ -331,6 +330,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of void $ translateBlock body return UnitVal TabCon _ _ _ -> error "Unexpected `TabCon` in Imp pass." + Project _ i x -> reduceProj i =<< substM x toImpRefOp :: Emits o => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o) @@ -873,25 +873,13 @@ atomToRepVal x = RepVal (getType x) <$> go x where else buildGarbageVal t <&> \(RepValAtom (RepVal _ tree)) -> tree return $ Branch $ tag':xs Con HeapVal -> return $ Branch [] - Var v -> lookupAtomName (atomVarName v) >>= \case + PtrVar ty p -> return $ Leaf $ IPtrVar p ty + Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" - PtrVar ty p -> return $ Leaf $ IPtrVar p ty - ProjectElt _ p val -> do - (ps, v) <- return $ asNaryProj p val - lookupAtomName (atomVarName v) >>= \case - TopDataBound (RepVal _ tree) -> applyProjection (toList ps) tree - _ -> error "should only be projecting a data atom" - where - applyProjection :: [Projection] -> Tree (IExpr n) -> SubstImpM i n (Tree (IExpr n)) - applyProjection [] t = return t - applyProjection (i:is) t = do - t' <- applyProjection is t - case i of - UnwrapNewtype -> error "impossible" - ProjectProduct idx -> case t' of - Branch ts -> return $ ts !! idx - _ -> error "should only be projecting a branch" + Stuck (StuckProject _ i val) -> do + Branch ts <- go $ Stuck val + return $ ts !! i -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 373844202..981519268 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -52,7 +52,7 @@ checkTopUType ty = liftInfererM $ checkUType ty {-# SCC checkTopUType #-} inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) -inferTopUExpr e = asTopBlock =<< liftInfererM (buildScoped $ bottomUp e) +inferTopUExpr e = fst <$> (asTopBlock =<< liftInfererM (buildScoped $ bottomUp e)) {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = @@ -114,22 +114,23 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d WithSrcB _ (UPatBinder b) -> do block <- liftInfererM $ buildScoped do checkMaybeAnnExpr tyAnn rhs - topBlock <- asTopBlock block - return $ UDeclResultBindName letAnn topBlock (Abs b result) + (topBlock, resultTy) <- asTopBlock block + let letAnn' = considerInlineAnn letAnn resultTy + return $ UDeclResultBindName letAnn' topBlock (Abs b result) _ -> do PairE block recon <- liftInfererM $ buildBlockInfWithRecon do val <- checkMaybeAnnExpr tyAnn rhs v <- emitHinted (getNameHint p) $ Atom val bindLetPat p v do renameM result - topBlock <- asTopBlock block + (topBlock, _) <- asTopBlock block return $ UDeclResultBindPattern (getNameHint p) topBlock recon {-# SCC inferTopUDecl #-} -asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n) +asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n, CType n) asTopBlock block = do - effTy <- blockEffTy block - return $ TopLam False (PiType Empty effTy) (LamExpr Empty block) + effTy@(EffTy _ ty) <- blockEffTy block + return (TopLam False (PiType Empty effTy) (LamExpr Empty block), ty) getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do @@ -584,7 +585,7 @@ forceSigmaAtom :: Emits o => SigmaAtom o -> InfererM i o (CAtom o) forceSigmaAtom sigmaAtom = case sigmaAtom of SigmaAtom _ x -> return x SigmaUVar _ _ v -> case v of - UAtomVar v' -> Var <$> toAtomVar v' + UAtomVar v' -> inlineTypeAliases v' _ -> applySigmaAtom sigmaAtom [] SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? @@ -613,9 +614,14 @@ withUDecl (WithSrcB src d) cont = addSrcContext src case d of UExprDecl e -> withDistinct $ bottomUp e >> cont ULet letAnn p ann rhs -> do val <- checkMaybeAnnExpr ann rhs - var <- emitDecl (getNameHint p) letAnn $ Atom val + let letAnn' = considerInlineAnn letAnn (getType val) + var <- emitDecl (getNameHint p) letAnn' $ Atom val bindLetPat p var cont +considerInlineAnn :: LetAnn -> CType n -> LetAnn +considerInlineAnn PlainLet TyKind = InlineLet +considerInlineAnn ann _ = ann + applyFromLiteralMethod :: Emits n => CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) applyFromLiteralMethod resultTy methodName litVal = @@ -683,7 +689,7 @@ getFieldDefs ty = case ty of projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of - ProdTy _ -> projectTuple i x + ProdTy _ -> proj i x NewtypeTyCon _ -> projectStruct i x RefTy _ valTy -> case valTy of ProdTy _ -> getProjRef (ProjectProduct i) x @@ -693,31 +699,33 @@ projectField i x = case getType x of where bad = error $ "bad projection: " ++ pprint (i, x) class PrettyE e => ExplicitArg (e::E) where - checkExplicitArg :: Emits o => IsDependent -> e i -> PartialType o -> InfererM i o (CAtom o) + checkExplicitNonDependentArg :: Emits o => e i -> PartialType o -> InfererM i o (CAtom o) + checkExplicitDependentArg :: e i -> PartialType o -> InfererM i o (CAtom o) inferExplicitArg :: Emits o => e i -> InfererM i o (CAtom o) isHole :: e n -> Bool instance ExplicitArg UExpr where - checkExplicitArg isDependent arg argTy = do - if isDependent - then checkSigmaDependent arg argTy -- do we actually need this? - else topDownPartial argTy arg - + checkExplicitDependentArg arg argTy = checkSigmaDependent arg argTy + checkExplicitNonDependentArg arg argTy = topDownPartial argTy arg inferExplicitArg arg = bottomUp arg isHole = \case WithSrcE _ UHole -> True _ -> False instance ExplicitArg CAtom where - checkExplicitArg _ arg argTy = do - arg' <- renameM arg - case argTy of - FullType argTy' -> expectEq argTy' (getType arg') - PartialType _ -> return () -- TODO? - return arg' + checkExplicitDependentArg = checkCAtom + checkExplicitNonDependentArg = checkCAtom inferExplicitArg arg = renameM arg isHole _ = False +checkCAtom :: CAtom i -> PartialType o -> InfererM i o (CAtom o) +checkCAtom arg argTy = do + arg' <- renameM arg + case argTy of + FullType argTy' -> expectEq argTy' (getType arg') + PartialType _ -> return () -- TODO? + return arg' + checkOrInferApp :: forall i o arg . (Emits o, ExplicitArg arg) => SigmaAtom o -> [arg i] -> [(SourceName, arg i)] @@ -737,7 +745,7 @@ checkOrInferApp f' posArgs namedArgs reqTy = do fDesc :: SourceName fDesc = getSourceName f' -buildAppConstraints :: RequiredTy n -> CorePiType n -> InfererM i n (Abs (Nest CBinder) Constraints n) +buildAppConstraints :: RequiredTy n -> CorePiType n -> InfererM i n (ConstrainedBinders n) buildAppConstraints reqTy (CorePiType _ _ bs effTy) = do effsAllowed <- infEffects <$> getInfState buildConstraints (Abs bs effTy) \_ (EffTy effs resultTy) -> do @@ -759,12 +767,18 @@ maybeInterpretPunsAsTyCons _ x = return x type IsDependent = Bool +inlineTypeAliases :: CAtomName n -> InfererM i n (CAtom n) +inlineTypeAliases v = do + lookupAtomName v >>= \case + LetBound (DeclBinding InlineLet (Atom e)) -> return e + _ -> Var <$> toAtomVar v + applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) applySigmaAtom (SigmaAtom _ f) args = emitExprWithEffects =<< mkApp f args applySigmaAtom (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do - f'' <- toAtomVar f' - emitExprWithEffects =<< mkApp (Var f'') args + f'' <- inlineTypeAliases f' + emitExprWithEffects =<< mkApp f'' args UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls @@ -847,42 +861,48 @@ data Constraint (n::S) = TypeConstraint (CType n) (CType n) -- permitted effects (no inference vars), proposed effects | EffectConstraint (EffectRow CoreIR n) (EffectRow CoreIR n) + type Constraints = ListE Constraint +type ConstrainedBinders n = ([IsDependent], Abs (Nest CBinder) Constraints n) buildConstraints - :: RenameE e + :: HasNamesE e => Abs (Nest CBinder) e o -> (forall o'. DExt o o' => [CAtom o'] -> e o' -> EnvReaderM o' [Constraint o']) - -> InfererM i o (Abs (Nest CBinder) Constraints o) + -> InfererM i o (ConstrainedBinders o) buildConstraints ab cont = liftEnvReaderM do refreshAbs ab \bs e -> do cs <- cont (Var <$> bindersVars bs) e - return $ Abs bs $ ListE cs + return (getDependence (Abs bs e), Abs bs $ ListE cs) + where + getDependence :: HasNamesE e => Abs (Nest CBinder) e n -> [IsDependent] + getDependence (Abs Empty _) = [] + getDependence (Abs (Nest b bs) e) = + (binderName b `isFreeIn` Abs bs e) : getDependence (Abs bs e) -- TODO: check that there are no extra named args provided inferMixedArgs :: forall arg i o . (Emits o, ExplicitArg arg) => SourceName - -> [Explicitness] -> Abs (Nest CBinder) Constraints o + -> [Explicitness] -> ConstrainedBinders o -> MixedArgs (arg i) -> InfererM i o [CAtom o] -inferMixedArgs fSourceName explsTop bsAbs argsTop@(_, namedArgsTop) = do +inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgsTop) = do checkNamedArgValidity explsTop (map fst namedArgsTop) - liftSolverM $ fromListE <$> go explsTop bsAbs argsTop + liftSolverM $ fromListE <$> go explsTop dependenceTop bsAbs argsTop where go :: Emits oo - => [Explicitness] -> Abs (Nest CBinder) Constraints oo -> MixedArgs (arg i) + => [Explicitness] -> [IsDependent] -> Abs (Nest CBinder) Constraints oo -> MixedArgs (arg i) -> SolverM i oo (ListE CAtom oo) - go expls (Abs bs cs) args = do + go expls dependence (Abs bs cs) args = do cs' <- eagerlyApplyConstraints bs cs - case (expls, bs) of - ([], Empty) -> return mempty - (expl:explsRest, Nest b bsRest) -> do - let isDependent = binderName b `isFreeIn` Abs bsRest cs' + case (expls, dependence, bs) of + ([], [], Empty) -> return mempty + (expl:explsRest, isDependent:dependenceRest, Nest b bsRest) -> do inferMixedArg isDependent (binderType b) expl args \arg restArgs -> do bs' <- applySubst (b @> SubstVal arg) (Abs bsRest cs') - (ListE [arg] <>) <$> go explsRest bs' restArgs - (_, _) -> error "zip error" + (ListE [arg] <>) <$> go explsRest dependenceRest bs' restArgs + (_, _, _) -> error "zip error" eagerlyApplyConstraints :: Nest CBinder oo oo' -> Constraints oo' @@ -932,8 +952,10 @@ inferMixedArgs fSourceName explsTop bsAbs argsTop@(_, namedArgsTop) = do checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) checkOrInferExplicitArg isDependent arg argTy = do arg' <- lift11 $ withoutInfVarsPartial argTy >>= \case - Just partialTy -> checkExplicitArg isDependent arg partialTy - Nothing -> inferExplicitArg arg + Just partialTy -> case isDependent of + True -> checkExplicitDependentArg arg partialTy + False -> checkExplicitNonDependentArg arg partialTy + Nothing -> inferExplicitArg arg constrainTypesEq argTy (getType arg') return arg' @@ -973,7 +995,7 @@ inferPrimArg x = do xBlock <- buildScoped $ bottomUp x EffTy _ ty <- blockEffTy xBlock case ty of - TyKind -> cheapReduce xBlock >>= \case + TyKind -> reduceBlock xBlock >>= \case Just reduced -> return reduced _ -> throw CompilerErr "Type args to primops must be reducible" _ -> emitBlock xBlock @@ -1061,7 +1083,7 @@ inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of return $ arg':rest' _ -> throw TypeErr $ "Expected a table type but got: " ++ pprint tabTy -checkSigmaDependent :: Emits o => UExpr i -> PartialType o -> InfererM i o (CAtom o) +checkSigmaDependent :: UExpr i -> PartialType o -> InfererM i o (CAtom o) checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ withReducibleEmissions depFunErrMsg $ topDownPartial (sink ty) e where @@ -1070,13 +1092,13 @@ checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ "Bind the argument to a name before you apply the function." withReducibleEmissions - :: ( Zonkable e, CheaplyReducibleE CoreIR e e, SubstE AtomSubstVal e) + :: Zonkable e => String -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) -> InfererM i o (e o) withReducibleEmissions msg cont = do - Abs decls result <- buildScoped cont - cheapReduceWithDecls decls result >>= \case + withDecls <- buildScoped cont + reduceWithDecls withDecls >>= \case Just t -> return t _ -> throw TypeErr msg @@ -1139,7 +1161,7 @@ instanceFun instanceName appExpl = do InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' - result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) + result <- DictCon <$> mkInstanceDict (sink instanceName) (Var <$> args) let effTy = EffTy Pure (getType result) let body = WithoutDecls result let piTy = CorePiType appExpl (snd<$>expls) bs' effTy @@ -1490,7 +1512,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do buildScoped do args <- forM idxs \projs -> do - ans <- normalizeNaryProj (init projs) (sink $ Var $ binderVar b) + ans <- applyProjectionsReduced (init projs) (sink $ Var $ binderVar b) emit $ Atom ans bindLetPats ps args $ cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" @@ -1529,16 +1551,17 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of case getType v of ProdTy ts | length ts == n -> return () ty -> throw TypeErr $ "Expected a product type but got: " ++ pprint ty - xs <- forM (iota n) \i -> do - normalizeProj (ProjectProduct i) (Var v) >>= emit . Atom + xs <- forM (iota n) \i -> proj i (Var v) >>= emitInline bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do case getType v of DepPairTy _ -> return () ty -> throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty - x1 <- getFst (Var v) >>= emit . Atom + -- XXX: we're careful here to reduce the projection because of the dependent + -- types. We do the same in the `UPatCon` case. + x1 <- reduceProj 0 (Var v) >>= emitInline bindLetPat p1 x1 do - x2 <- getSnd (sink $ Var v) >>= emit . Atom + x2 <- getSnd (sink $ Var v) >>= emitInline bindLetPat p2 x2 do cont UPatCon ~(InternalName _ _ conName) ps -> do @@ -1550,13 +1573,12 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of "Unexpected number of pattern binders. Expected " ++ show (length idxss) ++ " got " ++ show (nestLength ps) void $ inferParams (getType $ Var v) dataDefName - x <- cheapNormalize $ Var v - xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom + xs <- forM idxss \idxs -> applyProjectionsReduced idxs (Var v) >>= emitInline bindLetPats ps xs cont _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" UPatTable ps -> do let n = fromIntegral (nestLength ps) :: Word32 - cheapNormalize (getType v) >>= \case + case getType v of TabPi (TabPiType _ (_:>FinConst n') _) | n == n' -> return () ty -> throw TypeErr $ "Expected a Fin " ++ show n ++ " table type but got: " ++ pprint ty xs <- forM [0 .. n - 1] \i -> do @@ -1868,6 +1890,17 @@ renameForPrinting e = do e' <- applyRename (bsAbs@@>nestToNames bs') eAbs return $ Abs bs' e' +-- === builder and type querying helpers === + +makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) +makeStructRepVal tyConName args = do + TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName + case fields of + [_] -> case args of + [arg] -> return arg + _ -> error "wrong number of args" + _ -> return $ ProdVal args + -- === dictionary synthesis === -- Given a simplified dict (an Atom of type `DictTy _` in the @@ -1886,27 +1919,20 @@ generalizeDict ty dict = do Success ans -> return ans generalizeDictRec :: CType n -> Dict n -> InfererM i n (Dict n) -generalizeDictRec targetTy dict = do - -- TODO: we should be able to avoid the normalization here . We only need it - -- because we sometimes end up with superclass projections. But they shouldn't - -- really be allowed to occur in the post-simplification IR. - DictCon _ dict' <- cheapNormalize dict - case dict' of - InstanceDict instanceName args -> do - InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName - liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do - d <- mkDictAtom $ InstanceDict (sink instanceName) args' - constrainEq (sink $ Type targetTy) (Type $ getType d) - return d - IxFin _ -> case targetTy of - DictTy (DictType "Ix" _ [Type (NewtypeTyCon (Fin n))]) -> mkDictAtom $ IxFin n - _ -> error $ "not an Ix(Fin _) dict: " ++ pprint targetTy - InstantiatedGiven _ _ -> notSimplifiedDict - SuperclassProj _ _ -> notSimplifiedDict - DataData _ -> case targetTy of - DictTy (DictType "Data" _ [Type t]) -> mkDictAtom $ DataData t - _ -> error "not a data dict" - where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict +generalizeDictRec targetTy (DictCon dict) = case dict of + InstanceDict _ instanceName args -> do + InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName + liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do + d <- DictCon <$> mkInstanceDict (sink instanceName) args' + constrainEq (sink $ Type targetTy) (Type $ getType d) + return d + IxFin _ _ -> case targetTy of + DictTy (DictType "Ix" _ [Type (NewtypeTyCon (Fin n))]) -> DictCon <$> mkIxFin n + _ -> error $ "not an Ix(Fin _) dict: " ++ pprint targetTy + DataData _ _ -> case targetTy of + DictTy (DictType "Data" _ [Type t]) -> DictCon <$> mkDataData t + _ -> error "not a data dict" +generalizeDictRec _ _ = error "not a simplified dict" generalizeInstanceArgs :: Zonkable e => [RoleExpl] -> Nest CBinder o any -> [CAtom o] @@ -2030,8 +2056,7 @@ getSuperclassClosurePure env givens newGivens = superclasses <- case synthTy of SynthPiType _ -> return [] SynthDictType dTy -> getSuperclassTys dTy - forM (enumerate superclasses) \(i, ty) -> do - return $ DictCon ty $ SuperclassProj synthExpr i + forM (enumerate superclasses) \(i, _) -> reduceSuperclassProj i synthExpr synthTerm :: SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of @@ -2043,14 +2068,14 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of let lamExpr = LamExpr bs' (WithoutDecls synthExpr) return $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of - DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n + DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon $ IxFin (DictTy dictTy) n DictType "Data" _ [Type t] -> do void (synthDictForData dictTy <|> synthDictFromGiven dictTy) - return $ DictCon (DictTy dictTy) $ DataData t + return $ DictCon $ DataData (DictTy dictTy) t _ -> do dict <- synthDictFromInstance dictTy <|> synthDictFromGiven dictTy case dict of - DictCon _ (InstanceDict instanceName _) -> do + DictCon (InstanceDict _ instanceName _) -> do isReqMethodAccessAllowed <- reqMethodAccess `isMethodAccessAllowedBy` instanceName if isReqMethodAccessAllowed then return dict @@ -2078,7 +2103,7 @@ synthDictFromGiven targetTy = do return given SynthPiType givenPiTy -> do args <- instantiateSynthArgs targetTy givenPiTy - return $ DictCon (DictTy targetTy) $ InstantiatedGiven given args + reduceInstantiateGiven given args synthDictFromInstance :: DictType n -> InfererM i n (SynthAtom n) synthDictFromInstance targetTy@(DictType _ targetClass _) = do @@ -2086,7 +2111,7 @@ synthDictFromInstance targetTy@(DictType _ targetClass _) = do asum $ instances <&> \candidate -> typeErrAsSearchFailure do CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) - return $ DictCon (DictTy targetTy) $ InstanceDict candidate args + return $ DictCon $ InstanceDict (DictTy targetTy) candidate args getInstanceDicts :: EnvReader m => ClassName n -> m n [InstanceName n] getInstanceDicts name = do @@ -2138,7 +2163,7 @@ synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of ans <- synthDictForData $ DictType "Data" (sink dName) [Type body'] return $ ignoreHoistFailure $ hoist b' ans notData = empty - success = return $ DictCon (DictTy dictTy) $ DataData ty + success = return $ DictCon $ DataData (DictTy dictTy) ty synthDictForData dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy instance GenericE Givens where diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 36cfc1c8e..1e6ff6655 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -6,8 +6,6 @@ module Inline (inlineBindings) where -import Data.List.NonEmpty qualified as NE - import Builder import Core import Err @@ -149,11 +147,13 @@ inlineDeclsSubst = \case inlineDeclsSubst rest where dropOccInfo PlainLet = PlainLet + dropOccInfo InlineLet = InlineLet dropOccInfo NoInlineLet = NoInlineLet dropOccInfo (OccInfoPure _) = PlainLet dropOccInfo (OccInfoImpure _) = PlainLet resolveWorkConservation PlainLet _ = NoInline -- No occurrence info, assume the worst + resolveWorkConservation InlineLet _ = NoInline resolveWorkConservation NoInlineLet _ = NoInline -- Quick hack to always unconditionally inline renames, until we get -- a better story about measuring the sizes of atoms and expressions. @@ -227,6 +227,7 @@ inlineDeclsSubst = \case preInlineUnconditionally :: LetAnn -> Bool preInlineUnconditionally = \case PlainLet -> False -- "Missing occurrence annotation" + InlineLet -> True NoInlineLet -> False OccInfoPure (UsageInfo s (0, d)) | s <= One && d <= One -> True OccInfoPure _ -> False @@ -280,10 +281,9 @@ inlineExpr ctx = \case inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o) inlineAtom ctx = \case - Var name -> inlineName ctx name - ProjectElt _ i x -> do - let (idxs, v) = asNaryProj i x - ans <- normalizeNaryProj (NE.toList idxs) =<< inline Stop (Var v) + Stuck (StuckVar name) -> inlineName ctx name + Stuck (StuckProject _ i x) -> do + ans <- proj i =<< inline Stop (Stuck x) reconstruct ctx $ Atom ans atom -> (Atom <$> visitAtomPartial atom) >>= reconstruct ctx diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index d32b5230a..80986e383 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -17,7 +17,6 @@ import GHC.Stack import Builder import Core -import CheapReduction import Imp import IRVariants import MTL1 @@ -326,20 +325,15 @@ linearizeLambdaApp _ _ = error "not implemented" linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom linearizeAtom atom = case atom of - Var v -> do + Con con -> linearizePrimCon con + DepPair _ _ _ -> notImplemented + PtrVar _ _ -> emitZeroT + Stuck (StuckVar v) -> do v' <- renameM v activePrimalIdx v' >>= \case Nothing -> withZeroT $ return (Var v') Just idx -> return $ WithTangent (Var v') $ getTangentArg idx - Con con -> linearizePrimCon con - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> emitZeroT - ProjectElt _ i x -> do - WithTangent x' tx <- linearizeAtom x - xi <- normalizeProj i x' - return $ WithTangent xi do - t <- tx - normalizeProj i t + Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x) RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom @@ -428,6 +422,12 @@ linearizeExpr expr = case expr of ty' <- renameM ty seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> emitExpr $ TabCon Nothing (sink ty') xs' + Project _ i x -> do + WithTangent x' tx <- linearizeAtom x + xi <- proj i x' + return $ WithTangent xi do + t <- tx + proj i t linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 5b8456ff6..96a1ca8e8 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -143,10 +143,10 @@ lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do (i, destProd) <- fromPair $ Var b' - dest <- normalizeProj (ProjectProduct 0) destProd + dest <- proj 0 destProd idest <- emitOp =<< mkIndexRef dest i extendSubst (ib @> SubstVal i) $ lowerBlockWithDest idest body $> UnitVal - ans <- emitSeq dir ixTy' initDest body' >>= getProj 0 + ans <- emitSeq dir ixTy' initDest body' >>= proj 0 return $ PrimOp $ DAMOp $ Freeze ans lowerFor _ _ _ _ _ = error "expected a unary lambda expression" @@ -221,7 +221,7 @@ type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (ProjDest o) i' data ProjDest o = FullDest (Dest SimpIR o) - | ProjDest (NE.NonEmpty Projection) (Dest SimpIR o) -- dest corresponds to the projection applied to name + | ProjDest (NE.NonEmpty Int) (Dest SimpIR o) -- dest corresponds to the projection applied to name deriving (Show) instance SinkableE ProjDest where @@ -239,11 +239,19 @@ lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- XXX: When adding more cases, be careful about potentially repeated vars in the output! decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o)) decomposeDest dest = \case - Var v -> return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest - ProjectElt _ p x -> do + Stuck (StuckVar v) -> + return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest + Stuck (StuckProject _ p x) -> do (ps, v) <- return $ asNaryProj p x return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ ProjDest ps dest _ -> return Nothing + where + asNaryProj :: IRRep r => Int -> Stuck r n -> (NE.NonEmpty Int, AtomVar r n) + asNaryProj p (StuckVar v) = (p NE.:| [], v) + asNaryProj p1 (StuckProject _ p2 x) = do + let (p2' NE.:| ps, v) = asNaryProj p2 x + (p1 NE.:| (p2':ps), v) + asNaryProj _ _ = error $ "Can't normalize projection" lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) lowerBlockWithDest dest (Abs decls ans) = do @@ -345,16 +353,20 @@ lowerExprWithDest dest expr = case expr of bd <- getProjRef (ProjectProduct 0) fd rd <- getProjRef (ProjectProduct 1) fd return $ Just (Just bd, Just rd) - ProjDest (ProjectProduct 0 NE.:| []) pd -> return $ Just (Just pd, Nothing) - ProjDest (ProjectProduct 1 NE.:| []) pd -> return $ Just (Nothing, Just pd) + ProjDest (0 NE.:| []) pd -> return $ Just (Just pd, Nothing) + ProjDest (1 NE.:| []) pd -> return $ Just (Nothing, Just pd) ProjDest _ _ -> return Nothing place :: Emits o => ProjDest o -> SAtom o -> LowerM i o () place pd x = case pd of FullDest d -> void $ emitOp $ DAMOp $ Place d x ProjDest p d -> do - x' <- normalizeNaryProj (NE.toList p) x + x' <- applyProjs (NE.toList p) x void $ emitOp $ DAMOp $ Place d x' + where + applyProjs :: Emits n => [Int] -> SAtom n -> LowerM i n (SAtom n) + applyProjs [] atom = return atom + applyProjs (p:ps) atom = proj p =<< applyProjs ps atom -- === Extensions to the name system === diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 2e9b3d9aa..d80a161f0 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -249,12 +249,16 @@ class HasOCC (e::E) where instance HasOCC SAtom where occ a = \case - Var (AtomVar n ty) -> do + Stuck e -> Stuck <$> occ a e + atom -> runOCCMVisitor a $ visitAtomPartial atom + +instance HasOCC SStuck where + occ a = \case + StuckVar (AtomVar n ty) -> do modify (<> FV (singletonNameMapE n $ AccessInfo One a)) ty' <- occTy ty - return $ Var (AtomVar n ty') - ProjectElt t i x -> ProjectElt <$> occ a t <*> pure i <*> occ a x - atom -> runOCCMVisitor a $ visitAtomPartial atom + return $ StuckVar (AtomVar n ty') + StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x instance HasOCC SType where occ a ty = runOCCMVisitor a $ visitTypePartial ty diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 65d81b043..f8d2537fc 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -330,8 +330,8 @@ licmExpr = \case extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab body' <- withFreshBinder noHint lbTy \lb' -> do - (oldIx, allCarries) <- fromPair $ Var $ binderVar lb' - (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries + (oldIx, allCarries) <- fromPairReduced $ Var $ binderVar lb' + (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpackedReduced allCarries let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal block <- applySubst s bodyAbs @@ -425,8 +425,7 @@ instance Color c => HasDCE (Name c) where instance HasDCE SAtom where dce = \case - Var n -> modify (<> FV (freeVarsE n)) $> Var n - ProjectElt t i x -> ProjectElt <$> dce t <*> pure i <*> dce x + Stuck e -> modify (<> FV (freeVarsE e)) $> Stuck e atom -> visitAtomPartial atom instance HasDCE SType where dce = visitTypePartial diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index e223e2155..421d49a9c 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -161,14 +161,17 @@ instance PrettyE ann => Pretty (BinderP c ann n l) instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Expr r n) where - prettyPrec (Atom x) = prettyPrec x - prettyPrec (App _ f xs) = atPrec AppPrec $ pApp f <+> spaced (toList xs) - prettyPrec (TopApp _ f xs) = atPrec AppPrec $ pApp f <+> spaced (toList xs) - prettyPrec (TabApp _ f xs) = atPrec AppPrec $ pApp f <> "." <> dotted (toList xs) - prettyPrec (Case e alts (EffTy effs _)) = prettyPrecCase "case" e alts effs - prettyPrec (TabCon _ _ es) = atPrec ArgPrec $ list $ pApp <$> es - prettyPrec (PrimOp op) = prettyPrec op - prettyPrec (ApplyMethod _ d i xs) = atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + prettyPrec = \case + Atom x -> prettyPrec x + App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TabApp _ f xs -> atPrec AppPrec $ pApp f <> "." <> dotted (toList xs) + Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs + TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es + PrimOp op -> prettyPrec op + ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x + Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann prettyPrecCase name e alts effs = atPrec LowestPrec $ @@ -210,13 +213,11 @@ instance IRRep r => PrettyPrec (LamExpr r n) where instance IRRep r => Pretty (IxType r n) where pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict -instance Pretty (DictExpr n) where +instance Pretty (DictCon n) where pretty d = case d of - InstanceDict name args -> "Instance" <+> p name <+> p args - InstantiatedGiven v args -> "Given" <+> p v <+> p (toList args) - SuperclassProj d' i -> "SuperclassProj" <+> p d' <+> p i - IxFin n -> "Ix (Fin" <+> p n <> ")" - DataData a -> "Data " <+> p a + InstanceDict _ name args -> "Instance" <+> p name <+> p args + IxFin _ n -> "Ix (Fin" <+> p n <> ")" + DataData _ a -> "Data " <+> p a instance IRRep r => Pretty (IxDict r n) where pretty = \case @@ -239,16 +240,15 @@ instance Pretty (CoreLamExpr n) where instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Atom r n) where prettyPrec atom = case atom of - Var v -> atPrec ArgPrec $ p v + Stuck e -> prettyPrec e Lam lam -> atPrec LowestPrec $ p lam DepPair x y _ -> atPrec ArgPrec $ align $ group $ parens $ p x <+> ",>" <+> p y Con e -> prettyPrec e Eff e -> atPrec ArgPrec $ p e PtrVar _ v -> atPrec ArgPrec $ p v - DictCon _ d -> atPrec LowestPrec $ p d + DictCon d -> atPrec LowestPrec $ p d RepValAtom x -> atPrec LowestPrec $ pretty x - ProjectElt _ idxs v -> atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v NewtypeCon con x -> prettyPrecNewtype con x SimpInCore x -> prettyPrec x TypeAsAtom ty -> prettyPrec ty @@ -262,9 +262,16 @@ instance IRRep r => PrettyPrec (Type r n) where TC e -> prettyPrec e DictTy t -> atPrec LowestPrec $ p t NewtypeTyCon con -> prettyPrec con - TyVar v -> atPrec ArgPrec $ p v - ProjectEltTy _ idxs v -> - atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v + StuckTy e -> prettyPrec e + +instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Stuck r n) where + prettyPrec = \case + StuckVar v -> atPrec ArgPrec $ p v + StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v + InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) + SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i instance Pretty (SimpInCore n) where pretty = prettyFromPrettyPrec instance PrettyPrec (SimpInCore n) where @@ -977,6 +984,7 @@ instance Pretty CallingConvention where instance Pretty LetAnn where pretty ann = case ann of PlainLet -> "" + InlineLet -> "%inline" NoInlineLet -> "%noinline" OccInfoPure u -> p u <> line OccInfoImpure u -> p u <> ", impure" <> line diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 031d688b8..98c1b65f7 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -72,11 +72,6 @@ piTypeWithoutDest (PiType bsRefB _) = blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff -typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi piTy) xs = withSubstReaderT $ - withInstantiated piTy xs \(EffTy _ ty) -> substM ty -typeOfApp _ _ = error "expected a pi type" - typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) typeOfTabApp t [] = return t typeOfTabApp (TabPi tabTy) (i:rest) = do @@ -89,22 +84,6 @@ typeOfApplyMethod d i args = do ty <- Pi <$> getMethodType d i appEffTy ty args -typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n) -typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of - InstanceDict instanceName args -> do - instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName - sourceName <- getSourceName <$> lookupClassDef className - PairE (ListE params) _ <- instantiate instanceDef args - return $ DictTy $ DictType sourceName className params - InstantiatedGiven given args -> typeOfApp (getType given) args - SuperclassProj d i -> do - DictTy (DictType _ className params) <- return $ getType d - classDef <- lookupClassDef className - withSubstReaderT $ withInstantiated classDef params \(Abs superclasses _) -> do - substM $ getSuperclassType REmpty superclasses i - IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n - DataData ty -> DictTy <$> dataDictType ty - typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) typeOfTopApp f xs = do piTy <- getTypeTopFun f @@ -350,7 +329,7 @@ getSuperclassDicts dict = do case getType dict of DictTy dTy -> do ts <- getSuperclassTys dTy - forM (enumerate ts) \(i, t) -> return $ DictCon t $ SuperclassProj dict i + forM (enumerate ts) \(i, _) -> reduceSuperclassProj i dict _ -> error "expected a dict type" getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 258b5f5c0..89eb158ca 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -70,19 +70,24 @@ instance IRRep r => HasType r (AtomVar r) where instance IRRep r => HasType r (Atom r) where getType atom = case atom of - Var name -> getType name + Stuck e -> getType e Lam (CoreLamExpr piTy _) -> Pi piTy DepPair _ _ ty -> DepPairTy ty Con con -> getType con Eff _ -> EffKind PtrVar t _ -> PtrTy t - DictCon ty _ -> ty + DictCon d -> getType d NewtypeCon con _ -> getNewtypeType con RepValAtom (RepVal ty _) -> ty - ProjectElt t _ _ -> t SimpInCore x -> getType x TypeAsAtom ty -> getType ty +instance HasType CoreIR DictCon where + getType = \case + InstanceDict t _ _ -> t + IxFin t _ -> t + DataData t _ -> t + instance IRRep r => HasType r (Type r) where getType = \case NewtypeTyCon con -> getType con @@ -91,8 +96,15 @@ instance IRRep r => HasType r (Type r) where DepPairTy _ -> TyKind TC _ -> TyKind DictTy _ -> TyKind - TyVar v -> getType v - ProjectEltTy t _ _ -> t + StuckTy e -> getType e + +instance IRRep r => HasType r (Stuck r) where + getType = \case + StuckVar (AtomVar _ t) -> t + StuckProject t _ _ -> t + StuckUnwrap t _ -> t + InstantiatedGiven t _ _ -> t + SuperclassProj t _ _ -> t instance HasType CoreIR SimpInCore where getType = \case @@ -133,6 +145,8 @@ instance IRRep r => HasType r (Expr r) where PrimOp op -> getType op Case _ _ (EffTy _ resultTy) -> resultTy ApplyMethod (EffTy _ t) _ _ _ -> t + Project t _ _ -> t + Unwrap t _ -> t instance IRRep r => HasType r (DAMOp r) where getType = \case @@ -267,6 +281,8 @@ instance IRRep r => HasEffects (Expr r) r where TabCon _ _ _ -> Pure ApplyMethod (EffTy eff _) _ _ _ -> eff PrimOp primOp -> getEffects primOp + Project _ _ _ -> Pure + Unwrap _ _ -> Pure instance IRRep r => HasEffects (DeclBinding r) r where getEffects (DeclBinding _ expr) = getEffects expr diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 4a4c2c6a5..64e87ee04 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -78,7 +78,6 @@ showAnyRec atom = case getType atom of parens $ sepBy ", " $ map rec xs -- TODO: traverse the type and print out data components TypeKind -> printAsConstant - ProjectEltTy _ _ _ -> error "not implemented" Pi _ -> printTypeOnly "function" TabPi _ -> brackets $ forEachTabElt atom \iOrd x -> do isFirst <- ieq iOrd (NatVal 0) @@ -94,7 +93,7 @@ showAnyRec atom = case getType atom of EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype UserADTType "List" _ (TyConParams [Explicit] [Type Word8Ty]) -> do - charTab <- normalizeNaryProj [ProjectProduct 1, UnwrapNewtype] atom + charTab <- applyProjections [ProjectProduct 1, UnwrapNewtype] atom emitCharLit '"' emitCharTab charTab emitCharLit '"' @@ -121,14 +120,14 @@ showAnyRec atom = case getType atom of sepBy " " $ projss <&> \projs -> -- we use `init` to strip off the `UnwrapCompoundNewtype` since -- we're already under the case alternative - rec =<< normalizeNaryProj (init projs) arg + rec =<< applyProjections (init projs) arg DepPairTy _ -> parens do (x, y) <- fromPair atom rec x >> emitLit " ,> " >> rec y -- Done well, this could let you inspect the results of dictionary synthesis -- and maybe even debug synthesis failures. DictTy _ -> printAsConstant - TyVar v -> error $ "unexpected type variable: " ++ pprint v + StuckTy e -> error $ "unexpected stuck type expression: " ++ pprint e where rec :: Emits n' => CAtom n' -> Print n' rec = showAnyRec @@ -202,7 +201,7 @@ stringLitAsCharTab s = do finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) finTabTyCore n eltTy = do - d <- mkDictAtom $ IxFin n + d <- DictCon <$> mkIxFin n return $ IxType (FinTy n) (IxDictAtom d) ==> eltTy getPreludeFunction :: EnvReader m => String -> m n (CAtom n) diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index c3d46e4a8..9139af597 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -73,9 +73,15 @@ tryAsDataAtom atom = do where go :: Emits n => CAtom n -> SimplifyM i n (SAtom n) go = \case - Var v -> lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> go x - _ -> error "Shouldn't have irreducible top names left" + Stuck e -> case e of + StuckVar v -> lookupAtomName (atomVarName v) >>= \case + LetBound (DeclBinding _ (Atom x)) -> go x + _ -> error "Shouldn't have irreducible top names left" + StuckUnwrap _ x -> go (Stuck x) + -- TODO: do we need to think about a case like `fst (1, \x.x)`, where + -- the projection is data but the argument isn't? + StuckProject _ i x -> reduceProj i =<< go (Stuck x) + _ -> notData Con con -> Con <$> case con of Lit v -> return $ Lit v ProdCon xs -> ProdCon <$> mapM go xs @@ -85,10 +91,6 @@ tryAsDataAtom atom = do DepPair x y ty -> do DepPairTy ty' <- getRepType $ DepPairTy ty DepPair <$> go x <*> go y <*> pure ty' - ProjectElt _ UnwrapNewtype x -> go x - -- TODO: do we need to think about a case like `fst (1, \x.x)`, where - -- the projection is data but the argument isn't? - ProjectElt _ (ProjectProduct i) x -> normalizeProj (ProjectProduct i) =<< go x NewtypeCon _ x -> go x SimpInCore x -> case x of LiftSimp _ x' -> return x' @@ -96,7 +98,7 @@ tryAsDataAtom atom = do TabLam _ tabLam -> forceTabLam tabLam ACase scrut alts resultTy -> forceACase scrut alts resultTy Lam _ -> notData - DictCon _ _ -> notData + DictCon _ -> notData Eff _ -> notData TypeAsAtom _ -> notData where @@ -164,8 +166,7 @@ getRepType ty = go ty where go ty' Pi _ -> error notDataType DictTy _ -> error notDataType - TyVar _ -> error "Shouldn't have type variables in CoreIR IR with SimpIR builder names" - ProjectEltTy _ _ _ -> error "Shouldn't have this left" + StuckTy _ -> error "Shouldn't have stuck expressions in CoreIR IR with SimpIR builder names" where notDataType = "Not a type of runtime-representable data: " ++ pprint ty toDataAtom :: Emits n => CAtom n -> SimplifyM i n (SAtom n, Type CoreIR n) @@ -341,6 +342,18 @@ simplifyExpr hint expr = confuseGHC >>= \_ -> case expr of defuncCaseCore scrut' resultTy' \i x -> do Abs b body <- return $ alts !! i extendSubst (b@>SubstVal x) $ simplifyBlock body + Project ty i x -> do + ty' <- substM ty + x' <- substM x + tryAsDataAtom x' >>= \case + Just (x'', _) -> liftSimpAtom ty' =<< proj i x'' + Nothing -> requireReduced $ Project ty' i x' + Unwrap _ _ -> requireReduced =<< substM expr + +requireReduced :: CExpr o -> SimplifyM i o (CAtom o) +requireReduced expr = reduceExpr expr >>= \case + Just x -> return x + Nothing -> error "couldn't reduce expression" simplifyRefOp :: Emits o => RefOp CoreIR i -> SAtom o -> SimplifyM i o (SAtom o) simplifyRefOp op ref = case op of @@ -486,8 +499,8 @@ simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of -- This is a hack because we weren't normalize the unwrapping of -- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to -- normalize and inline. - ProjectElt _ i x -> do - x' <- simplifyAtom x >>= normalizeProj i + Stuck (StuckProject _ i x) -> do + x' <- simplifyStuck x >>= reduceProj i dropSubst $ simplifyAtomAndInline x' _ -> simplifyAtom atom >>= \case Var v -> doInline v @@ -572,11 +585,11 @@ simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) simplifyIxType (IxType t ixDict) = do t' <- getRepType t IxType t' <$> case ixDict of - IxDictAtom (DictCon _ (IxFin n)) -> do + IxDictAtom (DictCon (IxFin _ n)) -> do n' <- toDataAtomIgnoreReconAssumeNoDecls n return $ IxDictRawFin n' IxDictAtom d -> do - (dictAbs, params) <- generalizeIxDict =<< cheapNormalize d + (dictAbs, params) <- generalizeIxDict d params' <- mapM toDataAtomIgnoreReconAssumeNoDecls params sdName <- requireIxDictCache dictAbs return $ IxDictSpecialized t' sdName params' @@ -621,18 +634,23 @@ ixMethodType method absDict = do -- TODO: do we even need this, or is it just a glorified `SubstM`? simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o) simplifyAtom atom = confuseGHC >>= \_ -> case atom of - Var v -> simplifyVar v + Stuck e -> simplifyStuck e Lam _ -> substM atom DepPair x y ty -> DepPair <$> simplifyAtom x <*> simplifyAtom y <*> substM ty Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda") Eff eff -> Eff <$> substM eff PtrVar t v -> PtrVar t <$> substM v - DictCon t d -> (DictCon <$> substM t <*> substM d) >>= cheapNormalize + DictCon _ -> substM atom NewtypeCon _ _ -> substM atom - ProjectElt _ i x -> normalizeProj i =<< simplifyAtom x SimpInCore _ -> substM atom TypeAsAtom _ -> substM atom +simplifyStuck :: CStuck i -> SimplifyM i o (CAtom o) +simplifyStuck = \case + StuckVar v -> simplifyVar v + StuckProject _ i x -> reduceProj i =<< simplifyStuck x + stuck -> substM (Stuck stuck) + simplifyVar :: AtomVar CoreIR i -> SimplifyM i o (CAtom o) simplifyVar v = do env <- getSubst @@ -680,12 +698,12 @@ splitDataComponents = \case { dataTy = ProdTy $ map dataTy splits , nonDataTy = ProdTy $ map nonDataTy splits , toSplit = \xProd -> do - xs <- getUnpacked xProd + xs <- getUnpackedReduced xProd (ys, zs) <- unzip <$> forM (zip xs splits) \(x, split) -> toSplit split x return (ProdVal ys, ProdVal zs) , fromSplit = \xsProd ysProd -> do - xs <- getUnpacked xsProd - ys <- getUnpacked ysProd + xs <- getUnpackedReduced xsProd + ys <- getUnpackedReduced ysProd zs <- forM (zip (zip xs ys) splits) \((x, y), split) -> fromSplit split x y return $ ProdVal zs } ty -> tryGetRepType ty >>= \case @@ -779,23 +797,22 @@ pattern CoerceReconAbs :: Abs (Nest b) ReconstructAtom n pattern CoerceReconAbs <- Abs _ (CoerceRecon _) applyDictMethod :: Emits o => CType o -> CAtom o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) -applyDictMethod resultTy d i methodArgs = do - cheapNormalize d >>= \case - DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do - instanceArgs' <- mapM simplifyAtom instanceArgs - instanceDef <- lookupInstanceDef instanceName - withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do - let InstanceBody _ methods = body - let method = methods !! i - simplifyApp noHint resultTy method methodArgs - DictCon _ (IxFin n) -> applyIxFinMethod (toEnum i) n methodArgs - d' -> error $ "Not a simplified dict: " ++ pprint d' +applyDictMethod resultTy d i methodArgs = case d of + DictCon (InstanceDict _ instanceName instanceArgs) -> dropSubst do + instanceArgs' <- mapM simplifyAtom instanceArgs + instanceDef <- lookupInstanceDef instanceName + withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do + let InstanceBody _ methods = body + let method = methods !! i + simplifyApp noHint resultTy method methodArgs + DictCon (IxFin _ n) -> applyIxFinMethod (toEnum i) n methodArgs + d' -> error $ "Not a simplified dict: " ++ pprint d' where applyIxFinMethod :: EnvReader m => IxMethod -> CAtom n -> [CAtom n] -> m n (CAtom n) applyIxFinMethod method n args = do case (method, args) of (Size, []) -> return n -- result : Nat - (Ordinal, [ix]) -> unwrapNewtype ix -- result : Nat + (Ordinal, [ix]) -> reduceUnwrap ix -- result : Nat (UnsafeFromOrdinal, [ix]) -> return $ NewtypeCon (FinCon n) ix _ -> error "bad ix args" @@ -923,6 +940,44 @@ preludeMaybeNewtypeCon ty = do simplifyBlock :: Emits o => Block CoreIR i -> SimplifyM i o (CAtom o) simplifyBlock (Abs decls result) = simplifyDecls decls $ simplifyAtom result +liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) +liftSimpAtom ty simpAtom = case simpAtom of + Stuck _ -> justLift + RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? + _ -> do + (cons , ty') <- unwrapLeadingNewtypesType ty + atom <- case (ty', simpAtom) of + (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v + (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs + (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x + (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do + x1' <- rec t1 x1 + t2' <- applySubst (b@>SubstVal x1') t2 + x2' <- rec t2' x2 + return $ DepPair x1' x2' dpt + _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' + return $ wrapNewtypesData cons atom + where + rec = liftSimpAtom + justLift = return $ SimpInCore $ LiftSimp ty simpAtom +{-# INLINE liftSimpAtom #-} + +unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) +unwrapLeadingNewtypesType = \case + NewtypeTyCon tyCon -> do + (dataCon, ty) <- unwrapNewtypeType tyCon + (dataCons, ty') <- unwrapLeadingNewtypesType ty + return (dataCon:dataCons, ty') + ty -> return ([], ty) + +liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) +liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f +liftSimpFun _ _ = error "not a pi type" + +wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n +wrapNewtypesData [] x = x +wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x + -- === simplifying custom linearizations === linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) @@ -974,7 +1029,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do fCustom' <- sinkM fCustom resultTy <- typeOfApp (getType fCustom') staticArgs' pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' - (primalResult, fLin) <- fromPair pairResult + (primalResult, fLin) <- fromPairReduced pairResult primalResult' <- toDataAtomIgnoreRecon primalResult let explicitPrimalArgs = drop nImplicit staticArgs' allTangentTys <- forM explicitPrimalArgs \primalArg -> do diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 41bdf78e6..038362357 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -21,6 +21,7 @@ import IRVariants import Types.Core import Core import qualified RawName as R +import QueryTypePure import Err -- === SubstReader class === @@ -152,6 +153,12 @@ fromConstAbs (Abs b e) = hoist b e extendRenamer :: (SubstReader v m, FromName v) => SubstFrag Name i i' o -> m i' o r -> m i o r extendRenamer frag = extendSubst (fmapSubstFrag (const fromName) frag) +extendBinderRename + :: (SubstReader v m, FromName v, BindsAtMostOneName b c, BindsOneName b' c) + => b i i' -> b' o o' -> m i' o' r -> m i o' r +extendBinderRename b b' cont = extendSubst (b@>fromName (binderName b')) cont +{-# INLINE extendBinderRename #-} + applyRename :: (ScopeReader m, RenameE e, SinkableE e) => Ext h o => SubstFrag Name h i o -> e i -> m o (e o) @@ -272,6 +279,17 @@ instance ToSubstVal (SubstVal atom) atom where type AtomSubstReader v m = (SubstReader v m, FromName v, ToSubstVal v Atom) +toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) +toAtomVar v = do + ty <- getType <$> lookupAtomName v + return $ AtomVar v ty + +lookupAtomSubst :: (IRRep r, SubstReader AtomSubstVal m, EnvReader2 m) => AtomName r i -> m i o (Atom r o) +lookupAtomSubst v = do + lookupSubstM v >>= \case + Rename v' -> Var <$> toAtomVar v' + SubstVal x -> return x + atomSubstM :: (AtomSubstReader v m, EnvReader2 m, SinkableE e, SubstE AtomSubstVal e) => e i -> m i o (e o) atomSubstM e = do diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index f2f206038..6601e57cf 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -278,8 +278,7 @@ evalSourceBlock' mname block = case sbContents block of -- logTop $ ExportedFun name f GetType -> do -- TODO: don't actually evaluate it val <- evalUExpr expr - ty <- cheapNormalize $ getType val - logTop $ TextOut $ pprintCanonicalized ty + logTop $ TextOut $ pprintCanonicalized $ getType val DeclareForeign fname dexName cTy -> do let b = fromString dexName :: UBinder (AtomNameC CoreIR) VoidS VoidS ty <- evalUType =<< parseExpr cTy @@ -328,7 +327,7 @@ evalSourceBlock' mname block = case sbContents block of UnParseable _ s -> throw ParseErr s Misc m -> case m of GetNameType v -> do - ty <- cheapNormalize =<< sourceNameType v + ty <- sourceNameType v logTop $ TextOut $ pprintCanonicalized ty ImportModule moduleName -> importModule moduleName QueryEnv query -> void $ runEnvQuery query $> UnitE @@ -642,7 +641,7 @@ printCodegen :: (Topper m, Mut n) => CAtom n -> m n String printCodegen x = do block <- liftBuilder $ buildBlock do emitExpr $ PrimOp $ MiscOp $ ShowAny $ sink x - topBlock <- asTopBlock block + (topBlock, _) <- asTopBlock block getDexString =<< evalBlock topBlock loadObject :: (Topper m, Mut n) => FunObjCodeName n -> m n NativeFunction diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 904e608d1..908dcf4cb 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -15,7 +15,6 @@ import GHC.Stack import Builder import Core -import CheapReduction import Err import Imp import IRVariants @@ -209,22 +208,23 @@ transposeExpr expr ct = case expr of TabApp _ x is -> do is' <- mapM substNonlin is case x of - Var v -> do + Stuck (StuckVar v) -> do lookupSubstM (atomVarName v) >>= \case RenameNonlin _ -> error "shouldn't happen" LinRef ref -> do refProj <- naryIndexRef ref (toList is') emitCTToRef refProj ct LinTrivial -> return () - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - refProj <- naryIndexRef ref' (toList is') - emitCTToRef refProj ct - LinTrivial -> return () + Stuck (StuckProject _ _ _) -> undefined + -- ProjectElt _ i' x' -> do + -- let (idxs, v) = asNaryProj i' x' + -- lookupSubstM (atomVarName v) >>= \case + -- RenameNonlin _ -> error "an error, probably" + -- LinRef ref -> do + -- ref' <- getNaryProjRef (toList idxs) ref + -- refProj <- naryIndexRef ref' (toList is') + -- emitCTToRef refProj ct + -- LinTrivial -> return () _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do @@ -245,6 +245,7 @@ transposeExpr expr ct = case expr of forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e + Project _ _ _ -> undefined transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o () transposeOp op ct = case op of @@ -305,24 +306,25 @@ transposeMiscOp op _ = case op of transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o () transposeAtom atom ct = case atom of - Var v -> do + Con con -> transposeCon con ct + DepPair _ _ _ -> notImplemented + PtrVar _ _ -> notTangent + Stuck (StuckVar v) -> do lookupSubstM (atomVarName v) >>= \case RenameNonlin _ -> -- XXX: we seem to need this case, but it feels like it should be an error! return () LinRef ref -> emitCTToRef ref ct LinTrivial -> return () - Con con -> transposeCon con ct - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> notTangent - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - emitCTToRef ref' ct - LinTrivial -> return () + Stuck (StuckProject _ _ _) -> undefined + -- Stuck (StuckProject _ i' x') -> do + -- let (idxs, v) = asNaryProj i' x' + -- lookupSubstM (atomVarName v) >>= \case + -- RenameNonlin _ -> error "an error, probably" + -- LinRef ref -> do + -- ref' <- applyProjectionsRef (toList idxs) ref + -- emitCTToRef ref' ct + -- LinTrivial -> return () RepValAtom _ -> error "not implemented" where notTangent = error $ "Not a tangent atom: " ++ pprint atom @@ -366,9 +368,7 @@ transposeCon :: Emits o => Con SimpIR i -> SAtom o -> TransposeM i o () transposeCon con ct = case con of Lit _ -> return () ProdCon [] -> return () - ProdCon xs -> - forM_ (enumerate xs) \(i, x) -> - projectTuple i ct >>= transposeAtom x + ProdCon xs -> forM_ (enumerate xs) \(i, x) -> proj i ct >>= transposeAtom x SumCon _ _ _ -> notImplemented HeapVal -> notTangent where notTangent = error $ "Not a tangent atom: " ++ pprint (Con con) diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 94d2248f1..77045321a 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -48,15 +48,14 @@ import Types.Imp -- === core IR === data Atom (r::IR) (n::S) where - Var :: AtomVar r n -> Atom r n Con :: Con r n -> Atom r n + Stuck :: Stuck r n -> Atom r n PtrVar :: PtrType -> PtrName n -> Atom r n - ProjectElt :: Type r n -> Projection -> Atom r n -> Atom r n DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Atom r n -- === CoreIR only === Lam :: CoreLamExpr n -> Atom CoreIR n Eff :: EffectRow CoreIR n -> Atom CoreIR n - DictCon :: Type CoreIR n -> DictExpr n -> Atom CoreIR n + DictCon :: DictCon n -> Atom CoreIR n NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Atom CoreIR n TypeAsAtom :: Type CoreIR n -> Atom CoreIR n -- === Shims between IRs === @@ -67,14 +66,23 @@ data Type (r::IR) (n::S) where TC :: TC r n -> Type r n TabPi :: TabPiType r n -> Type r n DepPairTy :: DepPairType r n -> Type r n - TyVar :: AtomVar CoreIR n -> Type CoreIR n + StuckTy :: Stuck CoreIR n -> Type CoreIR n DictTy :: DictType n -> Type CoreIR n Pi :: CorePiType n -> Type CoreIR n NewtypeTyCon :: NewtypeTyCon n -> Type CoreIR n - -- It was bad enough having this in `Atom`, but it's even worse now that it's - -- replicated in `Type` too. We should be able to remove both once - -- we represent types as normalized blocks. - ProjectEltTy :: CType n -> Projection -> CAtom n -> Type CoreIR n + +data Stuck (r::IR) (n::S) where + StuckVar :: AtomVar r n -> Stuck r n + StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n + StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n + InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n + SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n + +pattern Var :: AtomVar r n -> Atom r n +pattern Var v = Stuck (StuckVar v) + +pattern TyVar :: AtomVar CoreIR n -> Type CoreIR n +pattern TyVar v = StuckTy (StuckVar v) data AtomVar (r::IR) (n::S) = AtomVar { atomVarName :: AtomName r n @@ -91,8 +99,11 @@ data SimpInCore (n::S) = deriving instance IRRep r => Show (Atom r n) deriving instance IRRep r => Show (Type r n) +deriving instance IRRep r => Show (Stuck r n) + deriving via WrapE (Atom r) n instance IRRep r => Generic (Atom r n) deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) +deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) data Expr r n where TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n @@ -101,7 +112,9 @@ data Expr r n where Atom :: Atom r n -> Expr r n TabCon :: Maybe (WhenCore r Dict n) -> Type r n -> [Atom r n] -> Expr r n PrimOp :: PrimOp r n -> Expr r n + Project :: Type r n -> Int -> Atom r n -> Expr r n App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n + Unwrap :: CType n -> CAtom n -> Expr CoreIR n ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n deriving instance IRRep r => Show (Expr r n) @@ -429,6 +442,7 @@ data RefOp r n = type CAtom = Atom CoreIR type CType = Type CoreIR +type CStuck = Stuck CoreIR type CBinder = Binder CoreIR type CExpr = Expr CoreIR type CBlock = Block CoreIR @@ -440,6 +454,7 @@ type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR +type SStuck = Stuck SimpIR type SExpr = Expr SimpIR type SBlock = Block SimpIR type SAlt = Alt SimpIR @@ -509,15 +524,12 @@ data InstanceBody (n::S) = data DictType (n::S) = DictType SourceName (ClassName n) [CAtom n] deriving (Show, Generic) -data DictExpr (n::S) = - InstantiatedGiven (CAtom n) [CAtom n] - | SuperclassProj (CAtom n) Int -- (could instantiate here too, but we don't need it for now) - -- We use NonEmpty because givens without args can be represented using `Var`. - | InstanceDict (InstanceName n) [CAtom n] +data DictCon (n::S) = + InstanceDict (CType n) (InstanceName n) [CAtom n] -- Special case for `Ix (Fin n)` (TODO: a more general mechanism for built-in classes and instances) - | IxFin (CAtom n) + | IxFin (CType n) (CAtom n) -- Special case for `Data ` - | DataData (CType n) + | DataData (CType n) (CType n) deriving (Show, Generic) -- TODO: Use an IntMap @@ -985,12 +997,10 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where -- XXX: only use this pattern when you're actually expecting a type. If it's -- a Var, it doesn't check whether it's a type. pattern Type :: CType n -> CAtom n -pattern Type t <- ((\case Var v -> Just (TyVar v) - ProjectElt t i x -> Just $ ProjectEltTy t i x - TypeAsAtom t -> Just t +pattern Type t <- ((\case Stuck e -> Just (StuckTy e) + TypeAsAtom t -> Just t _ -> Nothing) -> Just t) - where Type (TyVar v) = Var v - Type (ProjectEltTy t i x) = ProjectElt t i x + where Type (StuckTy e) = Stuck e Type t = TypeAsAtom t pattern IdxRepScalarBaseTy :: ScalarBaseType @@ -1450,14 +1460,13 @@ instance IRRep r => GenericE (Atom r) where -- toE/fromE entirely. If you wish to modify the order, please consult the -- GHC Core dump to make sure you haven't regressed this optimization. type RepE (Atom r) = EitherE3 - (EitherE4 - {- Var -} (AtomVar r) - {- ProjectElt -} (Type r `PairE` LiftE Projection `PairE` Atom r) + (EitherE3 + {- Stuck -} (Stuck r) {- Lam -} (WhenCore r CoreLamExpr) {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r) ) (EitherE3 - {- DictCon -} (WhenCore r (CType `PairE` DictExpr)) - {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) + {- DictCon -} (WhenCore r DictCon) + {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) {- Con -} (Con r) ) (EitherE5 {- Eff -} ( WhenCore r (EffectRow r)) @@ -1468,11 +1477,10 @@ instance IRRep r => GenericE (Atom r) where ) fromE atom = case atom of - Var v -> Case0 (Case0 v) - ProjectElt t idxs x -> Case0 (Case1 (t `PairE` LiftE idxs `PairE` x)) - Lam lamExpr -> Case0 (Case2 (WhenIRE lamExpr)) - DepPair l r ty -> Case0 (Case3 $ l `PairE` r `PairE` ty) - DictCon t d -> Case1 $ Case0 $ WhenIRE $ t `PairE` d + Stuck x -> Case0 (Case0 x) + Lam lamExpr -> Case0 (Case1 (WhenIRE lamExpr)) + DepPair l r ty -> Case0 (Case2 $ l `PairE` r `PairE` ty) + DictCon d -> Case1 $ Case0 $ WhenIRE d NewtypeCon c x -> Case1 $ Case1 $ WhenIRE (c `PairE` x) Con con -> Case1 $ Case2 con Eff effs -> Case2 $ Case0 $ WhenIRE effs @@ -1484,13 +1492,12 @@ instance IRRep r => GenericE (Atom r) where toE atom = case atom of Case0 val -> case val of - Case0 v -> Var v - Case1 (t `PairE` LiftE idxs `PairE` x) -> ProjectElt t idxs x - Case2 (WhenIRE (lamExpr)) -> Lam lamExpr - Case3 (l `PairE` r `PairE` ty) -> DepPair l r ty + Case0 e -> Stuck e + Case1 (WhenIRE (lamExpr)) -> Lam lamExpr + Case2 (l `PairE` r `PairE` ty) -> DepPair l r ty _ -> error "impossible" Case1 val -> case val of - Case0 (WhenIRE (t `PairE` d)) -> DictCon t d + Case0 (WhenIRE d) -> DictCon d Case1 (WhenIRE (c `PairE` x)) -> NewtypeCon c x Case2 con -> Con con _ -> error "impossible" @@ -1510,6 +1517,36 @@ instance IRRep r => AlphaEqE (Atom r) instance IRRep r => AlphaHashableE (Atom r) instance IRRep r => RenameE (Atom r) +instance IRRep r => GenericE (Stuck r) where + type RepE (Stuck r) = EitherE5 + {- StuckVar -} (AtomVar r) + {- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r) + {- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck)) + {- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom)) + {- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck)) + fromE = \case + StuckVar v -> Case0 v + StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e + StuckUnwrap t e -> Case2 $ WhenIRE $ t `PairE` e + InstantiatedGiven t e xs -> Case3 $ WhenIRE $ t `PairE` e `PairE` ListE xs + SuperclassProj t i e -> Case4 $ WhenIRE $ t `PairE` LiftE i `PairE` e + {-# INLINE fromE #-} + + toE = \case + Case0 v -> StuckVar v + Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e + Case2 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e + Case3 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs + Case4 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e + _ -> error "impossible" + {-# INLINE toE #-} + +instance IRRep r => SinkableE (Stuck r) +instance IRRep r => HoistableE (Stuck r) +instance IRRep r => AlphaEqE (Stuck r) +instance IRRep r => AlphaHashableE (Stuck r) +instance IRRep r => RenameE (Stuck r) + instance IRRep r => GenericE (AtomVar r) where type RepE (AtomVar r) = PairE (AtomName r) (Type r) fromE (AtomVar v t) = PairE v t @@ -1538,36 +1575,34 @@ instance IRRep r => AlphaHashableE (AtomVar r) where instance IRRep r => RenameE (AtomVar r) instance IRRep r => GenericE (Type r) where - type RepE (Type r) = EitherE8 - {- TyVar -} (WhenCore r CAtomVar) + type RepE (Type r) = EitherE7 + {- StuckTy -} (WhenCore r CStuck) {- Pi -} (WhenCore r CorePiType) {- TabPi -} (TabPiType r) {- DepPairTy -} (DepPairType r) {- DictTy -} (WhenCore r DictType) {- NewtypeTyCon -} (WhenCore r NewtypeTyCon) {- TC -} (TC r) - {- ProjectEltTy -} (WhenCore r (Type r `PairE` LiftE Projection `PairE` Atom r)) fromE = \case - TyVar v -> Case0 $ WhenIRE v + StuckTy e -> Case0 $ WhenIRE e Pi t -> Case1 $ WhenIRE t TabPi t -> Case2 t DepPairTy t -> Case3 t DictTy d -> Case4 $ WhenIRE d NewtypeTyCon t -> Case5 $ WhenIRE t TC con -> Case6 $ con - ProjectEltTy t idxs x -> Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) {-# INLINE fromE #-} toE = \case - Case0 (WhenIRE v) -> TyVar v + Case0 (WhenIRE e) -> StuckTy e Case1 (WhenIRE t) -> Pi t Case2 t -> TabPi t Case3 t -> DepPairTy t Case4 (WhenIRE d) -> DictTy d Case5 (WhenIRE t) -> NewtypeTyCon t Case6 con -> TC con - Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) -> ProjectEltTy t idxs x + _ -> error "impossible" {-# INLINE toE #-} instance IRRep r => SinkableE (Type r) @@ -1585,20 +1620,23 @@ instance IRRep r => GenericE (Expr r) where {- Atom -} (Atom r) {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) ) - ( EitherE3 + ( EitherE5 {- TabCon -} (MaybeE (WhenCore r Dict) `PairE` Type r `PairE` ListE (Atom r)) {- PrimOp -} (PrimOp r) - {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r)))) - + {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r))) + {- Project -} (Type r `PairE` LiftE Int `PairE` Atom r) + {- Unwrap -} (WhenCore r (CType `PairE` CAtom))) fromE = \case App et f xs -> Case0 $ Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) TabApp t f xs -> Case0 $ Case1 (t `PairE` f `PairE` ListE xs) Case e alts effTy -> Case0 $ Case2 (e `PairE` ListE alts `PairE` effTy) Atom x -> Case0 $ Case3 (x) - TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) - TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) - PrimOp op -> Case1 $ Case1 op + TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) + TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) + PrimOp op -> Case1 $ Case1 op ApplyMethod et d i xs -> Case1 $ Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) + Project ty i x -> Case1 $ Case3 (ty `PairE` LiftE i `PairE` x) + Unwrap t x -> Case1 $ Case4 (WhenIRE (t `PairE` x)) {-# INLINE fromE #-} toE = \case Case0 case0 -> case case0 of @@ -1612,6 +1650,8 @@ instance IRRep r => GenericE (Expr r) where Case0 (d `PairE` ty `PairE` ListE xs) -> TabCon (fromMaybeE d) ty xs Case1 op -> PrimOp op Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) -> ApplyMethod et d i xs + Case3 (ty `PairE` LiftE i `PairE` x) -> Project ty i x + Case4 (WhenIRE (t `PairE` x)) -> Unwrap t x _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1903,33 +1943,26 @@ instance AlphaEqE DictType instance AlphaHashableE DictType instance RenameE DictType -instance GenericE DictExpr where - type RepE DictExpr = - EitherE5 - {- InstanceDict -} (PairE InstanceName (ListE CAtom)) - {- InstantiatedGiven -} (PairE CAtom (ListE CAtom)) - {- SuperclassProj -} (PairE CAtom (LiftE Int)) - {- IxFin -} CAtom - {- DataData -} CType +instance GenericE DictCon where + type RepE DictCon = EitherE3 + {- InstanceDict -} (CType `PairE` PairE InstanceName (ListE CAtom)) + {- IxFin -} (CType `PairE` CAtom) + {- DataData -} (CType `PairE` CType) fromE d = case d of - InstanceDict v args -> Case0 $ PairE v (ListE args) - InstantiatedGiven given args -> Case1 $ PairE given (ListE args) - SuperclassProj x i -> Case2 (PairE x (LiftE i)) - IxFin x -> Case3 x - DataData ty -> Case4 ty + InstanceDict t v args -> Case0 $ t `PairE` PairE v (ListE args) + IxFin t x -> Case1 $ t `PairE` x + DataData t ty -> Case2 $ t `PairE` ty toE d = case d of - Case0 (PairE v (ListE args)) -> InstanceDict v args - Case1 (PairE given (ListE args)) -> InstantiatedGiven given args - Case2 (PairE x (LiftE i)) -> SuperclassProj x i - Case3 x -> IxFin x - Case4 ty -> DataData ty + Case0 (t `PairE` (PairE v (ListE args))) -> InstanceDict t v args + Case1 (t `PairE` x) -> IxFin t x + Case2 (t `PairE` ty) -> DataData t ty _ -> error "impossible" -instance SinkableE DictExpr -instance HoistableE DictExpr -instance AlphaEqE DictExpr -instance AlphaHashableE DictExpr -instance RenameE DictExpr +instance SinkableE DictCon +instance HoistableE DictCon +instance AlphaEqE DictCon +instance AlphaHashableE DictCon +instance RenameE DictCon instance GenericE Cache where type RepE Cache = @@ -2727,6 +2760,7 @@ instance IRRep r => Store (PrimOp r n) instance IRRep r => Store (RepVal r n) instance IRRep r => Store (Type r n) instance IRRep r => Store (EffTy r n) +instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) instance IRRep r => Store (AtomVar r n) instance IRRep r => Store (Expr r n) @@ -2753,7 +2787,7 @@ instance Store (ClassDef n) instance Store (InstanceDef n) instance Store (InstanceBody n) instance Store (DictType n) -instance Store (DictExpr n) +instance Store (DictCon n) instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (EffectOpType n) diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index 002a6d09a..d85d88247 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -60,6 +60,8 @@ data RequiredMethodAccess = Full | Partial Int deriving (Show, Eq, Ord, Generic) data LetAnn = -- Binding with no additional information PlainLet + -- Binding explicitly tagged "inline immediately" + | InlineLet -- Binding explicitly tagged "do not inline" | NoInlineLet -- Bound expression is pure, and the binding's occurrences are summarized by diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index bf585a089..9f4c27e54 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -227,7 +227,7 @@ simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) => IxType SimpIR n -> m n (Maybe Word32) simplifyIxSize ixty = do sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size [] - cheapReduce sizeMethod >>= \case + reduceBlock sizeMethod >>= \case Just (IdxRepVal n) -> return $ Just n _ -> return Nothing {-# INLINE simplifyIxSize #-} @@ -534,18 +534,17 @@ vectorizeType t = do vectorizeAtom :: SAtom i -> VectorizeM i o (VAtom o) vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do case atom of - Var v -> lookupSubstM (atomVarName v) >>= \case - VRename v' -> VVal Uniform . Var <$> toAtomVar v' - v' -> return v' - -- Vectors of base newtypes are already newtype-stripped. - ProjectElt _ (ProjectProduct i) x -> do - VVal vv x' <- vectorizeAtom x - ov <- case vv of - ProdStability sbs -> return $ sbs !! i - _ -> throwVectErr "Invalid projection" - x'' <- normalizeProj (ProjectProduct i) x' - return $ VVal ov x'' - ProjectElt _ UnwrapNewtype _ -> error "Shouldn't have newtypes left" -- TODO: check statically + Stuck e -> case e of + StuckVar v -> lookupSubstM (atomVarName v) >>= \case + VRename v' -> VVal Uniform . Var <$> toAtomVar v' + v' -> return v' + StuckProject _ i x -> do + VVal vv x' <- vectorizeAtom (Stuck x) + ov <- case vv of + ProdStability sbs -> return $ sbs !! i + _ -> throwVectErr "Invalid projection" + x'' <- reduceProj i x' + return $ VVal ov x'' Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l _ -> do subst <- getSubst From 0ff323354d426aead464b8a597d98a283da370be Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 19 Oct 2023 17:44:15 -0400 Subject: [PATCH 04/41] Make `Block` just a case of `Expr` instead of a separate type. I also removed ProjDest from Lower. It doesn't make sense now that projections aren't in atoms except to handle dependent types. We need to re-do the destination passing pass more carefully. --- src/lib/Algebra.hs | 95 +++++++-------- src/lib/Builder.hs | 140 +++++++++++---------- src/lib/CheapReduction.hs | 17 +-- src/lib/CheckType.hs | 44 ++++--- src/lib/Core.hs | 10 +- src/lib/Imp.hs | 113 +++++++---------- src/lib/Inference.hs | 78 ++++++------ src/lib/Inline.hs | 202 +++++++++++++++---------------- src/lib/Linearize.hs | 54 ++++----- src/lib/Lower.hs | 249 ++++++++++++++------------------------ src/lib/OccAnalysis.hs | 30 ++--- src/lib/Optimize.hs | 93 +++++++------- src/lib/PPrint.hs | 6 +- src/lib/QueryType.hs | 35 ++---- src/lib/QueryTypePure.hs | 2 + src/lib/RuntimePrint.hs | 4 +- src/lib/Simplify.hs | 130 +++++++++----------- src/lib/TopLevel.hs | 2 +- src/lib/Transpose.hs | 31 +++-- src/lib/Types/Core.hs | 114 +++++++++++------ src/lib/Vectorize.hs | 101 +++++++--------- 21 files changed, 714 insertions(+), 836 deletions(-) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index a0b022fdf..a8e125ecd 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -49,10 +49,10 @@ newtype Polynomial (n::S) = -- us compute sums in closed form. This tries to compute -- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`. sumUsingPolys :: Emits n - => Atom SimpIR n -> Abs (Binder SimpIR) (Block SimpIR) n -> BuilderM SimpIR n (Atom SimpIR n) + => Atom SimpIR n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (Atom SimpIR n) sumUsingPolys lim (Abs i body) = do sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do - blockAsPoly body' >>= \case + exprAsPoly body' >>= \case Just poly' -> return $ Abs i' poly' Nothing -> throw NotImplementedErr $ "Algebraic simplification failed to model index computations:\n" @@ -134,56 +134,53 @@ instance FromName PolySubstVal where fromName = PolyRename type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR)) i o a -blockAsPoly - :: (EnvExtender m, EnvReader m) - => Block SimpIR n -> m n (Maybe (Polynomial n)) -blockAsPoly (Abs decls result) = - liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ blockAsPolyRec decls result - -blockAsPolyRec :: Nest (Decl SimpIR) i i' -> Atom SimpIR i' -> BlockTraverserM i o (Polynomial o) -blockAsPolyRec decls result = case decls of - Empty -> atomAsPoly result - Nest (Let b (DeclBinding _ expr)) restDecls -> do - p <- optional (exprAsPoly expr) - extendSubst (b@>PolySubstVal p) $ blockAsPolyRec restDecls result - - where - atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o) - atomAsPoly = \case - Var v -> atomVarAsPoly v - RepValAtom (RepVal _ (Leaf (IVar v' _))) -> impNameAsPoly v' - IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] +exprAsPoly :: (EnvExtender m, EnvReader m) => SExpr n -> m n (Maybe (Polynomial n)) +exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPolyRec expr + +atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o) +atomAsPoly = \case + Var v -> atomVarAsPoly v + RepValAtom (RepVal _ (Leaf (IVar v' _))) -> impNameAsPoly v' + IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] + _ -> empty + +impNameAsPoly :: ImpName i -> BlockTraverserM i o (Polynomial o) +impNameAsPoly v = getSubst <&> (!v) >>= \case + PolyRename v' -> return $ poly [(1, mono [(RightE v', 1)])] + +atomVarAsPoly :: AtomVar SimpIR i -> BlockTraverserM i o (Polynomial o) +atomVarAsPoly v = getSubst <&> (! atomVarName v) >>= \case + PolySubstVal Nothing -> empty + PolySubstVal (Just cp) -> return cp + PolyRename v' -> do + v'' <- toAtomVar v' + case getType v'' of + IdxRepTy -> return $ poly [(1, mono [(LeftE v', 1)])] _ -> empty - impNameAsPoly :: ImpName i -> BlockTraverserM i o (Polynomial o) - impNameAsPoly v = getSubst <&> (!v) >>= \case - PolyRename v' -> return $ poly [(1, mono [(RightE v', 1)])] - - atomVarAsPoly :: AtomVar SimpIR i -> BlockTraverserM i o (Polynomial o) - atomVarAsPoly v = getSubst <&> (! atomVarName v) >>= \case - PolySubstVal Nothing -> empty - PolySubstVal (Just cp) -> return cp - PolyRename v' -> do - v'' <- toAtomVar v' - case getType v'' of - IdxRepTy -> return $ poly [(1, mono [(LeftE v', 1)])] - _ -> empty - - exprAsPoly :: Expr SimpIR i -> BlockTraverserM i o (Polynomial o) - exprAsPoly e = case e of - Atom a -> atomAsPoly a - PrimOp (BinOp op x y) -> case op of - IAdd -> add <$> atomAsPoly x <*> atomAsPoly y - IMul -> mul <$> atomAsPoly x <*> atomAsPoly y - -- XXX: we rely on the wrapping behavior of subtraction on unsigned ints - -- so that the distributive law holds, `a * (b - c) == (a * b) - (a * c)` - ISub -> sub <$> atomAsPoly x <*> atomAsPoly y - -- This is to handle `idiv` generated by `emitPolynomial` - IDiv -> case y of - IdxRepVal n -> mulConst (1 / fromIntegral n) <$> atomAsPoly x - _ -> empty - _ -> empty +exprAsPolyRec :: Expr SimpIR i -> BlockTraverserM i o (Polynomial o) +exprAsPolyRec e = case e of + Block _ block -> blockAsPoly block + Atom a -> atomAsPoly a + PrimOp (BinOp op x y) -> case op of + IAdd -> add <$> atomAsPoly x <*> atomAsPoly y + IMul -> mul <$> atomAsPoly x <*> atomAsPoly y + -- XXX: we rely on the wrapping behavior of subtraction on unsigned ints + -- so that the distributive law holds, `a * (b - c) == (a * b) - (a * c)` + ISub -> sub <$> atomAsPoly x <*> atomAsPoly y + -- This is to handle `idiv` generated by `emitPolynomial` + IDiv -> case y of + IdxRepVal n -> mulConst (1 / fromIntegral n) <$> atomAsPoly x _ -> empty + _ -> empty + _ -> empty + +blockAsPoly :: SBlock i -> BlockTraverserM i o (Polynomial o) +blockAsPoly (Abs decls result) = case decls of + Empty -> exprAsPolyRec result + Nest (Let b (DeclBinding _ expr)) restDecls -> do + p <- optional (exprAsPolyRec expr) + extendSubst (b@>PolySubstVal p) $ blockAsPoly $ Abs restDecls result -- === polynomials to Core expressions === diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index ee41f1b4a..52425566d 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -78,16 +78,21 @@ emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n emitHinted hint expr = emitDecl hint PlainLet expr {-# INLINE emitHinted #-} -emitOp :: (Builder r m, IsPrimOp e, Emits n) => e r n -> m n (Atom r n) -emitOp op = Var <$> emit (PrimOp $ toPrimOp op) -{-# INLINE emitOp #-} - -emitExpr :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) -emitExpr expr = Var <$> emit expr +emitExpr :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emitExpr e = case toExpr e of + Atom x -> return x + Block _ block -> emitDecls block >>= emitExpr + expr -> Var <$> emit expr {-# INLINE emitExpr #-} +emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n) +emitToVar e = case toExpr e of + Atom (Var v) -> return v + expr -> emit expr +{-# INLINE emitToVar #-} + emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) -emitHof hof = mkTypedHof hof >>= emitOp +emitHof hof = mkTypedHof hof >>= emitExpr mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) mkTypedHof hof = do @@ -95,12 +100,9 @@ mkTypedHof hof = do return $ TypedHof effTy hof emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) -emitUnOp op x = emitOp $ UnOp op x +emitUnOp op x = emitExpr $ UnOp op x {-# INLINE emitUnOp #-} -emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) -emitBlock = emitDecls - emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) => WithDecls r e n -> m n (e n) emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result @@ -628,10 +630,11 @@ newtype WrapWithEmits n r = -- === lambda-like things === buildBlock - :: ScopableBuilder r m - => (forall l. (Emits l, DExt n l) => m l (Atom r l)) - -> m n (Block r n) -buildBlock = buildScoped + :: (ScopableBuilder r m, HasNamesE e, ToExpr e r) + => (forall l. (Emits l, DExt n l) => m l (e l)) + -> m n (Expr r n) +buildBlock cont = mkBlock =<< buildScoped cont +{-# INLINE buildBlock #-} buildCoreLam :: ScopableBuilder CoreIR m @@ -731,7 +734,7 @@ buildCaseAlts scrut indexedAltBody = do injectAltResult :: EnvReader m => [SType n] -> Int -> Alt SimpIR n -> m n (Alt SimpIR n) injectAltResult sumTys con (Abs b body) = liftBuilder do buildAlt (binderType b) \v -> do - originalResult <- emitBlock =<< applySubst (b@>SubstVal (Var v)) body + originalResult <- emitExpr =<< applySubst (b@>SubstVal (Var v)) body (dataResult, nonDataResult) <- fromPairReduced originalResult return $ PairVal dataResult $ Con $ SumCon (sinkList sumTys) con nonDataResult @@ -752,8 +755,7 @@ buildCase' scrut resultTy indexedAltBody = do (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do blk <- buildBlock $ indexedAltBody i $ Var $ sink x - EffTy eff _ <- blockEffTy blk - return $ blk `PairE` eff + return $ blk `PairE` getEffects blk return (Abs b' body, ignoreHoistFailure $ hoist b' eff') return $ Case scrut alts $ EffTy (mconcat effs) resultTy @@ -920,7 +922,7 @@ addTangent x y = do liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) TC con -> case con of - BaseType (Scalar _) -> emitOp $ BinOp FAdd x y + BaseType (Scalar _) -> emitExpr $ BinOp FAdd x y ProdType _ -> do xs <- getUnpacked x ys <- getUnpacked y @@ -958,63 +960,63 @@ fLitLike x t = do _ -> error "Expected a floating point scalar" neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -neg x = emitOp $ UnOp FNeg x +neg x = emitExpr $ UnOp FNeg x add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -add x y = emitOp $ BinOp FAdd x y +add x y = emitExpr $ BinOp FAdd x y -- TODO: Implement constant folding for fixed-width integer types as well! iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) iadd (Con (Lit l)) y | getIntLit l == 0 = return y iadd x (Con (Lit l)) | getIntLit l == 0 = return x iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y -iadd x y = emitOp $ BinOp IAdd x y +iadd x y = emitExpr $ BinOp IAdd x y mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -mul x y = emitOp $ BinOp FMul x y +mul x y = emitExpr $ BinOp FMul x y imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) imul (Con (Lit l)) y | getIntLit l == 1 = return y imul x (Con (Lit l)) | getIntLit l == 1 = return x imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y -imul x y = emitOp $ BinOp IMul x y +imul x y = emitExpr $ BinOp IMul x y sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emitOp $ BinOp FSub x y +sub x y = emitExpr $ BinOp FSub x y isub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) isub x (Con (Lit l)) | getIntLit l == 0 = return x isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y -isub x y = emitOp $ BinOp ISub x y +isub x y = emitExpr $ BinOp ISub x y select :: (Builder r m, Emits n) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n) select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y -select p x y = emitOp $ MiscOp $ Select p x y +select p x y = emitExpr $ MiscOp $ Select p x y div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -div' x y = emitOp $ BinOp FDiv x y +div' x y = emitExpr $ BinOp FDiv x y idiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) idiv x (Con (Lit l)) | getIntLit l == 1 = return x idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y -idiv x y = emitOp $ BinOp IDiv x y +idiv x y = emitExpr $ BinOp IDiv x y irem :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -irem x y = emitOp $ BinOp IRem x y +irem x y = emitExpr $ BinOp IRem x y fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -fpow x y = emitOp $ BinOp FPow x y +fpow x y = emitExpr $ BinOp FPow x y flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -flog x = emitOp $ UnOp Log x +flog x = emitExpr $ UnOp Log x ilt :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y -ilt x y = emitOp $ BinOp (ICmp Less) x y +ilt x y = emitExpr $ BinOp (ICmp Less) x y ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y -ieq x y = emitOp $ BinOp (ICmp Equal) x y +ieq x y = emitExpr $ BinOp (ICmp Equal) x y fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) fromPair pair = do @@ -1028,7 +1030,7 @@ applyProjectionsRef [] ref = return ref applyProjectionsRef (i:is) ref = getProjRef i =<< applyProjectionsRef is ref getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n) -getProjRef i r = emitOp =<< mkProjRef r i +getProjRef i r = emitExpr =<< mkProjRef r i -- XXX: getUnpacked must reduce its argument to enforce the invariant that -- ProjectElt atoms are always fully reduced (to avoid type errors between two @@ -1109,6 +1111,23 @@ applyProjectionsReduced (p:ps) x = do ProjectProduct i -> reduceProj i x' UnwrapNewtype -> reduceUnwrap x' +mkBlock :: (EnvReader m, IRRep r) => ToExpr e r => Abs (Decls r) e n -> m n (Expr r n) +mkBlock (Abs decls body) = do + let block = Abs decls (toExpr body) + effTy <- blockEffTy block + return $ Block effTy block + +blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) +blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do + effs <- declsEffects decls mempty + return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result + where + declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) + declsEffects Empty !acc = return acc + declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do + expr' <- sinkM expr + declsEffects rest $ acc <> getEffects expr' + mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n) mkApp f xs = do et <- appEffTy (getType f) xs @@ -1150,14 +1169,12 @@ mkInstanceDict instanceName args = do mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) mkCase scrut resultTy alts = liftEnvReaderM do eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do - EffTy eff _ <- blockEffTy body - return $ ignoreHoistFailure $ hoist b eff + return $ ignoreHoistFailure $ hoist b $ getEffects body return $ Case scrut alts (EffTy eff' resultTy) -mkCatchException :: EnvReader m => CBlock n -> m n (Hof CoreIR n) +mkCatchException :: EnvReader m => CExpr n -> m n (Hof CoreIR n) mkCatchException body = do - EffTy _ bodyTy <- blockEffTy body - resultTy <- makePreludeMaybeTy bodyTy + resultTy <- makePreludeMaybeTy (getType body) return $ CatchException resultTy body app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) @@ -1175,7 +1192,7 @@ naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> naryTopAppInlined f xs = do TopFunBinding f' <- lookupEnv f case f' of - DexTopFun _ lam _ -> instantiate lam xs >>= emitBlock + DexTopFun _ lam _ -> instantiate lam xs >>= emitExpr _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} @@ -1197,19 +1214,19 @@ naryTabAppHinted hint f xs = do Var <$> emitHinted hint expr indexRef :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -indexRef ref i = emitOp =<< mkIndexRef ref i +indexRef ref i = emitExpr =<< mkIndexRef ref i naryIndexRef :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) naryIndexRef ref is = foldM indexRef ref is ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ptrOffset x (IdxRepVal 0) = return x -ptrOffset x i = emitOp $ MemOp $ PtrOffset x i +ptrOffset x i = emitExpr $ MemOp $ PtrOffset x i {-# INLINE ptrOffset #-} unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) unsafePtrLoad x = do - body <- liftEmitBuilder $ buildBlock $ emitOp . MemOp . PtrLoad =<< sinkM x + body <- liftEmitBuilder $ buildBlock $ emitExpr . MemOp . PtrLoad =<< sinkM x emitHof $ RunIO body mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n) @@ -1234,7 +1251,7 @@ applyIxMethod dict method args = case dict of IxDictSpecialized _ d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs - instantiate (fs !! fromEnum method) (params ++ args) >>= emitBlock + instantiate (fs !! fromEnum method) (params ++ args) >>= emitExpr unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n) unsafeFromOrdinal (IxType _ dict) i = applyIxMethod dict UnsafeFromOrdinal [i] @@ -1266,7 +1283,7 @@ emitIf :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Atom r n) emitIf predicate resultTy trueCase falseCase = do - predicate' <- emitOp $ MiscOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate + predicate' <- emitExpr $ MiscOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate buildCase predicate' resultTy \i _ -> case i of 0 -> falseCase @@ -1290,7 +1307,7 @@ fromJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) fromJustE x = liftEmitBuilder do MaybeTy a <- return $ getType x emitMaybeCase x a - (emitOp $ MiscOp $ ThrowError $ sink a) + (emitExpr $ MiscOp $ ThrowError $ sink a) (return) -- Maybe a -> Bool @@ -1311,7 +1328,7 @@ reduceE monoid xs = liftEmitBuilder do andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n) andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $ buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y -> - emitOp $ BinOp BAnd (sink $ Var x) (Var y) + emitExpr $ BinOp BAnd (sink $ Var x) (Var y) -- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b) mapE :: (Emits n, ScopableBuilder r m) @@ -1552,24 +1569,8 @@ type ExprVisitorNoEmits2 m r = forall i o. ExprVisitorNoEmits (m i o) r i o visitLamNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) => LamExpr r i -> m i o (LamExpr r o) -visitLamNoEmits (LamExpr bs (Abs decls result)) = - visitBinders bs \bs' -> LamExpr bs' <$> - visitDeclsNoEmits decls \decls' -> Abs decls' <$> do - visitAtom result - -visitDeclsNoEmits - :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) - => Nest (Decl r) i i' - -> (forall o'. DExt o o' => Nest (Decl r) o o' -> m i' o' a) - -> m i o a -visitDeclsNoEmits Empty cont = getDistinct >>= \Distinct -> cont Empty -visitDeclsNoEmits (Nest (Let b (DeclBinding ann expr)) decls) cont = do - expr' <- visitExprNoEmits expr - withFreshBinder (getNameHint b) (getType expr') \(b':>_) -> do - let decl' = Let b' $ DeclBinding ann expr' - extendRenamer (b@>binderName b') do - visitDeclsNoEmits decls \decls' -> - cont $ Nest decl' decls' +visitLamNoEmits (LamExpr bs body) = + visitBinders bs \bs' -> LamExpr bs' <$> visitExprNoEmits body -- === Emitting expression visitor === @@ -1593,12 +1594,7 @@ visitLamEmits :: (ExprVisitorEmits2 m r, IRRep r, SubstReader AtomSubstVal m, ScopableBuilder2 r m) => LamExpr r i -> m i o (LamExpr r o) visitLamEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$> - buildBlock (visitBlockEmits body) - -visitBlockEmits - :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) - => Block r i -> m i o (Atom r o) -visitBlockEmits (Abs decls result) = visitDeclsEmits decls $ visitAtom result + buildBlock (visitExprEmits body) visitDeclsEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 0f850166f..c7fd96589 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -8,7 +8,7 @@ {-# OPTIONS_GHC -Wno-orphans #-} module CheapReduction - ( reduceWithDecls, reduceExpr, reduceBlock + ( reduceWithDecls, reduceExpr , instantiateTyConDef, dataDefRep, unwrapNewtypeType , NonAtomRenamer (..), Visitor (..), VisitGeneric (..) , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 @@ -56,10 +56,6 @@ reduceExpr :: (IRRep r, EnvReader m) => Expr r n -> m n (Maybe (Atom r n)) reduceExpr e = liftReducerM $ reduceExprM e {-# INLINE reduceExpr #-} -reduceBlock :: (IRRep r, EnvReader m) => Block r n -> m n (Maybe (Atom r n)) -reduceBlock e = liftReducerM $ reduceBlockM e -{-# INLINE reduceBlock #-} - reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x {-# INLINE reduceProj #-} @@ -91,12 +87,10 @@ reduceWithDeclsM (Nest (Let b (DeclBinding _ expr)) rest) cont = do x <- reduceExprM expr extendSubst (b@>SubstVal x) $ reduceWithDeclsM rest cont -reduceBlockM :: IRRep r => Block r i -> ReducerM i o (Atom r o) -reduceBlockM (Abs decls x) = reduceWithDeclsM decls $ substM x - reduceExprM :: IRRep r => Expr r i -> ReducerM i o (Atom r o) reduceExprM = \case Atom x -> substM x + Block _ (Abs decls result) -> reduceWithDeclsM decls $ reduceExprM result App _ f xs -> mapM substM xs >>= reduceApp f Unwrap _ x -> substM x >>= reduceUnwrapM Project _ i x -> substM x >>= reduceProjM i @@ -125,7 +119,7 @@ reduceApp :: CAtom i -> [CAtom o] -> ReducerM i o (CAtom o) reduceApp f xs = do f' <- substM f -- TODO: avoid double-subst case f' of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceBlockM body + Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body -- TODO: check ultrapure Var v -> lookupAtomName (atomVarName v) >>= \case LetBound (DeclBinding _ (Atom f'')) -> dropSubst $ reduceApp f'' xs @@ -166,7 +160,7 @@ reduceSuperclassProjM superclassIx dict = case dict of reduceInstantiateGivenM :: CAtom o -> [CAtom o] -> ReducerM i o (CAtom o) reduceInstantiateGivenM f xs = case f of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceBlockM body + Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body Stuck f' -> do resultTy <- typeOfApp (getType f) xs return $ Stuck $ InstantiatedGiven resultTy f' xs @@ -308,7 +302,7 @@ instance VisitGeneric (Type r) r where visitGeneric = visitType instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam instance VisitGeneric (PiType r) r where visitGeneric = visitPi -visitBlock :: Visitor m r i o => Block r i -> m (Block r o) +visitBlock :: Visitor m r i o => Expr r i -> m (Expr r o) visitBlock b = visitGeneric (LamExpr Empty b) >>= \case LamExpr Empty b' -> return b' _ -> error "not a block" @@ -395,6 +389,7 @@ visitTypePartial = \case instance IRRep r => VisitGeneric (Expr r) r where visitGeneric = \case + Block _ _ -> error "not handled generically" TopApp et v xs -> TopApp <$> visitGeneric et <*> renameN v <*> mapM visitGeneric xs TabApp t tab xs -> TabApp <$> visitType t <*> visitGeneric tab <*> mapM visitGeneric xs -- TODO: should we reuse the original effects? Whether it's valid depends on diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 24d96fb53..6f3b2093a 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -6,7 +6,7 @@ {-# LANGUAGE UndecidableInstances #-} -module CheckType (CheckableE (..), CheckableB (..), checkBlock, checkTypes, checkTypeIs) where +module CheckType (CheckableE (..), CheckableB (..), checkTypes, checkTypeIs) where import Prelude hiding (id) import Control.Category ((>>>)) @@ -127,6 +127,12 @@ checkAndGetType x = do x' <- checkE x return (x', getType x') +checkWithEffTy :: (CheckableWithEffects r e, HasType r e, IRRep r) => EffTy r o -> e i -> TyperM r i o (e o) +checkWithEffTy (EffTy effs ty) e = do + e' <- checkWithEffects effs e + checkTypesEq ty (getType e') + return e' + instance CheckableE CoreIR SourceMap where checkE sm = renameM sm -- TODO? @@ -244,6 +250,12 @@ instance IRRep r => CheckableWithEffects r (Expr r) where return $ TopApp effTy' f' xs' Atom x -> Atom <$> checkE x PrimOp op -> PrimOp <$> checkWithEffects allowedEffs op + Block effTy (Abs decls body) -> do + effTy'@(EffTy effs ty) <- checkEffTy allowedEffs effTy + checkDecls effs decls \decls' -> do + body' <- checkWithEffects (sink effs) body + checkTypesEq (sink ty) (getType body') + return $ Block effTy' $ Abs decls' body' Case scrut alts effTy -> do effTy' <- checkEffTy allowedEffs effTy scrut' <- checkE scrut @@ -252,7 +264,7 @@ instance IRRep r => CheckableWithEffects r (Expr r) where alts' <- parallelAffines $ (zip alts altsBinderTys) <&> \(Abs b body, reqBinderTy) -> do checkB b \b' -> do checkTypesEq (sink reqBinderTy) (sink $ binderType b') - Abs b' <$> checkBlock (sink effTy') body + Abs b' <$> checkWithEffTy (sink effTy') body return $ Case scrut' alts' effTy' ApplyMethod effTy dict i args -> do effTy' <- checkEffTy allowedEffs effTy @@ -602,12 +614,6 @@ instance IRRep r => CheckableE r (VectorOp r) where unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" return $ VectorSubref ref' i' ty' -checkBlock :: IRRep r => EffTy r o -> Block r i -> TyperM r i o (Block r o) -checkBlock (EffTy effs ty) (Abs decls result) = - checkDecls effs decls \decls' -> do - result' <- result |: sink ty - return $ Abs decls' result' - checkHof :: IRRep r => EffTy r o -> Hof r i -> TyperM r i o (Hof r o) checkHof (EffTy effs reqTy) = \case For dir ixTy f -> do @@ -616,25 +622,25 @@ checkHof (EffTy effs reqTy) = \case TabPi tabTy <- return reqTy checkBinderType t b \b' -> do resultTy <- checkInstantiation (sink tabTy) [Var $ binderVar b'] - body' <- checkBlock (EffTy (sink effs) resultTy) body + body' <- checkWithEffTy (EffTy (sink effs) resultTy) body return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') While body -> do let effTy = EffTy effs (BaseTy $ Scalar Word8Type) checkTypesEq reqTy UnitTy - While <$> checkBlock effTy body + While <$> checkWithEffTy effTy body Linearize f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkBinderType xTy b \b' -> do PairTy resultTy fLinTy <- sinkM reqTy - body' <- checkBlock (EffTy Pure resultTy) body + body' <- checkWithEffTy (EffTy Pure resultTy) body checkTypesEq fLinTy (Pi $ nonDepPiType [sink xTy] Pure resultTy) return $ Linearize (LamExpr (UnaryNest b') body') x' Transpose f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkB b \b' -> do - body' <- checkBlock (EffTy Pure (sink xTy)) body + body' <- checkWithEffTy (EffTy Pure (sink xTy)) body checkTypesEq (sink $ binderType b') (sink reqTy) return $ Transpose (LamExpr (UnaryNest b') body') x' RunReader r f -> do @@ -669,13 +675,13 @@ checkHof (EffTy effs reqTy) = \case declareEff effs InitEffect Just <$> dest |: RawRefTy sTy return $ RunState d' s' f' - RunIO body -> RunIO <$> checkBlock (EffTy (extendEffect IOEffect effs) reqTy) body - RunInit body -> RunInit <$> checkBlock (EffTy (extendEffect InitEffect effs) reqTy) body + RunIO body -> RunIO <$> checkWithEffTy (EffTy (extendEffect IOEffect effs) reqTy) body + RunInit body -> RunInit <$> checkWithEffTy (EffTy (extendEffect InitEffect effs) reqTy) body CatchException reqTy' body -> do reqTy'' <- checkE reqTy' checkTypesEq reqTy reqTy'' TypeCon _ _ (TyConParams _[Type ty]) <- return reqTy'' -- TODO: take more care in unpacking Maybe - body' <- checkBlock (EffTy (extendEffect ExceptionEffect effs) ty) body + body' <- checkWithEffTy (EffTy (extendEffect ExceptionEffect effs) ty) body return $ CatchException reqTy'' body' instance IRRep r => CheckableWithEffects r (DAMOp r) where @@ -692,7 +698,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where _ -> badCarry let binderReqTy = PairTy (ixTypeType ixTy') carryTy' checkBinderType binderReqTy b \b' -> do - body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body return $ Seq effAnn' dir ixTy' carry' $ LamExpr (UnaryNest b') body' RememberDest effAnn d lam -> do LamExpr (UnaryNest b) body <- return lam @@ -700,7 +706,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkExtends effs effAnn' (d', dTy@(RawRefTy _)) <- checkAndGetType d checkBinderType dTy b \b' -> do - body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' AllocDest ty -> AllocDest <$> ty|:TyKind Place ref val -> do @@ -717,7 +723,7 @@ checkLamExpr :: IRRep r => PiType r o -> LamExpr r i -> TyperM r i o (LamExpr r checkLamExpr piTy (LamExpr bs body) = checkB bs \bs' -> do effTy <- checkInstantiation (sink piTy) (Var <$> bindersVars bs') - body' <- checkBlock effTy body + body' <- checkWithEffTy effTy body return $ LamExpr bs' body' checkDecls @@ -743,7 +749,7 @@ checkRWSAction resultTy referentTy effs rws f = do let refTy = RefTy h (sink referentTy) checkBinderType refTy bR \bR' -> do let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) - body' <- checkBlock (EffTy effs' (sink resultTy)) body + body' <- checkWithEffTy (EffTy effs' (sink resultTy)) body return $ BinaryLamExpr bH' bR' body' checkCaseAltsBinderTys :: IRRep r => Type r n -> TyperM r i n [Type r n] diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 6f4bd6b5a..10c60999b 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -410,9 +410,10 @@ withFreshBinders (binding:rest) cont = do -- structure. Excess binders, if any, are still left in the unary structures. liftLamExpr :: (IRRep r, EnvReader m) - => (forall l m2. EnvReader m2 => Block r l -> m2 l (Block r l)) - -> TopLam r n -> m n (TopLam r n) -liftLamExpr f (TopLam d ty (LamExpr bs body)) = liftM (TopLam d ty) $ liftEnvReaderM $ + => TopLam r n + -> (forall l m2. EnvReader m2 => Expr r l -> m2 l (Expr r l)) + -> m n (TopLam r n) +liftLamExpr (TopLam d ty (LamExpr bs body)) f = liftM (TopLam d ty) $ liftEnvReaderM $ refreshAbs (Abs bs body) \bs' body' -> LamExpr bs' <$> f body' fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) @@ -422,9 +423,8 @@ fromNaryForExpr maxDepth = \case extend <|> (Just $ (1, LamExpr (Nest b Empty) body)) where extend = do - expr <- exprBlock body guard $ maxDepth > 1 - (d, LamExpr bs body2) <- fromNaryForExpr (maxDepth - 1) expr + (d, LamExpr bs body2) <- fromNaryForExpr (maxDepth - 1) body return (d + 1, LamExpr (Nest b bs) body2) _ -> Nothing diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7bf22d6ef..630479799 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -62,14 +62,14 @@ toImpFunction cc (TopLam True destTy lam) = do RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy _ -> error "Expected a reference type for body destination" extendSubst (destB @> SubstVal (destToAtom dest)) do - void $ translateBlock body + void $ translateExpr body resultAtom <- loadAtom dest repValToList <$> atomToRepVal resultAtom _ -> do (argAtoms, resultDest) <- interpretImpArgsWithCC cc (sink ty) vs extendSubst (bs @@> (SubstVal <$> argAtoms)) do extendSubst (destB @> SubstVal (destToAtom (sink resultDest))) do - void $ translateBlock body + void $ translateExpr body return [] toImpFunction _ (TopLam False _ _) = error "expected a lambda in destination-passing form" {-# SCC toImpFunction #-} @@ -267,29 +267,17 @@ liftImpM cont = do -- === the actual pass === -translateBlock :: forall i o. Emits o - => SBlock i -> SubstImpM i o (SAtom o) -translateBlock (Abs decls result) = translateDeclNest decls $ substM result - -translateDeclNestSubst - :: Emits o => Subst AtomSubstVal l o - -> Nest SDecl l i' -> SubstImpM i o (Subst AtomSubstVal i' o) -translateDeclNestSubst !s = \case - Empty -> return s +translateDeclNest :: Emits o => Nest SDecl i i' -> SubstImpM i' o a -> SubstImpM i o a +translateDeclNest decls cont = case decls of + Empty -> cont Nest (Let b (DeclBinding _ expr)) rest -> do - x <- withSubst s $ translateExpr expr - translateDeclNestSubst (s <>> (b@>SubstVal x)) rest - -translateDeclNest :: Emits o - => Nest SDecl i i' -> SubstImpM i' o a -> SubstImpM i o a -translateDeclNest decls cont = do - s <- getSubst - s' <- translateDeclNestSubst s decls - withSubst s' cont + x <- translateExpr expr + extendSubst (b@>SubstVal x) $ translateDeclNest rest cont {-# INLINE translateDeclNest #-} translateExpr :: forall i o. Emits o => SExpr i -> SubstImpM i o (SAtom o) translateExpr expr = confuseGHC >>= \_ -> case expr of + Block _ (Abs decls result) -> translateDeclNest decls $ translateExpr result TopApp (EffTy _ resultTy') f' xs' -> do resultTy <- substM resultTy' f <- substM f' @@ -314,7 +302,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of case trySelectBranch e' of Just (con, arg) -> do Abs b body <- return $ alts !! con - extendSubst (b @> SubstVal arg) $ translateBlock body + extendSubst (b @> SubstVal arg) $ translateExpr body Nothing -> do RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e' ts <- caseAltsBinderTys sumTy @@ -327,7 +315,7 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of emitSwitch tag' (zip xss alts) $ \(xs, Abs b body) -> extendSubst (b @> SubstVal (sink xs)) $ - void $ translateBlock body + void $ translateExpr body return UnitVal TabCon _ _ _ -> error "Unexpected `TabCon` in Imp pass." Project _ i x -> reduceProj i =<< substM x @@ -364,7 +352,7 @@ toImpRefOp refDest' m = do True -> do BinaryLamExpr xb yb body <- return bc body' <- applySubst (xb @> SubstVal x <.> yb @> SubstVal y) body - ans <- liftBuilderImp $ emitBlock (sink body') + ans <- liftBuilderImp $ emitExpr (sink body') storeAtom accDest ans False -> case accTy of TabPi t -> do @@ -393,12 +381,12 @@ toImpOp op = case op of emitLoop (getNameHint b) d n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $ - translateBlock body + translateExpr body return carry' RememberDest _ d f -> do UnaryLamExpr b body <- return f d' <- substM d - void $ extendSubst (b @> SubstVal d') $ translateBlock body + void $ extendSubst (b @> SubstVal d') $ translateExpr body return d' Place ref val -> do val' <- substM val @@ -531,7 +519,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do For _ _ _ -> error $ "Unexpected `for` in Imp pass " ++ pprint hof While body -> do body' <- buildBlockImp do - ans <- fromScalarAtom =<< translateBlock body + ans <- fromScalarAtom =<< translateExpr body return [ans] emitStatement $ IWhile body' return UnitVal @@ -541,7 +529,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do rDest <- allocDest $ getType r' storeAtom rDest r' extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $ - translateBlock body + translateExpr body RunWriter d (BaseMonoid e _) f -> do BinaryLamExpr h ref body <- return f let PairTy ansTy accTy = resultTy @@ -555,7 +543,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do PairE accTy' e'' <- sinkM $ PairE accTy e' liftMonoidEmpty wDest accTy' e'' extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom wDest)) $ - translateBlock body >>= storeAtom aDest + translateExpr body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom wDest RunState d s f -> do BinaryLamExpr h ref body <- return f @@ -568,10 +556,10 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do return (aDest, sDest) storeAtom sDest =<< substM s extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $ - translateBlock body >>= storeAtom aDest + translateExpr body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom sDest - RunIO body -> translateBlock body - RunInit body -> translateBlock body + RunIO body -> translateExpr body + RunInit body -> translateExpr body where liftMonoidEmpty :: Emits n => Dest n -> SType n -> SAtom n -> SubstImpM i n () liftMonoidEmpty accDest accTy x = do @@ -1068,7 +1056,7 @@ type SBuilderM = BuilderM SimpIR computeElemCountImp :: Emits n => IndexStructure SimpIR n -> SubstImpM i n (IExpr n) computeElemCountImp Singleton = return $ IIdxRepVal 1 computeElemCountImp idxs = do - result <- coreToImpBuilder do + result <- liftBuilderImp do idxs' <- sinkM idxs computeElemCount idxs' fromScalarAtom result @@ -1077,7 +1065,7 @@ computeOffsetImp :: Emits n => IndexStructure SimpIR n -> IExpr n -> SubstImpM i n (IExpr n) computeOffsetImp idxs ixOrd = do let ixOrd' = toScalarAtom ixOrd - result <- coreToImpBuilder do + result <- liftBuilderImp do PairE idxs' ixOrd'' <- sinkM $ PairE idxs ixOrd' computeOffset idxs' ixOrd'' fromScalarAtom result @@ -1106,7 +1094,7 @@ elemCountPoly (Abs bs UnitE) = case bs of computeSizeGivenOrdinal :: EnvReader m => IxBinder SimpIR n l -> IndexStructure SimpIR l - -> m n (Abs (Binder SimpIR) (Block SimpIR) n) + -> m n (Abs SBinder SExpr n) computeSizeGivenOrdinal (PairB (LiftB d) (b:>t)) idxStruct = liftBuilder do withFreshBinder noHint IdxRepTy \bOrdinal -> Abs bOrdinal <$> buildBlock do @@ -1138,8 +1126,8 @@ computeOffset (EmptyAbs (Nest b idxs)) idxOrdinal = do computeOffset _ _ = error "Expected a nonempty nest of idx binders" sumUsingPolysImp - :: Emits n => Atom SimpIR n - -> Abs (Binder SimpIR) (Block SimpIR) n -> BuilderM SimpIR n (SAtom n) + :: Emits n => SAtom n + -> Abs SBinder SExpr n -> BuilderM SimpIR n (SAtom n) sumUsingPolysImp lim (Abs i body) = do ab <- hoistDecls i body sumUsingPolys lim ab @@ -1147,30 +1135,31 @@ sumUsingPolysImp lim (Abs i body) = do hoistDecls :: ( Builder SimpIR m, EnvReader m, Emits n , BindsNames b, BindsEnv b, RenameB b, SinkableB b) - => b n l -> SBlock l -> m n (Abs b SBlock n) + => b n l -> SExpr l -> m n (Abs b SExpr n) hoistDecls b block = do emitDecls =<< liftEnvReaderM do - refreshAbs (Abs b block) \b' (Abs decls result) -> - hoistDeclsRec b' Empty decls result + refreshAbs (Abs b block) \b' body -> + hoistDeclsRec b' Empty body {-# INLINE hoistDecls #-} hoistDeclsRec :: (BindsNames b, SinkableB b) - => b n1 n2 -> SDecls n2 n3 -> SDecls n3 n4 -> SAtom n4 - -> EnvReaderM n3 (Abs SDecls (Abs b (Abs SDecls SAtom)) n1) -hoistDeclsRec b declsAbove Empty result = - return $ Abs Empty $ Abs b $ Abs declsAbove result -hoistDeclsRec b declsAbove (Nest decl declsBelow) result = do - let (Let _ expr) = decl - let exprIsPure = isPure expr - refreshAbs (Abs decl (Abs declsBelow result)) - \decl' (Abs declsBelow' result') -> - case exchangeBs (PairB (PairB b declsAbove) decl') of - HoistSuccess (PairB hoistedDecl (PairB b' declsAbove')) | exprIsPure -> do - Abs hoistedDecls fullResult <- hoistDeclsRec b' declsAbove' declsBelow' result' - return $ Abs (Nest hoistedDecl hoistedDecls) fullResult - _ -> hoistDeclsRec b (declsAbove >>> Nest decl' Empty) declsBelow' result' -{-# INLINE hoistDeclsRec #-} + => b n1 n2 -> SDecls n2 n3 -> SExpr n3 + -> EnvReaderM n3 (Abs SDecls (Abs b SExpr) n1) +hoistDeclsRec = undefined +-- hoistDeclsRec b declsAbove Empty result = +-- return $ Abs Empty $ Abs b $ Abs declsAbove result +-- hoistDeclsRec b declsAbove (Nest decl declsBelow) result = do +-- let (Let _ expr) = decl +-- let exprIsPure = isPure expr +-- refreshAbs (Abs decl (Abs declsBelow result)) +-- \decl' (Abs declsBelow' result') -> +-- case exchangeBs (PairB (PairB b declsAbove) decl') of +-- HoistSuccess (PairB hoistedDecl (PairB b' declsAbove')) | exprIsPure -> do +-- Abs hoistedDecls fullResult <- hoistDeclsRec b' declsAbove' declsBelow' result' +-- return $ Abs (Nest hoistedDecl hoistedDecls) fullResult +-- _ -> hoistDeclsRec b (declsAbove >>> Nest decl' Empty) declsBelow' result' +-- {-# INLINE hoistDeclsRec #-} -- === Imp IR builder === @@ -1366,8 +1355,6 @@ fromScalarAtom atom = atomToRepVal atom >>= \case toScalarAtom :: IExpr n -> SAtom n toScalarAtom x = RepValAtom $ RepVal (BaseTy (getIType x)) (Leaf x) --- TODO: we shouldn't need the rank-2 type here because ImpBuilder and Builder --- are part of the same conspiracy. liftBuilderImp :: (Emits n, SubstE AtomSubstVal e, SinkableE e) => (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l)) -> SubstImpM i n (e n) @@ -1376,18 +1363,6 @@ liftBuilderImp cont = do dropSubst $ translateDeclNest decls $ substM result {-# INLINE liftBuilderImp #-} -coreToImpBuilder - :: (Emits n, ImpBuilder m, SinkableE e, RenameE e, SubstE AtomSubstVal e ) - => (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l)) - -> m n (e n) -coreToImpBuilder cont = do - block <- liftBuilder $ buildScoped cont - result <- liftImpM $ buildScopedImp $ dropSubst do - Abs decls result <- sinkM block - translateDeclNest decls $ substM result - emitDeclsImp result -{-# INLINE coreToImpBuilder #-} - -- === Type classes === ordinalImp :: Emits n => IxType SimpIR n -> SAtom n -> SubstImpM i n (IExpr n) @@ -1415,7 +1390,7 @@ appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> appSpecializedIxMethod d method args = do SpecializedDict _ (Just fs) <- lookupSpecDict d TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method - dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body + dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateExpr body -- === Abstracting link-time objects === diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 981519268..b89c77bd1 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -52,7 +52,7 @@ checkTopUType ty = liftInfererM $ checkUType ty {-# SCC checkTopUType #-} inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) -inferTopUExpr e = fst <$> (asTopBlock =<< liftInfererM (buildScoped $ bottomUp e)) +inferTopUExpr e = fst <$> (asTopBlock =<< liftInfererM (buildBlock $ bottomUp e)) {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = @@ -112,7 +112,7 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d UExprDecl _ -> error "Shouldn't have this at the top level (should have become a command instead)" ULet letAnn p tyAnn rhs -> case p of WithSrcB _ (UPatBinder b) -> do - block <- liftInfererM $ buildScoped do + block <- liftInfererM $ buildBlock do checkMaybeAnnExpr tyAnn rhs (topBlock, resultTy) <- asTopBlock block let letAnn' = considerInlineAnn letAnn resultTy @@ -127,10 +127,11 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d return $ UDeclResultBindPattern (getNameHint p) topBlock recon {-# SCC inferTopUDecl #-} -asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n, CType n) +asTopBlock :: EnvReader m => CExpr n -> m n (TopBlock CoreIR n, CType n) asTopBlock block = do - effTy@(EffTy _ ty) <- blockEffTy block - return (TopLam False (PiType Empty effTy) (LamExpr Empty block), ty) + let effs = getEffects block + let ty = getType block + return (TopLam False (PiType Empty (EffTy effs ty)) (LamExpr Empty block), ty) getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do @@ -388,9 +389,8 @@ etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do (Explicit, arg) -> Just arg _ -> Nothing withAllowedEffects effs' do - body <- buildScoped $ cont (sink reqTy') (sink <$> explicits) - resultTy <- blockTy body - let piTy = CorePiType appExpl expls bs' (EffTy effs' resultTy) + body <- buildBlock $ cont (sink reqTy') (sink <$> explicits) + let piTy = CorePiType appExpl expls bs' (EffTy effs' $ getType body) return $ CoreLamExpr piTy $ LamExpr bs' body -- Doesn't introduce implicit pi binders or dependent pairs @@ -513,8 +513,7 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of let scrutTy = getType scrut' alt'@(IndexedAlt _ altAbs) <- checkCaseAlt Infer scrutTy alt Abs b ty <- liftEnvReaderM $ refreshAbs altAbs \b body -> do - ty <- blockTy body - return $ Abs b ty + return $ Abs b (getType body) resultTy <- liftHoistExcept $ hoist b ty alts' <- mapM (checkCaseAlt (Check resultTy) scrutTy) alts SigmaAtom Nothing <$> buildSortedCase scrut' (alt':alts') resultTy @@ -992,13 +991,12 @@ checkNamedArgValidity expls offeredNames = do inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do - xBlock <- buildScoped $ bottomUp x - EffTy _ ty <- blockEffTy xBlock - case ty of - TyKind -> reduceBlock xBlock >>= \case + xBlock <- buildBlock $ bottomUp x + case getType xBlock of + TyKind -> reduceExpr xBlock >>= \case Just reduced -> return reduced _ -> throw CompilerErr "Type args to primops must be reducible" - _ -> emitBlock xBlock + _ -> emitExpr xBlock matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case @@ -1009,13 +1007,13 @@ matchPrimApp = \case UNatCon -> \case ~[x] -> return $ NewtypeCon NatCon x UPrimTC op -> \x -> Type . TC <$> matchGenericOp (Right op) x UCon op -> \x -> Con <$> matchGenericOp (Right op) x - UMiscOp op -> \x -> emitOp =<< MiscOp <$> matchGenericOp op x - UMemOp op -> \x -> emitOp =<< MemOp <$> matchGenericOp op x - UBinOp op -> \case ~[x, y] -> emitOp $ BinOp op x y - UUnOp op -> \case ~[x] -> emitOp $ UnOp op x - UMAsk -> \case ~[r] -> emitOp $ RefOp r MAsk - UMGet -> \case ~[r] -> emitOp $ RefOp r MGet - UMPut -> \case ~[r, x] -> emitOp $ RefOp r $ MPut x + UMiscOp op -> \x -> emitExpr =<< MiscOp <$> matchGenericOp op x + UMemOp op -> \x -> emitExpr =<< MemOp <$> matchGenericOp op x + UBinOp op -> \case ~[x, y] -> emitExpr $ BinOp op x y + UUnOp op -> \case ~[x] -> emitExpr $ UnOp op x + UMAsk -> \case ~[r] -> emitExpr $ RefOp r MAsk + UMGet -> \case ~[r] -> emitExpr $ RefOp r MGet + UMPut -> \case ~[r, x] -> emitExpr $ RefOp r $ MPut x UIndexRef -> \case ~[r, i] -> indexRef r i UApplyMethod i -> \case ~(d:args) -> emitExpr =<< mkApplyMethod d i args ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x @@ -1025,7 +1023,7 @@ matchPrimApp = \case UWhile -> \case ~[f] -> do f' <- lam0 f; emitHof $ While f' URunIO -> \case ~[f] -> do f' <- lam0 f; emitHof $ RunIO f' UCatchException-> \case ~[f] -> do f' <- lam0 f; emitHof =<< mkCatchException f' - UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emitOp $ RefOp r $ MExtend (BaseMonoid z f') x + UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emitExpr $ RefOp r $ MExtend (BaseMonoid z f') x URunWriter -> \args -> do [idVal, combiner, f] <- return args combiner' <- lam2 combiner @@ -1043,7 +1041,7 @@ matchPrimApp = \case ExplicitCoreLam (UnaryNest b) body <- return x return $ UnaryLamExpr b body - lam0 :: Fallible m => CAtom n -> m (CBlock n) + lam0 :: Fallible m => CAtom n -> m (CExpr n) lam0 x = do ExplicitCoreLam Empty body <- return x return body @@ -1058,7 +1056,7 @@ matchPrimApp = \case _ -> return $ Right x return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs [] -pattern ExplicitCoreLam :: Nest CBinder n l -> CBlock l -> CAtom n +pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n pattern ExplicitCoreLam bs body <- Lam (CoreLamExpr _ (LamExpr bs body)) -- === n-ary applications === @@ -1116,8 +1114,8 @@ buildNthOrderedAlt alts scrutTy resultTy i v = do case lookup (nthCaseAltIdx scrutTy i) [(idx, alt) | IndexedAlt idx alt <- alts] of Nothing -> do resultTy' <- sinkM resultTy - emitOp $ MiscOp $ ThrowError resultTy' - Just alt -> applyAbs alt (SubstVal v) >>= emitBlock + emitExpr $ ThrowError resultTy' + Just alt -> applyAbs alt (SubstVal v) >>= emitExpr -- converts from the ordinal index used in the core IR to the more complicated -- `CaseAltIndex` used in the surface IR. @@ -1151,7 +1149,7 @@ buildSortedCase scrut alts resultTy = do [_] -> do let [IndexedAlt _ alt] = alts scrut' <- unwrapNewtype scrut - emitBlock =<< applyAbs alt (SubstVal scrut') + emitExpr =<< applyAbs alt (SubstVal scrut') _ -> liftEmitBuilder $ buildMonomorphicCase alts scrut resultTy _ -> fail $ "Unexpected case expression type: " <> pprint scrutTy @@ -1163,9 +1161,8 @@ instanceFun instanceName appExpl = do args <- mapM toAtomVar $ nestToNames bs' result <- DictCon <$> mkInstanceDict (sink instanceName) (Var <$> args) let effTy = EffTy Pure (getType result) - let body = WithoutDecls result let piTy = CorePiType appExpl (snd<$>expls) bs' effTy - return $ Lam $ CoreLamExpr piTy (LamExpr bs' body) + return $ Lam $ CoreLamExpr piTy (LamExpr bs' $ Atom result) checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) checkMaybeAnnExpr ty expr = confuseGHC >>= \_ -> case ty of @@ -1341,9 +1338,9 @@ checkULamPartial partialPiTy lamExpr = do lamEffs'' <- checkUEffRow lamEffs' expectEq (Eff piEffs') (Eff lamEffs'') body' <- withAllowedEffects piEffs' do - buildScoped $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result + buildBlock $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result resultTy' <- case resultTy of - Infer -> blockTy body' + Infer -> return $ getType body' Check t -> return t let piTy = CorePiType piAppExpl expls lamBs' (EffTy piEffs' resultTy') return $ CoreLamExpr piTy (LamExpr lamBs' body') @@ -1373,7 +1370,7 @@ checkULamPartial partialPiTy lamExpr = do inferUForExpr :: Emits o => UForExpr i -> InfererM i o (LamExpr CoreIR o) inferUForExpr (UForExpr b body) = do withUBinder b \(WithAttrB _ b') -> do - body' <- buildScoped $ withBlockDecls body \result -> bottomUp result + body' <- buildBlock $ withBlockDecls body \result -> bottomUp result return $ LamExpr (UnaryNest b') body' checkUForExpr :: Emits o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) @@ -1390,13 +1387,12 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do Abs (ZipB expls bs') (PairE effTy body') <- inferUBinders bs \_ -> do effs' <- fromMaybe Pure <$> mapM checkUEffRow effs resultTy' <- mapM checkUType resultTy - body' <- buildScoped $ withAllowedEffects (sink effs') do + body' <- buildBlock $ withAllowedEffects (sink effs') do withBlockDecls body \result -> case resultTy' of Nothing -> bottomUp result Just resultTy'' -> topDown (sink resultTy'') result - resultTy'' <- blockTy body' - let effTy = EffTy effs' resultTy'' + let effTy = EffTy effs' (getType body') return $ PairE effTy body' return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body') @@ -1510,7 +1506,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do - buildScoped do + buildBlock do args <- forM idxs \projs -> do ans <- applyProjectionsReduced (init projs) (sink $ Var $ binderVar b) emit $ Atom ans @@ -2065,7 +2061,7 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of Abs bs' <$> synthTerm (SynthDictType targetTy') reqMethodAccess Abs bs' synthExpr <- return ab' let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) - let lamExpr = LamExpr bs' (WithoutDecls synthExpr) + let lamExpr = LamExpr bs' (Atom synthExpr) return $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon $ IxFin (DictTy dictTy) n @@ -2183,12 +2179,14 @@ type WithRoleExpl = WithAttrB RoleExpl buildBlockInfWithRecon :: HasNamesE e => (forall l. (Emits l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) + -> InfererM i n (PairE CExpr (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do ab <- buildScoped cont - liftEnvReaderM $ liftM toPairE $ refreshAbs ab \decls result -> do + (block, recon) <- liftEnvReaderM $ refreshAbs ab \decls result -> do (newResult, recon) <- telescopicCapture decls result return (Abs decls newResult, recon) + block' <- mkBlock block + return $ PairE block' recon {-# INLINE buildBlockInfWithRecon #-} -- === IFunType === diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 1e6ff6655..1a271a34f 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -21,8 +21,7 @@ import Types.Primitives -- === External API === inlineBindings :: (EnvReader m) => STopLam n -> m n (STopLam n) -inlineBindings = liftLamExpr \(Abs decls ans) -> liftInlineM $ - buildScoped $ inlineDecls decls $ inline Stop ans +inlineBindings lam = liftLamExpr lam \body -> liftInlineM $ buildBlock $ inlineExpr Stop body {-# INLINE inlineBindings #-} {-# SCC inlineBindings #-} @@ -73,12 +72,6 @@ data SizePreservationInfo = | UsedMulti deriving (Eq, Show) -inlineDecls :: Emits o => Nest SDecl i i' -> InlineM i' o a -> InlineM i o a -inlineDecls decls cont = do - s <- inlineDeclsSubst decls - withSubst s cont -{-# INLINE inlineDecls #-} - inlineDeclsSubst :: Emits o => Nest SDecl i i' -> InlineM i o (Subst InlineSubstVal i' o) inlineDeclsSubst = \case Empty -> getSubst @@ -91,23 +84,7 @@ inlineDeclsSubst = \case -- If the inliner starts moving effectful expressions, it may become -- necessary to query the effects of the new expression here. let presInfo = resolveWorkConservation ann expr' - -- A subtlety from the Secrets paper. In Haskell, it is feasible to have - -- a binding whose occurrence information indicates multiple uses, but - -- which does a small, bounded amount of runtime work. GHC will inline - -- such a binding, but not into contexts where GHC knows that no further - -- optimizations are possible. The example given in the paper is - -- f = \x -> E - -- g = \ys -> map f ys - -- Inlining f here is useless because it's not applied, and mildly costly - -- because it causes the closure to be allocated at every call to g rather - -- than just once. - -- TODO If we want to track this subtlety, we should make room for it in - -- the SizePreservationInfo ADT (maybe rename it), maybe with a - -- OnceButDuplicatesBoundedWork constructor. Then only the true UsedOnce - -- would be inlined unconditionally here, and the - -- OnceButDuplicatesBoundedWork constructor could be inlined or not - -- depending on its usage context. (This would correspond to the case - -- OnceUnsafe with whnfOrBot == True in the Secrets paper.) + -- See NoteSecretsSubtlety if presInfo == UsedOnce then do let substVal = case expr' of Atom (Var name') -> Rename $ atomVarName name' @@ -117,33 +94,7 @@ inlineDeclsSubst = \case -- expr' can't be Atom (Var x) here name' <- emitDecl (getNameHint b) (dropOccInfo ann) expr' extendSubst (b @> Rename (atomVarName name')) do - -- TODO For now, this inliner does not do any conditional inlining. - -- In order to do it, we would need to augment the environment at this - -- point, associating name' to (expr', presInfo) so name' could be - -- inlined at use sites. - -- - -- Conditional inlining is different in Dex vs Haskell because Dex is - -- strict. To wit, once we have emitted the bidning for `expr'`, we - -- are committed to doing the work it represents unless it's inlined - -- _everywhere_. For example, - -- xs = - -- case of - -- Nothing -> xs -- ok to inline here - -- Just _ -> xs ... xs -- not ok here - -- If this were Haskell, it would be work-preserving for GHC to inline - -- `xs` into the `Nothing` arm, but in Dex it's not, unless we first - -- explicitly push the binding into the case like - -- case of - -- Nothing -> xs = ; xs - -- Just _ -> xs = ; xs ... xs - -- - -- That said, the Secrets paper says that GHC only conditionally - -- inlines zero-work bindings anyway (or, more precisely, "bounded - -- finite work" bindings). All the heuristics about whether to inline - -- at a particular site are about code size and not increasing it - -- overmuch. But, of course, inlining even zero-work bindings can - -- help runtime performance because it can unblock other optimizations - -- that otherwise could not occur across the binding. + -- See NoteConditionalInlining inlineDeclsSubst rest where dropOccInfo PlainLet = PlainLet @@ -214,12 +165,8 @@ inlineDeclsSubst = \case -- since their main purpose is to force inlining in the simplifier, and if -- one just stuck like this it has become equivalent to a `for` anyway. ixDepthExpr :: Expr SimpIR n -> Int - ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthBlock body + ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthExpr body ixDepthExpr _ = 0 - ixDepthBlock :: Block SimpIR n -> Int - ixDepthBlock (exprBlock -> (Just expr)) = ixDepthExpr expr - ixDepthBlock (Abs Empty result) = ixDepthExpr $ Atom result - ixDepthBlock _ = 0 -- Should we decide to inline this binding wherever it appears, before we even -- know the expression? "Yes" only if we know it only occurs once, and in a @@ -277,6 +224,9 @@ inlineExpr ctx = \case Case scrut alts (EffTy effs resultTy) -> do s <- getSubst inlineAtom (CaseCtx alts resultTy effs s ctx) scrut + Block _ (Abs decls ans) -> do + s <- inlineDeclsSubst decls + withSubst s $ inlineExpr ctx ans expr -> visitGeneric expr >>= reconstruct ctx inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o) @@ -314,10 +264,10 @@ instance Inlinable SType where inline ctx ty = visitTypePartial ty >>= reconstruct ctx instance Inlinable SLam where - inline ctx (LamExpr bs (Abs decls ans)) = do + inline ctx (LamExpr bs body) = do reconstruct ctx =<< withBinders bs \bs' -> do - (LamExpr bs' <$>) $ buildScoped $ - inlineDecls decls $ inline Stop ans + body' <- buildBlock $ inlineExpr Stop body + return $ LamExpr bs' body' withBinders :: Nest SBinder i i' @@ -336,10 +286,6 @@ instance Inlinable (PiType SimpIR) where effTy' <- buildScopedAssumeNoDecls $ inline Stop effTy return $ PiType bs' effTy' -inlineBlockEmits :: Emits o => Context SExpr e2 o -> SBlock i -> InlineM i o (e2 o) -inlineBlockEmits ctx (Abs decls ans) = do - inlineDecls decls $ inlineAtom ctx ans - -- Still using InlineM because we may call back into inlining, and we wish to -- retain our output binding environment. reconstruct :: Emits o => Context e1 e2 o -> e1 o -> InlineM i o (e2 o) @@ -358,50 +304,15 @@ reconstructTabApp ctx expr [] = do reconstruct ctx expr reconstructTabApp ctx expr ixs = case fromNaryForExpr (length ixs) expr of - Just (bsCount, LamExpr bs (Abs decls result)) -> do + Just (bsCount, LamExpr bs body) -> do + -- See NoteReconstructTabAppDecisions let (ixsPref, ixsRest) = splitAt bsCount ixs - -- Note: There's a decision here. Is it ok to inline the atoms in - -- `ixsPref` into the body `decls`? If so, should we pre-process them and - -- carry them in `DoneEx`, or suspend them in `SuspEx`? (If not, we can - -- emit fresh bindings and use `Rename`.) We can't make this decision - -- properly without annotating the `for` binders with occurrence - -- information; even though `ixsPref` itself are atoms, we may be carrying - -- suspended inlining decisions that would want to make one an expression, - -- and thus force-inlining it may duplicate work. - -- - -- There remains a decision between just emitting bindings, or running - -- `mapM (inline $ EmitToAtomCtx Stop)` and inlining the resulting atoms. - -- In the work-heavy case where an element of `ixsPref` becomes an - -- expression after inlining, the result will be the same; but in the - -- work-light case where the element remains an atom, more inlining can - -- proceed. This decision only affects the runtime of the inliner and the - -- code size of the IR the inliner produces. - -- - -- Current status: Emitting bindings in the interest if "launch and - -- iterate"; have not tried `EmitToAtomCtx`. ixsPref' <- mapM (inline $ EmitToNameCtx Stop) ixsPref let ixsPref'' = [v | AtomVar v _ <- ixsPref'] s <- getSubst let moreSubst = bs @@> map Rename ixsPref'' dropSubst $ extendSubst moreSubst do - -- Decision here. These decls have already been processed by the - -- inliner once, so their occurrence information is stale (and should - -- have been erased). Do we rerun occurrence analysis, or just complete - -- the pass without inlining any of them? - -- - Con rerunning: Slower - -- - Con completing: No detection of erroneous lack of occurrence info - -- For now went with "completing"; to detect erroneous lack of - -- occurrence info, change the relevant PlainLet cases above. - -- - -- There's also a missed opportunity here to do more inlining in one - -- pass: we lost the occurrence information of the bindings, so we lost - -- the ability to inline them into the result, so in the common case - -- that the result is a variable reference, we will find ourselves - -- emitting a rename, _which will inhibit downstream inlining_ because a - -- rename is not indexable. - inlineDecls decls do - let ctx' = TabAppCtx ixsRest s ctx - inlineAtom ctx' result + inlineExpr (TabAppCtx ixsRest s ctx) body Nothing -> do array' <- emitExprToAtom expr ixs' <- mapM (inline Stop) ixs @@ -418,11 +329,11 @@ reconstructCase ctx scrutExpr alts resultTy effs = -- of the arms of the outer case resultTy' <- inline Stop resultTy reconstruct ctx =<< (buildCase' sscrut resultTy' \i val -> do - ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emitBlock + ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emitExpr buildCase ans (sink resultTy') \j jval -> do Abs b body <- return $ alts !! j extendSubst (b @> (SubstVal $ DoneEx $ Atom jval)) do - inlineBlockEmits Stop body >>= emitExprToAtom) + inlineExpr Stop body >>= emitExprToAtom) _ -> do -- Attempt case-of-known-constructor optimization -- I can't use `buildCase` here because I want to propagate the incoming @@ -433,7 +344,7 @@ reconstructCase ctx scrutExpr alts resultTy effs = Just (i, val) -> do Abs b body <- return $ alts !! i extendSubst (b @> (SubstVal $ DoneEx $ Atom val)) do - inlineBlockEmits ctx body + inlineExpr ctx body Nothing -> do alts' <- mapM visitAlt alts resultTy' <- inline Stop resultTy @@ -442,3 +353,84 @@ reconstructCase ctx scrutExpr alts resultTy effs = instance Inlinable (EffectRow SimpIR) instance Inlinable (EffTy SimpIR) + +-- === NoteReconstructTabAppDecisions === + +-- There's a decision here. Is it ok to inline the atoms in `ixsPref` into the +-- body `decls`? If so, should we pre-process them and carry them in `DoneEx`, +-- or suspend them in `SuspEx`? (If not, we can emit fresh bindings and use +-- `Rename`.) We can't make this decision properly without annotating the `for` +-- binders with occurrence information; even though `ixsPref` itself are atoms, +-- we may be carrying suspended inlining decisions that would want to make one +-- an expression, and thus force-inlining it may duplicate work. +-- +-- There remains a decision between just emitting bindings, or running `mapM +-- (inline $ EmitToAtomCtx Stop)` and inlining the resulting atoms. In the +-- work-heavy case where an element of `ixsPref` becomes an expression after +-- inlining, the result will be the same; but in the work-light case where the +-- element remains an atom, more inlining can proceed. This decision only +-- affects the runtime of the inliner and the code size of the IR the inliner +-- produces. +-- +-- Current status: Emitting bindings in the interest if "launch and iterate"; +-- have not tried `EmitToAtomCtx`. Decision here. These decls have already been +-- processed by the inliner once, so their occurrence information is stale (and +-- should have been erased). Do we rerun occurrence analysis, or just complete +-- the pass without inlining any of them? +-- - Con rerunning: Slower +-- - Con completing: No detection of erroneous lack of occurrence info +-- For now went with "completing"; to detect erroneous lack of +-- occurrence info, change the relevant PlainLet cases above. +-- +-- There's also a missed opportunity here to do more inlining in one pass: we +-- lost the occurrence information of the bindings, so we lost the ability to +-- inline them into the result, so in the common case that the result is a +-- variable reference, we will find ourselves emitting a rename, _which will +-- inhibit downstream inlining_ because a rename is not indexable. + +-- === NoteConditionalInlining === + +-- TODO For now, this inliner does not do any conditional inlining. In order to +-- do it, we would need to augment the environment at this point, associating +-- name' to (expr', presInfo) so name' could be inlined at use sites. +-- +-- Conditional inlining is different in Dex vs Haskell because Dex is strict. To +-- wit, once we have emitted the bidning for `expr'`, we are committed to doing +-- the work it represents unless it's inlined _everywhere_. For example, +-- xs = +-- case of +-- Nothing -> xs -- ok to inline here +-- Just _ -> xs ... xs -- not ok here +-- If this were Haskell, it would be work-preserving for GHC to inline +-- `xs` into the `Nothing` arm, but in Dex it's not, unless we first +-- explicitly push the binding into the case like +-- case of +-- Nothing -> xs = ; xs +-- Just _ -> xs = ; xs ... xs +-- +-- That said, the Secrets paper says that GHC only conditionally inlines +-- zero-work bindings anyway (or, more precisely, "bounded finite work" +-- bindings). All the heuristics about whether to inline at a particular site +-- are about code size and not increasing it overmuch. But, of course, inlining +-- even zero-work bindings can help runtime performance because it can unblock +-- other optimizations that otherwise could not occur across the binding. + +-- === NoteSecretsSubtlety === + +-- A subtlety from the Secrets paper. In Haskell, it is feasible to have a +-- binding whose occurrence information indicates multiple uses, but which does +-- a small, bounded amount of runtime work. GHC will inline such a binding, but +-- not into contexts where GHC knows that no further optimizations are possible. +-- The example given in the paper is +-- f = \x -> E +-- g = \ys -> map f ys +-- Inlining f here is useless because it's not applied, and mildly costly +-- because it causes the closure to be allocated at every call to g rather than +-- just once. +-- TODO If we want to track this subtlety, we should make room for it in +-- the SizePreservationInfo ADT (maybe rename it), maybe with a +-- OnceButDuplicatesBoundedWork constructor. Then only the true UsedOnce +-- would be inlined unconditionally here, and the +-- OnceButDuplicatesBoundedWork constructor could be inlined or not +-- depending on its usage context. (This would correspond to the case +-- OnceUnsafe with whnfOrBot == True in the Secrets paper.) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 80986e383..4fbae982c 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -257,28 +257,29 @@ instance ReconFunctor ObligateReconAbs where linLam' <- applyReconAbs reconAbs residuals return (primal, linLam') -linearizeBlockDefunc :: SBlock i -> PrimalM i o (SBlock o, LinLamAbs o) -linearizeBlockDefunc = linearizeBlockDefuncGeneral emptyOutFrag +linearizeExprDefunc :: SExpr i -> PrimalM i o (SExpr o, LinLamAbs o) +linearizeExprDefunc = linearizeExprDefuncGeneral emptyOutFrag -linearizeBlockDefuncGeneral +linearizeExprDefuncGeneral :: ReconFunctor f - => ScopeFrag o' o -> SBlock i -> PrimalM i o (SBlock o, f SLam o') -linearizeBlockDefuncGeneral locals block = do + => ScopeFrag o' o -> SExpr i -> PrimalM i o (SExpr o, f SLam o') +linearizeExprDefuncGeneral locals expr = do Abs decls result <- buildScoped do - WithTangent primalResult tangentFun <- linearizeBlock block + WithTangent primalResult tangentFun <- linearizeExpr expr lam <- tangentFunAsLambda tangentFun return $ PairE primalResult lam - (block', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do + (Abs decls' result', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do (primal', recon) <- capture (locals >>> toScopeFrag decls') primal lam return (Abs decls' primal', recon) - return (block', recon) + block <- mkBlock (Abs decls' result') + return (block, recon) -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom SimpIR o) applyLinLam (LamExpr bs body) = do TangentArgs args <- liftSubstReaderT $ getTangentArgs extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do - substM body >>= emitBlock + substM body >>= emitExpr -- === actual linearization passs === @@ -295,7 +296,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do True -> return $ Just v False -> return $ Nothing (body', linLamAbs) <- extendActivePrimalss activeVs do - linearizeBlockDefuncGeneral emptyOutFrag body + linearizeExprDefuncGeneral emptyOutFrag body let primalFun = LamExpr bs' body' ObligateRecon ty (Abs bsRecon (LamExpr bsTangent tangentBody)) <- return linLamAbs tangentFun <- withFreshBinder "residuals" ty \bResidual -> do @@ -307,7 +308,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do ts <- getUnpacked $ Var $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts - emitBlock =<< applySubst substFrag tangentBody + emitExpr =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun @@ -318,7 +319,7 @@ linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam linearizeLambdaApp (UnaryLamExpr b body) x = do vp <- emit $ Atom x extendActiveSubst b vp do - WithTangent primalResult tangentAction <- linearizeBlock body + WithTangent primalResult tangentAction <- linearizeExpr body tanFun <- tangentFunAsLambda tangentAction return (primalResult, tanFun) linearizeLambdaApp _ _ = error "not implemented" @@ -337,10 +338,6 @@ linearizeAtom atom = case atom of RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom -linearizeBlock :: Emits o => SBlock i -> LinM i o SAtom SAtom -linearizeBlock (Abs decls result) = - linearizeDecls decls $ linearizeAtom result - linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 linearizeDecls Empty cont = cont -- TODO: as an optimization, don't bother extending the tangent args if the @@ -366,6 +363,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do linearizeExpr :: Emits o => SExpr i -> LinM i o SAtom SAtom linearizeExpr expr = case expr of Atom x -> linearizeAtom x + Block _ (Abs decls result) -> linearizeDecls decls $ linearizeExpr result TopApp _ f xs -> do (xs', ts) <- unzip <$> forM xs \x -> do x' <- renameM x @@ -403,7 +401,7 @@ linearizeExpr expr = case expr of (alts', recons) <- unzip <$> buildCaseAlts e' \i b' -> do Abs b body <- return $ alts !! i extendSubst (b@>binderName b') do - (block, recon) <- linearizeBlockDefuncGeneral (toScopeFrag b') body + (block, recon) <- linearizeExprDefuncGeneral (toScopeFrag b') body return (Abs b' block, recon) let tys = recons <&> \(ObligateRecon t _) -> t alts'' <- forM (enumerate alts') \(i, alt) -> do @@ -445,8 +443,8 @@ linearizeOp op = case op of liftM Var $ emit $ PrimOp $ RefOp ref' $ MPut x' IndexRef _ i -> do zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> - emitOp =<< mkIndexRef ref' i' - ProjRef _ i -> la ref `bindLin` \ref' -> emitOp =<< mkProjRef ref' i + emitExpr =<< mkIndexRef ref' i' + ProjRef _ i -> la ref `bindLin` \ref' -> emitExpr =<< mkProjRef ref' i UnOp uop x -> linearizeUnOp uop x BinOp bop x y -> linearizeBinOp bop x y -- XXX: This assumes that pointers are always constants @@ -462,7 +460,7 @@ linearizeMiscOp op = case op of SumTag _ -> emitZeroT ToEnum _ _ -> emitZeroT Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin` - \(p' `PairE` t' `PairE` f') -> emitOp $ MiscOp $ Select p' t' f' + \(p' `PairE` t' `PairE` f') -> emitExpr $ MiscOp $ Select p' t' f' CastOp t v -> do vt <- getType <$> renameM v t' <- renameM t @@ -471,14 +469,14 @@ linearizeMiscOp op = case op of ((&&) <$> (vtTangentType `alphaEq` vt) <*> (tTangentType `alphaEq` t')) >>= \case True -> do - linearizeAtom v `bindLin` \v' -> emitOp $ MiscOp $ CastOp (sink t') v' + linearizeAtom v `bindLin` \v' -> emitExpr $ MiscOp $ CastOp (sink t') v' False -> do WithTangent x xt <- linearizeAtom v yt <- case (vtTangentType, tTangentType) of (_ , UnitTy) -> return $ UnitVal (UnitTy, tt ) -> zeroAt tt _ -> error "Expected at least one side of the CastOp to have a trivial tangent type" - y <- emitOp $ MiscOp $ CastOp t' x + y <- emitExpr $ MiscOp $ CastOp t' x return $ WithTangent y do xt >> return (sink yt) BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented @@ -495,7 +493,7 @@ linearizeMiscOp op = case op of linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom linearizeUnOp op x' = do WithTangent x tx <- linearizeAtom x' - let emitZeroT = withZeroT $ emitOp $ UnOp op x + let emitZeroT = withZeroT $ emitExpr $ UnOp op x case op of Exp -> do y <- emitUnOp Exp x @@ -526,7 +524,7 @@ linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom linearizeBinOp op x' y' = do WithTangent x tx <- linearizeAtom x' WithTangent y ty <- linearizeAtom y' - let emitZeroT = withZeroT $ emitOp $ BinOp op x y + let emitZeroT = withZeroT $ emitExpr $ BinOp op x y case op of IAdd -> emitZeroT ISub -> emitZeroT @@ -544,7 +542,7 @@ linearizeBinOp op x' y' = do ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty) (bindM2 mul (referToPrimal y) (referToPrimal y)) sub tx' ty' - FPow -> withT (emitOp $ BinOp FPow x y) do + FPow -> withT (emitExpr $ BinOp FPow x y) do px <- referToPrimal x py <- referToPrimal y c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px @@ -590,7 +588,7 @@ linearizeHof hof = case hof of UnaryLamExpr ib body <- return lam ixTy <- renameM ixTy' (lam', Abs ib' linLam) <- withFreshBinder noHint (ixTypeType ixTy) \ib' -> do - (block', linLam) <- extendSubst (ib@>binderName ib') $ linearizeBlockDefunc body + (block', linLam) <- extendSubst (ib@>binderName ib') $ linearizeExprDefunc body return (UnaryLamExpr ib' block', Abs ib' linLam) primalsAux <- emitHof $ For d ixTy lam' case linLam of @@ -649,7 +647,7 @@ linearizeHof hof = case hof of withSubstReaderT $ applyLinLam $ sink linLam emitHof $ RunWriter Nothing bm'' tanEffLam RunIO body -> do - (body', recon) <- linearizeBlockDefunc body + (body', recon) <- linearizeExprDefunc body primalAux <- emitHof $ RunIO body' (primal, linLam) <- reconstruct primalAux recon return $ WithTangent primal do @@ -666,7 +664,7 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do (body', linLam) <- extendActiveSubst hB hVar $ extendActiveSubst refB ref $ -- TODO: maybe we should check whether we need to extend the active effects extendActiveEffs (RWSEffect rws (Var hVar)) do - linearizeBlockDefunc body + linearizeExprDefunc body -- TODO: this assumes that references aren't returned. Our type system -- ensures that such references can never be *used* once the effect runner -- returns, but technically it's legal to return them. diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 96a1ca8e8..99fd67ef7 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -13,7 +13,6 @@ module Lower import Prelude hiding ((.)) import Data.Functor import Data.Maybe (fromJust) -import Data.List.NonEmpty qualified as NE import Control.Category import Control.Monad.Reader import Unsafe.Coerce @@ -59,45 +58,32 @@ import Util (enumerate) -- destination to a sub-block or sub-expression, hence "desintation -- passing style"). -type DestBlock = Abs (SBinder) SBlock +type DestBlock = Abs (SBinder) SExpr lowerFullySequential :: EnvReader m => Bool -> STopLam n -> m n (STopLam n) -lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM $ do - lam <- case wantDestStyle of - True -> do - refreshAbs (Abs bs body) \bs' body' -> do +lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM do + lam <- refreshAbs (Abs bs body) \bs' body' -> + liftAtomSubstBuilder case wantDestStyle of + True -> do xs <- bindersToAtoms bs' EffTy _ resultTy <- instantiate (sink piTy) xs - Abs b body'' <- lowerFullySequentialBlock resultTy body' - return $ LamExpr (bs' >>> UnaryNest b) body'' - False -> do - refreshAbs (Abs bs body) \bs' body' -> do - body'' <- lowerFullySequentialBlockNoDest body' - return $ LamExpr bs' body'' + let resultDestTy = RawRefTy resultTy + withFreshBinder "ans" resultDestTy \destBinder -> do + let dest = Var $ binderVar destBinder + LamExpr (bs' >>> UnaryNest destBinder) <$> buildBlock do + lowerExpr (Just (sink dest)) body' $> UnitVal + False -> LamExpr bs' <$> buildBlock (lowerExpr Nothing body') piTy' <- getLamExprType lam return $ TopLam wantDestStyle piTy' lam lowerFullySequential _ (TopLam True _ _) = error "already in destination style" -lowerFullySequentialBlock :: EnvReader m => SType n -> SBlock n -> m n (DestBlock n) -lowerFullySequentialBlock resultTy b = liftAtomSubstBuilder do - let resultDestTy = RawRefTy resultTy - withFreshBinder (getNameHint @String "ans") resultDestTy \destBinder -> do - Abs destBinder <$> buildBlock do - let dest = Var $ sink $ binderVar destBinder - lowerBlockWithDest dest b $> UnitVal -{-# SCC lowerFullySequentialBlock #-} - -lowerFullySequentialBlockNoDest :: EnvReader m => SBlock n -> m n (SBlock n) -lowerFullySequentialBlockNoDest b = liftAtomSubstBuilder $ buildBlock $ lowerBlock b -{-# SCC lowerFullySequentialBlockNoDest #-} - data LowerTag type LowerM = AtomSubstBuilder LowerTag SimpIR instance NonAtomRenamer (LowerM i o) i o where renameN = substM instance ExprVisitorEmits (LowerM i o) SimpIR i o where - visitExprEmits = lowerExpr + visitExprEmits = lowerExpr Nothing instance Visitor (LowerM i o) SimpIR i o where visitAtom = visitAtomDefault @@ -105,27 +91,13 @@ instance Visitor (LowerM i o) SimpIR i o where visitPi = visitPiDefault visitLam = visitLamEmits -lowerExpr :: Emits o => SExpr i -> LowerM i o (SAtom o) -lowerExpr expr = emitExpr =<< case expr of - TabCon Nothing ty els -> lowerTabCon Nothing ty els - PrimOp (Hof (TypedHof (EffTy _ resultTy) (For dir ixDict body))) -> do - resultTy' <- substM resultTy - lowerFor resultTy' Nothing dir ixDict body - -- this case is important because this pass changes effects - PrimOp (Hof (TypedHof _ hof)) -> - PrimOp . Hof <$> (visitGeneric hof >>= mkTypedHof) - Case e alts (EffTy _ ty) -> lowerCase Nothing e alts ty - _ -> visitGeneric expr - -lowerBlock :: Emits o => SBlock i -> LowerM i o (SAtom o) -lowerBlock = visitBlockEmits - -type Dest = Atom +type Dest = SAtom +type OptDest n = Maybe (Dest n) lowerFor :: Emits o - => SType o -> Maybe (Dest SimpIR o) -> ForAnn -> IxType SimpIR i -> LamExpr SimpIR i - -> LowerM i o (SExpr o) + => SType o -> OptDest o -> ForAnn -> IxType SimpIR i -> LamExpr SimpIR i + -> LowerM i o (SAtom o) lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do ixTy' <- substM ixTy ty' <- substM ty @@ -133,30 +105,29 @@ lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do True -> do body' <- buildUnaryLamExpr noHint (PairTy ty' UnitTy) \b' -> do (i, _) <- fromPair $ Var b' - extendSubst (ib @> SubstVal i) $ lowerBlock body $> UnitVal + extendSubst (ib @> SubstVal i) $ lowerExpr Nothing body $> UnitVal void $ emitSeq dir ixTy' UnitVal body' - Atom . fromJust <$> singletonTypeVal ansTy + fromJust <$> singletonTypeVal ansTy False -> do initDest <- ProdVal . (:[]) <$> case maybeDest of Just d -> return d - Nothing -> emitOp $ DAMOp $ AllocDest ansTy + Nothing -> emitExpr $ AllocDest ansTy let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do (i, destProd) <- fromPair $ Var b' dest <- proj 0 destProd - idest <- emitOp =<< mkIndexRef dest i - extendSubst (ib @> SubstVal i) $ lowerBlockWithDest idest body $> UnitVal + idest <- emitExpr =<< mkIndexRef dest i + extendSubst (ib @> SubstVal i) $ lowerExpr (Just idest) body $> UnitVal ans <- emitSeq dir ixTy' initDest body' >>= proj 0 - return $ PrimOp $ DAMOp $ Freeze ans + emitExpr $ Freeze ans lowerFor _ _ _ _ _ = error "expected a unary lambda expression" -lowerTabCon :: forall i o. Emits o - => Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o) +lowerTabCon :: Emits o => OptDest o -> SType i -> [SAtom i] -> LowerM i o (SAtom o) lowerTabCon maybeDest tabTy elems = do TabPi tabTy' <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy' + Nothing -> emitExpr $ AllocDest $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used @@ -168,23 +139,23 @@ lowerTabCon maybeDest tabTy elems = do let go incoming_dest [] = return incoming_dest go incoming_dest ((ord, e):rest) = do i <- dropSubst $ extendSubst (bord@>SubstVal (IdxRepVal (fromIntegral ord))) $ - lowerBlock ufoBlock + lowerExpr Nothing ufoBlock carried_dest <- buildRememberDest "dest" incoming_dest \local_dest -> do idest <- indexRef (Var local_dest) (sink i) - place (FullDest idest) =<< visitAtom e + place idest =<< visitAtom e return UnitVal go carried_dest rest dest' <- go dest (enumerate elems) - return $ PrimOp $ DAMOp $ Freeze dest' + emitExpr $ Freeze dest' lowerCase :: Emits o - => Maybe (Dest SimpIR o) -> SAtom i -> [Alt SimpIR i] -> SType i - -> LowerM i o (SExpr o) + => OptDest o -> SAtom i -> [Alt SimpIR i] -> SType i + -> LowerM i o (SAtom o) lowerCase maybeDest scrut alts resultTy = do resultTy' <- substM resultTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest resultTy' + Nothing -> emitExpr $ AllocDest resultTy' scrut' <- visitAtom scrut dest' <- buildRememberDest "case_dest" dest \local_dest -> do alts' <- forM alts \(Abs (b:>ty) body) -> do @@ -192,10 +163,10 @@ lowerCase maybeDest scrut alts resultTy = do buildAbs (getNameHint b) ty' \b' -> extendSubst (b @> Rename (atomVarName b')) $ buildBlock do - lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal + lowerExpr (Just (Var $ sink $ local_dest)) body $> UnitVal void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr return UnitVal - return $ PrimOp $ DAMOp $ Freeze dest' + emitExpr $ Freeze dest' -- Destination-passing traversals -- @@ -217,17 +188,9 @@ lowerCase maybeDest scrut alts resultTy = do -- so that it never allocates scratch space for its result, but will put it directly in -- the corresponding slice of the full 2D buffer. -type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (ProjDest o) i' - -data ProjDest o - = FullDest (Dest SimpIR o) - | ProjDest (NE.NonEmpty Int) (Dest SimpIR o) -- dest corresponds to the projection applied to name - deriving (Show) +type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (Dest o) i' -instance SinkableE ProjDest where - sinkingProofE = todoSinkableProof - -lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o) +lookupDest :: DestAssignment i' o -> SAtomName i' -> OptDest o lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- Matches up the free variables of the atom, with the given dest. For example, if the @@ -237,42 +200,11 @@ lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- as much as possible, but it can lead to unnecessary copies being done at run-time. -- -- XXX: When adding more cases, be careful about potentially repeated vars in the output! -decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o)) +decomposeDest :: Emits o => Dest o -> SExpr i' -> LowerM i o (Maybe (DestAssignment i' o)) decomposeDest dest = \case - Stuck (StuckVar v) -> - return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest - Stuck (StuckProject _ p x) -> do - (ps, v) <- return $ asNaryProj p x - return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ ProjDest ps dest + Atom (Stuck (StuckVar v)) -> + return $ Just $ singletonNameMapE (atomVarName v) $ LiftE dest _ -> return Nothing - where - asNaryProj :: IRRep r => Int -> Stuck r n -> (NE.NonEmpty Int, AtomVar r n) - asNaryProj p (StuckVar v) = (p NE.:| [], v) - asNaryProj p1 (StuckProject _ p2 x) = do - let (p2' NE.:| ps, v) = asNaryProj p2 x - (p1 NE.:| (p2':ps), v) - asNaryProj _ _ = error $ "Can't normalize projection" - -lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) -lowerBlockWithDest dest (Abs decls ans) = do - decomposeDest dest ans >>= \case - Nothing -> do - ans' <- visitDeclsEmits decls $ visitAtom ans - place (FullDest dest) ans' - return ans' - Just destMap -> do - s <- getSubst - case isDistinctNest decls of - Nothing -> error "Non-distinct decls?" - Just DistinctBetween -> do - s' <- traverseDeclNestWithDestS destMap s decls - -- But we have to emit explicit writes, for all the vars that are not defined in decls! - forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do - x <- case s ! n of - Rename v -> Var <$> toAtomVar v - SubstVal a -> return a - place d x - withSubst s' $ substM ans traverseDeclNestWithDestS :: forall i i' l o. (Emits o, DistinctBetween l i') @@ -280,93 +212,96 @@ traverseDeclNestWithDestS -> LowerM i o (Subst AtomSubstVal i' o) traverseDeclNestWithDestS destMap s = \case Empty -> return s - Nest (Let b (DeclBinding ann expr)) rest -> do + Nest (Let b (DeclBinding _ expr)) rest -> do DistinctBetween <- return $ withExtEvidence rest $ shortenBetween @i' b let maybeDest = lookupDest destMap $ sinkBetween $ binderName b - expr' <- withSubst s $ lowerExprWithDest maybeDest expr - v <- emitDecl (getNameHint b) ann expr' - traverseDeclNestWithDestS destMap (s <>> (b @> Rename (atomVarName v))) rest - -lowerExprWithDest :: forall i o. Emits o => Maybe (ProjDest o) -> SExpr i -> LowerM i o (SExpr o) -lowerExprWithDest dest expr = case expr of - TabCon Nothing ty els -> lowerTabCon tabDest ty els + result <- withSubst s $ lowerExpr maybeDest expr + traverseDeclNestWithDestS destMap (s <>> (b @> SubstVal result)) rest + +traverseDeclNest :: Emits o => Nest SDecl i i' -> LowerM i' o a -> LowerM i o a +traverseDeclNest decls cont = case decls of + Empty -> cont + Nest (Let b (DeclBinding _ expr)) rest -> do + x <- lowerExpr Nothing expr + extendSubst (b@>SubstVal x) $ traverseDeclNest rest cont + +lowerExpr :: forall i o. Emits o => OptDest o -> SExpr i -> LowerM i o (SAtom o) +lowerExpr dest expr = case expr of + Block _ (Abs decls result) -> case dest of + Nothing -> traverseDeclNest decls $ lowerExpr Nothing result + Just dest' -> do + decomposeDest dest' result >>= \case + Nothing -> do + traverseDeclNest decls do + lowerExpr (Just dest') result + Just destMap -> do + s <- getSubst + case isDistinctNest decls of + Nothing -> error "Non-distinct decls?" + Just DistinctBetween -> do + s' <- traverseDeclNestWithDestS destMap s decls + -- But we have to emit explicit writes, for all the vars that are not defined in decls! + forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do + x <- case s ! n of + Rename v -> Var <$> toAtomVar v + SubstVal a -> return a + place d x + withSubst s' (substM result) >>= emitExpr + TabCon Nothing ty els -> lowerTabCon dest ty els PrimOp (Hof (TypedHof (EffTy _ ansTy) (For dir ixDict body))) -> do ansTy' <- substM ansTy - lowerFor ansTy' tabDest dir ixDict body + lowerFor ansTy' dest dir ixDict body PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter Nothing m body))) -> do PairTy _ ansTy <- visitType ty traverseRWS ansTy body \ref' body' -> do m' <- visitGeneric m - return $ RunWriter ref' m' body' + emitHof $ RunWriter ref' m' body' PrimOp (Hof (TypedHof (EffTy _ ty) (RunState Nothing s body))) -> do PairTy _ ansTy <- visitType ty traverseRWS ansTy body \ref' body' -> do s' <- visitAtom s - return $ RunState ref' s' body' + emitHof $ RunState ref' s' body' -- this case is important because this pass changes effects PrimOp (Hof (TypedHof _ hof)) -> do - hof' <- PrimOp . Hof <$> (visitGeneric hof >>= mkTypedHof) + hof' <- emitExpr =<< (visitGeneric hof >>= mkTypedHof) placeGeneric hof' - Case e alts (EffTy _ ty) -> case dest of - Nothing -> lowerCase Nothing e alts ty - Just (FullDest d) -> lowerCase (Just d) e alts ty - Just d -> do - ans <- lowerCase Nothing e alts ty >>= emitExprToAtom - place d ans - return $ Atom ans + Case e alts (EffTy _ ty) -> lowerCase dest e alts ty _ -> generic where - tabDest = dest <&> \case FullDest d -> d; ProjDest _ _ -> error "unexpected projection" - - generic = visitGeneric expr >>= placeGeneric + generic :: LowerM i o (SAtom o) + generic = visitGeneric expr >>= emitExpr >>= placeGeneric + placeGeneric :: SAtom o -> LowerM i o (SAtom o) placeGeneric e = do case dest of Nothing -> return e Just d -> do - ans <- Var <$> emit e - place d ans - return $ Atom ans + place d e + return e traverseRWS :: SType o -> LamExpr SimpIR i - -> (Maybe (Dest SimpIR o) -> LamExpr SimpIR o -> LowerM i o (Hof SimpIR o)) - -> LowerM i o (SExpr o) + -> (OptDest o -> LamExpr SimpIR o -> LowerM i o (SAtom o)) + -> LowerM i o (SAtom o) traverseRWS referentTy (LamExpr (BinaryNest hb rb) body) cont = do unpackRWSDest dest >>= \case Nothing -> generic Just (bodyDest, refDest) -> do - hof <- cont refDest =<< + cont refDest =<< buildEffLam (getNameHint rb) referentTy \hb' rb' -> extendRenamer (hb@>atomVarName hb' <.> rb@>atomVarName rb') do - case bodyDest of - Nothing -> lowerBlock body - Just bd -> lowerBlockWithDest (sink bd) body - PrimOp . Hof <$> mkTypedHof hof - + lowerExpr (sink <$> bodyDest) body traverseRWS _ _ _ = error "Expected a binary lambda expression" unpackRWSDest = \case Nothing -> return Nothing - Just d -> case d of - FullDest fd -> do - bd <- getProjRef (ProjectProduct 0) fd - rd <- getProjRef (ProjectProduct 1) fd - return $ Just (Just bd, Just rd) - ProjDest (0 NE.:| []) pd -> return $ Just (Just pd, Nothing) - ProjDest (1 NE.:| []) pd -> return $ Just (Nothing, Just pd) - ProjDest _ _ -> return Nothing - -place :: Emits o => ProjDest o -> SAtom o -> LowerM i o () -place pd x = case pd of - FullDest d -> void $ emitOp $ DAMOp $ Place d x - ProjDest p d -> do - x' <- applyProjs (NE.toList p) x - void $ emitOp $ DAMOp $ Place d x' - where - applyProjs :: Emits n => [Int] -> SAtom n -> LowerM i n (SAtom n) - applyProjs [] atom = return atom - applyProjs (p:ps) atom = proj p =<< applyProjs ps atom + Just d -> do + bd <- getProjRef (ProjectProduct 0) d + rd <- getProjRef (ProjectProduct 1) d + return $ Just (Just bd, Just rd) + +place :: Emits o => Dest o -> SAtom o -> LowerM i o () +place d x = void $ emitExpr $ Place d x -- === Extensions to the name system === diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index d80a161f0..59b7a4384 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -29,13 +29,9 @@ import QueryType -- unused pure bindings as it goes, since it has all the needed information. analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) -analyzeOccurrences = liftLamExpr analyzeOccurrencesBlock +analyzeOccurrences lam = liftLamExpr lam \e -> liftOCCM $ occ accessOnce e {-# INLINE analyzeOccurrences #-} -analyzeOccurrencesBlock :: EnvReader m => SBlock n -> m n (SBlock n) -analyzeOccurrencesBlock = liftOCCM . occNest accessOnce -{-# SCC analyzeOccurrencesBlock #-} - -- === Overview === -- We analyze every binding in the program for occurrence information, @@ -272,7 +268,7 @@ occTy ty = occ accessOnce ty instance HasOCC SLam where occ a (LamExpr bs body) = do lam@(LamExpr bs' _) <- refreshAbs (Abs bs body) \bs' body' -> - LamExpr bs' <$> occNest (sink a) body' + LamExpr bs' <$> occ (sink a) body' countFreeVarsAsOccurrencesB bs' return lam @@ -294,11 +290,11 @@ instance HasOCC (EffTy SimpIR) where return $ EffTy effs ty' data ElimResult (n::S) where - ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n - ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SAtom l -> ElimResult n + ElimSuccess :: Abs (Nest SDecl) SExpr n -> ElimResult n + ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SExpr l -> ElimResult n -occNest :: Access n -> Abs (Nest SDecl) SAtom n - -> OCCM n (Abs (Nest SDecl) SAtom n) +occNest :: Access n -> Abs (Nest SDecl) SExpr n + -> OCCM n (Abs (Nest SDecl) SExpr n) occNest a (Abs decls ans) = case decls of Empty -> Abs Empty <$> occ a ans Nest d@(Let _ binding) ds -> do @@ -358,6 +354,10 @@ instance HasOCC (DeclBinding SimpIR) where instance HasOCC SExpr where occ a = \case + Block effTy (Abs decls ans) -> do + effTy' <- occ a effTy + Abs decls' ans' <- occNest a (Abs decls ans) + return $ Block effTy' (Abs decls' ans') TabApp t array ixs -> do t' <- occTy t (a', ixs') <- go a ixs @@ -395,7 +395,7 @@ occAlt acc scrut alt = do -- case statement in that event. scrutIx <- unknown $ sink scrut extend nb scrutIx do - body' <- occNest (sink acc) body + body' <- occ (sink acc) body return $ Abs b body' ty' <- occTy ty return $ Abs (b':>ty') body' @@ -415,10 +415,10 @@ instance HasOCC (Hof SimpIR) where ixDict' <- inlinedLater ixDict occWithBinder (Abs b body) \b' body' -> do extend b' (Occ.Var $ binderName b') do - body'' <- censored (abstractFor b') (occNest accessOnce body') + body'' <- censored (abstractFor b') (occ accessOnce body') return $ For ann ixDict' (UnaryLamExpr b' body'') For _ _ _ -> error "For body should be a unary lambda expression" - While body -> While <$> censored useManyTimes (occNest accessOnce body) + While body -> While <$> censored useManyTimes (occ accessOnce body) RunReader ini bd -> do iniIx <- summary ini bd' <- oneShot a [Deterministic [], iniIx] bd @@ -455,7 +455,7 @@ instance HasOCC (Hof SimpIR) where return $ RunState Nothing ini' bd' RunState (Just _) _ _ -> error "Expecting to do occurrence analysis before destination passing." - RunIO bd -> RunIO <$> occNest a bd + RunIO bd -> RunIO <$> occ a bd RunInit _ -> -- Though this is probably not too hard to implement. Presumably -- the lambda is one-shot. @@ -463,7 +463,7 @@ instance HasOCC (Hof SimpIR) where oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) oneShot acc [] (LamExpr Empty body) = - LamExpr Empty <$> occNest acc body + LamExpr Empty <$> occ acc body oneShot acc (ix:ixs) (LamExpr (Nest b bs) body) = do occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> extend b' (sink ix) do diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index f8d2537fc..3c187d58c 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -208,12 +208,12 @@ peepholeExpr expr = case expr of -- === Loop unrolling === unrollLoops :: EnvReader m => STopLam n -> m n (STopLam n) -unrollLoops = liftLamExpr unrollLoopsBlock +unrollLoops lam = liftLamExpr lam unrollLoopsExpr {-# SCC unrollLoops #-} -unrollLoopsBlock :: EnvReader m => SBlock n -> m n (SBlock n) -unrollLoopsBlock b = liftM fst $ - liftBuilder $ runStateT1 (runSubstReaderT idSubst (runULM $ ulBlock b)) (ULS 0) +unrollLoopsExpr :: EnvReader m => SExpr n -> m n (SExpr n) +unrollLoopsExpr b = liftM fst $ + liftBuilder $ runStateT1 (runSubstReaderT idSubst (runULM $ buildBlock $ ulExpr b)) (ULS 0) newtype ULS n = ULS Int deriving Show newtype ULM i o a = ULM { runULM :: SubstReaderT AtomSubstVal (StateT1 ULS (BuilderM SimpIR)) i o a} @@ -236,12 +236,6 @@ instance Visitor (ULM i o) SimpIR i o where instance ExprVisitorEmits (ULM i o) SimpIR i o where visitExprEmits = ulExpr -ulBlock :: SBlock i -> ULM i o (SBlock o) -ulBlock b = buildBlock $ visitBlockEmits b - -emitSubstBlock :: Emits o => SBlock i -> ULM i o (SAtom o) -emitSubstBlock (Abs decls ans) = visitDeclsEmits decls $ visitAtom ans - -- TODO: Refine the cost accounting so that operations that will become -- constant-foldable after inlining don't count towards it. ulExpr :: Emits o => SExpr i -> ULM i o (SAtom o) @@ -255,7 +249,7 @@ ulExpr expr = case expr of True -> case body' of UnaryLamExpr b' block' -> do vals <- dropSubst $ forM (iota n) \i -> do - extendSubst (b' @> SubstVal (IdxRepVal i)) $ emitSubstBlock block' + extendSubst (b' @> SubstVal (IdxRepVal i)) $ ulExpr block' inc $ fromIntegral n -- To account for the TabCon we emit below getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do @@ -270,6 +264,7 @@ ulExpr expr = case expr of _ -> nothingSpecial -- Avoid unrolling loops with large table literals TabCon _ _ els -> inc (length els) >> nothingSpecial + Block _ (Abs decls body) -> visitDeclsEmits decls $ ulExpr body _ -> nothingSpecial where inc i = modify \(ULS n) -> ULS (n + i) @@ -301,12 +296,12 @@ instance Visitor (LICMM i o) SimpIR i o where instance ExprVisitorEmits (LICMM i o) SimpIR i o where visitExprEmits = licmExpr -hoistLoopInvariantBlock :: EnvReader m => SBlock n -> m n (SBlock n) -hoistLoopInvariantBlock body = liftLICMM $ buildBlock $ visitBlockEmits body -{-# SCC hoistLoopInvariantBlock #-} +hoistLoopInvariantExpr :: EnvReader m => SExpr n -> m n (SExpr n) +hoistLoopInvariantExpr body = liftLICMM $ buildBlock $ visitExprEmits body +{-# SCC hoistLoopInvariantExpr #-} hoistLoopInvariant :: EnvReader m => STopLam n -> m n (STopLam n) -hoistLoopInvariant = liftLamExpr hoistLoopInvariantBlock +hoistLoopInvariant lam = liftLamExpr lam hoistLoopInvariantExpr {-# INLINE hoistLoopInvariant #-} licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o) @@ -317,7 +312,7 @@ licmExpr = \case let numCarriesOriginal = length dests' Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. - Abs decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildScoped $ visitExprEmits body -- Now, we process the decls and decide which ones to hoist. liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans @@ -334,19 +329,19 @@ licmExpr = \case (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpackedReduced allCarries let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal - block <- applySubst s bodyAbs + block <- mkBlock =<< applySubst s bodyAbs return $ UnaryLamExpr lb' block emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do ix' <- substM ix Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do - Abs decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildScoped $ visitExprEmits body liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls $ Abs hdecls destsAndBody ixTy <- substM $ binderType b body' <- withFreshBinder noHint ixTy \i -> do - block <- applyRename (lnb@>binderName i) bodyAbs + block <- mkBlock =<< applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' expr -> visitGeneric expr >>= emitExpr @@ -401,12 +396,12 @@ newtype DCEM n a = DCEM { runDCEM :: StateT1 FV EnvReaderM n a } , MonadState (FV n), EnvExtender) dceTop :: EnvReader m => STopLam n -> m n (STopLam n) -dceTop = liftLamExpr dceBlock +dceTop lam = liftLamExpr lam dceExpr {-# INLINE dceTop #-} -dceBlock :: EnvReader m => SBlock n -> m n (SBlock n) -dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dceBlock' b) mempty -{-# SCC dceBlock #-} +dceExpr :: EnvReader m => SExpr n -> m n (SExpr n) +dceExpr b = liftEnvReaderM $ evalStateT1 (runDCEM $ dce b) mempty +{-# SCC dceExpr #-} class HasDCE (e::E) where dce :: e n -> DCEM n (e n) @@ -434,7 +429,25 @@ instance HasDCE (PiType SimpIR) where dceBinders bs effTy \bs' effTy' -> PiType bs' <$> dce effTy' instance HasDCE (LamExpr SimpIR) where - dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dceBlock' e' + dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dce e' + +instance HasDCE (Expr SimpIR) where + dce = \case + Block effTy block -> do + -- The free vars accumulated in the state of DCEM should correspond to + -- the free vars of the Abs of the block answer, by the decls traversed + -- so far. dceNest takes care to uphold this invariant, but we temporarily + -- reset the state to an empty map, just so that names from the surrounding + -- block don't end up influencing elimination decisions here. Note that we + -- restore the state (and accumulate free vars of the DCE'd block into it) + -- right after dceNest. + effTy' <- dce effTy + old <- get + put mempty + block' <- dceBlock block + modify (<> old) + return $ Block effTy' block' + e -> visitGeneric e dceBinders :: (HoistableB b, BindsEnv b, RenameB b, RenameE e) @@ -447,21 +460,6 @@ dceBinders b e cont = do return ans {-# INLINE dceBinders #-} -dceBlock' :: SBlock n -> DCEM n (SBlock n) -dceBlock' (Abs decls ans) = do - -- The free vars accumulated in the state of DCEM should correspond to - -- the free vars of the Abs of the block answer, by the decls traversed - -- so far. dceNest takes care to uphold this invariant, but we temporarily - -- reset the state to an empty map, just so that names from the surrounding - -- block don't end up influencing elimination decisions here. Note that we - -- restore the state (and accumulate free vars of the DCE'd block into it) - -- right after dceNest. - old <- get - put mempty - block <- dceNest decls ans - modify (<> old) - return block - wrapWithCachedFVs :: HoistableE e => e n -> DCEM n (CachedFVs e n) wrapWithCachedFVs e = do FV fvs <- get @@ -482,11 +480,11 @@ hoistUsingCachedFVs :: (BindsNames b, HoistableE e) => hoistUsingCachedFVs b e = hoistViaCachedFVs b <$> wrapWithCachedFVs e data ElimResult n where - ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n - ElimFailure :: SDecl n l -> Abs (Nest SDecl) SAtom l -> ElimResult n + ElimSuccess :: SBlock n -> ElimResult n + ElimFailure :: SDecl n l -> SBlock l -> ElimResult n -dceNest :: Nest SDecl n l -> SAtom l -> DCEM n (Abs (Nest SDecl) SAtom n) -dceNest decls ans = case decls of +dceBlock :: SBlock n -> DCEM n (SBlock n) +dceBlock (Abs decls ans) = case decls of Empty -> Abs Empty <$> dce ans Nest b@(Let _ decl) bs -> do -- Note that we only ever dce the abs below under this refreshAbs, @@ -494,7 +492,7 @@ dceNest decls ans = case decls of -- because refreshAbs of StateT1 triggers hoistState, which we -- implement by deleting the entries that can't hoist). dceAttempt <- refreshAbs (Abs b (Abs bs ans)) \b' (Abs bs' ans') -> do - below <- dceNest bs' ans' + below <- dceBlock $ Abs bs' ans' case isPure decl of False -> return $ ElimFailure b' below True -> do @@ -503,11 +501,10 @@ dceNest decls ans = case decls of HoistFailure _ -> ElimFailure b' below case dceAttempt of ElimSuccess below' -> return below' - ElimFailure (Let b' decl') (Abs bs'' ans'') -> do - decl'' <- dce decl' + ElimFailure (Let b' (DeclBinding ann expr)) (Abs bs'' ans'') -> do + expr' <- dce expr modify (<>FV (freeVarsB b')) - return $ Abs (Nest (Let b' decl'') bs'') ans'' + return $ Abs (Nest (Let b' (DeclBinding ann expr')) bs'') ans'' instance HasDCE (EffectRow SimpIR) -instance HasDCE (DeclBinding SimpIR) instance HasDCE (EffTy SimpIR) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 421d49a9c..7e8708892 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -130,11 +130,6 @@ pApp a = prettyPrec a AppPrec pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec -instance IRRep r => Pretty (Block r n) where - pretty (Abs decls expr) = prettyBlock decls expr -instance IRRep r => PrettyPrec (Block r n) where - prettyPrec (Abs decls expr) = atPrec LowestPrec $ prettyBlock decls expr - prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann prettyBlock Empty expr = group $ line <> pLowest expr prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr @@ -163,6 +158,7 @@ instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Expr r n) where prettyPrec = \case Atom x -> prettyPrec x + Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) TabApp _ f xs -> atPrec AppPrec $ pApp f <> "." <> dotted (toList xs) diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 98c1b65f7..424ec295e 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -48,20 +48,6 @@ caseAltsBinderTys ty = case ty of extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t -blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) -blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do - effs <- declsEffects decls mempty - return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result - where - declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) - declsEffects Empty !acc = return acc - declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do - expr' <- sinkM expr - declsEffects rest $ acc <> getEffects expr' - -blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) -blockTy b = blockEffTy b <&> \(EffTy _ t) -> t - piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of @@ -69,9 +55,6 @@ piTypeWithoutDest (PiType bsRefB _) = PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here _ -> error "expected trailing dest binder" -blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) -blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff - typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) typeOfTabApp t [] = return t typeOfTabApp (TabPi tabTy) (i:rest) = do @@ -144,22 +127,22 @@ typeOfHof = \case RunState _ _ f -> do (resultTy, stateTy) <- getTypeRWSAction f return $ PairTy resultTy stateTy - RunIO f -> blockTy f - RunInit f -> blockTy f + RunIO f -> return $ getType f + RunInit f -> return $ getType f CatchException ty _ -> return ty hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) hofEffects = \case For _ _ f -> functionEffs f - While body -> blockEff body + While body -> return $ getEffects body Linearize _ _ -> return Pure -- Body has to be a pure function Transpose _ _ -> return Pure -- Body has to be a pure function RunReader _ f -> rwsFunEffects Reader f RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> deleteEff IOEffect <$> blockEff f - RunInit f -> deleteEff InitEffect <$> blockEff f - CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f + RunIO f -> return $ deleteEff IOEffect $ getEffects f + RunInit f -> return $ deleteEff InitEffect $ getEffects f + CatchException _ f -> return $ deleteEff ExceptionEffect $ getEffects f where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id @@ -308,10 +291,8 @@ rwsFunEffects rws f = getLamExprType f >>= \case _ -> error "Expected a binary function type" getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) -getLamExprType (LamExpr bs body) = liftEnvReaderM $ - refreshAbs (Abs bs body) \bs' body' -> do - effTy <- blockEffTy body' - return $ PiType bs' effTy +getLamExprType (LamExpr bs body) = + return $ PiType bs $ EffTy (getEffects body) (getType body) getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) getTypeRWSAction f = getLamExprType f >>= \case diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 89eb158ca..99b4687b8 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -141,6 +141,7 @@ instance IRRep r => HasType r (Expr r) where TopApp (EffTy _ ty) _ _ -> ty TabApp t _ _ -> t Atom x -> getType x + Block (EffTy _ ty) _ -> ty TabCon _ ty _ -> ty PrimOp op -> getType op Case _ _ (EffTy _ resultTy) -> resultTy @@ -274,6 +275,7 @@ ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of instance IRRep r => HasEffects (Expr r) r where getEffects = \case Atom _ -> Pure + Block (EffTy eff _) _ -> eff App (EffTy eff _) _ _ -> eff TopApp (EffTy eff _) _ _ -> eff TabApp _ _ _ -> Pure diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 64e87ee04..593a3fc9c 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -27,13 +27,13 @@ newtype Printer (n::S) (a :: *) = Printer { runPrinter' :: ReaderT1 (Atom CoreIR , Fallible, ScopeReader, MonadFail, EnvExtender, CBuilder, ScopableBuilder CoreIR) type Print n = Printer n () -showAny :: EnvReader m => Atom CoreIR n -> m n (Block CoreIR n) +showAny :: EnvReader m => Atom CoreIR n -> m n (CExpr n) showAny x = liftPrinter $ showAnyRec (sink x) liftPrinter :: EnvReader m => (forall l. (DExt n l, Emits l) => Print l) - -> m n (CBlock n) + -> m n (CExpr n) liftPrinter cont = liftBuilder $ buildBlock $ withBuffer \buf -> runReaderT1 buf (runPrinter' cont) diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 9139af597..09c966548 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -27,7 +27,7 @@ import IRVariants import Linearize import Name import Subst -import Optimize (peepholeOp) +import Optimize (peepholeExpr) import QueryType import RuntimePrint import Transpose @@ -241,13 +241,13 @@ deriving instance ScopableBuilder SimpIR (SimplifyM i) -- === Top-level API === data SimplifiedTopLam n = SimplifiedTopLam (STopLam n) (ReconstructAtom n) -data SimplifiedBlock n = SimplifiedBlock (SBlock n) (ReconstructAtom n) +data SimplifiedBlock n = SimplifiedBlock (SExpr n) (ReconstructAtom n) simplifyTopBlock :: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (SimplifiedTopLam n) simplifyTopBlock (TopLam _ _ (LamExpr Empty body)) = do SimplifiedBlock block recon <- liftSimplifyM do - {-# SCC "Simplify" #-} buildSimplifiedBlock $ simplifyBlock body + {-# SCC "Simplify" #-} buildSimplifiedBlock $ simplifyExpr body topLam <- asTopLam $ LamExpr Empty block return $ SimplifiedTopLam topLam recon simplifyTopBlock _ = error "not a block (nullary lambda)" @@ -263,7 +263,7 @@ applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m applyReconTop = applyRecon instance GenericE SimplifiedBlock where - type RepE SimplifiedBlock = PairE SBlock ReconstructAtom + type RepE SimplifiedBlock = PairE SExpr ReconstructAtom fromE (SimplifiedBlock block recon) = PairE block recon {-# INLINE fromE #-} toE (PairE block recon) = SimplifiedBlock block recon @@ -272,13 +272,6 @@ instance GenericE SimplifiedBlock where instance SinkableE SimplifiedBlock instance RenameE SimplifiedBlock instance HoistableE SimplifiedBlock -instance CheckableE SimpIR SimplifiedBlock where - checkE (SimplifiedBlock block recon) = do - block' <- renameM block - effTy <- blockEffTy block' -- TODO: store this in the simplified block instead - block'' <- dropSubst $ checkBlock effTy block' - recon' <- renameM recon -- TODO: CheckableE instance for the recon too - return $ SimplifiedBlock block'' recon' instance Pretty (SimplifiedBlock n) where pretty (SimplifiedBlock block recon) = @@ -310,22 +303,22 @@ simpDeclsSubst simpDeclsSubst !s = \case Empty -> return s Nest (Let b (DeclBinding _ expr)) rest -> do - let hint = (getNameHint b) - x <- withSubst s $ simplifyExpr hint expr + x <- withSubst s $ simplifyExpr expr simpDeclsSubst (s <>> (b@>SubstVal x)) rest -simplifyExpr :: Emits o => NameHint -> Expr CoreIR i -> SimplifyM i o (CAtom o) -simplifyExpr hint expr = confuseGHC >>= \_ -> case expr of +simplifyExpr :: Emits o => Expr CoreIR i -> SimplifyM i o (CAtom o) +simplifyExpr expr = confuseGHC >>= \_ -> case expr of + Block _ (Abs decls body) -> simplifyDecls decls $ simplifyExpr body App (EffTy _ ty) f xs -> do ty' <- substM ty xs' <- mapM simplifyAtom xs - simplifyApp hint ty' f xs' + simplifyApp ty' f xs' TabApp _ f xs -> do xs' <- mapM simplifyAtom xs f' <- simplifyAtom f simplifyTabApp f' xs' Atom x -> simplifyAtom x - PrimOp op -> simplifyOp hint op + PrimOp op -> simplifyOp op ApplyMethod (EffTy _ ty) dict i xs -> do ty' <- substM ty xs' <- mapM simplifyAtom xs @@ -341,7 +334,7 @@ simplifyExpr hint expr = confuseGHC >>= \_ -> case expr of resultTy' <- substM resultTy defuncCaseCore scrut' resultTy' \i x -> do Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) $ simplifyBlock body + extendSubst (b@>SubstVal x) $ simplifyExpr body Project ty i x -> do ty' <- substM ty x' <- substM x @@ -362,17 +355,17 @@ simplifyRefOp op ref = case op of x' <- simplifyDataAtom x (cb', CoerceReconAbs) <- simplifyLam cb emitRefOp $ MExtend (BaseMonoid em' cb') x' - MGet -> emitOp $ RefOp ref MGet + MGet -> emitExpr $ RefOp ref MGet MPut x -> do x' <- simplifyDataAtom x emitRefOp $ MPut x' MAsk -> emitRefOp MAsk IndexRef _ x -> do x' <- simplifyDataAtom x - emitOp =<< mkIndexRef ref x' - ProjRef _ (ProjectProduct i) -> emitOp =<< mkProjRef ref (ProjectProduct i) + emitExpr =<< mkIndexRef ref x' + ProjRef _ (ProjectProduct i) -> emitExpr =<< mkProjRef ref (ProjectProduct i) ProjRef _ UnwrapNewtype -> return ref - where emitRefOp op' = emitOp $ RefOp ref op' + where emitRefOp op' = emitExpr $ RefOp ref op' defuncCaseCore :: Emits o => Atom CoreIR o -> Type CoreIR o @@ -449,18 +442,19 @@ simplifyAlt split ty cont = do (resultData, resultNonData) <- toSplit split result (newResult, reconAbs) <- telescopicCapture locals resultNonData return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) - EffTy _ (PairTy _ nonDataType) <- blockEffTy body + body' <- mkBlock body + PairTy _ nonDataType <- return $ getType body' let nonDataType' = ignoreHoistFailure $ hoist b nonDataType - return (Abs b body, nonDataType', recon) + return (Abs b body', nonDataType', recon) simplifyApp :: forall i o. Emits o - => NameHint -> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyApp hint resultTy f xs = case f of + => CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) +simplifyApp resultTy f xs = case f of Lam (CoreLamExpr _ lam) -> fast lam _ -> slow =<< simplifyAtomAndInline f where fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o) - fast lam = withInstantiated lam xs \body -> simplifyBlock body + fast lam = withInstantiated lam xs \body -> simplifyExpr body slow :: CAtom o -> SimplifyM i o (CAtom o) slow = \case @@ -470,10 +464,10 @@ simplifyApp hint resultTy f xs = case f of Abs b body <- return $ alts !! i extendSubst (b@>SubstVal x) do xs' <- mapM sinkM xs - simplifyApp hint (sink resultTy) body xs' + simplifyApp (sink resultTy) body xs' SimpInCore (LiftSimpFun _ lam) -> do xs' <- mapM toDataAtomIgnoreRecon xs - result <- instantiate lam xs' >>= emitBlock + result <- instantiate lam xs' >>= emitExpr liftSimpAtom resultTy result Var v -> do lookupAtomName (atomVarName v) >>= \case @@ -680,7 +674,7 @@ simplifyLam (LamExpr bsTop body) = case bsTop of (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body return (LamExpr (Nest (b':>tySimp) bs') body', Abs (Nest b' bsRecon) recon) Empty -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body return (LamExpr Empty body', Abs Empty recon) data SplitDataNonData n = SplitDataNonData @@ -731,19 +725,21 @@ buildSimplifiedBlock cont = do return $ RightE (dataResult `PairE` ansTy) case eitherResult of LeftE ans -> do - (block, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do + (blockAbs, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do (newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' return (Abs decls' newResult, LamRecon reconAbs) - return $ SimplifiedBlock block recon + block' <- mkBlock blockAbs + return $ SimplifiedBlock block' recon RightE (ans `PairE` ty) -> do let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty - return $ SimplifiedBlock (Abs decls ans) (CoerceRecon ty') + block <- mkBlock $ Abs decls ans + return $ SimplifiedBlock block (CoerceRecon ty') -simplifyOp :: Emits o => NameHint -> PrimOp CoreIR i -> SimplifyM i o (CAtom o) -simplifyOp hint op = case op of +simplifyOp :: Emits o => PrimOp CoreIR i -> SimplifyM i o (CAtom o) +simplifyOp op = case op of Hof (TypedHof (EffTy _ ty) hof) -> do ty' <- substM ty - simplifyHof hint ty' hof + simplifyHof ty' hof MemOp op' -> simplifyGenericOp op' VectorOp op' -> simplifyGenericOp op' RefOp ref eff -> do @@ -759,10 +755,10 @@ simplifyOp hint op = case op of IDiv -> idiv x y ICmp Less -> ilt x y ICmp Equal -> ieq x y - _ -> emitOp $ BinOp binop x y + _ -> emitExpr $ BinOp binop x y UnOp unOp x' -> do x <- simplifyDataAtom x' - liftResult =<< emitOp (UnOp unOp x) + liftResult =<< emitExpr (UnOp unOp x) MiscOp op' -> case op' of Select c' x' y' -> do c <- simplifyDataAtom c' @@ -771,7 +767,7 @@ simplifyOp hint op = case op of liftResult =<< select c x y ShowAny x' -> do x <- simplifyAtom x' - dropSubst $ showAny x >>= simplifyBlock + dropSubst $ showAny x >>= simplifyExpr _ -> simplifyGenericOp op' where liftResult x = do @@ -779,7 +775,7 @@ simplifyOp hint op = case op of liftSimpAtom ty x simplifyGenericOp - :: (GenericOp op, IsPrimOp op, HasType CoreIR (op CoreIR), Emits o, + :: (GenericOp op, ToExpr (op SimpIR) SimpIR, HasType CoreIR (op CoreIR), Emits o, OpConst op CoreIR ~ OpConst op SimpIR) => op CoreIR i -> SimplifyM i o (CAtom o) @@ -789,7 +785,7 @@ simplifyGenericOp op = do (substM >=> getRepType) (simplifyAtom >=> toDataAtomIgnoreRecon) (error "shouldn't have lambda left") - result <- liftEnvReaderM (peepholeOp $ toPrimOp op') >>= emitExprToAtom + result <- liftEnvReaderM (peepholeExpr $ toExpr op') >>= emitExprToAtom liftSimpAtom ty result {-# INLINE simplifyGenericOp #-} @@ -804,7 +800,7 @@ applyDictMethod resultTy d i methodArgs = case d of withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do let InstanceBody _ methods = body let method = methods !! i - simplifyApp noHint resultTy method methodArgs + simplifyApp resultTy method methodArgs DictCon (IxFin _ n) -> applyIxFinMethod (toEnum i) n methodArgs d' -> error $ "Not a simplified dict: " ++ pprint d' where @@ -816,8 +812,8 @@ applyDictMethod resultTy d i methodArgs = case d of (UnsafeFromOrdinal, [ix]) -> return $ NewtypeCon (FinCon n) ix _ -> error "bad ix args" -simplifyHof :: Emits o => NameHint -> CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o) -simplifyHof _hint resultTy = \case +simplifyHof :: Emits o => CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o) +simplifyHof resultTy = \case For d ixTypeCore' lam -> do (lam', Abs (UnaryNest bIx) recon) <- simplifyLam lam ixTypeCore <- substM ixTypeCore' @@ -833,7 +829,7 @@ simplifyHof _hint resultTy = \case xs <- unpackTelescope bsClosure =<< tabApp (sink ans) (Var i') applySubst (bIx@>Rename (atomVarName i') <.> bsClosure @@> map SubstVal xs) reconResult While body -> do - SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body result <- emitHof $ While body' liftSimpAtom resultTy result RunReader r lam -> do @@ -865,11 +861,11 @@ simplifyHof _hint resultTy = \case return $ PairVal ans' sOut' RunState _ _ _ -> error "Shouldn't see a RunState with a dest in Simplify" RunIO body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body ans <- emitHof $ RunIO body' applyRecon recon ans RunInit body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body ans <- emitHof $ RunInit body' applyRecon recon ans Linearize lam x -> do @@ -889,11 +885,10 @@ simplifyHof _hint resultTy = \case result <- transpose lam' x' liftSimpAtom resultTy result CatchException _ body-> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body - simplifiedResultTy <- blockTy body' + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ - exceptToMaybeBlock (sink simplifiedResultTy) body' - result <- emitBlock block + exceptToMaybeExpr body' + result <- emitExpr block case recon of CoerceRecon ty -> do maybeTy <- makePreludeMaybeTy ty @@ -937,9 +932,6 @@ preludeMaybeNewtypeCon ty = do let params = TyConParams [Explicit] [Type ty] return $ UserADTData sn tyConName params -simplifyBlock :: Emits o => Block CoreIR i -> SimplifyM i o (CAtom o) -simplifyBlock (Abs decls result) = simplifyDecls decls $ simplifyAtom result - liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) liftSimpAtom ty simpAtom = case simpAtom of Stuck _ -> justLift @@ -1028,7 +1020,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do ListE staticArgs' <- instantiate (sink $ Abs runtimeBs staticArgs) (sink <$> runtimeArgs) fCustom' <- sinkM fCustom resultTy <- typeOfApp (getType fCustom') staticArgs' - pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' + pairResult <- dropSubst $ simplifyApp resultTy fCustom' staticArgs' (primalResult, fLin) <- fromPairReduced pairResult primalResult' <- toDataAtomIgnoreRecon primalResult let explicitPrimalArgs = drop nImplicit staticArgs' @@ -1051,7 +1043,7 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do let tangentCoreTys = fromNonDepNest bs tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs resultTyTangent <- typeOfApp (getType fLin') tangentArgs' - tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' + tangentResult <- dropSubst $ simplifyApp resultTyTangent fLin' tangentArgs' toDataAtomIgnoreRecon tangentResult return $ PairE primalResult' fLin' PairE primalFun tangentFun <- defuncLinearized linearized @@ -1093,10 +1085,10 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do return $ Abs (Nest rB tBs') UnitE residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') - let primalFun = LamExpr bs declsAndResult + primalFun <- LamExpr bs <$> mkBlock declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) - applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitBlock + applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitExpr let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun @@ -1106,7 +1098,7 @@ type HandlerM = SubstReaderT AtomSubstVal (BuilderM SimpIR) exceptToMaybeBlock :: Emits o => SType o -> SBlock i -> HandlerM i o (SAtom o) exceptToMaybeBlock ty (Abs Empty result) = do - result' <- substM result + result' <- exceptToMaybeExpr result return $ JustAtom ty result' exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalResult) = do maybeResult <- exceptToMaybeExpr rhs @@ -1121,14 +1113,16 @@ exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalR exceptToMaybeExpr :: Emits o => SExpr i -> HandlerM i o (SAtom o) exceptToMaybeExpr expr = case expr of + Block (EffTy _ ty) body -> do + ty' <- substM ty + exceptToMaybeBlock ty' body Case e alts (EffTy _ resultTy) -> do e' <- substM e resultTy' <- substM $ MaybeTy resultTy buildCase e' resultTy' \i v -> do Abs b body <- return $ alts !! i extendSubst (b @> SubstVal v) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body Atom x -> do x' <- substM x let ty = getType x' @@ -1136,9 +1130,7 @@ exceptToMaybeExpr expr = case expr of PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do - extendSubst (b@>Rename (atomVarName i)) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do ty <- substM $ getType expr @@ -1148,8 +1140,7 @@ exceptToMaybeExpr expr = case expr of BinaryLamExpr h ref body <- return lam result <- emitRunState noHint s' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body (maybeAns, newState) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) @@ -1160,16 +1151,13 @@ exceptToMaybeExpr expr = case expr of PairTy _ accumTy <- substM resultTy result <- emitRunWriter noHint accumTy monoid' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body (maybeAns, accumResult) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) (return $ NothingAtom $ sink a) (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink accumResult)) - PrimOp (Hof (TypedHof _ (While body))) -> do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - runMaybeWhile $ exceptToMaybeBlock (sink blockResultTy) body + PrimOp (Hof (TypedHof _ (While body))) -> runMaybeWhile $ exceptToMaybeExpr body _ -> do expr' <- substM expr case hasExceptions expr' of diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 6601e57cf..87f4c4a19 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -536,7 +536,7 @@ evalBlock typed = do SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock typed opt <- simpOptimizations simp simpResult <- case opt of - TopLam _ _ (LamExpr Empty (WithoutDecls result)) -> return result + TopLam _ _ (LamExpr Empty (Atom result)) -> return result _ -> do lowered <- checkPass LowerPass $ lowerFullySequential True opt lOpt <- checkPass OptPass $ loweredOptimizations lowered diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 908dcf4cb..0cde8255b 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -33,7 +33,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do UnaryLamExpr b body <- sinkM lam withAccumulator (binderType b) \refSubstVal -> extendSubst (b @> refSubstVal) $ - transposeBlock body (sink ct) + transposeExpr body (sink ct) {-# SCC transpose #-} runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a @@ -52,16 +52,15 @@ transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do inTy <- substNonlin $ binderType bLin withAccumulator inTy \refSubstVal -> extendSubst (bLin @> refSubstVal) $ - transposeBlock body (sink ct) - EffTy _ bodyTy <- blockEffTy body' - let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure bodyTy) + transposeExpr body (sink ct) + let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure (getType body')) let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' return $ TopLam False piTy lamT transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n - -> m n ( Abs (Nest SBinder) (Abs SBinder SBlock) n + -> m n ( Abs (Nest SBinder) (Abs SBinder SExpr) n , Abs (Nest SBinder) SType n) unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 @@ -146,7 +145,7 @@ withAccumulator ty cont = do emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) - void $ emitOp $ RefOp ref $ MExtend baseMonoid ct + void $ emitExpr $ RefOp ref $ MExtend baseMonoid ct getLinRegions :: TransposeM i o [SAtomVar o] getLinRegions = asks fromListE @@ -156,11 +155,8 @@ extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont -- === actual pass === -transposeBlock :: Emits o => SBlock i -> SAtom o -> TransposeM i o () -transposeBlock (Abs decls result) ct = transposeWithDecls decls result ct - -transposeWithDecls :: Emits o => Nest SDecl i i' -> SAtom i' -> SAtom o -> TransposeM i o () -transposeWithDecls Empty atom ct = transposeAtom atom ct +transposeWithDecls :: Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () +transposeWithDecls Empty atom ct = transposeExpr atom ct transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct = substExprIfNonlin expr >>= \case Nothing -> do @@ -196,6 +192,7 @@ getTransposedTopFun f = do transposeExpr :: Emits o => SExpr i -> SAtom o -> TransposeM i o () transposeExpr expr ct = case expr of + Block _ (Abs decls result) -> transposeWithDecls decls result ct Atom atom -> transposeAtom atom ct TopApp _ f xs -> do Just fT <- getTransposedTopFun =<< substNonlin f @@ -237,7 +234,7 @@ transposeExpr expr ct = case expr of v' <- emit (Atom v) Abs b body <- return $ alts !! i extendSubst (b @> RenameNonlin (atomVarName v')) do - transposeBlock body (sink ct) + transposeExpr body (sink ct) return UnitVal TabCon _ ty es -> do TabTy d b _ <- return ty @@ -252,7 +249,7 @@ transposeOp op ct = case op of DAMOp _ -> error "unreachable" -- TODO: rule out statically RefOp refArg m -> do refArg' <- substNonlin refArg - let emitEff = emitOp . RefOp refArg' + let emitEff = emitExpr . RefOp refArg' case m of MAsk -> do baseMonoid <- tangentBaseMonoidFor (getType ct) @@ -335,14 +332,14 @@ transposeHof hof ct = case hof of ixTy <- substNonlin ixTy' void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do ctElt <- tabApp (sink ct) (Var i) - extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeBlock body ctElt + extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal RunState Nothing s (BinaryLamExpr hB refB body) -> do (ctBody, ctState) <- fromPair ct (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ extendLinRegions h $ - transposeBlock body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal transposeAtom s cts RunReader r (BinaryLamExpr hB refB body) -> do @@ -351,7 +348,7 @@ transposeHof hof ct = case hof of (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ extendLinRegions h $ - transposeBlock body (sink ct) + transposeExpr body (sink ct) return UnitVal transposeAtom r ct' RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do @@ -360,7 +357,7 @@ transposeHof hof ct = case hof of void $ emitRunReader noHint ctEff \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ extendLinRegions h $ - transposeBlock body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal _ -> notImplemented diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 77045321a..2ef17f6d9 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -106,7 +106,8 @@ deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) data Expr r n where - TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n + Block :: EffTy r n -> Block r n -> Expr r n + TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n TabApp :: Type r n -> Atom r n -> [Atom r n] -> Expr r n Case :: Atom r n -> [Alt r n] -> EffTy r n -> Expr r n Atom :: Atom r n -> Expr r n @@ -152,7 +153,7 @@ type FunObjCodeName = Name FunObjCodeNameC type AtomBinderP (r::IR) = BinderP (AtomNameC r) type Binder r = AtomBinderP r (Type r) :: B -type Alt r = Abs (Binder r) (Block r) :: E +type Alt r = Abs (Binder r) (Expr r) :: E newtype DotMethods n = DotMethods (M.Map SourceName (CAtomName n)) deriving (Show, Generic, Monoid, Semigroup) @@ -186,7 +187,7 @@ data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] deriving (Show, Generic) type WithDecls (r::IR) = Abs (Decls r) :: E -> E -type Block (r::IR) = WithDecls r (Atom r) :: E +type Block (r::IR) = WithDecls r (Expr r) :: E type TopBlock = TopLam -- used for nullary lambda type IsDestLam = Bool @@ -194,7 +195,7 @@ data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) deriving (Show, Generic) data LamExpr (r::IR) (n::S) where - LamExpr :: Nest (Binder r) n l -> Block r l -> LamExpr r n + LamExpr :: Nest (Binder r) n l -> Expr r l -> LamExpr r n data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) @@ -250,7 +251,7 @@ class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where toAbs (CorePiType _ _ bs effTy) = Abs bs effTy -instance ToBindersAbs CoreLamExpr (Block CoreIR) CoreIR where +instance ToBindersAbs CoreLamExpr (Expr CoreIR) CoreIR where toAbs (CoreLamExpr _ lam) = toAbs lam instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where @@ -259,7 +260,7 @@ instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where instance ToBindersAbs (PiType r) (EffTy r) r where toAbs (PiType bs effTy) = Abs bs effTy -instance ToBindersAbs (LamExpr r) (Block r) r where +instance ToBindersAbs (LamExpr r) (Expr r) r where toAbs (LamExpr bs body) = Abs bs body instance ToBindersAbs (TabPiType r) (Type r) r where @@ -277,17 +278,11 @@ instance ToBindersAbs TyConDef DataConDefs CoreIR where instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where toAbs (ClassDef _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) -instance ToBindersAbs (TopLam r) (Block r) r where +instance ToBindersAbs (TopLam r) (Expr r) r where toAbs (TopLam _ _ lam) = toAbs lam -- === GenericOp class === -class IsPrimOp (e::IR->E) where - toPrimOp :: e r n -> PrimOp r n - -instance IsPrimOp PrimOp where - toPrimOp x = x - class GenericOp (e::IR->E) where type OpConst e (r::IR) :: * fromOp :: e r n -> GenericOpRep (OpConst e r) r n @@ -407,13 +402,13 @@ data TypedHof r n = TypedHof (EffTy r n) (Hof r n) data Hof r n where For :: ForAnn -> IxType r n -> LamExpr r n -> Hof r n - While :: Block r n -> Hof r n + While :: Expr r n -> Hof r n RunReader :: Atom r n -> LamExpr r n -> Hof r n RunWriter :: Maybe (Atom r n) -> BaseMonoid r n -> LamExpr r n -> Hof r n RunState :: Maybe (Atom r n) -> Atom r n -> LamExpr r n -> Hof r n -- dest, initial value, body lambda - RunIO :: Block r n -> Hof r n - RunInit :: Block r n -> Hof r n - CatchException :: CType n -> Block CoreIR n -> Hof CoreIR n + RunIO :: Expr r n -> Hof r n + RunInit :: Expr r n -> Hof r n + CatchException :: CType n -> Expr CoreIR n -> Hof CoreIR n Linearize :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n Transpose :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n @@ -992,6 +987,55 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where toBinding (LeftE e) = toBinding e toBinding (RightE e) = toBinding e +-- === ToAtom === + +class ToAtom (e::E) (r::IR) | e -> r where + toAtom :: e n -> Atom r n + +instance ToAtom (Atom r) r where + toAtom = id + +instance ToAtom (AtomVar r) r where + toAtom = Var + +instance ToAtom (Con r) r where + toAtom = Con + +instance ToAtom (Type CoreIR) CoreIR where + toAtom = TypeAsAtom + +-- === ToExpr === + +class ToExpr (e::E) (r::IR) | e -> r where + toExpr :: e n -> Expr r n + +instance ToExpr (Expr r) r where + toExpr = id + +instance ToExpr (Atom r) r where + toExpr = Atom + +instance ToExpr (AtomVar r) r where + toExpr = toExpr . toAtom + +instance ToExpr (PrimOp r) r where + toExpr = PrimOp + +instance ToExpr (MiscOp r) r where + toExpr = PrimOp . MiscOp + +instance ToExpr (MemOp r) r where + toExpr = PrimOp . MemOp + +instance ToExpr (VectorOp r) r where + toExpr = PrimOp . VectorOp + +instance ToExpr (TypedHof r) r where + toExpr = PrimOp . Hof + +instance ToExpr (DAMOp SimpIR) SimpIR where + toExpr = PrimOp . DAMOp + -- === Pattern synonyms === -- XXX: only use this pattern when you're actually expecting a type. If it's @@ -1092,24 +1136,15 @@ pattern EffKind = NewtypeTyCon EffectRowKind pattern FinConst :: Word32 -> Type CoreIR n pattern FinConst n = NewtypeTyCon (Fin (NatVal n)) -pattern NullaryLamExpr :: Block r n -> LamExpr r n +pattern NullaryLamExpr :: Expr r n -> LamExpr r n pattern NullaryLamExpr body = LamExpr Empty body -pattern UnaryLamExpr :: Binder r n l -> Block r l -> LamExpr r n +pattern UnaryLamExpr :: Binder r n l -> Expr r l -> LamExpr r n pattern UnaryLamExpr b body = LamExpr (UnaryNest b) body -pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Block r l2 -> LamExpr r n +pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Expr r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body -pattern WithoutDecls :: e n -> WithDecls r e n -pattern WithoutDecls x = Abs Empty x - -exprBlock :: IRRep r => Block r n -> Maybe (Expr r n) -exprBlock (Abs (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) - | n == binderName b = Just expr -exprBlock _ = Nothing -{-# INLINE exprBlock #-} - pattern MaybeTy :: Type r n -> Type r n pattern MaybeTy a = SumTy [UnitTy, a] @@ -1331,7 +1366,6 @@ instance IRRep r => RenameE (DAMOp r) instance IRRep r => AlphaEqE (DAMOp r) instance IRRep r => AlphaHashableE (DAMOp r) -instance IsPrimOp TypedHof where toPrimOp = Hof instance IRRep r => GenericE (TypedHof r) where type RepE (TypedHof r) = EffTy r `PairE` Hof r fromE (TypedHof effTy hof) = effTy `PairE` hof @@ -1349,14 +1383,14 @@ instance IRRep r => GenericE (Hof r) where type RepE (Hof r) = EitherE2 (EitherE6 {- For -} (LiftE ForAnn `PairE` IxType r `PairE` LamExpr r) - {- While -} (Block r) + {- While -} (Expr r) {- RunReader -} (Atom r `PairE` LamExpr r) {- RunWriter -} (MaybeE (Atom r) `PairE` BaseMonoid r `PairE` LamExpr r) {- RunState -} (MaybeE (Atom r) `PairE` Atom r `PairE` LamExpr r) - {- RunIO -} (Block r) + {- RunIO -} (Expr r) ) (EitherE4 - {- RunInit -} (Block r) - {- CatchException -} (WhenCore r (Type r `PairE` Block r)) + {- RunInit -} (Expr r) + {- CatchException -} (WhenCore r (Type r `PairE` Expr r)) {- Linearize -} (WhenCore r (LamExpr r `PairE` Atom r)) {- Transpose -} (WhenCore r (LamExpr r `PairE` Atom r))) @@ -1613,12 +1647,13 @@ instance IRRep r => RenameE (Type r) instance IRRep r => GenericE (Expr r) where type RepE (Expr r) = EitherE2 - ( EitherE5 + ( EitherE6 {- App -} (WhenCore r (EffTy r `PairE` Atom r `PairE` ListE (Atom r))) {- TabApp -} (Type r `PairE` Atom r `PairE` ListE (Atom r)) {- Case -} (Atom r `PairE` ListE (Alt r) `PairE` EffTy r) {- Atom -} (Atom r) {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) + {- Block -} (EffTy r `PairE` Block r) ) ( EitherE5 {- TabCon -} (MaybeE (WhenCore r Dict) `PairE` Type r `PairE` ListE (Atom r)) @@ -1632,6 +1667,7 @@ instance IRRep r => GenericE (Expr r) where Case e alts effTy -> Case0 $ Case2 (e `PairE` ListE alts `PairE` effTy) Atom x -> Case0 $ Case3 (x) TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) + Block et block -> Case0 $ Case5 (et `PairE` block) TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) PrimOp op -> Case1 $ Case1 op ApplyMethod et d i xs -> Case1 $ Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) @@ -1642,9 +1678,10 @@ instance IRRep r => GenericE (Expr r) where Case0 case0 -> case case0 of Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> App et f xs Case1 (t `PairE` f `PairE` ListE xs) -> TabApp t f xs - Case2 (e `PairE` ListE alts `PairE` effTy) -> Case e alts effTy + Case2 (e `PairE` ListE alts `PairE` effTy) -> Case e alts effTy Case3 (x) -> Atom x Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> TopApp et f xs + Case5 (et `PairE` block) -> Block et block _ -> error "impossible" Case1 case1 -> case case1 of Case0 (d `PairE` ty `PairE` ListE xs) -> TabCon (fromMaybeE d) ty xs @@ -1725,7 +1762,6 @@ instance GenericOp VectorOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp VectorOp where toPrimOp = VectorOp instance IRRep r => GenericE (VectorOp r) where type RepE (VectorOp r) = GenericOpRep (OpConst VectorOp r) r fromE = fromEGenericOpRep @@ -1754,7 +1790,6 @@ instance GenericOp MemOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp MemOp where toPrimOp = MemOp instance IRRep r => GenericE (MemOp r) where type RepE (MemOp r) = GenericOpRep (OpConst MemOp r) r fromE = fromEGenericOpRep @@ -1797,7 +1832,6 @@ instance GenericOp MiscOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp MiscOp where toPrimOp = MiscOp instance IRRep r => GenericE (MiscOp r) where type RepE (MiscOp r) = GenericOpRep (OpConst MiscOp r) r fromE = fromEGenericOpRep @@ -2004,7 +2038,7 @@ instance Semigroup (Cache n) where Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) instance GenericE (LamExpr r) where - type RepE (LamExpr r) = Abs (Nest (Binder r)) (Block r) + type RepE (LamExpr r) = Abs (Nest (Binder r)) (Expr r) fromE (LamExpr b block) = Abs b block {-# INLINE fromE #-} toE (Abs b block) = LamExpr b block diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 9f4c27e54..f18fd4faf 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -138,26 +138,21 @@ vectorizeLoopsDestBlock (Abs (destb:>destTy) body) = do destTy' <- renameM destTy withFreshBinder (getNameHint destb) destTy' \destb' -> do extendRenamer (destb @> binderName destb') do - Abs destb' <$> buildBlock (vectorizeLoopsBlock body) - -vectorizeLoopsBlock :: (Emits o) - => Block SimpIR i -> TopVectorizeM i o (SAtom o) -vectorizeLoopsBlock (Abs decls ans) = - vectorizeLoopsDecls decls $ renameM ans + Abs destb' <$> buildBlock (vectorizeLoopsExpr body) vectorizeLoopsDecls :: (Emits o) => Nest SDecl i i' -> TopVectorizeM i' o a -> TopVectorizeM i o a vectorizeLoopsDecls nest cont = case nest of Empty -> cont - Nest (Let b (DeclBinding ann expr)) rest -> do - v <- emitDecl (getNameHint b) ann =<< vectorizeLoopsExpr expr + Nest (Let b (DeclBinding _ expr)) rest -> do + v <- emitToVar =<< vectorizeLoopsExpr expr extendSubst (b @> atomVarName v) $ vectorizeLoopsDecls rest cont vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) vectorizeLoopsLamExpr (LamExpr bs body) = case bs of - Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body) + Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsExpr body) Nest (b:>ty) rest -> do ty' <- renameM ty withFreshBinder (getNameHint b) ty' \b' -> do @@ -165,12 +160,13 @@ vectorizeLoopsLamExpr (LamExpr bs body) = case bs of LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body return $ LamExpr (Nest b' bs') body' -vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) +vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) vectorizeLoopsExpr expr = do vectorByteWidth <- askVectorByteWidth narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth case expr of + Block _ (Abs decls body) -> vectorizeLoopsDecls decls $ vectorizeLoopsExpr body PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do sz <- simplifyIxSize =<< renameM ixty case sz of @@ -182,9 +178,8 @@ vectorizeLoopsExpr expr = do let vn = n `div` loopWidth body' <- vectorizeSeq loopWidth ixty body dest' <- renameM dest - seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body' - return $ PrimOp $ DAMOp seqOp) - else renameM expr) + emitExpr =<< mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body') + else renameM expr >>= emitExpr) `catchErr` \err -> do let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr ctx = mempty { messageCtx = [msg] } @@ -198,8 +193,8 @@ vectorizeLoopsExpr expr = do lam <- buildEffLam noHint itemTy \hb refb -> extendRenamer (hb' @> atomVarName hb) do extendRenamer (refb' @> atomVarName refb) do - vectorizeLoopsBlock body - PrimOp . Hof <$> mkTypedHof (RunReader item' lam) + vectorizeLoopsExpr body + emitExpr =<< mkTypedHof (RunReader item' lam) PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do dest' <- renameM dest @@ -210,24 +205,24 @@ vectorizeLoopsExpr expr = do extendRenamer (hb' @> atomVarName hb) do extendRenamer (refb' @> atomVarName refb) do extendCommuteMap (atomVarName hb) commutativity do - vectorizeLoopsBlock body - PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam) - _ -> renameM expr + vectorizeLoopsExpr body + emitExpr =<< mkTypedHof (RunWriter (Just dest') monoid' lam) + _ -> renameM expr >>= emitExpr where - recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) + recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do effs' <- renameM effs ixty' <- renameM ixty dest' <- renameM dest body' <- vectorizeLoopsLamExpr body - return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body' + emitExpr $ Seq effs' dir ixty' dest' body' recurSeq _ = error "Impossible" simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) => IxType SimpIR n -> m n (Maybe Word32) simplifyIxSize ixty = do sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size [] - reduceBlock sizeMethod >>= \case + reduceExpr sizeMethod >>= \case Just (IdxRepVal n) -> return $ Just n _ -> return Nothing {-# INLINE simplifyIxSize #-} @@ -261,7 +256,7 @@ isAdditionMonoid monoid = do BaseMonoid { baseEmpty = (Con (Lit l)) , baseCombine = BinaryLamExpr (b1:>_) (b2:>_) body } <- Just monoid unless (_isZeroLit l) Nothing - PrimOp (BinOp op (Var b1') (Var b2')) <- exprBlock body + PrimOp (BinOp op (Var b1') (Var b2')) <- return body unless (op `elem` [P.IAdd, P.FAdd]) Nothing case (binderName b1, atomVarName b1', binderName b2, atomVarName b2') of -- Checking the raw names here because (i) I don't know how to convince the @@ -333,7 +328,7 @@ vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do -- probably be a separate pass. i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd] extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $ - vectorizeBlock body $> UnitVal + vectorizeExpr body $> UnitVal vectorizeSeq _ _ _ = error "expected a unary lambda expression" newtype VectorizeM i o a = @@ -365,7 +360,7 @@ vectorizeLamExpr :: LamExpr SimpIR i -> [Stability] vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of (Empty, []) -> do LamExpr Empty <$> buildBlock (do - vectorizeBlock body >>= \case + vectorizeExpr body >>= \case (VVal _ ans) -> return ans (VRename v) -> Var <$> toAtomVar v) (Nest (b:>ty) rest, (stab:stabs)) -> do @@ -378,21 +373,10 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of return $ LamExpr (Nest b' rest') body' _ -> error "Zip error" -vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) -vectorizeBlock block@(Abs decls (ans :: SAtom i')) = - addVectErrCtx "vectorizeBlock" ("Block:\n" ++ pprint block) $ - go decls - where - go :: Emits o => Nest SDecl i i' -> VectorizeM i o (VAtom o) - go = \case - Empty -> vectorizeAtom ans - Nest (Let b (DeclBinding _ expr)) rest -> do - v <- vectorizeExpr expr - extendSubst (b @> v) $ go rest - vectorizeExpr :: Emits o => SExpr i -> VectorizeM i o (VAtom o) vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do case expr of + Block _ block -> vectorizeBlock block TabApp _ tbl [ix] -> do VVal Uniform tbl' <- vectorizeAtom tbl VVal Contiguous ix' <- vectorizeAtom ix @@ -401,13 +385,19 @@ vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Varying <$> emitOp (VectorOp $ VectorIdx tbl' ix' vty) + VVal Varying <$> emitExpr (VectorIdx tbl' ix' vty) tblTy -> do throwVectErr $ "bad type: " ++ pprint tblTy ++ "\ntbl' : " ++ pprint tbl' Atom atom -> vectorizeAtom atom PrimOp op -> vectorizePrimOp op _ -> throwVectErr $ "Cannot vectorize expr: " ++ pprint expr +vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) +vectorizeBlock (Abs Empty body) = vectorizeExpr body +vectorizeBlock (Abs (Nest (Let b (DeclBinding _ rhs)) rest) body) = do + v <- vectorizeExpr rhs + extendSubst (b @> v) $ vectorizeBlock (Abs rest body) + vectorizeDAMOp :: Emits o => DAMOp SimpIR i -> VectorizeM i o (VAtom o) vectorizeDAMOp op = case op of @@ -415,11 +405,11 @@ vectorizeDAMOp op = VVal vref ref <- vectorizeAtom ref' sval@(VVal vval val) <- vectorizeAtom val' VVal Uniform <$> case (vref, vval) of - (Uniform , Uniform ) -> emitExpr $ PrimOp $ DAMOp $ Place ref val + (Uniform , Uniform ) -> emitExpr $ Place ref val (Uniform , _ ) -> throwVectErr "Write conflict? This should never happen!" (Varying , _ ) -> throwVectErr "Vector scatter not implemented" - (Contiguous, Varying ) -> emitExpr $ PrimOp $ DAMOp $ Place ref val - (Contiguous, Contiguous) -> emitExpr . PrimOp . DAMOp . Place ref =<< ensureVarying sval + (Contiguous, Varying ) -> emitExpr $ Place ref val + (Contiguous, Contiguous) -> emitExpr . Place ref =<< ensureVarying sval _ -> throwVectErr "Not implemented yet" _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op @@ -430,7 +420,7 @@ vectorizeRefOp ref' op = -- TODO A contiguous reference becomes a vector load producing a varying -- result. VVal Uniform ref <- vectorizeAtom ref' - VVal Uniform <$> emitOp (RefOp ref MAsk) + VVal Uniform <$> emitExpr (RefOp ref MAsk) MExtend basemonoid' x' -> do VVal refStab ref <- vectorizeAtom ref' VVal xStab x <- vectorizeAtom x' @@ -447,7 +437,7 @@ vectorizeRefOp ref' op = Contiguous -> do vectorizeBaseMonoid basemonoid' Varying xStab s -> throwVectErr $ "Cannot vectorize reference with loop-varying stability " ++ show s - VVal Uniform <$> emitOp (RefOp ref $ MExtend basemonoid x) + VVal Uniform <$> emitExpr (RefOp ref $ MExtend basemonoid x) IndexRef _ i' -> do VVal Uniform ref <- vectorizeAtom ref' VVal Contiguous i <- vectorizeAtom i' @@ -456,7 +446,7 @@ vectorizeRefOp ref' op = vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Contiguous <$> emitOp (VectorOp $ VectorSubref ref i vty) + VVal Contiguous <$> emitExpr (VectorSubref ref i vty) refTy -> do throwVectErr do "bad type: " ++ pprint refTy ++ "\nref' : " ++ pprint ref' @@ -482,7 +472,7 @@ vectorizePrimOp op = case op of sx@(VVal vx x) <- vectorizeAtom arg let v = case vx of Uniform -> Uniform; _ -> Varying x' <- if vx /= v then ensureVarying sx else return x - VVal v <$> emitOp (UnOp opk x') + VVal v <$> emitExpr (UnOp opk x') BinOp opk arg1 arg2 -> do sx@(VVal vx x) <- vectorizeAtom arg1 sy@(VVal vy y) <- vectorizeAtom arg2 @@ -493,7 +483,7 @@ vectorizePrimOp op = case op of _ -> Varying x' <- if v == Varying then ensureVarying sx else return x y' <- if v == Varying then ensureVarying sy else return y - VVal v <$> emitOp (BinOp opk x' y') + VVal v <$> emitExpr (BinOp opk x' y') MiscOp (CastOp tyArg arg) -> do ty <- vectorizeType tyArg VVal vx x <- vectorizeAtom arg @@ -502,28 +492,29 @@ vectorizePrimOp op = case op of Varying -> getVectorType ty Contiguous -> return ty ProdStability _ -> throwVectErr "Unexpected cast of product type" - VVal vx <$> emitOp (MiscOp $ CastOp ty' x) + VVal vx <$> emitExpr (CastOp ty' x) DAMOp op' -> vectorizeDAMOp op' RefOp ref op' -> vectorizeRefOp ref op' MemOp (PtrOffset arg1 arg2) -> do VVal Uniform ptr <- vectorizeAtom arg1 VVal Contiguous off <- vectorizeAtom arg2 - VVal Contiguous <$> emitOp (MemOp $ PtrOffset ptr off) + VVal Contiguous <$> emitExpr (PtrOffset ptr off) MemOp (PtrLoad arg) -> do VVal Contiguous ptr <- vectorizeAtom arg BaseTy (PtrType (addrSpace, a)) <- return $ getType ptr BaseTy av <- getVectorType $ BaseTy a - ptr' <- emitOp $ MiscOp $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr - VVal Varying <$> emitOp (MemOp $ PtrLoad ptr') + ptr' <- emitExpr $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr + VVal Varying <$> emitExpr (PtrLoad ptr') -- Vectorizing IO might not always be safe! Here, we depend on vectorizeOp -- being picky about the IO-inducing ops it supports, and expect it to -- complain about FFI calls and the like. Hof (TypedHof _ (RunIO body)) -> do -- TODO: buildBlockAux? Abs decls (LiftE vy `PairE` y) <- buildScoped do - VVal vy y <- vectorizeBlock body + VVal vy y <- vectorizeExpr body return $ PairE (LiftE vy) y - VVal vy <$> emitHof (RunIO $ Abs decls y) + block <- mkBlock (Abs decls y) + VVal vy <$> emitHof (RunIO block) _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op vectorizeType :: SType i -> VectorizeM i o (SType o) @@ -571,16 +562,16 @@ ensureVarying (VVal s val) = case s of Varying -> return val Uniform -> do vty <- getVectorType $ getType val - emitOp $ VectorOp $ VectorBroadcast val vty + emitExpr $ VectorBroadcast val vty -- Note that the implementation of this case will depend on val's type. Contiguous -> do let ty = getType val vty <- getVectorType ty case ty of BaseTy (Scalar sbt) -> do - bval <- emitOp $ VectorOp $ VectorBroadcast val vty - iota <- emitOp $ VectorOp $ VectorIota vty - emitOp $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota + bval <- emitExpr $ VectorBroadcast val vty + iota <- emitExpr $ VectorIota vty + emitExpr $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota _ -> throwVectErr "Not implemented" ProdStability _ -> throwVectErr "Not implemented" ensureVarying (VRename v) = do From de88bf8bfcf164fc90d603c61539274455498f96 Mon Sep 17 00:00:00 2001 From: Dougal Date: Sun, 22 Oct 2023 21:06:15 -0400 Subject: [PATCH 05/41] Factor out the way Simplify handles ACase. --- src/lib/Simplify.hs | 209 +++++++++++++++++++++----------------------- 1 file changed, 98 insertions(+), 111 deletions(-) diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 09c966548..b597ffa9f 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -104,6 +104,75 @@ tryAsDataAtom atom = do where notData = error $ "Not runtime-representable data: " ++ pprint atom +data WithSubst (e::E) (o::S) where + WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o + +data ConcreteCAtom (n::S) = + CCCon (WithSubst CAtom n) -- can't be Stuck or SimpInCore + | CCSimpInCore (SimpInCore n) -- can't be ACase + | CCNoInlineFun (CAtomVar n) (CType n) (CAtom n) + | CCFFIFun (CorePiType n) (TopFunName n) + +-- Yields to the continuation a term with a concrete CoreIR constructor, +-- or LiftSimpFun, liftSimp, or TabLam. +forceConstructor + :: Emits o + => CAtom i + -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceConstructor atom cont = withDistinct case atom of + Stuck stuck -> forceStuck stuck cont + SimpInCore lifted -> case lifted of + ACase e alts resultTy -> do + e' <- substM e + resultTy' <- substM resultTy + defuncCase e' resultTy' \i x -> do + Abs b body <- return $ alts !! i + extendSubst (b@>SubstVal x) do + forceConstructor body cont + _ -> do + lifted' <- substM lifted + cont $ CCSimpInCore lifted' + _ -> do + Distinct <- getDistinct + subst <- getSubst + cont $ CCCon $ WithSubst subst atom + +forceStuck + :: Emits o + => CStuck i + -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceStuck stuck cont = withDistinct case stuck of + StuckVar v -> lookupSubstM (atomVarName v) >>= \case + SubstVal x -> dropSubst $ forceConstructor x cont + Rename v' -> lookupAtomName v' >>= \case + LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x cont + NoinlineFun t f -> do + v'' <- toAtomVar v' + cont $ CCNoInlineFun v'' t f + FFIFunBound t f -> cont $ CCFFIFun t f + _ -> error "shouldn't have other CVars left" + -- TODO: figure out how to de-dup these cases with their Expr counterpart + StuckProject _ i x -> do + ty <- substM $ getType stuck + forceStuck x \case + CCSimpInCore (LiftSimp _ x') -> do + x'' <- proj i x' + cont $ CCSimpInCore $ LiftSimp (sink ty) x'' + CCCon (WithSubst s con) -> withSubst s case con of + ProdVal xs -> forceConstructor (xs!!i) cont + DepPair l r _ -> forceConstructor ([l, r]!!i) cont + _ -> error "Can't project stuck term" + _ -> error "Can't project stuck term" + StuckUnwrap _ x -> forceStuck x \case + CCCon (WithSubst s con) -> withSubst s case con of + NewtypeCon _ x' -> forceConstructor x' cont + _ -> error "can't unwrap stuck term" + _ -> error "can't unwrap stuck term" + InstantiatedGiven _ _ _ -> error "shouldn't have this left" + SuperclassProj _ _ _ -> error "shouldn't have this left" + forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) forceTabLam (PairE ixTy (Abs b ab)) = buildFor (getNameHint b) Fwd ixTy \v -> do @@ -315,8 +384,7 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of simplifyApp ty' f xs' TabApp _ f xs -> do xs' <- mapM simplifyAtom xs - f' <- simplifyAtom f - simplifyTabApp f' xs' + simplifyTabApp f xs' Atom x -> simplifyAtom x PrimOp op -> simplifyOp op ApplyMethod (EffTy _ ty) dict i xs -> do @@ -379,6 +447,7 @@ defuncCaseCore scrut resultTy cont = do let xCoreTy = altBinderTys !! i x' <- liftSimpAtom (sink xCoreTy) x cont i x' + -- TODO: we should use forceConstructor here Nothing -> case trySelectBranch scrut of Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg Nothing -> go scrut where @@ -449,61 +518,21 @@ simplifyAlt split ty cont = do simplifyApp :: forall i o. Emits o => CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyApp resultTy f xs = case f of - Lam (CoreLamExpr _ lam) -> fast lam - _ -> slow =<< simplifyAtomAndInline f - where - fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o) - fast lam = withInstantiated lam xs \body -> simplifyExpr body - - slow :: CAtom o -> SimplifyM i o (CAtom o) - slow = \case - Lam (CoreLamExpr _ lam) -> dropSubst $ fast lam - SimpInCore (ACase e alts _) -> dropSubst do - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - simplifyApp (sink resultTy) body xs' - SimpInCore (LiftSimpFun _ lam) -> do - xs' <- mapM toDataAtomIgnoreRecon xs - result <- instantiate lam xs' >>= emitExpr - liftSimpAtom resultTy result - Var v -> do - lookupAtomName (atomVarName v) >>= \case - NoinlineFun _ _ -> simplifyTopFunApp v xs - FFIFunBound _ f' -> do - xs' <- mapM toDataAtomIgnoreRecon xs - liftSimpAtom resultTy =<< naryTopApp f' xs' - b -> error $ "Should only have noinline functions left " ++ pprint b - atom -> error $ "Unexpected function: " ++ pprint atom - --- | Like `simplifyAtom`, but will try to inline function definitions found --- in the environment. The only exception is when we're going to differentiate --- and the function has a custom derivative rule defined. --- TODO(dougalm): do we still need this? -simplifyAtomAndInline :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of - Var v -> do - env <- getSubst - case env ! atomVarName v of - Rename v' -> doInline =<< toAtomVar v' - SubstVal (Var v') -> doInline v' - SubstVal x -> return x - -- This is a hack because we weren't normalize the unwrapping of - -- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to - -- normalize and inline. - Stuck (StuckProject _ i x) -> do - x' <- simplifyStuck x >>= reduceProj i - dropSubst $ simplifyAtomAndInline x' - _ -> simplifyAtom atom >>= \case - Var v -> doInline v - ans -> return ans - where - doInline v = do - lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtomAndInline x - _ -> return $ Var v +simplifyApp resultTy f xs = forceConstructor f \f' -> do + xs' <- mapM sinkM xs + case f' of + CCCon (WithSubst s (Lam (CoreLamExpr _ lam))) -> + withSubst s $ withInstantiated lam xs' \body -> + simplifyExpr body + CCSimpInCore (LiftSimpFun _ lam) -> do + xs'' <- mapM toDataAtomIgnoreRecon xs' + result <- instantiate lam xs'' >>= emitExpr + liftSimpAtom (sink resultTy) result + CCNoInlineFun v _ _ -> simplifyTopFunApp v xs' + CCFFIFun _ f'' -> do + xs'' <- mapM toDataAtomIgnoreRecon xs' + liftSimpAtom (sink resultTy) =<< naryTopApp f'' xs'' + _ -> error "not a function" simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n) simplifyTopFunApp fName xs = do @@ -547,33 +576,23 @@ specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do naryApp f' staticArgs' simplifyTabApp :: forall i o. Emits o - => CAtom o -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyTabApp f [] = return f -simplifyTabApp f@(SimpInCore sic) xs = case sic of - TabLam _ _ -> do - case fromNaryTabLam (length xs) f of + => CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) +simplifyTabApp f [] = simplifyAtom f +simplifyTabApp f xs = forceConstructor f \case + CCSimpInCore sic@(TabLam _ _) -> do + case fromNaryTabLam (length xs) (SimpInCore sic) of Just (bsCount, ab) -> do - let (xsPref, xsRest) = splitAt bsCount xs + (xsPref, xsRest) <- splitAt bsCount <$> mapM sinkM xs xsPref' <- mapM toDataAtomIgnoreRecon xsPref block' <- instantiate ab xsPref' atom <- emitDecls block' - simplifyTabApp atom xsRest + dropSubst $ simplifyTabApp atom xsRest Nothing -> error "should never happen" - ACase e alts ty -> dropSubst do - resultTy <- typeOfTabApp ty xs - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - body' <- substM body - simplifyTabApp body' xs' - LiftSimp _ f' -> do - fTy <- return $ getType f - resultTy <- typeOfTabApp fTy xs - xs' <- mapM toDataAtomIgnoreRecon xs + CCSimpInCore (LiftSimp fTy f') -> do + resultTy <- typeOfTabApp fTy (sink<$>xs) + xs' <- mapM (toDataAtomIgnoreRecon . sink) xs liftSimpAtom resultTy =<< naryTabApp f' xs' - LiftSimpFun _ _ -> error "not implemented" -simplifyTabApp f _ = error $ "Unexpected table: " ++ pprint f + _ -> error "not a table" simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) simplifyIxType (IxType t ixDict) = do @@ -625,40 +644,8 @@ ixMethodType method absDict = do let allBs = extraArgBs >>> methodArgs return $ PiType allBs (EffTy Pure resultTy) --- TODO: do we even need this, or is it just a glorified `SubstM`? simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtom atom = confuseGHC >>= \_ -> case atom of - Stuck e -> simplifyStuck e - Lam _ -> substM atom - DepPair x y ty -> DepPair <$> simplifyAtom x <*> simplifyAtom y <*> substM ty - Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda") - Eff eff -> Eff <$> substM eff - PtrVar t v -> PtrVar t <$> substM v - DictCon _ -> substM atom - NewtypeCon _ _ -> substM atom - SimpInCore _ -> substM atom - TypeAsAtom _ -> substM atom - -simplifyStuck :: CStuck i -> SimplifyM i o (CAtom o) -simplifyStuck = \case - StuckVar v -> simplifyVar v - StuckProject _ i x -> reduceProj i =<< simplifyStuck x - stuck -> substM (Stuck stuck) - -simplifyVar :: AtomVar CoreIR i -> SimplifyM i o (CAtom o) -simplifyVar v = do - env <- getSubst - case env ! atomVarName v of - SubstVal x -> return x - Rename v' -> do - AtomNameBinding bindingInfo <- lookupEnv v' - let ty = getType bindingInfo - case bindingInfo of - -- Functions get inlined only at application sites - LetBound (DeclBinding _ _) | isFun -> return $ Var $ AtomVar v' ty - where isFun = case ty of Pi _ -> True; _ -> False - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtom x - _ -> return $ Var $ AtomVar v' ty +simplifyAtom = substM -- Assumes first order (args/results are "data", allowing newtypes), monormophic simplifyLam From d80f318b9ea90f32420c9aa49bc935ae2aed6324 Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 23 Oct 2023 11:02:47 -0400 Subject: [PATCH 06/41] Add a `StuckTabApp` case to `Stuck` --- src/lib/Builder.hs | 12 ++++++++++++ src/lib/CheapReduction.hs | 24 +++++++++++++++++++++++- src/lib/CheckType.hs | 7 +++++++ src/lib/Imp.hs | 6 ++++++ src/lib/Inference.hs | 11 ----------- src/lib/Linearize.hs | 2 ++ src/lib/OccAnalysis.hs | 16 +++++++++++----- src/lib/PPrint.hs | 1 + src/lib/QueryTypePure.hs | 1 + src/lib/Simplify.hs | 8 ++++++++ src/lib/Transpose.hs | 4 ++-- src/lib/Types/Core.hs | 18 +++++++++++------- src/lib/Vectorize.hs | 2 ++ 13 files changed, 86 insertions(+), 26 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 52425566d..5574aa44c 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -129,6 +129,18 @@ buildScopedAssumeNoDecls cont = do _ -> error "Expected no decl emissions" {-# INLINE buildScopedAssumeNoDecls #-} +withReducibleEmissions + :: (ScopableBuilder r m, Builder r m, HasNamesE e, SubstE AtomSubstVal e) + => String + -> (forall o' . (Emits o', DExt o o') => m o' (e o')) + -> m o (e o) +withReducibleEmissions msg cont = do + withDecls <- buildScoped cont + reduceWithDecls withDecls >>= \case + Just t -> return t + _ -> throw TypeErr msg +{-# INLINE withReducibleEmissions #-} + -- === "Hoisting" top-level builder class === -- `emitHoistedEnv` lets you emit top env fragments, like cache entries or diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c7fd96589..03d60730b 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -109,8 +109,14 @@ reduceExprM = \case case (ty, val) of (BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v _ -> empty + TabApp ty tab xs -> do + ty' <- substM ty + xs' <- mapM substM xs + tab' <- substM tab + case tab' of + Stuck tab'' -> return $ Stuck $ StuckTabApp ty' tab'' xs' + _ -> error "not a table" -- what about RepVal? TopApp _ _ _ -> empty - TabApp _ _ _ -> empty Case _ _ _ -> empty TabCon _ _ _ -> empty PrimOp _ -> empty @@ -188,6 +194,11 @@ typeOfApp (Pi piTy) xs = withSubstReaderT $ withInstantiated piTy xs \(EffTy _ ty) -> substM ty typeOfApp _ _ = error "expected a pi type" +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) +typeOfTabApp (TabPi piTy) xs = withSubstReaderT $ + withInstantiated piTy xs \ty -> substM ty +typeOfTabApp _ _ = error "expected a TabPi type" + repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) repValAtom (RepVal ty tree) = case ty of ProdTy ts -> case tree of @@ -220,6 +231,13 @@ reduceUnwrapM = \case _ -> error "expected a newtype" _ -> empty +reduceTabAppM :: IRRep r => Atom r o -> [Atom r o] -> ReducerM i o (Atom r o) +reduceTabAppM tab xs = case tab of + Stuck tab' -> do + ty <- typeOfTabApp (getType tab') xs + return $ Stuck $ StuckTabApp ty tab' xs + _ -> error $ "not a table" ++ pprint tab + unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) unwrapNewtypeType = \case Nat -> return (NatCon, IdxRepTy) @@ -616,6 +634,10 @@ reduceStuck = \case StuckUnwrap _ x -> do x' <- reduceStuck x dropSubst $ reduceUnwrapM x' + StuckTabApp _ f xs -> do + f' <- reduceStuck f + xs' <- mapM substM xs + dropSubst $ reduceTabAppM f' xs' InstantiatedGiven _ f xs -> do xs' <- mapM substM xs f' <- reduceStuck f diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 6f3b2093a..ad6698a9b 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -320,6 +320,13 @@ instance IRRep r => CheckableE r (Stuck r) where StuckProject resultTy i x -> do Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x) return $ StuckProject resultTy' i' x' + StuckTabApp reqTy f xs -> do + reqTy' <- reqTy |: TyKind + (f', tabTy) <- checkAndGetType f + xs' <- mapM checkE xs + ty' <- checkTabApp tabTy xs' + checkTypesEq reqTy' ty' + return $ StuckTabApp reqTy' f' xs' InstantiatedGiven resultTy given args -> do resultTy' <- resultTy |: TyKind (given', Pi piTy) <- checkAndGetType given diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 630479799..9ab013f23 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -865,9 +865,15 @@ atomToRepVal x = RepVal (getType x) <$> go x where Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" + -- TODO: I think we want to be able to rule this one out by insisting that + -- RepValAtom is itself part of Stuck and it can't represent a product. Stuck (StuckProject _ i val) -> do Branch ts <- go $ Stuck val return $ ts !! i + Stuck (StuckTabApp _ f xs) -> do + f' <- atomToRepVal $ Stuck f + RepVal _ t <- naryIndexRepVal f' (toList xs) + return t -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index b89c77bd1..276bd9236 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -1089,17 +1089,6 @@ checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ "Dependent functions can only be applied to fully evaluated expressions. " ++ "Bind the argument to a name before you apply the function." -withReducibleEmissions - :: Zonkable e - => String - -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) - -> InfererM i o (e o) -withReducibleEmissions msg cont = do - withDecls <- buildScoped cont - reduceWithDecls withDecls >>= \case - Just t -> return t - _ -> throw TypeErr msg - -- === sorting case alternatives === data IndexedAlt n = IndexedAlt CaseAltIndex (Alt CoreIR n) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 4fbae982c..426615649 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -334,7 +334,9 @@ linearizeAtom atom = case atom of activePrimalIdx v' >>= \case Nothing -> withZeroT $ return (Var v') Just idx -> return $ WithTangent (Var v') $ getTangentArg idx + -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x) + Stuck (StuckTabApp t f xs) -> linearizeExpr $ TabApp t (Stuck f) xs RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 59b7a4384..711374df1 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -255,6 +255,11 @@ instance HasOCC SStuck where ty' <- occTy ty return $ StuckVar (AtomVar n ty') StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x + StuckTabApp t array ixs -> do + t' <- occTy t + (a', ixs') <- occIdxs a ixs + array' <- occ a' array + return $ StuckTabApp t' array' ixs' instance HasOCC SType where occ a ty = runOCCMVisitor a $ visitTypePartial ty @@ -360,7 +365,7 @@ instance HasOCC SExpr where return $ Block effTy' (Abs decls' ans') TabApp t array ixs -> do t' <- occTy t - (a', ixs') <- go a ixs + (a', ixs') <- occIdxs a ixs array' <- occ a' array return $ TabApp t' array' ixs' Case scrut alts (EffTy effs ty) -> do @@ -376,10 +381,11 @@ instance HasOCC SExpr where ref' <- occ a ref PrimOp . RefOp ref' <$> occ a op expr -> occGeneric a expr - where - go acc [] = return (acc, []) - go acc (ix:ixs) = do - (acc', ixs') <- go acc ixs + +occIdxs :: Access n -> [SAtom n] -> OCCM n (Access n, [SAtom n]) +occIdxs acc [] = return (acc, []) +occIdxs acc (ix:ixs) = do + (acc', ixs') <- occIdxs acc ixs (summ, ix') <- occurrenceAndSummary ix return (location summ acc', ix':ixs') diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 7e8708892..968193a00 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -265,6 +265,7 @@ instance IRRep r => PrettyPrec (Stuck r n) where prettyPrec = \case StuckVar v -> atPrec ArgPrec $ p v StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp _ f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 99b4687b8..258bbb9b3 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -102,6 +102,7 @@ instance IRRep r => HasType r (Stuck r) where getType = \case StuckVar (AtomVar _ t) -> t StuckProject t _ _ -> t + StuckTabApp t _ _ -> t StuckUnwrap t _ -> t InstantiatedGiven t _ _ -> t SuperclassProj t _ _ -> t diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b597ffa9f..0f71998f4 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -165,6 +165,14 @@ forceStuck stuck cont = withDistinct case stuck of DepPair l r _ -> forceConstructor ([l, r]!!i) cont _ -> error "Can't project stuck term" _ -> error "Can't project stuck term" + StuckTabApp _ f xs -> do + ty <- substM $ getType stuck + xs' <- forM xs \x -> toDataAtomIgnoreRecon =<< substM x + forceStuck f \case + CCSimpInCore (LiftSimp _ f') -> do + result <- naryTabApp f' (sink<$>xs') + cont $ CCSimpInCore $ LiftSimp (sink ty) result + _ -> error "not a table" -- what about table lambda? StuckUnwrap _ x -> forceStuck x \case CCCon (WithSubst s con) -> withSubst s case con of NewtypeCon _ x' -> forceConstructor x' cont diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 0cde8255b..75e14ec7c 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -313,8 +313,8 @@ transposeAtom atom ct = case atom of return () LinRef ref -> emitCTToRef ref ct LinTrivial -> return () - Stuck (StuckProject _ _ _) -> undefined - -- Stuck (StuckProject _ i' x') -> do + Stuck (StuckProject _ _ _) -> error "not implemented" + Stuck (StuckTabApp _ _ _) -> error "not implemented" -- let (idxs, v) = asNaryProj i' x' -- lookupSubstM (atomVarName v) >>= \case -- RenameNonlin _ -> error "an error, probably" diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 2ef17f6d9..95380bbca 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -74,6 +74,7 @@ data Type (r::IR) (n::S) where data Stuck (r::IR) (n::S) where StuckVar :: AtomVar r n -> Stuck r n StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n + StuckTabApp :: Type r n -> Stuck r n -> [Atom r n] -> Stuck r n StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n @@ -1552,26 +1553,29 @@ instance IRRep r => AlphaHashableE (Atom r) instance IRRep r => RenameE (Atom r) instance IRRep r => GenericE (Stuck r) where - type RepE (Stuck r) = EitherE5 + type RepE (Stuck r) = EitherE6 {- StuckVar -} (AtomVar r) {- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r) + {- StuckTabApp -} (Type r `PairE` Stuck r `PairE` ListE (Atom r)) {- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck)) {- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom)) {- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck)) fromE = \case StuckVar v -> Case0 v StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e - StuckUnwrap t e -> Case2 $ WhenIRE $ t `PairE` e - InstantiatedGiven t e xs -> Case3 $ WhenIRE $ t `PairE` e `PairE` ListE xs - SuperclassProj t i e -> Case4 $ WhenIRE $ t `PairE` LiftE i `PairE` e + StuckTabApp t f x -> Case2 $ t `PairE` f `PairE` ListE x + StuckUnwrap t e -> Case3 $ WhenIRE $ t `PairE` e + InstantiatedGiven t e xs -> Case4 $ WhenIRE $ t `PairE` e `PairE` ListE xs + SuperclassProj t i e -> Case5 $ WhenIRE $ t `PairE` LiftE i `PairE` e {-# INLINE fromE #-} toE = \case Case0 v -> StuckVar v Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e - Case2 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e - Case3 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs - Case4 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e + Case2 (t `PairE` f `PairE` ListE x) -> StuckTabApp t f x + Case3 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e + Case4 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs + Case5 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e _ -> error "impossible" {-# INLINE toE #-} diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index f18fd4faf..3be5058b1 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -536,6 +536,8 @@ vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do _ -> throwVectErr "Invalid projection" x'' <- reduceProj i x' return $ VVal ov x'' + -- TODO: think about this case + StuckTabApp _ _ _ -> throwVectErr $ "Cannot vectorize atom: " ++ pprint atom Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l _ -> do subst <- getSubst From fb544394fc9cfaf3ea5c77e4ec0cfcf5e42cee81 Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 23 Oct 2023 22:35:34 -0400 Subject: [PATCH 07/41] Embrace the `Stuck` vs `Con` distinction everywhere. `Atom` and `Type` now have exactly two cases - a "constructor" case, which admits won't-fail pattern-matching once you know something about the term, and a `Stuck` case which includes variables and various sorts of stuck expressions. --- src/lib/Algebra.hs | 9 +- src/lib/Builder.hs | 238 +++++------ src/lib/CheapReduction.hs | 386 ++++++++--------- src/lib/CheckType.hs | 305 +++++++------- src/lib/Core.hs | 37 +- src/lib/Err.hs | 8 +- src/lib/Export.hs | 36 +- src/lib/Generalize.hs | 73 ++-- src/lib/Imp.hs | 171 ++++---- src/lib/Inference.hs | 667 +++++++++++++++++------------ src/lib/Inline.hs | 71 ++-- src/lib/JAX/ToSimp.hs | 12 +- src/lib/Linearize.hs | 74 ++-- src/lib/Lower.hs | 22 +- src/lib/MTL1.hs | 2 +- src/lib/OccAnalysis.hs | 56 +-- src/lib/Optimize.hs | 32 +- src/lib/PPrint.hs | 109 +++-- src/lib/QueryType.hs | 182 ++++---- src/lib/QueryTypePure.hs | 124 +++--- src/lib/RuntimePrint.hs | 80 ++-- src/lib/Simplify.hs | 679 +++++++++++++++--------------- src/lib/Subst.hs | 2 +- src/lib/TopLevel.hs | 10 +- src/lib/Transpose.hs | 80 ++-- src/lib/Types/Core.hs | 858 +++++++++++++++++++------------------- src/lib/Vectorize.hs | 66 +-- 27 files changed, 2198 insertions(+), 2191 deletions(-) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index a8e125ecd..bf8462a83 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -20,6 +20,7 @@ import Data.Tuple (swap) import Builder hiding (sub, add, mul) import Core +import CheapReduction import Err import IRVariants import MTL1 @@ -139,8 +140,8 @@ exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPol atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o) atomAsPoly = \case - Var v -> atomVarAsPoly v - RepValAtom (RepVal _ (Leaf (IVar v' _))) -> impNameAsPoly v' + Stuck _ (Var v) -> atomVarAsPoly v + Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v' IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] _ -> empty @@ -207,10 +208,10 @@ emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (Atom SimpIR n) emitMonomial (Monomial m) = do varAtoms <- forM (toList m) \(v, e) -> case v of LeftE v' -> do - v'' <- Var <$> toAtomVar v' + v'' <- toAtom <$> toAtomVar v' ipow v'' e RightE v' -> do - let atom = RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy)) + atom <- mkStuck $ RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy)) ipow atom e foldM imul (IdxRepVal 1) varAtoms diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 5574aa44c..7a654911a 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -62,7 +62,7 @@ type Builder2 (r::IR) (m :: MonadKind2) = forall i. Builder r (m type ScopableBuilder2 (r::IR) (m :: MonadKind2) = forall i. ScopableBuilder r (m i) emitDecl :: (Builder r m, Emits n) => NameHint -> LetAnn -> Expr r n -> m n (AtomVar r n) -emitDecl _ _ (Atom (Var n)) = return n +emitDecl _ _ (Atom (Stuck _ (Var n))) = return n emitDecl hint ann expr = rawEmitDecl hint ann expr {-# INLINE emitDecl #-} @@ -82,12 +82,12 @@ emitExpr :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) emitExpr e = case toExpr e of Atom x -> return x Block _ block -> emitDecls block >>= emitExpr - expr -> Var <$> emit expr + expr -> toAtom <$> emit expr {-# INLINE emitExpr #-} emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n) emitToVar e = case toExpr e of - Atom (Var v) -> return v + Atom (Stuck _ (Var v)) -> return v expr -> emit expr {-# INLINE emitToVar #-} @@ -117,7 +117,7 @@ emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do emitExprToAtom :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) emitExprToAtom (Atom atom) = return atom -emitExprToAtom expr = Var <$> emit expr +emitExprToAtom expr = toAtom <$> emit expr {-# INLINE emitExprToAtom #-} buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m) @@ -315,13 +315,6 @@ emitTopFunBinding hint def f = do emitSourceMap :: TopBuilder m => SourceMap n -> m n () emitSourceMap sm = emitLocalModuleEnv $ mempty {envSourceMap = sm} -emitSynthCandidates :: TopBuilder m => SynthCandidates n -> m n () -emitSynthCandidates sc = emitLocalModuleEnv $ mempty {envSynthCandidates = sc} - -addInstanceSynthCandidate :: TopBuilder m => ClassName n -> InstanceName n -> m n () -addInstanceSynthCandidate className instanceName = - emitSynthCandidates $ SynthCandidates (M.singleton className [instanceName]) - updateTransposeRelation :: (Mut n, TopBuilder m) => TopFunName n -> TopFunName n -> m n () updateTransposeRelation f1 f2 = updateTopEnv $ ExtendCache $ mempty { transpositionCache = eMapSingleton f1 f2 <> eMapSingleton f2 f1} @@ -709,7 +702,7 @@ buildLamExpr (Abs bs UnitE) cont = case bs of Empty -> LamExpr Empty <$> buildBlock (cont []) Nest b rest -> do Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do - rest' <- applySubst (b@>SubstVal (Var v)) $ EmptyAbs rest + rest' <- applySubst (b@>SubstVal (toAtom v)) $ EmptyAbs rest buildLamExpr rest' \vs -> cont $ sink v : vs return $ LamExpr (Nest b' bs') body' @@ -746,9 +739,9 @@ buildCaseAlts scrut indexedAltBody = do injectAltResult :: EnvReader m => [SType n] -> Int -> Alt SimpIR n -> m n (Alt SimpIR n) injectAltResult sumTys con (Abs b body) = liftBuilder do buildAlt (binderType b) \v -> do - originalResult <- emitExpr =<< applySubst (b@>SubstVal (Var v)) body + originalResult <- emitExpr =<< applySubst (b@>SubstVal (toAtom v)) body (dataResult, nonDataResult) <- fromPairReduced originalResult - return $ PairVal dataResult $ Con $ SumCon (sinkList sumTys) con nonDataResult + return $ toAtom $ ProdCon [dataResult, Con $ SumCon (sinkList sumTys) con nonDataResult] -- TODO: consider a version with nonempty list of alternatives where we figure -- out the result type from one of the alts rather than providing it explicitly @@ -756,20 +749,20 @@ buildCase' :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n -> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l)) -> m n (Expr r n) -buildCase' scrut resultTy indexedAltBody = do - case trySelectBranch scrut of - Just (i, arg) -> do - Distinct <- getDistinct - Atom <$> indexedAltBody i (sink arg) - Nothing -> do - scrutTy <- return $ getType scrut - altBinderTys <- caseAltsBinderTys scrutTy - (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do - (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do - blk <- buildBlock $ indexedAltBody i $ Var $ sink x - return $ blk `PairE` getEffects blk - return (Abs b' body, ignoreHoistFailure $ hoist b' eff') - return $ Case scrut alts $ EffTy (mconcat effs) resultTy +buildCase' scrut resultTy indexedAltBody = case scrut of + Con con -> do + SumCon _ i arg <- return con + Distinct <- getDistinct + Atom <$> indexedAltBody i (sink arg) + Stuck _ _ -> do + scrutTy <- return $ getType scrut + altBinderTys <- caseAltsBinderTys scrutTy + (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do + (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do + blk <- buildBlock $ indexedAltBody i $ toAtom $ sink x + return $ blk `PairE` getEffects blk + return (Abs b' body, ignoreHoistFailure $ hoist b' eff') + return $ Case scrut alts $ EffTy (mconcat effs) resultTy buildCase :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n @@ -783,8 +776,8 @@ buildEffLam -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) -> m n (LamExpr r n) buildEffLam hint ty body = do - withFreshBinder noHint (TC HeapType) \h -> do - let ty' = RefTy (Var $ binderVar h) (sink ty) + withFreshBinder noHint (TyCon HeapType) \h -> do + let ty' = RefTy (toAtom $ binderVar h) (sink ty) withFreshBinder hint ty' \b -> do let ref = binderVar b hVar <- sinkM $ binderVar h @@ -809,20 +802,14 @@ buildFor :: (Emits n, ScopableBuilder r m) -> m n (Atom r n) buildFor hint dir ty body = buildForAnn hint dir ty body -buildMap :: (Emits n, ScopableBuilder r m) - => Atom r n - -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) - -> m n (Atom r n) +buildMap :: (Emits n, ScopableBuilder SimpIR m) + => SAtom n + -> (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l)) + -> m n (SAtom n) buildMap xs f = do - TabPi t <- return $ getType xs + TabPi t <- return $ getTyCon xs buildFor noHint Fwd (tabIxType t) \i -> - tabApp (sink xs) (Var i) >>= f - -unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n) -unzipTab tab = do - fsts <- liftEmitBuilder $ buildMap tab getFst - snds <- liftEmitBuilder $ buildMap tab getSnd - return (fsts, snds) + tabApp (sink xs) (toAtom i) >>= f emitRunWriter :: (Emits n, ScopableBuilder r m) @@ -882,12 +869,12 @@ buildRememberDest hint dest cont = do zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) - go = \case - BaseTy bt -> return $ Con $ Lit $ zeroLit bt - ProdTy tys -> ProdVal <$> mapM go tys - TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> - go =<< instantiate (sink tabPi) [Var i] - _ -> unreachable + go (TyCon con) = case con of + BaseType bt -> return $ Con $ Lit $ zeroLit bt + ProdType tys -> toAtom . ProdCon <$> mapM go tys + TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> + go =<< instantiate (sink tabPi) [toAtom i] + _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 Scalar Float32Type -> Float32Lit 0.0 @@ -908,15 +895,15 @@ maybeTangentType ty = liftEnvReaderT $ maybeTangentType' ty maybeTangentType' :: IRRep r => Type r n -> EnvReaderT Maybe n (Type r n) maybeTangentType' ty = case ty of - TabTy d b bodyTy -> do - refreshAbs (Abs b bodyTy) \b' bodyTy' -> do - bodyTanTy <- rec bodyTy' - return $ TabTy d b' bodyTanTy - TC con -> case con of - BaseType (Scalar Float64Type) -> return $ TC con - BaseType (Scalar Float32Type) -> return $ TC con + TyCon con -> case con of + TabPi (TabPiType d b bodyTy) -> do + refreshAbs (Abs b bodyTy) \b' bodyTy' -> do + bodyTanTy <- rec bodyTy' + return $ TabTy d b' bodyTanTy + BaseType (Scalar Float64Type) -> return $ toType con + BaseType (Scalar Float32Type) -> return $ toType con BaseType _ -> return $ UnitTy - ProdType tys -> ProdTy <$> traverse rec tys + ProdType tys -> toType . ProdType <$> traverse rec tys _ -> empty _ -> empty where rec = maybeTangentType' @@ -924,52 +911,49 @@ maybeTangentType' ty = case ty of tangentBaseMonoidFor :: (Emits n, SBuilder m) => SType n -> m n (BaseMonoid SimpIR n) tangentBaseMonoidFor ty = do zero <- zeroAt ty - adder <- liftBuilder $ buildBinaryLamExpr (noHint, ty) (noHint, ty) \x y -> addTangent (Var x) (Var y) + adder <- liftBuilder $ buildBinaryLamExpr (noHint, ty) (noHint, ty) \x y -> + addTangent (toAtom x) (toAtom y) return $ BaseMonoid zero adder addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do - case getType x of + case getTyCon x of + BaseType (Scalar _) -> emitExpr $ BinOp FAdd x y + ProdType _ -> do + xs <- getUnpacked x + ys <- getUnpacked y + toAtom . ProdCon <$> zipWithM addTangent xs ys TabPi t -> liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do - bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) - TC con -> case con of - BaseType (Scalar _) -> emitExpr $ BinOp FAdd x y - ProdType _ -> do - xs <- getUnpacked x - ys <- getUnpacked y - ProdVal <$> zipWithM addTangent xs ys - ty -> notTangent ty + bindM2 addTangent (tabApp (sink x) (toAtom i)) (tabApp (sink y) (toAtom i)) ty -> notTangent ty where notTangent ty = error $ "Not a tangent type: " ++ pprint ty symbolicTangentTy :: (EnvReader m, Fallible1 m) => CType n -> m n (CType n) symbolicTangentTy elTy = lookupSourceMap "SymbolicTangent" >>= \case Just (UTyConVar symTanName) -> do - return $ TypeCon "SymbolicTangent" symTanName $ - TyConParams [Explicit] [Type elTy] + return $ toType $ UserADTType "SymbolicTangent" symTanName $ + TyConParams [Explicit] [toAtom elTy] Nothing -> throw UnboundVarErr $ "Can't define a custom linearization with symbolic zeros: " ++ "the SymbolicTangent type is not in scope." Just _ -> throw TypeErr "SymbolicTangent should name a `data` type" symbolicTangentZero :: EnvReader m => SType n -> m n (SAtom n) -symbolicTangentZero argTy = return $ SumVal [UnitTy, argTy] 0 UnitVal +symbolicTangentZero argTy = return $ toAtom $ SumCon [UnitTy, argTy] 0 UnitVal symbolicTangentNonZero :: EnvReader m => SAtom n -> m n (SAtom n) symbolicTangentNonZero val = do ty <- return $ getType val - return $ SumVal [UnitTy, ty] 1 val + return $ toAtom $ SumCon [UnitTy, ty] 1 val -- === builder versions of common local ops === -fLitLike :: (Builder r m, Emits n) => Double -> Atom r n -> m n (Atom r n) -fLitLike x t = do - ty <- return $ getType t - case ty of - BaseTy (Scalar Float64Type) -> return $ Con $ Lit $ Float64Lit x - BaseTy (Scalar Float32Type) -> return $ Con $ Lit $ Float32Lit $ realToFrac x - _ -> error "Expected a floating point scalar" +fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n) +fLitLike x t = case getTyCon t of + BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x + BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x + _ -> error "Expected a floating point scalar" neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) neg x = emitExpr $ UnOp FNeg x @@ -1054,15 +1038,15 @@ getUnpacked atom = forM (productIdxs atom) \i -> proj i atom productIdxs :: IRRep r => Atom r n -> [Int] productIdxs atom = let positions = case getType atom of - ProdTy tys -> void tys - DepPairTy _ -> [(), ()] + TyCon (ProdType tys) -> void tys + TyCon (DepPairTy _) -> [(), ()] ty -> error $ "not a product type: " ++ pprint ty in fst <$> enumerate positions unwrapNewtype :: (Emits n, Builder CoreIR m) => CAtom n -> m n (CAtom n) -unwrapNewtype (NewtypeCon _ x) = return x +unwrapNewtype (Con (NewtypeCon _ x)) = return x unwrapNewtype x = case getType x of - NewtypeTyCon con -> do + TyCon (NewtypeTyCon con) -> do (_, ty) <- unwrapNewtypeType con emitExpr $ Unwrap ty x _ -> error "not a newtype" @@ -1070,9 +1054,11 @@ unwrapNewtype x = case getType x of proj ::(Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) proj i = \case - ProdVal xs -> return $ xs !! i - DepPair l _ _ | i == 0 -> return l - DepPair _ r _ | i == 1 -> return r + Con con -> case con of + ProdCon xs -> return $ xs !! i + DepPair l _ _ | i == 0 -> return l + DepPair _ r _ | i == 1 -> return r + _ -> error "not a product" x -> do ty <- projType i x emitExpr $ Project ty i x @@ -1097,7 +1083,7 @@ projectStructRef i x = do {-# INLINE projectStructRef #-} getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection] -getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do +getStructProjections i (TyCon (NewtypeTyCon (UserADTType _ tyConName _))) = do TyConDef _ _ _ ~(StructFields fields) <- lookupTyCon tyConName return case fields of [_] | i == 0 -> [UnwrapNewtype] @@ -1145,7 +1131,7 @@ mkApp f xs = do et <- appEffTy (getType f) xs return $ App et f xs -mkTabApp :: (EnvReader m, IRRep r) => Atom r n -> [Atom r n] -> m n (Expr r n) +mkTabApp :: (EnvReader m, IRRep r) => Atom r n -> Atom r n -> m n (Expr r n) mkTabApp xs ixs = do ty <- typeOfTabApp (getType xs) ixs return $ TabApp ty xs ixs @@ -1155,28 +1141,18 @@ mkTopApp f xs = do resultTy <- typeOfTopApp f xs return $ TopApp resultTy f xs -mkApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (CExpr n) +mkApplyMethod :: EnvReader m => CDict n -> Int -> [CAtom n] -> m n (CExpr n) mkApplyMethod d i xs = do resultTy <- typeOfApplyMethod d i xs - return $ ApplyMethod resultTy d i xs - -mkIxFin :: (EnvReader m, Fallible1 m) => CAtom n -> m n (DictCon n) -mkIxFin n = do - dictTy <- liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n - return $ IxFin dictTy n + return $ ApplyMethod resultTy (toAtom d) i xs -mkDataData :: (EnvReader m, Fallible1 m) => CType n -> m n (DictCon n) -mkDataData dataTy = do - dictTy <- DictTy <$> dataDictType dataTy - return $ DataData dictTy dataTy - -mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (DictCon n) +mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n) mkInstanceDict instanceName args = do instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName sourceName <- getSourceName <$> lookupClassDef className PairE (ListE params) _ <- instantiate instanceDef args - let ty = DictTy $ DictType sourceName className params - return $ InstanceDict ty instanceName args + let ty = toType $ DictType sourceName className params + return $ toDict $ InstanceDict ty instanceName args mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) mkCase scrut resultTy alts = liftEnvReaderM do @@ -1210,21 +1186,18 @@ naryTopAppInlined f xs = do naryAppHinted :: (CBuilder m, Emits n) => NameHint -> CAtom n -> [CAtom n] -> m n (CAtom n) -naryAppHinted hint f xs = Var <$> (mkApp f xs >>= emitHinted hint) +naryAppHinted hint f xs = toAtom <$> (mkApp f xs >>= emitHinted hint) tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -tabApp x i = mkTabApp x [i] >>= emitExpr +tabApp x i = mkTabApp x i >>= emitExpr naryTabApp :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) -naryTabApp = naryTabAppHinted noHint +naryTabApp f [] = return f +naryTabApp f (x:xs) = do + ans <- mkTabApp f x >>= emitExpr + naryTabApp ans xs {-# INLINE naryTabApp #-} -naryTabAppHinted :: (Builder r m, Emits n) - => NameHint -> Atom r n -> [Atom r n] -> m n (Atom r n) -naryTabAppHinted hint f xs = do - expr <- mkTabApp f xs - Var <$> emitHinted hint expr - indexRef :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) indexRef ref i = emitExpr =<< mkIndexRef ref i @@ -1254,13 +1227,13 @@ mkProjRef ref i = do -- === index set type class === applyIxMethod :: (SBuilder m, Emits n) => IxDict SimpIR n -> IxMethod -> [SAtom n] -> m n (SAtom n) -applyIxMethod dict method args = case dict of +applyIxMethod (DictCon dict) method args = case dict of -- These cases are used in SimpIR and they work with IdxRepVal - IxDictRawFin n -> case method of + IxRawFin n -> case method of Size -> do [] <- return args; return n Ordinal -> do [i] <- return args; return i UnsafeFromOrdinal -> do [i] <- return args; return i - IxDictSpecialized _ d params -> do + IxSpecialized d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs instantiate (fs !! fromEnum method) (params ++ args) >>= emitExpr @@ -1277,9 +1250,8 @@ indexSetSize (IxType _ dict) = applyIxMethod dict Size [] -- === core versions of index set type class === applyIxMethodCore :: (CBuilder m, Emits n) => IxMethod -> IxType CoreIR n -> [CAtom n] -> m n (CAtom n) -applyIxMethodCore method (IxType _ (IxDictAtom dict)) args = do +applyIxMethodCore method (IxType _ dict) args = emitExpr =<< mkApplyMethod dict (fromEnum method) args -applyIxMethodCore _ _ _ = error "not an ix type" -- === pseudo-prelude === @@ -1295,7 +1267,7 @@ emitIf :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Atom r n) emitIf predicate resultTy trueCase falseCase = do - predicate' <- emitExpr $ MiscOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate + predicate' <- emitExpr $ ToEnum (TyCon (SumType [UnitTy, UnitTy])) predicate buildCase predicate' resultTy \i _ -> case i of 0 -> falseCase @@ -1328,31 +1300,31 @@ isJustE x = liftEmitBuilder $ emitMaybeCase x BoolTy (return FalseAtom) (\_ -> return TrueAtom) -- Monoid a -> (n=>a) -> a -reduceE :: (Emits n, Builder r m) => BaseMonoid r n -> Atom r n -> m n (Atom r n) +reduceE :: (Emits n, SBuilder m) => BaseMonoid SimpIR n -> SAtom n -> m n (SAtom n) reduceE monoid xs = liftEmitBuilder do - TabPi tabPi <- return $ getType xs + TabPi tabPi <- return $ getTyCon xs let a = assumeConst tabPi getSnd =<< emitRunWriter noHint a monoid \_ ref -> buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do - x <- tabApp (sink xs) (Var i) - emitExpr $ PrimOp $ RefOp (sink $ Var ref) $ MExtend (sink monoid) x + x <- tabApp (sink xs) (toAtom i) + emitExpr $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n) andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $ buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y -> - emitExpr $ BinOp BAnd (sink $ Var x) (Var y) + emitExpr $ BinOp BAnd (sink $ toAtom x) (toAtom y) -- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b) -mapE :: (Emits n, ScopableBuilder r m) - => (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) - -> Atom r n -> m n (Atom r n) +mapE :: (Emits n, ScopableBuilder SimpIR m) + => (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l)) + -> SAtom n -> m n (SAtom n) mapE cont xs = do - TabPi tabPi <- return $ getType xs + TabPi tabPi <- return $ getTyCon xs buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> do - tabApp (sink xs) (Var i) >>= cont + tabApp (sink xs) (toAtom i) >>= cont -- (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -catMaybesE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) +catMaybesE :: (Emits n, SBuilder m) => SAtom n -> m n (SAtom n) catMaybesE maybes = do TabTy d n (MaybeTy a) <- return $ getType maybes justs <- liftEmitBuilder $ mapE isJustE maybes @@ -1391,7 +1363,7 @@ runMaybeWhile body = do emitWhile do ans <- body emitMaybeCase ans Word8Ty - (emit (PrimOp $ RefOp (sink $ Var ref) $ MPut TrueAtom) >> return FalseAtom) + (emit (toExpr $ RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom) (return) return UnitVal emitIf hadError (MaybeTy UnitTy) @@ -1456,7 +1428,7 @@ telescopicCapture bs e = do let vsTysSorted = toposortAnnVars $ zip vs vTys let vsSorted = map fst vsTysSorted ty <- liftEnvReaderM $ buildTelescopeTy vsTysSorted - valsSorted <- forM vsSorted \v -> Var <$> toAtomVar v + valsSorted <- forM vsSorted \v -> toAtom <$> toAtomVar v result <- buildTelescopeVal valsSorted ty reconAbs <- return $ ignoreHoistFailure $ hoist bs do case abstractFreeVarsNoAnn vsSorted e of @@ -1496,19 +1468,19 @@ buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where go ty rest = case ty of ProdTelescope tys -> do (xs, rest') <- return $ splitAt (length tys) rest - return (ProdVal xs, rest') + return (toAtom $ ProdCon xs, rest') DepTelescope ty1 (Abs b ty2) -> do (x1, ~(xDep : rest')) <- go ty1 rest ty2' <- applySubst (b@>SubstVal xDep) ty2 (x2, rest'') <- go ty2' rest' let depPairTy = DepPairType ExplicitDepPair b (telescopeTypeType ty2) - return (PairVal x1 (DepPair xDep x2 depPairTy), rest'') + return (toAtom $ ProdCon [x1, toAtom $ DepPair xDep x2 depPairTy], rest'') telescopeTypeType :: TelescopeType (AtomNameC r) (Type r) n -> Type r n -telescopeTypeType (ProdTelescope tys) = ProdTy tys +telescopeTypeType (ProdTelescope tys) = toType $ ProdType tys telescopeTypeType (DepTelescope lhs (Abs b rhs)) = do let lhs' = telescopeTypeType lhs - let rhs' = DepPairTy (DepPairType ExplicitDepPair b (telescopeTypeType rhs)) + let rhs' = toType $ DepPairTy (DepPairType ExplicitDepPair b (telescopeTypeType rhs)) PairTy lhs' rhs' unpackTelescope diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 03d60730b..8fc403f03 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -9,18 +9,17 @@ module CheapReduction ( reduceWithDecls, reduceExpr - , instantiateTyConDef, dataDefRep, unwrapNewtypeType + , instantiateTyConDef, dataDefRep, unwrapNewtypeType, projType , NonAtomRenamer (..), Visitor (..), VisitGeneric (..) - , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 + , visitAtomDefault, visitTypeDefault, Visitor2, mkStuck, mkStuckTy , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst - , repValAtom, projType, reduceUnwrap, reduceProj, reduceSuperclassProj - , reduceInstantiateGiven, typeOfApp) + , repValAtom, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp + , reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck) where import Control.Applicative import Control.Monad.Writer.Strict hiding (Alt) -import Control.Monad.Reader import Data.Functor ((<&>)) import Data.Maybe (fromJust) @@ -56,15 +55,17 @@ reduceExpr :: (IRRep r, EnvReader m) => Expr r n -> m n (Maybe (Atom r n)) reduceExpr e = liftReducerM $ reduceExprM e {-# INLINE reduceExpr #-} +-- TODO: just let the caller use `liftReducerM` themselves directly? + reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x {-# INLINE reduceProj #-} -reduceUnwrap :: (IRRep r, EnvReader m) => Atom r n -> m n (Atom r n) +reduceUnwrap :: EnvReader m => CAtom n -> m n (CAtom n) reduceUnwrap x = liftM fromJust $ liftReducerM $ reduceUnwrapM x {-# INLINE reduceUnwrap #-} -reduceSuperclassProj :: EnvReader m => Int -> CAtom n -> m n (CAtom n) +reduceSuperclassProj :: EnvReader m => Int -> CDict n -> m n (CAtom n) reduceSuperclassProj i x = liftM fromJust $ liftReducerM $ reduceSuperclassProjM i x {-# INLINE reduceSuperclassProj #-} @@ -72,6 +73,10 @@ reduceInstantiateGiven :: EnvReader m => CAtom n -> [CAtom n] -> m n (CAtom n) reduceInstantiateGiven f xs = liftM fromJust $ liftReducerM $ reduceInstantiateGivenM f xs {-# INLINE reduceInstantiateGiven #-} +reduceTabApp :: (IRRep r, EnvReader m) => Atom r n -> Atom r n -> m n (Atom r n) +reduceTabApp f x = liftM fromJust $ liftReducerM $ reduceTabAppM f x +{-# INLINE reduceTabApp #-} + -- === internal === type ReducerM = SubstReaderT AtomSubstVal (EnvReaderT FallibleM) @@ -98,7 +103,7 @@ reduceExprM = \case explicitArgs' <- mapM substM explicitArgs dict' <- substM dict case dict' of - DictCon (InstanceDict _ instanceName args) -> dropSubst do + Con (DictConAtom (InstanceDict _ instanceName args)) -> dropSubst do def <- lookupInstanceDef instanceName withInstantiated def args \(PairE _ (InstanceBody _ methods)) -> do reduceApp (methods !! i) explicitArgs' @@ -107,15 +112,13 @@ reduceExprM = \case ty <- substM ty' val <- substM val' case (ty, val) of - (BaseTy (Scalar Word32Type), Con (Lit (Word64Lit v))) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v + (TyCon (BaseType (Scalar Word32Type)), Con (Lit (Word64Lit v))) -> + return $ Con $ Lit $ Word32Lit $ fromIntegral v _ -> empty - TabApp ty tab xs -> do - ty' <- substM ty - xs' <- mapM substM xs + TabApp _ tab x -> do + x' <- substM x tab' <- substM tab - case tab' of - Stuck tab'' -> return $ Stuck $ StuckTabApp ty' tab'' xs' - _ -> error "not a table" -- what about RepVal? + reduceTabAppM tab' x' TopApp _ _ _ -> empty Case _ _ _ -> empty TabCon _ _ _ -> empty @@ -125,92 +128,112 @@ reduceApp :: CAtom i -> [CAtom o] -> ReducerM i o (CAtom o) reduceApp f xs = do f' <- substM f -- TODO: avoid double-subst case f' of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body - -- TODO: check ultrapure - Var v -> lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom f'')) -> dropSubst $ reduceApp f'' xs - _ -> empty + Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body _ -> empty reduceProjM :: IRRep r => Int -> Atom r o -> ReducerM i o (Atom r o) reduceProjM i x = case x of - ProdVal xs -> return $ xs !! i - DepPair l _ _ | i == 0 -> return l - DepPair _ r _ | i == 1 -> return r - SimpInCore (LiftSimp _ simpAtom) -> do - simpAtom' <- dropSubst $ reduceProjM i simpAtom - resultTy <- getResultType - return $ SimpInCore $ LiftSimp resultTy simpAtom' - RepValAtom (RepVal _ tree) -> case tree of - Branch trees -> do - resultTy <- getResultType - repValAtom $ RepVal resultTy (trees!!i) - Leaf _ -> error "unexpected leaf" - Stuck e -> do - resultTy <- getResultType - return $ Stuck $ StuckProject resultTy i e - _ -> empty - where getResultType = projType i x - -reduceSuperclassProjM :: Int -> CAtom o -> ReducerM i o (CAtom o) + Con con -> case con of + ProdCon xs -> return $ xs !! i + DepPair l _ _ | i == 0 -> return l + DepPair _ r _ | i == 1 -> return r + _ -> error "not a product" + Stuck _ e -> mkStuck $ StuckProject i e + +reduceSuperclassProjM :: Int -> CDict o -> ReducerM i o (CAtom o) reduceSuperclassProjM superclassIx dict = case dict of DictCon (InstanceDict _ instanceName args) -> dropSubst do args' <- mapM substM args InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName let InstanceBody superclasses _ = body instantiate (Abs bs (superclasses !! superclassIx)) args' - Stuck child' -> do - resultTy <- superclassProjType superclassIx (getType dict) - return $ Stuck $ SuperclassProj resultTy superclassIx child' + StuckDict _ child -> mkStuck $ SuperclassProj superclassIx child _ -> error "invalid superclass projection" reduceInstantiateGivenM :: CAtom o -> [CAtom o] -> ReducerM i o (CAtom o) reduceInstantiateGivenM f xs = case f of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body - Stuck f' -> do - resultTy <- typeOfApp (getType f) xs - return $ Stuck $ InstantiatedGiven resultTy f' xs + Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body + Stuck _ f' -> mkStuck $ InstantiatedGiven f' xs _ -> error "bad instantiation" +mkStuck:: (IRRep r, EnvReader m) => Stuck r n -> m n (Atom r n) +mkStuck x = do + ty <- queryStuckType x + return $ Stuck ty x + +mkStuckTy :: EnvReader m => CStuck n -> m n (CType n) +mkStuckTy x = do + ty <- queryStuckType x + return $ StuckTy ty x + +queryStuckType :: (IRRep r, EnvReader m) => Stuck r n -> m n (Type r n) +queryStuckType = \case + Var v -> return $ getType v + StuckProject i s -> projType i =<< mkStuck s + StuckTabApp f x -> do + f' <- mkStuck f + typeOfTabApp (getType f') x + PtrVar t _ -> return $ PtrTy t + RepValAtom repVal -> return $ getType repVal + StuckUnwrap s -> queryStuckType s >>= \case + TyCon (NewtypeTyCon con) -> snd <$> unwrapNewtypeType con + _ -> error "not a newtype" + InstantiatedGiven _ _ -> undefined + SuperclassProj i s -> superclassProjType i =<< queryStuckType s + LiftSimp t _ -> return t + LiftSimpFun t _ -> return $ TyCon $ Pi t + -- TabLam and ACase are just defunctionalization tools. The result type + -- in both cases should *not* be `Data`. + TabLam _ -> undefined + ACase _ _ resultTy -> return resultTy + projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n) projType i x = case getType x of - ProdTy xs -> return $ xs !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - liftReducerM (reduceProjM 0 x) >>= \case - Just xFst -> instantiate t [xFst] - Nothing -> err + TyCon con -> case con of + ProdType xs -> return $ xs !! i + DepPairTy t | i == 0 -> return $ depPairLeftTy t + DepPairTy t | i == 1 -> do + liftReducerM (reduceProjM 0 x) >>= \case + Just xFst -> instantiate t [xFst] + Nothing -> err + _ -> err _ -> err where err = error $ "Can't project type: " ++ pprint (getType x) superclassProjType :: EnvReader m => Int -> CType n -> m n (CType n) -superclassProjType i (DictTy (DictType _ className params)) = do - ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className - instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params +superclassProjType i (TyCon (DictTy dictTy)) = case dictTy of + DictType _ className params -> do + ClassDef _ _ _ _ _ bs superclasses _ <- lookupClassDef className + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params + IxDictType t | i == 0 -> return $ toType $ DataDictType t + _ -> error "bad superclass projection" superclassProjType _ _ = error "bad superclass projection" +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> Atom r n -> m n (Type r n) +typeOfTabApp (TyCon (TabPi piTy)) x = withSubstReaderT $ + withInstantiated piTy [x] \ty -> substM ty +typeOfTabApp _ _ = error "expected a TabPi type" + typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi piTy) xs = withSubstReaderT $ +typeOfApp (TyCon (Pi piTy)) xs = withSubstReaderT $ withInstantiated piTy xs \(EffTy _ ty) -> substM ty typeOfApp _ _ = error "expected a pi type" -typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfTabApp (TabPi piTy) xs = withSubstReaderT $ - withInstantiated piTy xs \ty -> substM ty -typeOfTabApp _ _ = error "expected a TabPi type" - -repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) +repValAtom :: EnvReader m => RepVal n -> m n (SAtom n) repValAtom (RepVal ty tree) = case ty of - ProdTy ts -> case tree of - Branch trees -> ProdVal <$> mapM repValAtom (zipWith RepVal ts trees) + TyCon (ProdType ts) -> case tree of + Branch trees -> toAtom <$> ProdCon <$> mapM repValAtom (zipWith RepVal ts trees) _ -> malformed - BaseTy _ -> case tree of + TyCon (BaseType _) -> case tree of Leaf x -> case x of - ILit l -> return $ Con $ Lit l + ILit l -> return $ toAtom $ Lit l _ -> fallback _ -> malformed + -- TODO: make sure this covers all the cases. Maybe only TabPi should hit the + -- fallback? This could be a place where we accidentally violate the `Stuck` + -- assumption _ -> fallback - where fallback = return $ RepValAtom $ RepVal ty tree + where fallback = return $ Stuck ty $ RepValAtom $ RepVal ty tree malformed = error "malformed repval" {-# INLINE repValAtom #-} @@ -218,24 +241,16 @@ depPairLeftTy :: DepPairType r n -> Type r n depPairLeftTy (DepPairType _ (_:>ty) _) = ty {-# INLINE depPairLeftTy #-} -reduceUnwrapM :: IRRep r => Atom r o -> ReducerM i o (Atom r o) +reduceUnwrapM :: CAtom o -> ReducerM i o (CAtom o) reduceUnwrapM = \case - NewtypeCon _ x -> return x - SimpInCore (LiftSimp (NewtypeTyCon t) x) -> do - t' <- snd <$> unwrapNewtypeType t - return $ SimpInCore $ LiftSimp t' x - Stuck e -> case getType e of - NewtypeTyCon t -> do - t' <- snd <$> unwrapNewtypeType t - return $ Stuck $ StuckUnwrap t' e - _ -> error "expected a newtype" - _ -> empty - -reduceTabAppM :: IRRep r => Atom r o -> [Atom r o] -> ReducerM i o (Atom r o) -reduceTabAppM tab xs = case tab of - Stuck tab' -> do - ty <- typeOfTabApp (getType tab') xs - return $ Stuck $ StuckTabApp ty tab' xs + Con con -> case con of + NewtypeCon _ x -> return x + _ -> error "not a newtype" + Stuck _ e -> mkStuck $ StuckUnwrap e + +reduceTabAppM :: IRRep r => Atom r o -> Atom r o -> ReducerM i o (Atom r o) +reduceTabAppM tab x = case tab of + Stuck _ tab' -> mkStuck (StuckTabApp tab' x) _ -> error $ "not a table" ++ pprint tab unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) @@ -294,10 +309,10 @@ dataDefRep :: DataConDefs n -> CType n dataDefRep (ADTCons cons) = case cons of [] -> error "unreachable" -- There's no representation for a void type [DataConDef _ _ ty _] -> ty - tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty + tys -> toType $ SumType $ tys <&> \(DataConDef _ _ ty _) -> ty dataDefRep (StructFields fields) = case map snd fields of [ty] -> ty - tys -> ProdTy tys + tys -> toType (ProdType tys) -- === traversable terms === @@ -336,20 +351,19 @@ traverseOpTerm => e r i -> m (e r o) traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric -visitAtomDefault - :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) - => Atom r i -> m i o (Atom r o) -visitAtomDefault atom = case atom of - Stuck _ -> atomSubstM atom - SimpInCore _ -> atomSubstM atom - _ -> visitAtomPartial atom - visitTypeDefault :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) => Type r i -> m i o (Type r o) visitTypeDefault ty = case ty of - StuckTy _ -> atomSubstM ty - x -> visitTypePartial x + StuckTy _ _ -> atomSubstM ty + TyCon con -> TyCon <$> visitGeneric con + +visitAtomDefault + :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) + => Atom r i -> m i o (Atom r o) +visitAtomDefault ty = case ty of + Stuck _ _ -> atomSubstM ty + Con con -> Con <$> visitGeneric con visitPiDefault :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) @@ -372,44 +386,11 @@ visitBinders (Nest (b:>ty) bs) cont = do visitBinders bs \bs' -> cont $ Nest b' bs' --- XXX: This doesn't handle the `Stuck` or `SimpInCore` cases. These should be --- handled explicitly beforehand. TODO: split out these cases under a separate --- constructor, perhaps even a `hole` paremeter to `Atom` or part of `IR`. -visitAtomPartial :: (IRRep r, Visitor m r i o) => Atom r i -> m (Atom r o) -visitAtomPartial = \case - Stuck _ -> error "Not handled generically" - SimpInCore _ -> error "Not handled generically" - Con con -> Con <$> visitGeneric con - PtrVar t v -> PtrVar t <$> renameN v - DepPair x y t -> do - x' <- visitGeneric x - y' <- visitGeneric y - ~(DepPairTy t') <- visitGeneric $ DepPairTy t - return $ DepPair x' y' t' - Lam lam -> Lam <$> visitGeneric lam - Eff eff -> Eff <$> visitGeneric eff - DictCon d -> DictCon <$> visitGeneric d - NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x - TypeAsAtom t -> TypeAsAtom <$> visitGeneric t - RepValAtom repVal -> RepValAtom <$> visitGeneric repVal - --- XXX: This doesn't handle the `Stuck` case. It should be handled explicitly --- beforehand. -visitTypePartial :: (IRRep r, Visitor m r i o) => Type r i -> m (Type r o) -visitTypePartial = \case - StuckTy _ -> error "Not handled generically" - NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t - Pi t -> Pi <$> visitGeneric t - TabPi t -> TabPi <$> visitGeneric t - TC t -> TC <$> visitGeneric t - DepPairTy t -> DepPairTy <$> visitGeneric t - DictTy t -> DictTy <$> visitGeneric t - instance IRRep r => VisitGeneric (Expr r) r where visitGeneric = \case Block _ _ -> error "not handled generically" TopApp et v xs -> TopApp <$> visitGeneric et <*> renameN v <*> mapM visitGeneric xs - TabApp t tab xs -> TabApp <$> visitType t <*> visitGeneric tab <*> mapM visitGeneric xs + TabApp t tab x -> TabApp <$> visitType t <*> visitGeneric tab <*> visitGeneric x -- TODO: should we reuse the original effects? Whether it's valid depends on -- the type-preservation requirements for a visitor. We should clarify what -- those are. @@ -477,17 +458,36 @@ instance IRRep r => VisitGeneric (EffectRow r) r where effs' <- eSetFromList <$> mapM visitGeneric (eSetToList effs) tailEffRow <- case tailVar of NoTail -> return $ EffectRow mempty NoTail - EffectRowTail v -> visitGeneric (Var v) <&> \case - Var v' -> EffectRow mempty (EffectRowTail v') - Eff r -> r + EffectRowTail v -> visitGeneric (toAtom v) <&> \case + Stuck _ (Var v') -> EffectRow mempty (EffectRowTail v') + Con (Eff r) -> r _ -> error "Not a valid effect substitution" return $ extendEffRow effs' tailEffRow -instance VisitGeneric DictCon CoreIR where +instance IRRep r => VisitGeneric (DictCon r) r where visitGeneric = \case - InstanceDict t v xs -> InstanceDict <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs - IxFin t x -> IxFin <$> visitGeneric t <*> visitGeneric x - DataData t dataTy -> DataData <$> visitGeneric t <*> visitGeneric dataTy + InstanceDict t v xs -> InstanceDict <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs + IxFin x -> IxFin <$> visitGeneric x + DataData dataTy -> DataData <$> visitGeneric dataTy + IxRawFin x -> IxRawFin <$> visitGeneric x + IxSpecialized v xs -> IxSpecialized <$> renameN v <*> mapM visitGeneric xs + +instance IRRep r => VisitGeneric (Con r) r where + visitGeneric = \case + Lit l -> return $ Lit l + ProdCon xs -> ProdCon <$> mapM visitGeneric xs + SumCon ty con arg -> SumCon <$> mapM visitGeneric ty <*> return con <*> visitGeneric arg + HeapVal -> return HeapVal + DepPair x y t -> do + x' <- visitGeneric x + y' <- visitGeneric y + ~(DepPairTy t') <- visitGeneric $ DepPairTy t + return $ DepPair x' y' t' + Lam lam -> Lam <$> visitGeneric lam + Eff eff -> Eff <$> visitGeneric eff + DictConAtom d -> DictConAtom <$> visitGeneric d + TyConAtom t -> TyConAtom <$> visitGeneric t + NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x instance VisitGeneric NewtypeCon CoreIR where visitGeneric = \case @@ -505,17 +505,15 @@ instance VisitGeneric NewtypeTyCon CoreIR where instance VisitGeneric TyConParams CoreIR where visitGeneric (TyConParams expls xs) = TyConParams expls <$> mapM visitGeneric xs -instance IRRep r => VisitGeneric (IxDict r) r where - visitGeneric = \case - IxDictAtom x -> IxDictAtom <$> visitGeneric x - IxDictRawFin x -> IxDictRawFin <$> visitGeneric x - IxDictSpecialized t v xs -> IxDictSpecialized <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs instance IRRep r => VisitGeneric (IxType r) r where visitGeneric (IxType t d) = IxType <$> visitType t <*> visitGeneric d instance VisitGeneric DictType CoreIR where - visitGeneric (DictType n v xs) = DictType n <$> renameN v <*> mapM visitGeneric xs + visitGeneric = \case + DictType n v xs -> DictType n <$> renameN v <*> mapM visitGeneric xs + IxDictType t -> IxDictType <$> visitGeneric t + DataDictType t -> DataDictType <$> visitGeneric t instance VisitGeneric CoreLamExpr CoreIR where visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam @@ -538,7 +536,7 @@ instance IRRep r => VisitGeneric (DepPairType r) r where PiType (UnaryNest b') (EffTy Pure ty') -> DepPairType expl b' ty' _ -> error "not a dependent pair type" -instance VisitGeneric (RepVal SimpIR) SimpIR where +instance VisitGeneric RepVal SimpIR where visitGeneric (RepVal ty tree) = RepVal <$> visitGeneric ty <*> mapM renameIExpr tree where renameIExpr = \case ILit l -> return $ ILit l @@ -566,8 +564,25 @@ instance VisitGeneric DataConDef CoreIR where repTy' <- visitGeneric repTy return $ DataConDef sn (Abs bs' UnitE) repTy' ps -instance VisitGeneric (Con r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (TC r) r where visitGeneric = traverseOpTerm +instance IRRep r => VisitGeneric (TyCon r) r where + visitGeneric = \case + BaseType bt -> return $ BaseType bt + ProdType tys -> ProdType <$> mapM visitGeneric tys + SumType tys -> SumType <$> mapM visitGeneric tys + RefType h t -> RefType <$> visitGeneric h <*> visitGeneric t + HeapType -> return HeapType + TabPi t -> TabPi <$> visitGeneric t + DepPairTy t -> DepPairTy <$> visitGeneric t + TypeKind -> return TypeKind + DictTy t -> DictTy <$> visitGeneric t + Pi t -> Pi <$> visitGeneric t + NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t + +instance IRRep r => VisitGeneric (Dict r) r where + visitGeneric = \case + StuckDict ty s -> fromJust <$> toMaybeDict <$> visitGeneric (Stuck ty s) + DictCon con -> DictCon <$> visitGeneric con + instance VisitGeneric (MiscOp r) r where visitGeneric = traverseOpTerm instance VisitGeneric (VectorOp r) r where visitGeneric = traverseOpTerm instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm @@ -582,22 +597,7 @@ bindersToVars bs = do mapM toAtomVar $ nestToNames bs bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n] -bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs - -newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } - deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) - -substV :: (Distinct o, SubstE AtomSubstVal e) => e i -> SubstVisitor i o (e o) -substV x = ask <&> \env -> substE env x - -instance Distinct o => NonAtomRenamer (SubstVisitor i o) i o where - renameN = substV - -instance (Distinct o, IRRep r) => Visitor (SubstVisitor i o) r i o where - visitType = substV - visitAtom = substV - visitLam = substV - visitPi = substV +bindersToAtoms bs = liftM (toAtom <$>) $ bindersToVars bs instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where substE (_, env) (Rename name) = env ! name @@ -605,18 +605,26 @@ instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where instance SubstV (SubstVal Atom) (SubstVal Atom) where +instance IRRep r => SubstE AtomSubstVal (IxDict r) where + substE es = \case + StuckDict _ e -> fromJust $ toMaybeDict $ substStuck es e + DictCon con -> DictCon $ substE es con + instance IRRep r => SubstE AtomSubstVal (Atom r) where substE es = \case - Stuck e -> substStuck es e - SimpInCore x -> SimpInCore (substE es x) - atom -> runReader (runSubstVisitor $ visitAtomPartial atom) es + Stuck _ e -> substStuck es e + Con con -> Con $ substE es con instance IRRep r => SubstE AtomSubstVal (Type r) where substE es = \case - StuckTy e -> case substStuck es e of - Type t -> t - _ -> error "bad substitution" - ty -> runReader (runSubstVisitor $ visitTypePartial ty) es + StuckTy _ e -> fromJust $ toMaybeType $ substStuck es e + TyCon con -> TyCon $ substE es con + +substMStuck :: (SubstReader AtomSubstVal m, EnvReader2 m, IRRep r) => Stuck r i -> m i o (Atom r o) +substMStuck stuck = do + subst <- getSubst + env <- unsafeGetEnv + withDistinct $ return $ substStuck (env, subst) stuck substStuck :: (IRRep r, Distinct o) => (Env o, Subst AtomSubstVal i o) -> Stuck r i -> Atom r o substStuck (env, subst) stuck = @@ -624,29 +632,33 @@ substStuck (env, subst) stuck = reduceStuck :: (IRRep r, Distinct o) => Stuck r i -> ReducerM i o (Atom r o) reduceStuck = \case - StuckVar (AtomVar v ty) -> do + Var (AtomVar v ty) -> do lookupSubstM v >>= \case - Rename v' -> Var . AtomVar v' <$> substM ty + Rename v' -> toAtom . AtomVar v' <$> substM ty SubstVal x -> return x - StuckProject _ i x -> do + StuckProject i x -> do x' <- reduceStuck x dropSubst $ reduceProjM i x' - StuckUnwrap _ x -> do + StuckUnwrap x -> do x' <- reduceStuck x dropSubst $ reduceUnwrapM x' - StuckTabApp _ f xs -> do + StuckTabApp f x -> do f' <- reduceStuck f - xs' <- mapM substM xs - dropSubst $ reduceTabAppM f' xs' - InstantiatedGiven _ f xs -> do + x' <- substM x + dropSubst $ reduceTabAppM f' x' + InstantiatedGiven f xs -> do xs' <- mapM substM xs f' <- reduceStuck f reduceInstantiateGivenM f' xs' - SuperclassProj _ superclassIx child -> do - child' <- reduceStuck child + SuperclassProj superclassIx child -> do + Just child' <- toMaybeDict <$> reduceStuck child reduceSuperclassProjM superclassIx child' - -instance SubstE AtomSubstVal SimpInCore + PtrVar ptrTy ptr -> mkStuck =<< PtrVar ptrTy <$> substM ptr + RepValAtom repVal -> mkStuck =<< RepValAtom <$> substM repVal + LiftSimp _ _ -> undefined + LiftSimpFun _ _ -> undefined + TabLam _ -> undefined + ACase _ _ _ -> undefined instance IRRep r => SubstE AtomSubstVal (EffectRow r) where substE env (EffectRow effs tailVar) = do @@ -657,8 +669,8 @@ instance IRRep r => SubstE AtomSubstVal (EffectRow r) where Rename v' -> do let v'' = runEnvReaderM (fst env) $ toAtomVar v' EffectRow mempty (EffectRowTail v'') - SubstVal (Var v') -> EffectRow mempty (EffectRowTail v') - SubstVal (Eff r) -> r + SubstVal (Stuck _ (Var v')) -> EffectRow mempty (EffectRowTail v') + SubstVal (Con (Eff r)) -> r _ -> error "Not a valid effect substitution" extendEffRow effs' tailEffRow @@ -668,21 +680,22 @@ instance SubstE AtomSubstVal SpecializationSpec where substE env (AppSpecialization (AtomVar f _) ab) = do let f' = case snd env ! f of Rename v -> runEnvReaderM (fst env) $ toAtomVar v - SubstVal (Var v) -> v + SubstVal (Stuck _ (Var v)) -> v _ -> error "bad substitution" AppSpecialization f' (substE env ab) instance SubstE AtomSubstVal EffectDef instance SubstE AtomSubstVal EffectOpType instance SubstE AtomSubstVal IExpr -instance IRRep r => SubstE AtomSubstVal (RepVal r) +instance SubstE AtomSubstVal RepVal instance SubstE AtomSubstVal TyConParams instance SubstE AtomSubstVal DataConDef instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) instance IRRep r => SubstE AtomSubstVal (DAMOp r) instance IRRep r => SubstE AtomSubstVal (TypedHof r) instance IRRep r => SubstE AtomSubstVal (Hof r) -instance IRRep r => SubstE AtomSubstVal (TC r) +instance IRRep r => SubstE AtomSubstVal (TyCon r) +instance IRRep r => SubstE AtomSubstVal (DictCon r) instance IRRep r => SubstE AtomSubstVal (Con r) instance IRRep r => SubstE AtomSubstVal (MiscOp r) instance IRRep r => SubstE AtomSubstVal (VectorOp r) @@ -705,6 +718,5 @@ instance IRRep r => SubstE AtomSubstVal (DeclBinding r) instance IRRep r => SubstB AtomSubstVal (Decl r) instance SubstE AtomSubstVal NewtypeTyCon instance SubstE AtomSubstVal NewtypeCon -instance IRRep r => SubstE AtomSubstVal (IxDict r) instance IRRep r => SubstE AtomSubstVal (IxType r) instance SubstE AtomSubstVal DataConDefs diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index ad6698a9b..31db509cf 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -148,6 +148,14 @@ instance (CheckableB r b, CheckableE r e) => CheckableE r (Abs b e) where -- === type checking core === +checkStuck :: IRRep r => Type r i -> Stuck r i -> TyperM r i o (Type r o, Stuck r o) +checkStuck ty e = do + e' <- checkE e + ty' <- checkE ty + ty'' <- queryStuckType e' + checkTypesEq ty' ty'' + return (ty', e') + instance IRRep r => CheckableE r (TopLam r) where checkE (TopLam destFlag piTy lam) = do -- TODO: check destination-passing flag @@ -160,26 +168,8 @@ instance IRRep r => CheckableE r (AtomName r) where instance IRRep r => CheckableE r (Atom r) where checkE = \case - Stuck e -> Stuck <$> checkE e - Lam lam -> Lam <$> checkE lam - DepPair l r ty -> do - l' <- checkE l - ty' <- checkE ty - rTy <- checkInstantiation ty' [l'] - r' <- r |: rTy - return $ DepPair l' r' ty' - Con con -> Con <$> checkE con - Eff eff -> Eff <$> checkE eff - PtrVar t v -> PtrVar t <$> renameM v - -- TODO: check against cached type - DictCon con -> DictCon <$> checkE con - RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check - NewtypeCon con x -> do - (x', xTy) <- checkAndGetType x - con' <- typeCheckNewtypeCon con xTy - return $ NewtypeCon con' x' - SimpInCore x -> SimpInCore <$> checkE x - TypeAsAtom ty -> TypeAsAtom <$> checkE ty + Stuck ty e -> uncurry Stuck <$> checkStuck ty e + Con e -> Con <$> checkE e instance IRRep r => CheckableE r (AtomVar r) where checkE (AtomVar v t1) = do @@ -191,21 +181,8 @@ instance IRRep r => CheckableE r (AtomVar r) where instance IRRep r => CheckableE r (Type r) where checkE = \case - Pi t -> Pi <$> checkE t - TabPi t -> TabPi <$> checkE t - NewtypeTyCon t -> NewtypeTyCon <$> checkE t - TC t -> TC <$> checkE t - DepPairTy t -> DepPairTy <$> checkE t - DictTy (DictType sn className params) -> do - className' <- renameM className - ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className' - params' <- mapM checkE params - void $ checkInstantiation (Abs paramBs UnitE) params' - return $ DictTy (DictType sn className' params') - StuckTy e -> StuckTy <$> checkE e - -instance CheckableE CoreIR SimpInCore where - checkE x = renameM x -- TODO: check + StuckTy ty e -> uncurry StuckTy <$> checkStuck ty e + TyCon e -> TyCon <$> checkE e instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c ann) where checkB (b:>ann) cont = do @@ -228,18 +205,18 @@ instance IRRep r => CheckableWithEffects r (Expr r) where App effTy f xs -> do effTy' <- checkEffTy allowedEffs effTy f' <- checkE f - Pi piTy <- return $ getType f' + TyCon (Pi piTy) <- return $ getType f' xs' <- mapM checkE xs effTy'' <- checkInstantiation piTy xs' checkAlphaEq effTy' effTy'' return $ App effTy' f' xs' - TabApp reqTy f xs -> do - reqTy' <- reqTy |: TyKind + TabApp reqTy f x -> do + reqTy' <- checkE reqTy (f', tabTy) <- checkAndGetType f - xs' <- mapM checkE xs - ty' <- checkTabApp tabTy xs' + x' <- checkE x + ty' <- checkTabApp tabTy x' checkTypesEq reqTy' ty' - return $ TabApp reqTy' f' xs' + return $ TabApp reqTy' f' x' TopApp effTy f xs -> do f' <- renameM f effTy' <- checkEffTy allowedEffs effTy @@ -259,7 +236,7 @@ instance IRRep r => CheckableWithEffects r (Expr r) where Case scrut alts effTy -> do effTy' <- checkEffTy allowedEffs effTy scrut' <- checkE scrut - altsBinderTys <- checkCaseAltsBinderTys $ getType scrut' + TyCon (SumType altsBinderTys) <- return $ getType scrut' assertEq (length altsBinderTys) (length alts) "" alts' <- parallelAffines $ (zip alts altsBinderTys) <&> \(Abs b body, reqBinderTy) -> do checkB b \b' -> do @@ -268,14 +245,14 @@ instance IRRep r => CheckableWithEffects r (Expr r) where return $ Case scrut' alts' effTy' ApplyMethod effTy dict i args -> do effTy' <- checkEffTy allowedEffs effTy - dict' <- checkE dict + Just dict' <- toMaybeDict <$> checkE dict args' <- mapM checkE args methodTy <- getMethodType dict' i effTy'' <- checkInstantiation methodTy args' checkAlphaEq effTy' effTy'' - return $ ApplyMethod effTy' dict' i args' + return $ ApplyMethod effTy' (toAtom dict') i args' TabCon maybeD ty xs -> do - ty'@(TabPi (TabPiType _ b restTy)) <- ty |: TyKind + ty'@(TyCon (TabPi (TabPiType _ b restTy))) <- checkE ty maybeD' <- mapM renameM maybeD -- TODO: check xs' <- case fromConstAbs (Abs b restTy) of HoistSuccess elTy -> forM xs (|: elTy) @@ -285,20 +262,14 @@ instance IRRep r => CheckableWithEffects r (Expr r) where HoistFailure _ -> forM xs checkE return $ TabCon maybeD' ty' xs' Project resultTy i x -> do - resultTy' <- resultTy |: TyKind - (x', xTy) <- checkAndGetType x - resultTy'' <- case xTy of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- reduceProj 0 x' - checkInstantiation t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy + x' <-checkE x + resultTy' <- checkE resultTy + resultTy'' <- checkProject i x' checkTypesEq resultTy' resultTy'' return $ Project resultTy' i x' Unwrap resultTy x -> do - resultTy' <- resultTy |: TyKind - (x', NewtypeTyCon con) <- checkAndGetType x + resultTy' <- checkE resultTy + (x', TyCon (NewtypeTyCon con)) <- checkAndGetType x resultTy'' <- snd <$> unwrapNewtypeType con checkTypesEq resultTy' resultTy'' return $ Unwrap resultTy' x' @@ -308,58 +279,63 @@ instance CheckableE CoreIR TyConParams where instance IRRep r => CheckableE r (Stuck r) where checkE = \case - StuckVar name -> do + Var name -> do name' <- checkE name case getType name' of RawRefTy _ -> affineUsed $ atomVarName name' _ -> return () - return $ StuckVar name' - StuckUnwrap resultTy x -> do - Unwrap resultTy' (Stuck x') <- checkWithEffects Pure $ Unwrap resultTy (Stuck x) - return $ StuckUnwrap resultTy' x' - StuckProject resultTy i x -> do - Project resultTy' i' (Stuck x') <- checkWithEffects Pure $ Project resultTy i (Stuck x) - return $ StuckProject resultTy' i' x' - StuckTabApp reqTy f xs -> do - reqTy' <- reqTy |: TyKind - (f', tabTy) <- checkAndGetType f - xs' <- mapM checkE xs - ty' <- checkTabApp tabTy xs' - checkTypesEq reqTy' ty' - return $ StuckTabApp reqTy' f' xs' - InstantiatedGiven resultTy given args -> do - resultTy' <- resultTy |: TyKind - (given', Pi piTy) <- checkAndGetType given + return $ Var name' + StuckUnwrap x -> do + x' <- checkE x + TyCon (NewtypeTyCon _) <- queryStuckType x' + return $ StuckUnwrap x' + StuckProject i x -> do + x' <-checkE x + x'' <- mkStuck x' + void $ checkProject i x'' + return $ StuckProject i x' + StuckTabApp f x -> do + f' <- checkE f + tabTy <- queryStuckType f' + x' <- checkE x + void $ checkTabApp tabTy x' + return $ StuckTabApp f' x' + InstantiatedGiven given args -> do + given' <- checkE given + TyCon (Pi piTy) <- queryStuckType given' args' <- mapM checkE args - EffTy Pure ty <- checkInstantiation piTy args' - checkTypesEq resultTy' ty - return $ InstantiatedGiven resultTy' given' args' - SuperclassProj t i d -> SuperclassProj <$> checkE t <*> pure i <*> checkE d -- TODO: check index in range + EffTy Pure _ <- checkInstantiation piTy args' + return $ InstantiatedGiven given' args' + SuperclassProj i d -> SuperclassProj <$> pure i <*> checkE d -- TODO: check index in range + PtrVar t v -> PtrVar t <$> renameM v + RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check + LiftSimp t x -> LiftSimp <$> checkE t <*> renameM x -- TODO: check + LiftSimpFun t x -> LiftSimpFun <$> checkE t <*> renameM x -- TODO: check + ACase scrut alts resultTy -> ACase <$> renameM scrut <*> mapM renameM alts <*> checkE resultTy -- TODO: check + TabLam lam -> TabLam <$> renameM lam -- TODO: check depPairLeftTy :: DepPairType r n -> Type r n depPairLeftTy (DepPairType _ (_:>ty) _) = ty {-# INLINE depPairLeftTy #-} -instance CheckableE CoreIR DictCon where +instance IRRep r => CheckableE r (DictCon r) where checkE = \case InstanceDict ty instanceName args -> do - ty' <- ty |: TyKind + ty' <- checkE ty instanceName' <- renameM instanceName args' <- mapM checkE args instanceDef <- lookupInstanceDef instanceName' void $ checkInstantiation instanceDef args' return $ InstanceDict ty' instanceName' args' - IxFin ty n -> do - ty' <- ty |: TyKind - IxFin ty' <$> n |: NatTy - DataData ty dataTy -> do - ty' <- ty |: TyKind - DataData ty' <$> dataTy |: TyKind + IxFin n -> IxFin <$> n |: NatTy + DataData dataTy -> DataData <$> checkE dataTy + IxRawFin n -> IxRawFin <$> n |: IdxRepTy + IxSpecialized v params -> IxSpecialized <$> renameM v <*> mapM checkE params instance IRRep r => CheckableE r (DepPairType r) where checkE (DepPairType expl b ty) = do checkB b \b' -> do - ty' <- ty |: TyKind + ty' <- checkE ty return $ DepPairType expl b' ty' instance CheckableE CoreIR CorePiType where @@ -384,7 +360,7 @@ instance IRRep r => CheckableE r (TabPiType r) where checkE (TabPiType d b resultTy) = do d' <- checkE d checkB b \b' -> do - resultTy' <- resultTy|:TyKind + resultTy' <- checkE resultTy return $ TabPiType d' b' resultTy' instance (BindsNames b, CheckableB r b) => CheckableB r (Nest b) where @@ -402,25 +378,57 @@ instance CheckableE CoreIR CoreLamExpr where lamExpr' <- checkLamExpr (PiType bs effTy) lamExpr return $ CoreLamExpr (CorePiType expl expls bs effTy) lamExpr' -instance IRRep r => CheckableE r (TC r) where +instance IRRep r => CheckableE r (TyCon r) where checkE = \case BaseType b -> return $ BaseType b - ProdType tys -> ProdType <$> mapM (|:TyKind) tys - SumType cs -> SumType <$> mapM (|:TyKind) cs - RefType r a -> RefType <$> r|:TC HeapType <*> a|:TyKind + ProdType tys -> ProdType <$> mapM checkE tys + SumType cs -> SumType <$> mapM checkE cs + RefType r a -> RefType <$> r|:TyCon HeapType <*> checkE a TypeKind -> return TypeKind HeapType -> return HeapType + Pi t -> Pi <$> checkE t + TabPi t -> TabPi <$> checkE t + NewtypeTyCon t -> NewtypeTyCon <$> checkE t + DepPairTy t -> DepPairTy <$> checkE t + DictTy t -> DictTy <$> checkE t + + +instance CheckableE CoreIR DictType where + checkE = \case + DictType sn className params -> do + className' <- renameM className + ClassDef _ Nothing _ _ _ paramBs _ _ <- lookupClassDef className' + params' <- mapM checkE params + void $ checkInstantiation (Abs paramBs UnitE) params' + return $ DictType sn className' params' + IxDictType t -> IxDictType <$> checkE t + DataDictType t -> DataDictType <$> checkE t instance IRRep r => CheckableE r (Con r) where checkE = \case Lit l -> return $ Lit l ProdCon xs -> ProdCon <$> mapM checkE xs SumCon tys tag payload -> do - tys' <- mapM (|:TyKind) tys + tys' <- mapM checkE tys unless (0 <= tag && tag < length tys') $ throw TypeErr "Invalid SumType tag" payload' <- payload |: (tys' !! tag) return $ SumCon tys' tag payload' HeapVal -> return HeapVal + Lam lam -> Lam <$> checkE lam + DepPair l r ty -> do + l' <- checkE l + ty' <- checkE ty + rTy <- checkInstantiation ty' [l'] + r' <- r |: rTy + return $ DepPair l' r' ty' + Eff eff -> Eff <$> checkE eff + -- TODO: check against cached type + DictConAtom con -> DictConAtom <$> checkE con + NewtypeCon con x -> do + (x', xTy) <- checkAndGetType x + con' <- typeCheckNewtypeCon con xTy + return $ NewtypeCon con' x' + TyConAtom tyCon -> TyConAtom <$> checkE tyCon typeCheckNewtypeCon :: NewtypeCon i -> CType o -> TyperM CoreIR i o (NewtypeCon o) @@ -464,20 +472,20 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where BinOp binop x y -> do x' <- checkE x y' <- checkE y - TC (BaseType xTy) <- return $ getType x' - TC (BaseType yTy) <- return $ getType y' + TyCon (BaseType xTy) <- return $ getType x' + TyCon (BaseType yTy) <- return $ getType y' checkBinOp binop xTy yTy return $ BinOp binop x' y' UnOp unop x -> do x' <- checkE x - TC (BaseType xTy) <- return $ getType x' + TyCon (BaseType xTy) <- return $ getType x' checkUnOp unop xTy return $ UnOp unop x' MiscOp op -> MiscOp <$> checkWithEffects effs op MemOp op -> MemOp <$> checkWithEffects effs op DAMOp op -> DAMOp <$> checkWithEffects effs op RefOp ref m -> do - (ref', TC (RefType h s)) <- checkAndGetType ref + (ref', TyCon (RefType h s)) <- checkAndGetType ref m' <- case m of MGet -> declareEff effs (RWSEffect State h) $> MGet MPut x -> do @@ -491,22 +499,22 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where declareEff effs (RWSEffect Writer h) return $ MExtend b' x' IndexRef givenTy i -> do - givenTy' <- givenTy |: TyKind - TabPi tabTy <- return s + givenTy' <- checkE givenTy + TyCon (TabPi tabTy) <- return s i' <- checkE i eltTy' <- checkInstantiation tabTy [i'] - checkTypesEq givenTy' (TC $ RefType h eltTy') + checkTypesEq givenTy' (TyCon $ RefType h eltTy') return $ IndexRef givenTy' i' ProjRef givenTy p -> do - givenTy' <- givenTy |: TyKind + givenTy' <- checkE givenTy resultEltTy <- case p of ProjectProduct i -> do - ProdTy tys <- return s + TyCon (ProdType tys) <- return s return $ tys !! i UnwrapNewtype -> do - NewtypeTyCon tc <- return s + TyCon (NewtypeTyCon tc) <- return s snd <$> unwrapNewtypeType tc - checkTypesEq givenTy' (TC $ RefType h resultEltTy) + checkTypesEq givenTy' (TyCon $ RefType h resultEltTy) return $ ProjRef givenTy' p return $ RefOp ref' m' @@ -551,29 +559,29 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where x' <- checkE x y' <- y |: getType x' return $ Select p' x' y' - CastOp t@(StuckTy (StuckVar _)) e -> CastOp <$> (t|:TyKind) <*> renameM e + CastOp t@(StuckTy _ (Var _)) e -> CastOp <$> checkE t <*> renameM e CastOp destTy e -> do e' <- checkE e - destTy' <- destTy |: TyKind + destTy' <- checkE destTy checkValidCast (getType e') destTy' return $ CastOp destTy' e' - BitcastOp t@(StuckTy (StuckVar _)) e -> BitcastOp <$> (t|:TyKind) <*> renameM e + BitcastOp t@(StuckTy _ (Var _)) e -> BitcastOp <$> checkE t <*> renameM e BitcastOp destTy e -> do - destTy' <- destTy |: TyKind + destTy' <- checkE destTy e' <- checkE e let sourceTy = getType e' case (destTy', sourceTy) of (BaseTy dbt@(Scalar _), BaseTy sbt@(Scalar _)) | sizeOf sbt == sizeOf dbt -> return $ BitcastOp destTy' e' _ -> throw TypeErr $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy - UnsafeCoerce t e -> UnsafeCoerce <$> t|:TyKind <*> renameM e - GarbageVal t -> GarbageVal <$> (t|:TyKind) + UnsafeCoerce t e -> UnsafeCoerce <$> checkE t <*> renameM e + GarbageVal t -> GarbageVal <$> checkE t SumTag x -> do x' <- checkE x void $ checkSomeSumType $ getType x' return $ SumTag x' ToEnum t x -> do - t' <- t |: TyKind + t' <- checkE t x' <- x |: Word8Ty cases <- checkSomeSumType t' forM_ cases \cty -> checkTypesEq cty UnitTy @@ -584,40 +592,40 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where x' <- checkE x BaseTy (Scalar _) <- return $ getType x' return $ ShowScalar x' - ThrowError ty -> ThrowError <$> (ty|:TyKind) + ThrowError ty -> ThrowError <$> checkE ty ThrowException ty -> ThrowException <$> do declareEff effs ExceptionEffect - ty|:TyKind + checkE ty checkSomeSumType :: IRRep r => Type r o -> TyperM r i o [Type r o] checkSomeSumType = \case - SumTy cases -> return cases - NewtypeTyCon con -> do - (_, SumTy cases) <- unwrapNewtypeType con + TyCon (SumType cases) -> return cases + TyCon (NewtypeTyCon con) -> do + (_, TyCon (SumType cases)) <- unwrapNewtypeType con return cases t -> error $ "not some sum type: " ++ pprint t instance IRRep r => CheckableE r (VectorOp r) where checkE = \case VectorBroadcast v ty -> do - ty'@(BaseTy (Vector _ sbt)) <- ty |: TyKind + ty'@(BaseTy (Vector _ sbt)) <- checkE ty v' <- v |: BaseTy (Scalar sbt) return $ VectorBroadcast v' ty' VectorIota ty -> do - ty'@(BaseTy (Vector _ _)) <- ty |: TyKind + ty'@(BaseTy (Vector _ _)) <- checkE ty return $ VectorIota ty' VectorIdx tbl i ty -> do tbl' <- checkE tbl TabTy _ b (BaseTy (Scalar sbt)) <- return $ getType tbl' i' <- i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind + ty'@(BaseTy (Vector _ sbt')) <- checkE ty unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" return $ VectorIdx tbl' i' ty' VectorSubref ref i ty -> do ref' <- checkE ref RefTy _ (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref' i' <- i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind + ty'@(BaseTy (Vector _ sbt')) <- checkE ty unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" return $ VectorSubref ref' i' ty' @@ -626,9 +634,9 @@ checkHof (EffTy effs reqTy) = \case For dir ixTy f -> do IxType t d <- checkE ixTy LamExpr (UnaryNest b) body <- return f - TabPi tabTy <- return reqTy + TyCon (TabPi tabTy) <- return reqTy checkBinderType t b \b' -> do - resultTy <- checkInstantiation (sink tabTy) [Var $ binderVar b'] + resultTy <- checkInstantiation (sink tabTy) [toAtom $ binderVar b'] body' <- checkWithEffTy (EffTy (sink effs) resultTy) body return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') While body -> do @@ -641,7 +649,7 @@ checkHof (EffTy effs reqTy) = \case checkBinderType xTy b \b' -> do PairTy resultTy fLinTy <- sinkM reqTy body' <- checkWithEffTy (EffTy Pure resultTy) body - checkTypesEq fLinTy (Pi $ nonDepPiType [sink xTy] Pure resultTy) + checkTypesEq fLinTy (toType $ nonDepPiType [sink xTy] Pure resultTy) return $ Linearize (LamExpr (UnaryNest b') body') x' Transpose f x -> do (x', xTy) <- checkAndGetType x @@ -687,8 +695,10 @@ checkHof (EffTy effs reqTy) = \case CatchException reqTy' body -> do reqTy'' <- checkE reqTy' checkTypesEq reqTy reqTy'' - TypeCon _ _ (TyConParams _[Type ty]) <- return reqTy'' -- TODO: take more care in unpacking Maybe - body' <- checkWithEffTy (EffTy (extendEffect ExceptionEffect effs) ty) body + -- TODO: take more care in unpacking Maybe + TyCon (NewtypeTyCon (UserADTType _ _ (TyConParams _ [ty]))) <- return reqTy'' + Just ty' <- return $ toMaybeType ty + body' <- checkWithEffTy (EffTy (extendEffect ExceptionEffect effs) ty') body return $ CatchException reqTy'' body' instance IRRep r => CheckableWithEffects r (DAMOp r) where @@ -701,7 +711,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where (carry', carryTy') <- checkAndGetType carry let badCarry = throw TypeErr $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' case carryTy' of - ProdTy refTys -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry + TyCon (ProdType refTys) -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry _ -> badCarry let binderReqTy = PairTy (ixTypeType ixTy') carryTy' checkBinderType binderReqTy b \b' -> do @@ -715,7 +725,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkBinderType dTy b \b' -> do body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' - AllocDest ty -> AllocDest <$> ty|:TyKind + AllocDest ty -> AllocDest <$> checkE ty Place ref val -> do val' <- checkE val ref' <- ref |: RawRefTy (getType val') @@ -729,7 +739,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkLamExpr :: IRRep r => PiType r o -> LamExpr r i -> TyperM r i o (LamExpr r o) checkLamExpr piTy (LamExpr bs body) = checkB bs \bs' -> do - effTy <- checkInstantiation (sink piTy) (Var <$> bindersVars bs') + effTy <- checkInstantiation (sink piTy) (toAtom <$> bindersVars bs') body' <- checkWithEffTy effTy body return $ LamExpr bs' body' @@ -751,32 +761,27 @@ checkRWSAction -> RWS -> LamExpr r i -> TyperM r i o (LamExpr r o) checkRWSAction resultTy referentTy effs rws f = do BinaryLamExpr bH bR body <- return f - checkBinderType (TC HeapType) bH \bH' -> do - let h = Var $ binderVar bH' + checkBinderType (TyCon HeapType) bH \bH' -> do + let h = toAtom $ binderVar bH' let refTy = RefTy h (sink referentTy) checkBinderType refTy bR \bR' -> do let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) body' <- checkWithEffTy (EffTy effs' (sink resultTy)) body return $ BinaryLamExpr bH' bR' body' -checkCaseAltsBinderTys :: IRRep r => Type r n -> TyperM r i n [Type r n] -checkCaseAltsBinderTys ty = case ty of - SumTy types -> return types - NewtypeTyCon t -> case t of - UserADTType _ defName (TyConParams _ params) -> do - def <- lookupTyCon defName - ADTCons cons <- checkInstantiation def params - return [repTy | DataConDef _ _ repTy _ <- cons] - _ -> fail msg - _ -> fail msg - where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty - -checkTabApp :: (IRRep r) => Type r o -> [Atom r o] -> TyperM r i o (Type r o) -checkTabApp ty [] = return ty -checkTabApp ty (i:rest) = do - TabPi tabTy <- return ty - resultTy <- checkInstantiation tabTy [i] - checkTabApp resultTy rest +checkProject :: (IRRep r) => Int -> Atom r o -> TyperM r i o (Type r o) +checkProject i x = case getType x of + TyCon (ProdType tys) -> return $ tys !! i + TyCon (DepPairTy t) | i == 0 -> return $ depPairLeftTy t + TyCon (DepPairTy t) | i == 1 -> do + xFst <- reduceProj 0 x + checkInstantiation t [xFst] + xTy -> throw TypeErr $ "Not a product type:" ++ pprint xTy + +checkTabApp :: (IRRep r) => Type r o -> Atom r o -> TyperM r i o (Type r o) +checkTabApp ty i = do + TyCon (TabPi tabTy) <- return ty + checkInstantiation tabTy [i] checkInstantiation :: forall r e body i o . @@ -824,7 +829,7 @@ checkFloatBaseType t = case t of "Expected a fixed-width scalar floating-point type, but found: " ++ pprint t checkValidCast :: (Fallible1 m, IRRep r) => Type r n -> Type r n -> m n () -checkValidCast (BaseTy l) (BaseTy r) = checkValidBaseCast l r +checkValidCast (TyCon (BaseType l)) (TyCon (BaseType r)) = checkValidBaseCast l r checkValidCast sourceTy destTy = throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy @@ -905,7 +910,7 @@ instance IRRep r => CheckableE r (EffectRow r) where checkE (EffectRow effs effTail) = do effs' <- eSetFromList <$> forM (eSetToList effs) \eff -> case eff of RWSEffect rws v -> do - v' <- v |: TC HeapType + v' <- v |: TyCon HeapType return $ RWSEffect rws v' ExceptionEffect -> return ExceptionEffect IOEffect -> return IOEffect diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 10c60999b..8f5ee3acb 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -428,33 +428,6 @@ fromNaryForExpr maxDepth = \case return (d + 1, LamExpr (Nest b bs) body2) _ -> Nothing -mkConsListTy :: [Type r n] -> Type r n -mkConsListTy = foldr PairTy UnitTy - -mkConsList :: [Atom r n] -> Atom r n -mkConsList = foldr PairVal UnitVal - -fromConsListTy :: (IRRep r, Fallible m) => Type r n -> m [Type r n] -fromConsListTy ty = case ty of - UnitTy -> return [] - PairTy t rest -> (t:) <$> fromConsListTy rest - _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty - --- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) -fromLeftLeaningConsListTy :: (IRRep r, Fallible m) => Int -> Type r n -> m (Type r n, [Type r n]) -fromLeftLeaningConsListTy depth initTy = go depth initTy [] - where - go 0 ty xs = return (ty, reverse xs) - go remDepth ty xs = case ty of - PairTy lt rt -> go (remDepth - 1) lt (rt : xs) - _ -> throw CompilerErr $ "Not a pair: " ++ show xs - -fromConsList :: (IRRep r, Fallible m) => Atom r n -> m [Atom r n] -fromConsList xs = case xs of - UnitVal -> return [] - PairVal x rest -> (x:) <$> fromConsList rest - _ -> throw CompilerErr $ "Not a pair or unit: " ++ show xs - type BundleDesc = Int -- length bundleFold :: a -> (a -> a -> a) -> [a] -> (a, BundleDesc) @@ -465,16 +438,10 @@ bundleFold emptyVal pair els = case els of where (tb, td) = bundleFold emptyVal pair t mkBundleTy :: [Type r n] -> (Type r n, BundleDesc) -mkBundleTy = bundleFold UnitTy PairTy +mkBundleTy = bundleFold UnitTy (\x y -> TyCon (ProdType [x, y])) mkBundle :: [Atom r n] -> (Atom r n, BundleDesc) -mkBundle = bundleFold UnitVal PairVal - -trySelectBranch :: IRRep r => Atom r n -> Maybe (Int, Atom r n) -trySelectBranch e = case e of - SumVal _ i value -> Just (i, value) - NewtypeCon con e' | isSumCon con -> trySelectBranch e' - _ -> Nothing +mkBundle = bundleFold UnitVal (\x y -> Con (ProdCon [x, y])) freeAtomVarsList :: forall r e n. (IRRep r, HoistableE e) => e n -> [Name (AtomNameC r) n] freeAtomVarsList = freeVarsList diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 6af6141c8..34e1a374c 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -59,8 +59,7 @@ data ErrType = NoErr | ZipErr | EscapedNameErr | ModuleImportErr - | MonadFailErr - | SearchFailure -- used as the identity for `Alternative` instances + | SearchFailure -- used as the identity for `Alternative` instances and for MonadFail deriving (Show, Eq) type SrcTextCtx = Maybe (Int, Text) -- Int is the offset in the source file @@ -320,7 +319,7 @@ layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions -- === instances === instance MonadFail FallibleM where - fail s = throw MonadFailErr s + fail s = throw SearchFailure s {-# INLINE fail #-} instance Fallible Except where @@ -333,7 +332,7 @@ instance Fallible Except where {-# INLINE addErrCtx #-} instance MonadFail Except where - fail s = Failure $ Err CompilerErr mempty s + fail s = Failure $ Err SearchFailure mempty s {-# INLINE fail #-} instance Exception Err @@ -393,7 +392,6 @@ instance Pretty ErrType where ZipErr -> "Zipping error" EscapedNameErr -> "Leaked local variables:" ModuleImportErr -> "Module import error: " - MonadFailErr -> "MonadFail error (internal error)" SearchFailure -> "Search error (internal error)" instance Fallible m => Fallible (ReaderT r m) where diff --git a/src/lib/Export.hs b/src/lib/Export.hs index f7ab3184d..813874bba 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -45,14 +45,14 @@ prepareFunctionForExport :: (Mut n, Topper m) => CallingConvention -> CAtom n -> m n ExportNativeFunction prepareFunctionForExport cc f = do naryPi <- case getType f of - Pi piTy -> return piTy + TyCon (Pi piTy) -> return piTy _ -> throw TypeErr "Only first-order functions can be exported" sig <- liftExportSigM $ corePiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s - f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) + f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (toAtom <$> xs) fSimp <- simplifyTopFunction $ coreLamToTopLam f' fImp <- compileTopLevelFun cc fSimp nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -143,11 +143,11 @@ goResult :: IRRep r => Type r i Nest ExportResult o o' -> ExportSigM r i o' a) -> ExportSigM r i o a goResult ty cont = case ty of - ProdTy [one] -> + TyCon (ProdType [one]) -> goResult one cont - ProdTy (lty:rest) -> + TyCon (ProdType (lty:rest)) -> goResult lty \lres -> - goResult (ProdTy rest) \rres -> + goResult (TyCon (ProdType rest)) \rres -> cont $ lres >>> rres _ -> do ety <- toExportType ty @@ -157,7 +157,7 @@ goResult ty cont = case ty of toExportType :: IRRep r => Type r i -> ExportSigM r i o (ExportType o) toExportType ty = case ty of BaseTy (Scalar sbt) -> return $ ScalarType sbt - NewtypeTyCon Nat -> return $ ScalarType IdxRepScalarBaseTy + TyCon (NewtypeTyCon Nat) -> return $ ScalarType IdxRepScalarBaseTy TabTy _ _ _ -> parseTabTy ty >>= \case Nothing -> unsupported Just ety -> return ety @@ -168,22 +168,18 @@ toExportType ty = case ty of parseTabTy :: IRRep r => Type r i -> ExportSigM r i o (Maybe (ExportType o)) parseTabTy = go [] where - go :: forall r i o. IRRep r => [ExportDim o] -> Type r i - -> ExportSigM r i o (Maybe (ExportType o)) + go :: IRRep r => [ExportDim o] -> Type r i -> ExportSigM r i o (Maybe (ExportType o)) go shape = \case - BaseTy (Scalar sbt) -> return $ Just $ RectContArrayPtr sbt shape - NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape - TabTy d (b:>ixty) a -> do - maybeN <- case IxType ixty d of - IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n - IxType _ (IxDictRawFin n) -> return $ Just n - _ -> return Nothing + TyCon (BaseType (Scalar sbt)) -> return $ Just $ RectContArrayPtr sbt shape + TyCon (NewtypeTyCon Nat) -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape + TyCon (TabPi (TabPiType d (b:>ixty) a)) -> do + maybeN <- fromIxFin $ IxType ixty d maybeDim <- case maybeN of - Just (Var v) -> do + Just (Stuck _ (Var v)) -> do s <- getSubst let (Rename v') = s ! atomVarName v return $ Just (ExportDimVar v') - Just (NewtypeCon NatCon (IdxRepVal s)) -> return $ Just (ExportDimLit $ fromIntegral s) + Just (Con (NewtypeCon NatCon (IdxRepVal s))) -> return $ Just (ExportDimLit $ fromIntegral s) Just (IdxRepVal s) -> return $ Just (ExportDimLit $ fromIntegral s) _ -> return Nothing case maybeDim of @@ -193,6 +189,12 @@ parseTabTy = go [] Nothing -> return Nothing _ -> return Nothing + fromIxFin :: IRRep r => IxType r i -> ExportSigM r i o (Maybe (Atom r i)) + fromIxFin = \case + IxType (TyCon (NewtypeTyCon (Fin n))) (DictCon (IxFin _)) -> return $ Just n + IxType _ (DictCon (IxRawFin n)) -> return $ Just n + _ -> return Nothing + data ArgVisibility = ImplicitArg | ExplicitArg data ExportDim n = diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 58c0721d4..945552a7d 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -7,6 +7,7 @@ module Generalize (generalizeArgs, generalizeIxDict) where import Control.Monad +import Data.Maybe (fromJust) import Core import Err @@ -16,20 +17,18 @@ import IRVariants import QueryType import Name import Subst -import MTL1 import Types.Primitives type RolePiBinder = WithAttrB RoleExpl CBinder type RolePiBinders = Nest RolePiBinder -generalizeIxDict :: EnvReader m => Atom CoreIR n -> m n (Generalized CoreIR CAtom n) +generalizeIxDict :: EnvReader m => CDict n -> m n (Generalized CoreIR CDict n) generalizeIxDict dict = liftGeneralizerM do dict' <- sinkM dict dictTy <- return $ getType dict' dictTyGeneralized <- generalizeType dictTy - dictGeneralized <- liftEnvReaderM $ generalizeDict dictTyGeneralized dict' - return dictGeneralized --- {-# INLINE generalizeIxDict #-} + liftEnvReaderM $ generalizeDict dictTyGeneralized dict' +{-# INLINE generalizeIxDict #-} generalizeArgs ::EnvReader m => CorePiType n -> [Atom CoreIR n] -> m n (Generalized CoreIR (ListE CAtom) n) generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do @@ -40,13 +39,12 @@ generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do -> SubstReaderT AtomSubstVal GeneralizerM i n [Atom CoreIR n] go (Nest (WithAttrB expl b) bs) (arg:args) = do ty' <- substM $ binderType b - arg' <- case (ty', expl) of - (TyKind, _) -> liftSubstReaderT case arg of - Type t -> Type <$> generalizeType t - _ -> error "not a type" - (DictTy _, Inferred Nothing (Synth _)) -> generalizeDict ty' arg + arg' <- liftSubstReaderT case (ty', expl) of + (TyKind, _) -> toAtom <$> generalizeType (fromJust $ toMaybeType arg) + (TyCon (DictTy _), Inferred Nothing (Synth _)) -> + toAtom <$> generalizeDict ty' (fromJust $ toMaybeDict arg) _ -> isData ty' >>= \case - True -> liftM Var $ liftSubstReaderT $ emitGeneralizationParameter ty' arg + True -> toAtom <$> emitGeneralizationParameter ty' arg False -> do -- Unlike in `inferRoles` in `Inference`, it's ok to have non-data, -- non-type, non-dict arguments (e.g. a function). We just don't @@ -108,11 +106,9 @@ emitGeneralizationParameter ty val = GeneralizerM do -- Given a type (an Atom of type `Type`), abstracts over all data components generalizeType :: Type CoreIR n -> GeneralizerM n (Type CoreIR n) generalizeType ty = traverseTyParams ty \paramRole paramReqTy param -> case paramRole of - TypeParam -> Type <$> case param of - Type t -> generalizeType t - _ -> error "not a type" - DictParam -> generalizeDict paramReqTy param - DataParam -> Var <$> emitGeneralizationParameter paramReqTy param + TypeParam -> toAtom <$> generalizeType (fromJust $ toMaybeType param) + DictParam -> toAtom <$> generalizeDict paramReqTy (fromJust $ toMaybeDict param) + DataParam -> toAtom <$> emitGeneralizationParameter paramReqTy param -- === role-aware type traversal === @@ -125,27 +121,28 @@ traverseTyParams => CType n -> (forall l . DExt n l => ParamRole -> CType l -> CAtom l -> m l (CAtom l)) -> m n (CType n) -traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of - DictTy (DictType sn name params) -> do - Abs paramRoles UnitE <- getClassRoleBinders name - params' <- traverseRoleBinders f paramRoles params - return $ DictTy $ DictType sn name params' - TabPi (TabPiType (IxDictAtom d) (b:>iTy) resultTy) -> do +traverseTyParams (StuckTy _ _) _ = error "shouldn't have StuckTy left" +traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case ty of + DictTy dictTy -> DictTy <$> case dictTy of + DictType sn name params -> do + Abs paramRoles UnitE <- getClassRoleBinders name + params' <- traverseRoleBinders f paramRoles params + return $ DictType sn name params' + IxDictType t -> IxDictType <$> f' TypeParam TyKind t + DataDictType t -> DataDictType <$> f' TypeParam TyKind t + TabPi (TabPiType d (b:>iTy) resultTy) -> do iTy' <- f' TypeParam TyKind iTy - dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy' - d' <- f DictParam dictTy d + let dictTy = toType $ IxDictType iTy' + d' <- fromJust . toMaybeDict <$> f DictParam dictTy (toAtom d) withFreshBinder (getNameHint b) iTy' \(b':>_) -> do resultTy' <- applyRename (b@>binderName b') resultTy >>= (f' TypeParam TyKind) - return $ TabTy (IxDictAtom d') (b':>iTy') resultTy' - -- shouldn't need this once we can exclude IxDictFin and IxDictSpecialized from CoreI - TabPi t -> return $ TabPi t - TC tc -> TC <$> case tc of - BaseType b -> return $ BaseType b - ProdType tys -> ProdType <$> forM tys \t -> f' TypeParam TyKind t - RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter? - SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t - TypeKind -> return TypeKind - HeapType -> return HeapType + return $ TabPi $ TabPiType d' (b':>iTy') resultTy' + BaseType b -> return $ BaseType b + ProdType tys -> ProdType <$> forM tys \t -> f' TypeParam TyKind t + RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter? + SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t + TypeKind -> return TypeKind + HeapType -> return HeapType NewtypeTyCon con -> NewtypeTyCon <$> case con of Nat -> return Nat Fin n -> Fin <$> f DataParam NatTy n @@ -157,11 +154,7 @@ traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of _ -> error $ "Not implemented: " ++ pprint ty where f' :: forall l . DExt n l => ParamRole -> CType l -> CType l -> m l (CType l) - f' r t x = fromType <$> f r t (Type x) - - fromType :: CAtom l -> CType l - fromType (Type t) = t - fromType x = error $ "not a type: " ++ pprint x + f' r t x = fromJust <$> toMaybeType <$> f r t (toAtom x) {-# INLINE traverseTyParams #-} traverseRoleBinders @@ -191,7 +184,7 @@ getDataDefRoleBinders def = do getClassRoleBinders :: EnvReader m => ClassName n -> m n (Abs RolePiBinders UnitE n) getClassRoleBinders def = do - ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef def + ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef def return $ Abs (zipAttrs roleExpls bs) UnitE {-# INLINE getClassRoleBinders #-} diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 9ab013f23..e99e527a4 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -288,10 +288,10 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of scalarArgs <- liftM toList $ mapM fromScalarAtom xs results <- impCall f scalarArgs restructureScalarOrPairType resultTy results - TabApp _ f' xs' -> do - xs <- mapM substM xs' + TabApp _ f' x' -> do + x <- substM x' f <- atomToRepVal =<< substM f' - repValAtom =<< naryIndexRepVal f (toList xs) + repValAtom =<< indexRepVal f x Atom x -> substM x PrimOp op -> toImpOp op Case e alts (EffTy _ unitResultTy) -> do @@ -299,11 +299,12 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of case unitResultTy of UnitTy -> return () _ -> error $ "Unexpected returning Case in Imp " ++ pprint expr - case trySelectBranch e' of - Just (con, arg) -> do - Abs b body <- return $ alts !! con + case e' of + Con con -> do + SumCon _ i arg <- return con + Abs b body <- return $ alts !! i extendSubst (b @> SubstVal arg) $ translateExpr body - Nothing -> do + Stuck _ _ -> do RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e' ts <- caseAltsBinderTys sumTy tag' <- repValAtom $ RepVal TagRepTy tag @@ -355,7 +356,7 @@ toImpRefOp refDest' m = do ans <- liftBuilderImp $ emitExpr (sink body') storeAtom accDest ans False -> case accTy of - TabPi t -> do + TyCon (TabPi t) -> do let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do @@ -431,7 +432,7 @@ castPtrToVectorType ptr vty = do let PtrType (addrSpace, _) = getIType ptr cast ptr (PtrType (addrSpace, vty)) -toImpMiscOp :: Emits o => MiscOp SimpIR o -> SubstImpM i o (SAtom o) +toImpMiscOp :: forall i o . Emits o => MiscOp SimpIR o -> SubstImpM i o (SAtom o) toImpMiscOp op = case op of ThrowError resultTy -> do emitStatement IThrowError @@ -458,15 +459,14 @@ toImpMiscOp op = case op of returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y) SumTag con -> case con of Con (SumCon _ tag _) -> return $ TagRepVal $ fromIntegral tag - RepValAtom dRepVal -> go dRepVal + Stuck _ (RepValAtom dRepVal) -> do + RepVal _ (Branch (tag:_)) <- return dRepVal + return $ toAtom $ RepVal (TagRepTy :: SType o) tag _ -> error $ "Not a data constructor: " ++ pprint con - where go dRepVal = do - RepVal _ (Branch (tag:_)) <- return dRepVal - return $ RepValAtom $ RepVal TagRepTy tag ToEnum ty i -> case ty of - SumTy cases -> do + TyCon (SumType cases) -> do i' <- fromScalarAtom i - return $ RepValAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases + return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases _ -> error $ "Not an enum: " ++ pprint ty OutputStream -> returnIExprVal =<< emitInstr IOutputStream ThrowException _ -> error "shouldn't have ThrowException left" -- also, should be replaced with user-defined errors @@ -567,7 +567,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do alphaEq xTy accTy >>= \case True -> storeAtom accDest x False -> case accTy of - TabPi t -> do + TyCon (TabPi t) -> do let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do @@ -683,28 +683,27 @@ typeToTree :: EnvReader m => SType n -> m n (Tree (LeafType n)) typeToTree tyTop = return $ go REmpty tyTop where go :: RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (LeafType n) - go ctx = \case - BaseTy b -> Leaf $ LeafType (unRNest ctx) b - TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy - RefTy _ t -> go (RNest ctx RefCtx) t + go ctx (TyCon con) = case con of + BaseType b -> Leaf $ LeafType (unRNest ctx) b + TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy + RefType _ t -> go (RNest ctx RefCtx) t DepPairTy (DepPairType _ (b:>t1) (t2)) -> do let tree1 = rec t1 let tree2 = go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 Branch [tree1, tree2] - ProdTy ts -> Branch $ map rec ts - SumTy ts -> do + ProdType ts -> Branch $ map rec ts + SumType ts -> do let tag = rec TagRepTy let xs = map rec ts Branch $ tag:xs - TC HeapType -> Branch [] - ty -> error $ "not implemented " ++ pprint ty + HeapType -> Branch [] where rec = go ctx traverseScalarRepTys :: EnvReader m => SType n -> (LeafType n -> m n a) -> m n (Tree a) traverseScalarRepTys ty f = traverse f =<< typeToTree ty {-# INLINE traverseScalarRepTys #-} -storeRepVal :: Emits n => Dest n -> SRepVal n -> SubstImpM i n () +storeRepVal :: Emits n => Dest n -> RepVal n -> SubstImpM i n () storeRepVal (Dest _ destTree) repVal@(RepVal _ valTree) = do leafTys <- valueToTree repVal forM_ (zipTrees (zipTrees leafTys destTree) valTree) \((leafTy, ptr), val) -> do @@ -713,16 +712,16 @@ storeRepVal (Dest _ destTree) repVal@(RepVal _ valTree) = do -- Like `typeToTree`, but when we additionally have the value, we can populate -- the existentially-hidden fields. -valueToTree :: EnvReader m => SRepVal n -> m n (Tree (LeafType n)) +valueToTree :: EnvReader m => RepVal n -> m n (Tree (LeafType n)) valueToTree (RepVal tyTop valTop) = do go REmpty tyTop valTop where go :: EnvReader m => RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (IExpr n) -> m n (Tree (LeafType n)) - go ctx ty val = case ty of - BaseTy b -> return $ Leaf $ LeafType (unRNest ctx) b - TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val - RefTy _ t -> go (RNest ctx RefCtx) t val + go ctx (TyCon ty) val = case ty of + BaseType b -> return $ Leaf $ LeafType (unRNest ctx) b + TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val + RefType _ t -> go (RNest ctx RefCtx) t val DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of Branch [v1, v2] -> do case allDepPairCtxs (unRNest ctx) of @@ -737,10 +736,10 @@ valueToTree (RepVal tyTop valTop) = do tree2 <- go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 v2 return $ Branch [tree1, tree2] _ -> error "expected a branch" - ProdTy ts -> case val of + ProdType ts -> case val of Branch vals -> Branch <$> zipWithM rec ts vals _ -> error "expected a branch" - SumTy ts -> case val of + SumType ts -> case val of Branch (tagVal:vals) -> do tag <- rec TagRepTy tagVal results <- zipWithM rec ts vals @@ -831,7 +830,7 @@ isNull p = do nullPtrIExpr :: BaseType -> IExpr n nullPtrIExpr baseTy = ILit $ PtrLit (CPU, baseTy) NullPtr -loadRepVal :: (ImpBuilder m, Emits n) => Dest n -> m n (SRepVal n) +loadRepVal :: (ImpBuilder m, Emits n) => Dest n -> m n (RepVal n) loadRepVal (Dest valTy destTree) = do leafTys <- typeToTree valTy RepVal valTy <$> forM (zipTrees leafTys destTree) \(leafTy, ptr) -> do @@ -841,54 +840,55 @@ loadRepVal (Dest valTy destTree) = do _ -> return ptr {-# INLINE loadRepVal #-} -atomToRepVal :: Emits n => SAtom n -> SubstImpM i n (SRepVal n) +atomToRepVal :: Emits n => SAtom n -> SubstImpM i n (RepVal n) atomToRepVal x = RepVal (getType x) <$> go x where go :: Emits n => SAtom n -> SubstImpM i n (Tree (IExpr n)) - go atom = case atom of - RepValAtom dRepVal -> do - (RepVal _ tree) <- return dRepVal - return tree + go (Con con) = case con of DepPair lhs rhs _ -> do lhsTree <- go lhs rhsTree <- go rhs return $ Branch [lhsTree, rhsTree] - Con (Lit l) -> return $ Leaf $ ILit l - Con (ProdCon xs) -> Branch <$> mapM go xs - Con (SumCon cases tag payload) -> do + Lit l -> return $ Leaf $ ILit l + ProdCon xs -> Branch <$> mapM go xs + SumCon cases tag payload -> do tag' <- go $ TagRepVal $ fromIntegral tag xs <- forM (enumerate cases) \(i, t) -> if i == tag then go payload - else buildGarbageVal t <&> \(RepValAtom (RepVal _ tree)) -> tree + else buildGarbageVal t <&> \(Stuck _ (RepValAtom (RepVal _ tree))) -> tree return $ Branch $ tag':xs - Con HeapVal -> return $ Branch [] - PtrVar ty p -> return $ Leaf $ IPtrVar p ty - Stuck (StuckVar v) -> lookupAtomName (atomVarName v) >>= \case + HeapVal -> return $ Branch [] + go (Stuck _ stuck) = case stuck of + Var v -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" + PtrVar ty p -> return $ Leaf $ IPtrVar p ty + RepValAtom dRepVal -> do + (RepVal _ tree) <- return dRepVal + return tree -- TODO: I think we want to be able to rule this one out by insisting that -- RepValAtom is itself part of Stuck and it can't represent a product. - Stuck (StuckProject _ i val) -> do - Branch ts <- go $ Stuck val + StuckProject i val -> do + Branch ts <- go =<< mkStuck val return $ ts !! i - Stuck (StuckTabApp _ f xs) -> do - f' <- atomToRepVal $ Stuck f - RepVal _ t <- naryIndexRepVal f' (toList xs) + StuckTabApp f x' -> do + f' <- atomToRepVal =<< mkStuck f + RepVal _ t <- indexRepVal f' x' return t -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of -- type `Ref _`. destToAtom :: Dest n -> SAtom n -destToAtom (Dest valTy tree) = RepValAtom $ RepVal (RefTy (Con HeapVal) valTy) tree +destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy (Con HeapVal) valTy) tree atomToDest :: EnvReader m => SAtom n -> m n (Dest n) -atomToDest (RepValAtom val) = do +atomToDest (Stuck _ (RepValAtom val)) = do (RepVal ~(RefTy _ valTy) valTree) <- return val return $ Dest valTy valTree atomToDest atom = error $ "Expected a non-var atom of type `RawRef _`, got: " ++ pprint atom {-# INLINE atomToDest #-} -repValToList :: SRepVal n -> [IExpr n] +repValToList :: RepVal n -> [IExpr n] repValToList (RepVal _ tree) = toList tree -- TODO: augment with device, backend information as needed @@ -961,7 +961,7 @@ storeAtom dest x = storeRepVal dest =<< atomToRepVal x loadAtom :: Emits n => Dest n -> SubstImpM i n (SAtom n) loadAtom d = repValAtom =<< loadRepVal d -repValFromFlatList :: (TopBuilder m, Mut n) => SType n -> [LitVal] -> m n (SRepVal n) +repValFromFlatList :: (TopBuilder m, Mut n) => SType n -> [LitVal] -> m n (RepVal n) repValFromFlatList ty xs = do (litValTree, []) <- runStreamReaderT1 xs $ traverseScalarRepTys ty \_ -> fromJust <$> readStream @@ -977,7 +977,7 @@ litValToIExpr litval = case litval of buildGarbageVal :: Emits n => SType n -> SubstImpM i n (SAtom n) buildGarbageVal ty = - RepValAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do + toAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do case getIExprInterpretation leafTy of BufferPtr bufferTy -> allocBuffer Managed bufferTy RawValue b -> return $ ILit $ emptyLit b @@ -985,10 +985,10 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest (TabPi tabTy) tree) i = do +indexDest (Dest (TyCon (TabPi tabTy)) tree) i = do eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i - leafTys <- typeToTree $ TabPi tabTy + leafTys <- typeToTree $ toType tabTy Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do BufferType ixStruct _ <- return $ getRefBufferType leafTy offset <- computeOffsetImp ixStruct ord @@ -996,23 +996,15 @@ indexDest (Dest (TabPi tabTy) tree) i = do indexDest _ _ = error "expected a reference to a table" {-# INLINE indexDest #-} --- TODO: direct n-ary version for efficiency? -naryIndexRepVal :: Emits n => RepVal SimpIR n -> [SAtom n] -> SubstImpM i n (RepVal SimpIR n) -naryIndexRepVal x [] = return x -naryIndexRepVal x (ix:ixs) = do - x' <- indexRepVal x ix - naryIndexRepVal x' ixs -{-# INLINE naryIndexRepVal #-} - -- TODO: de-dup with indexDest? indexRepValParam :: Emits n - => SRepVal n -> SAtom n -> (SType n -> SType n) + => RepVal n -> SAtom n -> (SType n -> SType n) -> (IExpr n -> SubstImpM i n (IExpr n)) - -> SubstImpM i n (SRepVal n) -indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do + -> SubstImpM i n (RepVal n) +indexRepValParam (RepVal (TyCon (TabPi tabTy)) vals) i tyFunc func = do eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i - leafTys <- typeToTree (TabPi tabTy) + leafTys <- typeToTree (toType tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy offset <- computeOffsetImp ixStruct ord @@ -1028,14 +1020,11 @@ indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do indexRepValParam _ _ _ _ = error "expected table type" {-# INLINE indexRepValParam #-} -indexRepVal :: Emits n - => RepVal SimpIR n -> SAtom n -> SubstImpM i n (RepVal SimpIR n) +indexRepVal :: Emits n => RepVal n -> SAtom n -> SubstImpM i n (RepVal n) indexRepVal rep i = indexRepValParam rep i id return {-# INLINE indexRepVal #-} -vectorIndexRepVal :: Emits n - => RepVal SimpIR n -> SAtom n -> SType n - -> SubstImpM i n (RepVal SimpIR n) +vectorIndexRepVal :: Emits n => RepVal n -> SAtom n -> SType n -> SubstImpM i n (RepVal n) vectorIndexRepVal rep i vty = -- Passing `const vty` here depends on knowing that `vectorIndexRepVal` is -- only called on references of scalar base type, so that the give `vty` is, @@ -1045,7 +1034,7 @@ vectorIndexRepVal rep i vty = {-# INLINE vectorIndexRepVal #-} projectDest :: Int -> Dest n -> Dest n -projectDest i (Dest (ProdTy tys) (Branch ds)) = +projectDest i (Dest (TyCon (ProdType tys)) (Branch ds)) = Dest (tys!!i) (ds!!i) projectDest _ (Dest ty _) = error $ "Can't project dest: " ++ pprint ty @@ -1104,7 +1093,7 @@ computeSizeGivenOrdinal computeSizeGivenOrdinal (PairB (LiftB d) (b:>t)) idxStruct = liftBuilder do withFreshBinder noHint IdxRepTy \bOrdinal -> Abs bOrdinal <$> buildBlock do - i <- unsafeFromOrdinal (sink $ IxType t d) $ Var $ sink $ binderVar bOrdinal + i <- unsafeFromOrdinal (sink $ IxType t d) $ toAtom $ sink $ binderVar bOrdinal idxStruct' <- applySubst (b@>SubstVal i) idxStruct elemCountPoly $ sink idxStruct' @@ -1358,8 +1347,8 @@ fromScalarAtom atom = atomToRepVal atom >>= \case Leaf x -> return x _ -> error $ "Not a scalar atom:" ++ pprint ty -toScalarAtom :: IExpr n -> SAtom n -toScalarAtom x = RepValAtom $ RepVal (BaseTy (getIType x)) (Leaf x) +toScalarAtom :: forall n. IExpr n -> SAtom n +toScalarAtom x = toAtom $ RepVal (BaseTy (getIType x) :: SType n) (Leaf x) liftBuilderImp :: (Emits n, SubstE AtomSubstVal e, SinkableE e) => (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l)) @@ -1372,24 +1361,24 @@ liftBuilderImp cont = do -- === Type classes === ordinalImp :: Emits n => IxType SimpIR n -> SAtom n -> SubstImpM i n (IExpr n) -ordinalImp (IxType _ dict) i = fromScalarAtom =<< case dict of - IxDictRawFin _ -> return i - IxDictSpecialized _ d params -> do +ordinalImp (IxType _ (DictCon dict)) i = fromScalarAtom =<< case dict of + IxRawFin _ -> return i + IxSpecialized d params -> do appSpecializedIxMethod d Ordinal (params ++ [i]) unsafeFromOrdinalImp :: Emits n => IxType SimpIR n -> IExpr n -> SubstImpM i n (SAtom n) -unsafeFromOrdinalImp (IxType _ dict) i = do +unsafeFromOrdinalImp (IxType _ (DictCon dict)) i = do let i' = toScalarAtom i case dict of - IxDictRawFin _ -> return i' - IxDictSpecialized _ d params -> + IxRawFin _ -> return i' + IxSpecialized d params -> appSpecializedIxMethod d UnsafeFromOrdinal (params ++ [i']) indexSetSizeImp :: Emits n => IxType SimpIR n -> SubstImpM i n (IExpr n) -indexSetSizeImp (IxType _ dict) = do +indexSetSizeImp (IxType _ (DictCon dict)) = do fromScalarAtom =<< case dict of - IxDictRawFin n -> return n - IxDictSpecialized _ d params -> + IxRawFin n -> return n + IxSpecialized d params -> appSpecializedIxMethod d Size (params ++ []) appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n) @@ -1434,10 +1423,10 @@ abstractLinktimeObjects f = do isSingletonType :: Type SimpIR n -> Bool isSingletonType topTy = isJust $ checkIsSingleton topTy where - checkIsSingleton :: Type r n -> Maybe () - checkIsSingleton ty = case ty of + checkIsSingleton :: SType n -> Maybe () + checkIsSingleton (TyCon ty) = case ty of TabPi (TabPiType _ _ body) -> checkIsSingleton body - TC (ProdType tys) -> mapM_ checkIsSingleton tys + ProdType tys -> mapM_ checkIsSingleton tys _ -> Nothing singletonTypeVal :: (EnvReader m) @@ -1447,7 +1436,7 @@ singletonTypeVal ty = do if length tree == 0 then do -- The tree has 0 of these if the type is empty let tree' = fmap (const $ ILit $ Int32Lit 0) tree - return $ Just $ RepValAtom $ RepVal ty tree' + Just <$> mkStuck (RepValAtom $ RepVal ty tree') else return Nothing {-# INLINE singletonTypeVal #-} diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 276bd9236..eae234af6 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -43,6 +43,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import qualified Types.OpNames as P import Util hiding (group) -- === Top-level interface === @@ -70,7 +71,7 @@ inferTopUDecl (UStructDecl tc def) result = do forM_ methods \(letAnn, methodName, methodDef) -> do method <- liftInfererM $ extendRenamer (tc@>tc') $ inferDotMethod tc' (Abs paramBs methodDef) - method' <- emitTopLet (getNameHint methodName) letAnn (Atom $ Lam method) + method' <- emitTopLet (getNameHint methodName) letAnn (Atom $ toAtom $ Lam method) updateTopEnv $ UpdateFieldDef tc' methodName (atomVarName method') UDeclResultDone <$> applyRename (tc @> tc') result inferTopUDecl (UDataDefDecl def tc dcs) result = do @@ -92,14 +93,15 @@ inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do inferTopUDecl (UInstance className bs params methods maybeName expl) result = do let (InternalName _ _ className') = className def <- liftInfererM $ withRoleUBinders bs \(ZipB roleExpls bs') -> do - ClassDef _ _ _ _ paramBinders _ _ <- lookupClassDef (sink className') + ClassDef _ _ _ _ _ paramBinders _ _ <- lookupClassDef (sink className') params' <- checkInstanceParams paramBinders params body <- checkInstanceBody (sink className') params' methods return $ InstanceDef className' roleExpls bs' params' body UDeclResultDone <$> case maybeName of RightB UnitB -> do instanceName <- emitInstanceDef def - addInstanceSynthCandidate className' instanceName + ClassDef _ builtinName _ _ _ _ _ _ <- lookupClassDef className' + addInstanceSynthCandidate className' builtinName instanceName return result JustB instanceName' -> do instanceName <- emitInstanceDef def @@ -137,8 +139,7 @@ getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do className' <- sinkM className - ClassDef classSourceName _ _ _ _ _ _ <- lookupClassDef className' - let dTy = DictTy $ DictType classSourceName className' params' + dTy <- toType <$> dictType className' params' return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy -- === Inferer monad === @@ -188,6 +189,14 @@ zonk e = do applySolverSubst s e {-# INLINE zonk #-} +zonkStuck :: CStuck n -> SolverM i n (CAtom n) +zonkStuck stuck = do + solverSubst <- getDiffState + Distinct <- getDistinct + env <- unsafeGetEnv + let subst = newSubst (lookupSolverSubst solverSubst) + return $ substStuck (env, subst) stuck + applySolverSubst :: (EnvReader m, Zonkable e) => SolverSubst n -> e n -> m n (e n) applySolverSubst subst e = do Distinct <- getDistinct @@ -210,7 +219,7 @@ withFreshBinderInf :: NameHint -> Explicitness -> CType o -> InfererCPSB CBinder withFreshBinderInf hint expl ty cont = withFreshBinder hint ty \b -> do givens <- case expl of - Inferred _ (Synth _) -> return [Var $ binderVar b] + Inferred _ (Synth _) -> return [toAtom $ binderVar b] _ -> return [] extendGivens givens $ cont b {-# INLINE withFreshBinderInf #-} @@ -296,7 +305,7 @@ withFreshDictVar withFreshDictVar dictTy synthIt cont = hasInferenceVars dictTy >>= \case False -> withDistinct $ synthIt dictTy >>= cont True -> withInferenceVar "_dict_" (DictBound dictTy) \v -> do - ans <- cont =<< (Var <$> toAtomVar v) + ans <- cont =<< (toAtom <$> toAtomVar v) dictTy' <- zonk $ sink dictTy dict <- synthIt dictTy' return (ans, dict) @@ -371,8 +380,8 @@ topDownPartial :: Emits o => PartialType o -> UExpr i -> InfererM i o (CAtom o) topDownPartial partialTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos $ case partialTy of PartialType partialPiTy -> case expr of - ULam lam -> Lam <$> checkULamPartial partialPiTy lam - _ -> Lam <$> etaExpandPartialPi partialPiTy \resultTy explicitArgs -> do + ULam lam -> toAtom <$> Lam <$> checkULamPartial partialPiTy lam + _ -> toAtom <$> Lam <$> etaExpandPartialPi partialPiTy \resultTy explicitArgs -> do expr' <- bottomUpExplicit exprWithSrc dropSubst $ checkOrInferApp expr' explicitArgs [] resultTy FullType ty -> topDownExplicit ty exprWithSrc @@ -384,7 +393,7 @@ etaExpandPartialPi -> InfererM i o (CoreLamExpr o) etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do withFreshBindersInf expls (Abs bs (PairE effs reqTy)) \bs' (PairE effs' reqTy') -> do - let args = zip expls (Var <$> bindersVars bs') + let args = zip expls (toAtom <$> bindersVars bs') explicits <- return $ catMaybes $ args <&> \case (Explicit, arg) -> Just arg _ -> Nothing @@ -397,10 +406,10 @@ etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case expr of ULam lamExpr -> case reqTy of - Pi piTy -> Lam <$> checkULam lamExpr piTy + TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam lamExpr piTy _ -> throw TypeErr $ "Unexpected lambda. Expected: " ++ pprint reqTy UFor dir uFor -> case reqTy of - TabPi tabPiTy -> do + TyCon (TabPi tabPiTy) -> do lam@(UnaryLamExpr b' _) <- checkUForExpr uFor tabPiTy ixTy <- asIxType $ binderType b' emitHof $ For dir ixTy lam @@ -409,11 +418,11 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case e f' <- bottomUpExplicit f checkOrInferApp f' posArgs namedArgs (Check reqTy) UDepPair lhs rhs -> case reqTy of - DepPairTy ty@(DepPairType _ (_ :> lhsTy) _) -> do + TyCon (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do lhs' <- checkSigmaDependent lhs (FullType lhsTy) rhsTy <- instantiate ty [lhs'] rhs' <- topDown rhsTy rhs - return $ DepPair lhs' rhs' ty + return $ toAtom $ DepPair lhs' rhs' ty _ -> throw TypeErr $ "Unexpected dependent pair. Expected: " ++ pprint reqTy UCase scrut alts -> do scrut' <- bottomUp scrut @@ -423,7 +432,7 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case e UDo block -> withBlockDecls block \result -> topDownExplicit (sink reqTy) result UTabCon xs -> do case reqTy of - TabPi tabPiTy -> checkTabCon tabPiTy xs + TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy xs _ -> throw TypeErr $ "Unexpected table constructor. Expected: " ++ pprint reqTy UNatLit x -> do let litVal = Con $ Lit $ Word64Lit $ fromIntegral x @@ -432,10 +441,10 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case e let litVal = Con $ Lit $ Int64Lit $ fromIntegral x applyFromLiteralMethod reqTy "from_integer" litVal UPrim UTuple xs -> case reqTy of - TyKind -> Type . ProdTy <$> mapM checkUType xs - ProdTy reqTys -> do + TyKind -> toAtom . ProdType <$> mapM checkUType xs + TyCon (ProdType reqTys) -> do when (length reqTys /= length xs) $ throw TypeErr "Tuple length mismatch" - ProdVal <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x + toAtom <$> ProdCon <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x _ -> throw TypeErr $ "Unexpected tuple. Expected: " ++ pprint reqTy UFieldAccess _ _ -> infer UVar _ -> infer @@ -475,11 +484,11 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of FieldDotMethod method (TyConParams _ params) -> do method' <- toAtomVar method resultTy <- partialAppType (getType method') (params ++ [x']) - return $ SigmaPartialApp resultTy (Var method') (params ++ [x']) + return $ SigmaPartialApp resultTy (toAtom method') (params ++ [x']) Nothing -> throw TypeErr $ "Can't resolve field " ++ pprint field ++ " of type " ++ pprint ty ++ "\nKnown fields are: " ++ pprint (M.keys fields) - ULam lamExpr -> SigmaAtom Nothing <$> Lam <$> inferULam lamExpr + ULam lamExpr -> SigmaAtom Nothing <$> toAtom <$> inferULam lamExpr UFor dir uFor -> do lam@(UnaryLamExpr b' _) <- inferUForExpr uFor ixTy <- asIxType $ binderType b' @@ -494,18 +503,18 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of -- TODO: check explicitness constraints withUBinders bs \(ZipB expls bs') -> do effTy' <- EffTy <$> checkUEffRow effs <*> checkUType ty - return $ SigmaAtom Nothing $ Type $ + return $ SigmaAtom Nothing $ toAtom $ Pi $ CorePiType appExpl expls bs' effTy' UTabPi (UTabPiExpr b ty) -> do Abs b' ty' <- withUBinder b \(WithAttrB _ b') -> liftM (Abs b') $ checkUType ty d <- getIxDict $ binderType b' let piTy = TabPiType d b' ty' - return $ SigmaAtom Nothing $ Type $ TabPi piTy + return $ SigmaAtom Nothing $ toAtom $ TabPi piTy UDepPairTy (UDepPairType expl b rhs) -> do withUBinder b \(WithAttrB _ b') -> do rhs' <- checkUType rhs - return $ SigmaAtom Nothing $ Type $ DepPairTy $ DepPairType expl b' rhs' + return $ SigmaAtom Nothing $ toAtom $ DepPairTy $ DepPairType expl b' rhs' UDepPair _ _ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate its type" UCase scrut (alt:alts) -> do @@ -525,7 +534,7 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of liftM (SigmaAtom Nothing) $ topDown ty' val UPrim UTuple xs -> do xs' <- forM xs \x -> bottomUp x - return $ SigmaAtom Nothing $ ProdVal xs' + return $ SigmaAtom Nothing $ Con $ ProdCon xs' UPrim UMonoLiteral [WithSrcE _ l] -> case l of UIntLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Int32Lit $ fromIntegral x UNatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Word32Lit $ fromIntegral x @@ -540,9 +549,9 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of UPrim prim xs -> do xs' <- forM xs \x -> do inferPrimArg x >>= \case - Var v -> lookupAtomName (atomVarName v) >>= \case + Stuck _ (Var v) -> lookupAtomName (atomVarName v) >>= \case LetBound (DeclBinding _ (Atom e)) -> return e - _ -> return $ Var v + _ -> return $ toAtom v x' -> return x' liftM (SigmaAtom Nothing) $ matchPrimApp prim xs' UNatLit _ -> throw TypeErr $ "Can't infer type of literal. Try an explicit annotation" @@ -567,7 +576,7 @@ matchReq Infer x = return x instantiateSigma :: Emits o => RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) instantiateSigma reqTy sigmaAtom = case sigmaAtom of SigmaUVar _ _ _ -> case getType sigmaAtom of - Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do + TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy))) -> do bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do case reqTy of Infer -> return [] @@ -628,8 +637,8 @@ applyFromLiteralMethod resultTy methodName litVal = Nothing -> error $ "prelude function not found: " ++ methodName Just ~(UMethodVar methodName') -> do MethodBinding className _ <- lookupEnv methodName' - dictTy <- DictTy <$> dictType className [Type resultTy] - d <- trySynthTerm dictTy Full + dictTy <- toType <$> dictType className [toAtom resultTy] + Just d <- toMaybeDict <$> trySynthTerm dictTy Full emitExpr =<< mkApplyMethod d 0 [litVal] -- atom that requires instantiation to become a rho type @@ -662,39 +671,43 @@ data FieldDef (n::S) = getFieldDefs :: CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) getFieldDefs ty = case ty of - NewtypeTyCon (UserADTType _ tyName params) -> do - TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName - instantiateTyConDef tyDef params >>= \case - StructFields fields -> do - let projFields = enumerate fields <&> \(i, (field, _)) -> - [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] - let methodFields = M.toList dotMethods <&> \(field, f) -> - (FieldName field, FieldDotMethod f params) - return $ M.fromList $ concat projFields ++ methodFields - ADTCons _ -> noFields "" - RefTy _ valTy -> case valTy of - RefTy _ _ -> noFields "" - _ -> do - valFields <- getFieldDefs valTy - return $ M.filter isProj valFields - where isProj = \case - FieldProj _ -> True - _ -> False - ProdTy ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) - TabPi _ -> noFields "\nArray indexing uses [] now." - _ -> noFields "" + StuckTy _ _ -> noFields "" + TyCon con -> case con of + NewtypeTyCon (UserADTType _ tyName params) -> do + TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName + instantiateTyConDef tyDef params >>= \case + StructFields fields -> do + let projFields = enumerate fields <&> \(i, (field, _)) -> + [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] + let methodFields = M.toList dotMethods <&> \(field, f) -> + (FieldName field, FieldDotMethod f params) + return $ M.fromList $ concat projFields ++ methodFields + ADTCons _ -> noFields "" + RefType _ valTy -> case valTy of + RefTy _ _ -> noFields "" + _ -> do + valFields <- getFieldDefs valTy + return $ M.filter isProj valFields + where isProj = \case + FieldProj _ -> True + _ -> False + ProdType ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) + TabPi _ -> noFields "\nArray indexing uses [] now." + _ -> noFields "" where noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of - ProdTy _ -> proj i x - NewtypeTyCon _ -> projectStruct i x - RefTy _ valTy -> case valTy of - ProdTy _ -> getProjRef (ProjectProduct i) x - NewtypeTyCon _ -> projectStructRef i x + StuckTy _ _ -> bad + TyCon con -> case con of + ProdType _ -> proj i x + NewtypeTyCon _ -> projectStruct i x + RefType _ valTy -> case valTy of + TyCon (ProdType _) -> getProjRef (ProjectProduct i) x + TyCon (NewtypeTyCon _) -> projectStructRef i x + _ -> bad _ -> bad - _ -> bad where bad = error $ "bad projection: " ++ pprint (i, x) class PrettyE e => ExplicitArg (e::E) where @@ -732,7 +745,7 @@ checkOrInferApp checkOrInferApp f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of - Pi piTy@(CorePiType appExpl expls _ _) -> case appExpl of + TyCon (Pi piTy@(CorePiType appExpl expls _ _)) -> case appExpl of ExplicitApp -> do checkExplicitArity expls posArgs bsConstrained <- buildAppConstraints reqTy piTy @@ -770,7 +783,7 @@ inlineTypeAliases :: CAtomName n -> InfererM i n (CAtom n) inlineTypeAliases v = do lookupAtomName v >>= \case LetBound (DeclBinding InlineLet (Atom e)) -> return e - _ -> Var <$> toAtomVar v + _ -> toAtom <$> toAtomVar v applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) applySigmaAtom (SigmaAtom _ f) args = emitExprWithEffects =<< mkApp f args @@ -781,7 +794,7 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls - return $ Type $ NewtypeTyCon $ UserADTType sn f' (TyConParams expls args) + return $ toAtom $ UserADTType sn f' (TyConParams expls args) UDataConVar v -> do (tyCon, i) <- lookupDataCon v applyDataCon tyCon i args @@ -790,17 +803,23 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of -- interpret as a data constructor by default (params, dataArgs) <- splitParamPrefix tc args repVal <- makeStructRepVal tc dataArgs - return $ NewtypeCon (UserADTData sn tc params) repVal - UClassVar f' -> do - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef f' - return $ Type $ DictTy $ DictType sourceName f' args + return $ toAtom $ NewtypeCon (UserADTData sn tc params) repVal + UClassVar f' -> do + ClassDef sourceName builtinName _ _ _ _ _ _ <- lookupClassDef f' + return $ toAtom case builtinName of + Just Ix -> IxDictType singleTyParam + Just Data -> DataDictType singleTyParam + Nothing -> DictType sourceName f' args + where singleTyParam = case args of + [p] -> fromJust $ toMaybeType p + _ -> error "not a single type param" UMethodVar f' -> do MethodBinding className methodIdx <- lookupEnv f' - ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className + ClassDef _ _ _ _ _ paramBs _ _ <- lookupClassDef className let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args - emitExprWithEffects =<< mkApplyMethod dictArg methodIdx args' + emitExprWithEffects =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' applySigmaAtom (SigmaPartialApp _ f prevArgs) args = emitExprWithEffects =<< mkApp f (prevArgs ++ args) @@ -821,24 +840,24 @@ applyDataCon tc conIx topArgs = do repVal <- return case conDefs of [] -> error "unreachable" [_] -> conProd - _ -> SumVal conTys conIx conProd + _ -> Con $ SumCon conTys conIx conProd where conTys = conDefs <&> \(DataConDef _ _ rty _) -> rty - return $ NewtypeCon (UserADTData sn tc params) repVal + return $ toAtom $ NewtypeCon (UserADTData sn tc params) repVal where wrap :: EnvReader m => CType n -> [CAtom n] -> m n (CAtom n) wrap _ [arg] = return $ arg wrap rty args = case rty of - ProdTy tys -> + TyCon (ProdType tys) -> if nargs == ntys - then return $ ProdVal args - else ProdVal . (curArgs ++) . (:[]) <$> wrap (last tys) remArgs + then return $ Con $ ProdCon args + else Con . ProdCon . (curArgs ++) . (:[]) <$> wrap (last tys) remArgs where nargs = length args; ntys = length tys (curArgs, remArgs) = splitAt (ntys - 1) args - DepPairTy dpt@(DepPairType _ b rty') -> do + TyCon (DepPairTy dpt@(DepPairType _ b rty')) -> do rty'' <- applySubst (b@>SubstVal h) rty' ans <- wrap rty'' t - return $ DepPair h ans dpt + return $ toAtom $ DepPair h ans dpt where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty @@ -871,7 +890,7 @@ buildConstraints -> InfererM i o (ConstrainedBinders o) buildConstraints ab cont = liftEnvReaderM do refreshAbs ab \bs e -> do - cs <- cont (Var <$> bindersVars bs) e + cs <- cont (toAtom <$> bindersVars bs) e return (getDependence (Abs bs e), Abs bs $ ListE cs) where getDependence :: HasNamesE e => Abs (Nest CBinder) e n -> [IsDependent] @@ -934,7 +953,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs then do let desc = (fSourceName, "_") withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> - cont (Var v) (argsRest, namedArgs) + cont (toAtom v) (argsRest, namedArgs) else do arg' <- checkOrInferExplicitArg isDependent arg argTy withDistinct $ cont arg' (argsRest, namedArgs) @@ -945,7 +964,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs arg' <- checkOrInferExplicitArg isDependent arg argTy withDistinct $ cont arg' args Nothing -> case infMech of - Unify -> withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (Var v) args + Unify -> withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) args Synth _ -> withDict argTy \d -> cont d args checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) @@ -955,7 +974,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs True -> checkExplicitDependentArg arg partialTy False -> checkExplicitNonDependentArg arg partialTy Nothing -> inferExplicitArg arg - constrainTypesEq argTy (getType arg') + constrainEq argTy (getType arg') return arg' lookupNamedArg :: MixedArgs x -> Maybe SourceName -> Maybe x @@ -964,7 +983,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs withoutInfVarsPartial :: CType n -> InfererM i n (Maybe (PartialType n)) withoutInfVarsPartial = \case - Pi piTy -> + TyCon (Pi piTy) -> withoutInfVars piTy >>= \case Just piTy' -> return $ Just $ PartialType $ piAsPartialPi piTy' Nothing -> withoutInfVars $ PartialType $ piAsPartialPiDropResultTy piTy @@ -1000,13 +1019,21 @@ inferPrimArg x = do matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case - UNat -> \case ~[] -> return $ Type $ NewtypeTyCon Nat - UFin -> \case ~[n] -> return $ Type $ NewtypeTyCon (Fin n) - UEffectRowKind -> \case ~[] -> return $ Type $ NewtypeTyCon EffectRowKind - UBaseType b -> \case ~[] -> return $ Type $ TC $ BaseType b - UNatCon -> \case ~[x] -> return $ NewtypeCon NatCon x - UPrimTC op -> \x -> Type . TC <$> matchGenericOp (Right op) x - UCon op -> \x -> Con <$> matchGenericOp (Right op) x + UNat -> \case ~[] -> return $ toAtom $ NewtypeTyCon Nat + UFin -> \case ~[n] -> return $ toAtom $ NewtypeTyCon (Fin n) + UEffectRowKind -> \case ~[] -> return $ toAtom $ NewtypeTyCon EffectRowKind + UBaseType b -> \case ~[] -> return $ toAtomR $ BaseType b + UNatCon -> \case ~[x] -> return $ toAtom $ NewtypeCon NatCon x + UPrimTC tc -> case tc of + P.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts + P.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts + P.RefType -> \case ~[h, a] -> return $ toAtom $ RefType h (fromJust $ toMaybeType a) + P.TypeKind -> \case ~[] -> return $ Con $ TyConAtom $ TypeKind + P.HeapType -> \case ~[] -> return $ Con $ TyConAtom $ HeapType + UCon con -> case con of + P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs + P.HeapVal -> \case ~[] -> return $ toAtom HeapVal + P.SumCon _ -> error "not supported" UMiscOp op -> \x -> emitExpr =<< MiscOp <$> matchGenericOp op x UMemOp op -> \x -> emitExpr =<< MemOp <$> matchGenericOp op x UBinOp op -> \case ~[x, y] -> emitExpr $ BinOp op x y @@ -1015,7 +1042,7 @@ matchPrimApp = \case UMGet -> \case ~[r] -> emitExpr $ RefOp r MGet UMPut -> \case ~[r, x] -> emitExpr $ RefOp r $ MPut x UIndexRef -> \case ~[r, i] -> indexRef r i - UApplyMethod i -> \case ~(d:args) -> emitExpr =<< mkApplyMethod d i args + UApplyMethod i -> \case ~(d:args) -> emitExpr =<< mkApplyMethod (fromJust $ toMaybeDict d) i args ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x URunReader -> \case ~[x, f] -> do f' <- lam2 f; emitHof $ RunReader x f' @@ -1051,13 +1078,13 @@ matchPrimApp = \case (tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do case getType x of TyKind -> do - Type x' <- return x + Just x' <- return $ toMaybeType x return $ Left x' _ -> return $ Right x return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs [] pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n -pattern ExplicitCoreLam bs body <- Lam (CoreLamExpr _ (LamExpr bs body)) +pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body))) -- === n-ary applications === @@ -1065,12 +1092,12 @@ inferTabApp :: Emits o => SrcPosCtx -> CAtom o -> [UExpr i] -> InfererM i o (CAt inferTabApp tabCtx tab args = addSrcContext tabCtx do tabTy <- return $ getType tab args' <- inferNaryTabAppArgs tabTy args - emitExpr =<< mkTabApp tab args' + naryTabApp tab args' inferNaryTabAppArgs :: Emits o => CType o -> [UExpr i] -> InfererM i o [CAtom o] inferNaryTabAppArgs _ [] = return [] inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of - TabPi (TabPiType _ b resultTy) -> do + TyCon (TabPi (TabPiType _ b resultTy)) -> do let ixTy = binderType b let isDependent = binderName b `isFreeIn` resultTy arg' <- if isDependent @@ -1099,20 +1126,13 @@ instance SinkableE IndexedAlt where buildNthOrderedAlt :: (Emits n, Builder CoreIR m) => [IndexedAlt n] -> CType n -> CType n -> Int -> CAtom n -> m n (CAtom n) -buildNthOrderedAlt alts scrutTy resultTy i v = do - case lookup (nthCaseAltIdx scrutTy i) [(idx, alt) | IndexedAlt idx alt <- alts] of +buildNthOrderedAlt alts _ resultTy i v = do + case lookup i [(idx, alt) | IndexedAlt idx alt <- alts] of Nothing -> do resultTy' <- sinkM resultTy emitExpr $ ThrowError resultTy' Just alt -> applyAbs alt (SubstVal v) >>= emitExpr --- converts from the ordinal index used in the core IR to the more complicated --- `CaseAltIndex` used in the surface IR. -nthCaseAltIdx :: CType n -> Int -> CaseAltIndex -nthCaseAltIdx ty i = case ty of - TypeCon _ _ _ -> i - _ -> error $ "can't pattern-match on: " <> pprint ty - buildMonomorphicCase :: (Emits n, ScopableBuilder CoreIR m) => [IndexedAlt n] -> CAtom n -> CType n -> m n (CAtom n) @@ -1130,7 +1150,7 @@ buildSortedCase :: (Fallible1 m, Builder CoreIR m, Emits n) buildSortedCase scrut alts resultTy = do scrutTy <- return $ getType scrut case scrutTy of - TypeCon _ defName _ -> do + TyCon (NewtypeTyCon (UserADTType _ defName _)) -> do TyConDef _ _ _ (ADTCons cons) <- lookupTyCon defName case cons of [] -> error "case of void?" @@ -1139,7 +1159,9 @@ buildSortedCase scrut alts resultTy = do let [IndexedAlt _ alt] = alts scrut' <- unwrapNewtype scrut emitExpr =<< applyAbs alt (SubstVal scrut') - _ -> liftEmitBuilder $ buildMonomorphicCase alts scrut resultTy + _ -> do + scrut' <- unwrapNewtype scrut + liftEmitBuilder $ buildMonomorphicCase alts scrut' resultTy _ -> fail $ "Unexpected case expression type: " <> pprint scrutTy -- TODO: cache this with the instance def (requires a recursive binding) @@ -1148,10 +1170,10 @@ instanceFun instanceName appExpl = do InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' - result <- DictCon <$> mkInstanceDict (sink instanceName) (Var <$> args) + result <- toAtom <$> mkInstanceDict (sink instanceName) (toAtom <$> args) let effTy = EffTy Pure (getType result) let piTy = CorePiType appExpl (snd<$>expls) bs' effTy - return $ Lam $ CoreLamExpr piTy (LamExpr bs' $ Atom result) + return $ toAtom $ CoreLamExpr piTy (LamExpr bs' $ Atom result) checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) checkMaybeAnnExpr ty expr = confuseGHC >>= \_ -> case ty of @@ -1184,7 +1206,7 @@ inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do withFreshBindersInf expls (Abs paramBs UnitE) \paramBs' UnitE -> do let paramVs = bindersVars paramBs' extendRenamer (uparamBs @@> (atomVarName <$> paramVs)) do - let selfTy = NewtypeTyCon $ UserADTType sn (sink tc) (TyConParams expls (Var <$> paramVs)) + let selfTy = toType $ UserADTType sn (sink tc) (TyConParams expls (toAtom <$> paramVs)) withFreshBinderInf "self" Explicit selfTy \selfB' -> do lam' <- extendRenamer (selfB @> binderName selfB') $ inferULam lam return $ prependCoreLamExpr (expls ++ [Explicit]) (paramBs' >>> UnaryNest selfB') lam' @@ -1213,7 +1235,7 @@ dataConRepTy (Abs topBs UnitE) = case topBs of Empty -> case revAcc of [] -> error "should never happen" [ty] -> (ty, [projIdxs]) - _ -> ( ProdTy $ reverse revAcc + _ -> ( toType $ ProdType $ reverse revAcc , iota (length revAcc) <&> \i -> ProjectProduct i:projIdxs ) Nest b bs -> case hoist b (EmptyAbs bs) of HoistSuccess (Abs bs' UnitE) -> go (binderType b:revAcc) projIdxs bs' @@ -1222,11 +1244,11 @@ dataConRepTy (Abs topBs UnitE) = case topBs of accSize = length revAcc (fullTy, depTyIdxs) = case revAcc of [] -> (depTy, []) - _ -> (ProdTy $ reverse revAcc ++ [depTy], [ProjectProduct accSize]) + _ -> (toType $ ProdType $ reverse revAcc ++ [depTy], [ProjectProduct accSize]) (tailTy, tailIdxs) = go [] (ProjectProduct 1 : (depTyIdxs ++ projIdxs)) bs idxs = (iota accSize <&> \i -> ProjectProduct i : projIdxs) ++ ((ProjectProduct 0 : (depTyIdxs ++ projIdxs)) : tailIdxs) - depTy = DepPairTy $ DepPairType ExplicitDepPair b tailTy + depTy = toType $ DepPairTy $ DepPairType ExplicitDepPair b tailTy inferClassDef :: SourceName -> [SourceName] -> Nest UAnnBinder i i' -> [UType i'] @@ -1238,7 +1260,7 @@ inferClassDef className methodNames paramBs methodTys = do _ -> Just $ Just $ getSourceName b methodTys' <- forM methodTys \m -> do checkUType m >>= \case - Pi t -> return t + TyCon (Pi t) -> return t t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) PairB paramBs'' superclassBs <- partitionBinders (zipAttrs roleExpls paramBs') $ \b@(WithAttrB (_, expl) b') -> case expl of @@ -1246,7 +1268,13 @@ inferClassDef className methodNames paramBs methodTys = do Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" Inferred _ (Synth _) -> return $ RightB b' let (roleExpls', paramBs''') = unzipAttrs paramBs'' - return $ ClassDef className methodNames paramNames roleExpls' paramBs''' superclassBs methodTys' + builtinName <- case className of + -- TODO: this is hacky. Let's just make the Ix class, including its + -- methods, fully built-in instead of prelude-defined. + "Ix" -> return $ Just Ix + "Data" -> return $ Just Data + _ -> return Nothing + return $ ClassDef className builtinName methodNames paramNames roleExpls' paramBs''' superclassBs methodTys' withUBinder :: UAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a withUBinder (UAnnBinder expl b ann cs) cont = do @@ -1305,7 +1333,7 @@ inferAnn ann cs = case ann of UNoAnn -> case cs of WithSrcE _ (UVar ~(InternalName _ _ v)):_ -> do renameM v >>= getUVarType >>= \case - Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _) -> return ty + TyCon (Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _)) -> return ty ty -> throw TypeErr $ "Constraint should be a unary function. Got: " ++ pprint ty _ -> throw TypeErr "Type annotation or constraint required" @@ -1393,7 +1421,7 @@ piAsPartialPi (CorePiType appExpl expls bs (EffTy effs ty)) = PartialPiType appExpl expls bs effs (Check ty) typeAsPartialType :: CType n -> PartialType n -typeAsPartialType (Pi piTy) = PartialType $ piAsPartialPi piTy +typeAsPartialType (TyCon (Pi piTy)) = PartialType $ piAsPartialPi piTy typeAsPartialType ty = FullType ty piAsPartialPiDropResultTy :: CorePiType n -> PartialPiType n @@ -1415,7 +1443,7 @@ checkInstanceBody :: ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do - ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className + ClassDef _ _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys superclassTys <- superclassDictTys scBs' superclassDicts <- mapM (flip trySynthTerm Full) superclassTys @@ -1439,9 +1467,9 @@ checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do - ClassBinding (ClassDef classSourceName _ _ _ _ _ _) <- lookupEnv className - throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint classSourceName - (i,) <$> Lam <$> checkULam rhs (methodTys !! i) + ClassBinding classDef <- lookupEnv className + throw TypeErr $ pprint sourceName ++ " is not a method of " ++ getSourceName classDef + (i,) <$> toAtom <$> Lam <$> checkULam rhs (methodTys !! i) checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) checkUEffRow (UEffectRow effs t) = do @@ -1458,8 +1486,8 @@ checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) checkUEff eff = case eff of URWSEffect rws (~(SIInternalName _ region _ _)) -> do region' <- renameM region >>= toAtomVar - expectEq (TC HeapType) (getType region') - return $ RWSEffect rws (Var region') + expectEq (TyCon HeapType) (getType region') + return $ RWSEffect rws (toAtom region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect @@ -1497,7 +1525,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do buildBlock do args <- forM idxs \projs -> do - ans <- applyProjectionsReduced (init projs) (sink $ Var $ binderVar b) + ans <- applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) emit $ Atom ans bindLetPats ps args $ cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" @@ -1510,7 +1538,7 @@ inferParams ty dataDefName = do Explicit -> Inferred Nothing Unify expl -> expl paramBsAbs <- buildConstraints (Abs paramBs UnitE) \params _ -> do - let ty' = TypeCon sourceName (sink dataDefName) $ TyConParams paramExpls params + let ty' = toType $ UserADTType sourceName (sink dataDefName) $ TyConParams paramExpls params return [TypeConstraint (sink ty) ty'] args <- inferMixedArgs sourceName inferenceExpls paramBsAbs emptyMixedArgs return $ TyConParams paramExpls args @@ -1534,19 +1562,19 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of UPatProd ps -> do let n = nestLength ps case getType v of - ProdTy ts | length ts == n -> return () + TyCon (ProdType ts) | length ts == n -> return () ty -> throw TypeErr $ "Expected a product type but got: " ++ pprint ty - xs <- forM (iota n) \i -> proj i (Var v) >>= emitInline + xs <- forM (iota n) \i -> proj i (toAtom v) >>= emitInline bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do case getType v of - DepPairTy _ -> return () + TyCon (DepPairTy _) -> return () ty -> throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty -- XXX: we're careful here to reduce the projection because of the dependent -- types. We do the same in the `UPatCon` case. - x1 <- reduceProj 0 (Var v) >>= emitInline + x1 <- reduceProj 0 (toAtom v) >>= emitInline bindLetPat p1 x1 do - x2 <- getSnd (sink $ Var v) >>= emitInline + x2 <- getSnd (sink $ toAtom v) >>= emitInline bindLetPat p2 x2 do cont UPatCon ~(InternalName _ _ conName) ps -> do @@ -1557,22 +1585,22 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of when (length idxss /= nestLength ps) $ throw TypeErr $ "Unexpected number of pattern binders. Expected " ++ show (length idxss) ++ " got " ++ show (nestLength ps) - void $ inferParams (getType $ Var v) dataDefName - xs <- forM idxss \idxs -> applyProjectionsReduced idxs (Var v) >>= emitInline + void $ inferParams (getType $ toAtom v) dataDefName + xs <- forM idxss \idxs -> applyProjectionsReduced idxs (toAtom v) >>= emitInline bindLetPats ps xs cont _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" UPatTable ps -> do let n = fromIntegral (nestLength ps) :: Word32 case getType v of - TabPi (TabPiType _ (_:>FinConst n') _) | n == n' -> return () + TyCon (TabPi (TabPiType _ (_:>FinConst n') _)) | n == n' -> return () ty -> throw TypeErr $ "Expected a Fin " ++ show n ++ " table type but got: " ++ pprint ty xs <- forM [0 .. n - 1] \i -> do - emit =<< mkTabApp (Var v) [NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)] + emit =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont checkUType :: UType i -> InfererM i o (CType o) checkUType t = do - Type t' <- checkUParam TyKind t + Just t' <- toMaybeType <$> checkUParam TyKind t return t' checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) @@ -1590,8 +1618,8 @@ inferTabCon xs = do ixTy <- asIxType finTy let tabTy = ixTy ==> elemTy xs' <- forM xs \x -> topDown elemTy x - dTy <- DictTy <$> dataDictType elemTy - dataDict <- trySynthTerm dTy Full + let dTy = toType $ DataDictType elemTy + Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full emitExpr $ TabCon (Just $ WhenIRE dataDict) tabTy xs' checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> [UExpr i] -> InfererM i o (CAtom o) @@ -1600,18 +1628,18 @@ checkTabCon tabTy@(TabPiType _ b elemTy) xs = do let finTy = FinConst n expectEq (binderType b) finTy xs' <- forM (enumerate xs) \(i, x) -> do - let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o + let i' = toAtom (NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) :: CAtom o elemTy' <- applySubst (b@>SubstVal i') elemTy topDown elemTy' x dTy <- case hoist b elemTy of - HoistSuccess elemTy' -> DictTy <$> dataDictType elemTy' + HoistSuccess elemTy' -> return $ toType $ DataDictType elemTy' HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do withFreshBinder noHint finTy \b' -> do elemTy' <- applyRename (b@>binderName b') elemTy - dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) - dataDict <- trySynthTerm dTy Full - emitExpr $ TabCon (Just $ WhenIRE dataDict) (TabPi tabTy) xs' + let dTy = toType $ DataDictType elemTy' + return $ toType $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) + Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full + emitExpr $ TabCon (Just $ WhenIRE dataDict) (TyCon (TabPi tabTy)) xs' addEffects :: EffectRow CoreIR o -> InfererM i o () addEffects Pure = return () @@ -1622,9 +1650,7 @@ addEffects eff = do Failure _ -> expectEq (Eff effsAllowed) (Eff eff) getIxDict :: CType o -> InfererM i o (IxDict CoreIR o) -getIxDict t = do - dictTy <- DictTy <$> ixDictType t - IxDictAtom <$> trySynthTerm dictTy Full +getIxDict t = fromJust <$> toMaybeDict <$> trySynthTerm (toType $ IxDictType t) Full asIxType :: CType o -> InfererM i o (IxType CoreIR o) asIxType ty = IxType ty <$> getIxDict ty @@ -1643,7 +1669,7 @@ lookupSolverSubst (SolverSubst m) name = applyConstraint :: Constraint o -> SolverM i o () applyConstraint = \case - TypeConstraint t1 t2 -> constrainTypesEq t1 t2 + TypeConstraint t1 t2 -> constrainEq t1 t2 EffectConstraint r1 r2' -> do -- r1 shouldn't have inference variables. And we can't infer anything about -- any inference variables in r2's explicit effects because we don't know @@ -1653,21 +1679,18 @@ applyConstraint = \case "\nRequested effects: " ++ pprint r2 case checkExtends r1 r2 of Success () -> return () - Failure _ -> addContext msg $ searchFailureAsTypeErr do + Failure _ -> searchFailureAsTypeErr msg do EffectRow effs1 t1 <- return r1 EffectRow effs2 (EffectRowTail v2) <- return r2 guard =<< isUnificationName (atomVarName v2) guard $ null (eSetToList $ effs2 `eSetDifference` effs1) let extras1 = effs1 `eSetDifference` effs2 - extendSolution (atomVarName v2) (Eff $ EffectRow extras1 t1) - -constrainTypesEq :: CType o -> CType o -> SolverM i o () -constrainTypesEq t1 t2 = constrainEq (Type t1) (Type t2) -- TODO: use a type class instead? + extendSolution v2 (toAtom $ EffectRow extras1 t1) -constrainEq :: CAtom o -> CAtom o -> SolverM i o () +constrainEq :: ToAtom e CoreIR => e o -> e o -> SolverM i o () constrainEq t1 t2 = do - t1' <- zonk t1 - t2' <- zonk t2 + t1' <- zonk $ toAtom t1 + t2' <- zonk $ toAtom t2 msg <- liftEnvReaderM $ do ab <- renameForPrinting $ PairE t1' t2' return $ canonicalizeForPrinting ab \(Abs infVars (PairE t1Pretty t2Pretty)) -> @@ -1676,75 +1699,153 @@ constrainEq t1 t2 = do ++ (case infVars of Empty -> "" _ -> "\n(Solving for: " ++ pprint (nestToList pprint infVars) ++ ")") - void $ addContext msg $ unify t1' t2' - -class (AlphaEqE e, Zonkable e) => Unifiable (e::E) where - unifyZonked :: e n -> e n -> SolverM i n () - -unify :: Unifiable e => e n -> e n -> SolverM i n () -unify e1 e2 = do - e1' <- zonk e1 - e2' <- zonk e2 - searchFailureAsTypeErr $ unifyZonked e1' e2' -{-# INLINE unify #-} -{-# SCC unify #-} + void $ searchFailureAsTypeErr msg $ unify t1' t2' -searchFailureAsTypeErr :: SolverM i n a -> SolverM i n a -searchFailureAsTypeErr cont = cont <|> throw TypeErr "" +searchFailureAsTypeErr :: String -> SolverM i n a -> SolverM i n a +searchFailureAsTypeErr msg cont = cont <|> throw TypeErr msg {-# INLINE searchFailureAsTypeErr #-} +class AlphaEqE e => Unifiable (e::E) where + unify :: e n -> e n -> SolverM i n () + +instance Unifiable (Stuck CoreIR) where + unify s1 s2 = do + x1 <- zonkStuck s1 + x2 <- zonkStuck s2 + case (x1, x2) of + (Con c, Con c') -> unify c c' + (Stuck _ s, Stuck _ s') -> unifyStuckZonked s s' + (Stuck _ s, Con c) -> unifyStuckConZonked s c + (Con c, Stuck _ s) -> unifyStuckConZonked s c + +-- assumes both `CStuck` args are zonked +unifyStuckZonked :: CStuck n -> CStuck n -> SolverM i n () +unifyStuckZonked s1 s2 = do + x1 <- mkStuck s1 + x2 <- mkStuck s2 + case (s1, s2) of + (Var v1, Var v2) -> do + if atomVarName v1 == atomVarName v2 + then return () + else extendSolution v2 x1 <|> extendSolution v1 x2 + (_, Var v2) -> extendSolution v2 x1 + (Var v1, _) -> extendSolution v1 x2 + (_, _) -> unifyEq s1 s2 + +unifyStuckConZonked :: CStuck n -> Con CoreIR n -> SolverM i n () +unifyStuckConZonked s x = case s of + Var v -> extendSolution v (Con x) + _ -> empty + +unifyStuckCon :: CStuck n -> Con CoreIR n -> SolverM i n () +unifyStuckCon s con = do + x <- zonkStuck s + case x of + Stuck _ s' -> unifyStuckConZonked s' con + Con con' -> unify con' con + instance Unifiable (Atom CoreIR) where - unifyZonked e1 e2 = confuseGHC >>= \_ -> case sameConstructor e1 e2 of - False -> case (e1, e2) of - (t, Var (AtomVar v _)) -> extendSolution v t - (Var (AtomVar v _), t) -> extendSolution v t - _ -> empty - True -> case (e1, e2) of - (Var (AtomVar v' _), Var (AtomVar v _)) -> - if v == v' then return () else extendSolution v e1 <|> extendSolution v' e2 - (Eff eff, Eff eff') -> unify eff eff' - (Type t, Type t') -> case (t, t') of - (Pi piTy, Pi piTy') -> unify piTy piTy' - (TabPi piTy, TabPi piTy') -> unifyTabPiType piTy piTy' - (TC con, TC con') -> do - GenericOpRep name ts xs [] <- return $ fromEGenericOpRep con - GenericOpRep name' ts' xs' [] <- return $ fromEGenericOpRep con' - guard $ name == name' && length ts == length ts' && length xs == length xs' - zipWithM_ unify (Type <$> ts) (Type <$> ts') - zipWithM_ unify xs xs' - (DictTy d, DictTy d') -> unify d d' - (NewtypeTyCon con, NewtypeTyCon con') -> unify con con' - _ -> unifyEq t t' - _ -> unifyEq e1 e2 + unify (Con c) (Con c') = unify c c' + unify (Stuck _ s) (Stuck _ s') = unify s s' + unify (Stuck _ s) (Con c) = unifyStuckCon s c + unify (Con c) (Stuck _ s) = unifyStuckCon s c + +-- TODO: do this directly rather than going via `CAtom` using `Type`. We just +-- need to deal with `Stuck`. +instance Unifiable (Type CoreIR) where + unify (TyCon c) (TyCon c') = unify c c' + unify (StuckTy _ s) (StuckTy _ s') = unify s s' + unify (StuckTy _ s) (TyCon c) = unifyStuckCon s (TyConAtom c) + unify (TyCon c) (StuckTy _ s) = unifyStuckCon s (TyConAtom c) + +instance Unifiable (Con CoreIR) where + unify e1 e2 = case e1 of + ( Lit x ) -> do + { Lit x' <- matchit; guard (x == x')} + ( ProdCon xs ) -> do + { ProdCon xs' <- matchit; unifyLists xs xs'} + ( SumCon ts i x ) -> do + { SumCon ts' i' x' <- matchit; unifyLists ts ts'; guard (i==i'); unify x x'} + ( DepPair t x y ) -> do + { DepPair t' x' y' <- matchit; unify t t'; unify x x'; unify y y'} + ( HeapVal ) -> do + { HeapVal <- matchit; return ()} + ( Eff eff ) -> do + { Eff eff' <- matchit; unify eff eff'} + ( Lam lam ) -> do + { Lam lam' <- matchit; unifyEq lam lam'} + ( NewtypeCon con x ) -> do + { NewtypeCon con' x' <- matchit; unifyEq con con'; unify x x'} + ( TyConAtom t ) -> do + { TyConAtom t' <- matchit; unify t t'} + ( DictConAtom d ) -> do + { DictConAtom d' <- matchit; unifyEq d d'} + where matchit = return e2 + +instance Unifiable (TyCon CoreIR) where + unify t1 t2 = case t1 of + ( BaseType b ) -> do + { BaseType b' <- matchit; guard $ b == b'} + ( HeapType ) -> do + { HeapType <- matchit; return () } + ( TypeKind ) -> do + { TypeKind <- matchit; return () } + ( Pi piTy ) -> do + { Pi piTy' <- matchit; unify piTy piTy'} + ( TabPi piTy) -> do + { TabPi piTy' <- matchit; unify piTy piTy'} + ( DictTy d ) -> do + { DictTy d' <- matchit; unify d d'} + ( NewtypeTyCon con ) -> do + { NewtypeTyCon con' <- matchit; unify con con'} + ( SumType ts ) -> do + { SumType ts' <- matchit; unifyLists ts ts'} + ( ProdType ts ) -> do + { ProdType ts' <- matchit; unifyLists ts ts'} + ( RefType h t ) -> do + { RefType h' t' <- matchit; unify h h'; unify t t'} + ( DepPairTy t ) -> do + { DepPairTy t' <- matchit; unify t t'} + where matchit = return t2 + +unifyLists :: Unifiable e => [e n] -> [e n] -> SolverM i n () +unifyLists [] [] = return () +unifyLists (x:xs) (y:ys) = unify x y >> unifyLists xs ys +unifyLists _ _ = empty instance Unifiable DictType where - unifyZonked (DictType _ c params) (DictType _ c' params') = - guard (c == c') >> zipWithM_ unify params params' - {-# INLINE unifyZonked #-} + unify d1 d2 = case d1 of + ( DictType _ c params )-> do + { DictType _ c' params' <- matchit; guard (c == c'); unifyLists params params'} + ( IxDictType t ) -> do + { IxDictType t' <- matchit; unify t t'} + ( DataDictType t ) -> do + { DataDictType t' <- matchit; unify t t'} + where matchit = return d2 + {-# INLINE unify #-} instance Unifiable NewtypeTyCon where - unifyZonked e1 e2 = case (e1, e2) of - (Nat, Nat) -> return () - (Fin n, Fin n') -> unify n n' - (EffectRowKind, EffectRowKind) -> return () - (UserADTType _ c params, UserADTType _ c' params') -> guard (c == c') >> unify params params' - _ -> empty + unify e1 e2 = case e1 of + ( Nat ) -> do + { Nat <- matchit; return ()} + ( Fin n ) -> do + { Fin n' <- matchit; unify n n'} + ( EffectRowKind ) -> do + { EffectRowKind <- matchit; return ()} + ( UserADTType _ c params ) -> do + { UserADTType _ c' params' <- matchit; guard (c == c') >> unify params params' } + where matchit = return e2 instance Unifiable (IxType CoreIR) where -- We ignore the dictionaries because we assume coherence - unifyZonked (IxType t _) (IxType t' _) = unifyZonked t t' - --- TODO: do this directly rather than going via `CAtom` using `Type`. We just --- need to deal with `TyVar`. -instance Unifiable (Type CoreIR) where - unifyZonked t t' = unifyZonked (Type t) (Type t') + unify (IxType t _) (IxType t' _) = unify t t' instance Unifiable TyConParams where -- We ignore the dictionaries because we assume coherence - unifyZonked ps ps' = zipWithM_ unify (ignoreSynthParams ps) (ignoreSynthParams ps') + unify ps ps' = zipWithM_ unify (ignoreSynthParams ps) (ignoreSynthParams ps') instance Unifiable (EffectRow CoreIR) where - unifyZonked x1 x2 = + unify x1 x2 = unifyDirect x1 x2 <|> unifyDirect x2 x1 <|> unifyZip x1 x2 @@ -1754,7 +1855,7 @@ instance Unifiable (EffectRow CoreIR) where unifyDirect r@(EffectRow effs' mv') (EffectRow effs (EffectRowTail v)) | null (eSetToList effs) = case mv' of EffectRowTail v' | v == v' -> guard $ null $ eSetToList effs' - _ -> extendSolution (atomVarName v) (Eff r) + _ -> extendSolution v (Con $ Eff r) unifyDirect _ _ = empty {-# INLINE unifyDirect #-} @@ -1783,7 +1884,7 @@ unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} instance Unifiable CorePiType where - unifyZonked (CorePiType appExpl1 expls1 bsTop1 effTy1) + unify (CorePiType appExpl1 expls1 bsTop1 effTy1) (CorePiType appExpl2 expls2 bsTop2 effTy2) = do unless (appExpl1 == appExpl2) empty unless (expls1 == expls2) empty @@ -1803,16 +1904,25 @@ instance Unifiable CorePiType where return UnitE go _ _ = empty -unifyTabPiType :: TabPiType CoreIR n -> TabPiType CoreIR n -> SolverM i n () -unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do - let ann1 = binderType b1 - let ann2 = binderType b2 - unify ann1 ann2 - void $ withFreshSkolemName ann1 \v -> do - ty1' <- applyRename (b1@>atomVarName v) ty1 - ty2' <- applyRename (b2@>atomVarName v) ty2 - unify ty1' ty2' - return UnitE +instance Unifiable (TabPiType CoreIR) where + unify (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = + unify (Abs b1 ty1) (Abs b2 ty2) + +instance Unifiable (DepPairType CoreIR) where + unify (DepPairType expl1 b1 rhs1) (DepPairType expl2 b2 rhs2) = do + guard $ expl1 == expl2 + unify (Abs b1 rhs1) (Abs b2 rhs2) + +instance Unifiable (Abs CBinder CType) where + unify (Abs b1 ty1) (Abs b2 ty2) = do + let ann1 = binderType b1 + let ann2 = binderType b2 + unify ann1 ann2 + void $ withFreshSkolemName ann1 \v -> do + ty1' <- applyRename (b1@>atomVarName v) ty1 + ty2' <- applyRename (b2@>atomVarName v) ty2 + unify ty1' ty2' + return UnitE withFreshSkolemName :: Zonkable e => Kind CoreIR o @@ -1828,8 +1938,8 @@ withFreshSkolemName ty cont = diffStateT1 \s -> do return (ans, diff') {-# INLINE withFreshSkolemName #-} -extendSolution :: CAtomName n -> CAtom n -> SolverM i n () -extendSolution v t = +extendSolution :: CAtomVar n -> CAtom n -> SolverM i n () +extendSolution (AtomVar v _) t = isUnificationName v >>= \case True -> do when (v `isFreeIn` t) $ throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) @@ -1884,7 +1994,7 @@ makeStructRepVal tyConName args = do [_] -> case args of [arg] -> return arg _ -> error "wrong number of args" - _ -> return $ ProdVal args + _ -> return $ Con $ ProdCon args -- === dictionary synthesis === @@ -1895,7 +2005,7 @@ makeStructRepVal tyConName args = do -- valid to implement `generalizeDict` by re-synthesizing the whole dictionary, -- but we know that the derivation tree has to be the same, so we take the -- shortcut of just generalizing the data parameters. -generalizeDict :: EnvReader m => CType n -> Dict n -> m n (Dict n) +generalizeDict :: EnvReader m => CType n -> CDict n -> m n (CDict n) generalizeDict ty dict = do result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict case result of @@ -1903,20 +2013,21 @@ generalizeDict ty dict = do ++ " to " ++ pprint ty ++ " because " ++ pprint e Success ans -> return ans -generalizeDictRec :: CType n -> Dict n -> InfererM i n (Dict n) +generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n) generalizeDictRec targetTy (DictCon dict) = case dict of InstanceDict _ instanceName args -> do InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do - d <- DictCon <$> mkInstanceDict (sink instanceName) args' - constrainEq (sink $ Type targetTy) (Type $ getType d) + d <- mkInstanceDict (sink instanceName) args' + constrainEq (sink $ toAtom targetTy) (toAtom $ getType d) return d - IxFin _ _ -> case targetTy of - DictTy (DictType "Ix" _ [Type (NewtypeTyCon (Fin n))]) -> DictCon <$> mkIxFin n - _ -> error $ "not an Ix(Fin _) dict: " ++ pprint targetTy - DataData _ _ -> case targetTy of - DictTy (DictType "Data" _ [Type t]) -> DictCon <$> mkDataData t - _ -> error "not a data dict" + IxFin _ -> do + TyCon (DictTy (IxDictType (TyCon (NewtypeTyCon (Fin n))))) <- return targetTy + return $ DictCon $ IxFin n + DataData _ -> do + TyCon (DictTy (DataDictType t')) <- return targetTy + return $ DictCon $ DataData t' + IxRawFin _ -> error "not a simplified dict" generalizeDictRec _ _ = error "not a simplified dict" generalizeInstanceArgs @@ -1941,9 +2052,12 @@ generalizeInstanceArg role ty arg cont = case role of -- that it's valid to implement `generalizeDict` by synthesizing an entirely -- fresh dictionary, and if we were to do that, we would infer this type -- parameter exactly as we do here, using inference. - TypeParam -> withFreshUnificationVarNoEmits MiscInfVar TyKind \v -> cont $ Var v - DictParam -> withFreshDictVarNoEmits ty (\ty' -> lift11 $ generalizeDictRec ty' (sink arg)) cont - DataParam -> withFreshUnificationVarNoEmits MiscInfVar ty \v -> cont $ Var v + TypeParam -> withFreshUnificationVarNoEmits MiscInfVar TyKind \v -> cont $ toAtom v + DictParam -> withFreshDictVarNoEmits ty ( + \ty' -> case toMaybeDict (sink arg) of + Just d -> liftM toAtom $ lift11 $ generalizeDictRec ty' d + _ -> error "not a dict") cont + DataParam -> withFreshUnificationVarNoEmits MiscInfVar ty \v -> cont $ toAtom v emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do @@ -2000,8 +2114,9 @@ getSynthType x = ignoreExcept $ typeAsSynthType (getType x) typeAsSynthType :: CType n -> Except (SynthType n) typeAsSynthType = \case - DictTy dictTy -> return $ SynthDictType dictTy - Pi (CorePiType ImplicitApp expls bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (expls, Abs bs d) + TyCon (DictTy dictTy) -> return $ SynthDictType dictTy + TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> + return $ SynthPiType (expls, Abs bs d) ty -> Failure $ Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty {-# SCC typeAsSynthType #-} @@ -2041,7 +2156,8 @@ getSuperclassClosurePure env givens newGivens = superclasses <- case synthTy of SynthPiType _ -> return [] SynthDictType dTy -> getSuperclassTys dTy - forM (enumerate superclasses) \(i, _) -> reduceSuperclassProj i synthExpr + forM (enumerate superclasses) \(i, _) -> do + reduceSuperclassProj i $ fromJust (toMaybeDict synthExpr) synthTerm :: SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of @@ -2051,16 +2167,16 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of Abs bs' synthExpr <- return ab' let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) let lamExpr = LamExpr bs' (Atom synthExpr) - return $ Lam $ CoreLamExpr piTy lamExpr + return $ toAtom $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of - DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon $ IxFin (DictTy dictTy) n - DictType "Data" _ [Type t] -> do + IxDictType (TyCon (NewtypeTyCon (Fin n))) -> return $ toAtom $ IxFin n + DataDictType t -> do void (synthDictForData dictTy <|> synthDictFromGiven dictTy) - return $ DictCon $ DataData (DictTy dictTy) t + return $ toAtom $ DataData t _ -> do dict <- synthDictFromInstance dictTy <|> synthDictFromGiven dictTy case dict of - DictCon (InstanceDict _ instanceName _) -> do + Con (DictConAtom (InstanceDict _ instanceName _)) -> do isReqMethodAccessAllowed <- reqMethodAccess `isMethodAccessAllowedBy` instanceName if isReqMethodAccessAllowed then return dict @@ -2072,7 +2188,7 @@ isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName isMethodAccessAllowedBy access instanceName = do InstanceDef className _ _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let numInstanceMethods = length methods - ClassDef _ _ _ _ _ _ methodTys <- lookupClassDef className + ClassDef _ _ _ _ _ _ _ methodTys <- lookupClassDef className let numClassMethods = length methodTys case access of Full -> return $ numClassMethods == numInstanceMethods @@ -2091,24 +2207,34 @@ synthDictFromGiven targetTy = do reduceInstantiateGiven given args synthDictFromInstance :: DictType n -> InfererM i n (SynthAtom n) -synthDictFromInstance targetTy@(DictType _ targetClass _) = do - instances <- getInstanceDicts targetClass +synthDictFromInstance targetTy = do + instances <- getInstanceDicts targetTy asum $ instances <&> \candidate -> typeErrAsSearchFailure do - CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate + CorePiType _ expls bs (EffTy _ (TyCon (DictTy candidateTy))) <- lookupInstanceTy candidate args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) - return $ DictCon $ InstanceDict (DictTy targetTy) candidate args - -getInstanceDicts :: EnvReader m => ClassName n -> m n [InstanceName n] -getInstanceDicts name = do - env <- withEnv moduleEnv - return $ M.findWithDefault [] name $ instanceDicts $ envSynthCandidates env -{-# INLINE getInstanceDicts #-} + return $ toAtom $ InstanceDict (toType targetTy) candidate args + +getInstanceDicts :: EnvReader m => DictType n -> m n [InstanceName n] +getInstanceDicts dictTy = do + env <- withEnv (envSynthCandidates . moduleEnv) + case dictTy of + DictType _ name _ -> return $ M.findWithDefault [] name $ instanceDicts env + IxDictType _ -> return $ ixInstances env + DataDictType _ -> return [] + +addInstanceSynthCandidate :: TopBuilder m => ClassName n -> Maybe BuiltinClassName -> InstanceName n -> m n () +addInstanceSynthCandidate className maybeBuiltin instanceName = do + sc <- return case maybeBuiltin of + Nothing -> mempty {instanceDicts = M.singleton className [instanceName] } + Just Ix -> mempty {ixInstances = [instanceName]} + Just Data -> mempty + emitLocalModuleEnv $ mempty {envSynthCandidates = sc} instantiateSynthArgs :: DictType n -> SynthPiType n -> InfererM i n [CAtom n] instantiateSynthArgs target (expls, synthPiTy) = do liftM fromListE $ withReducibleEmissions "dict args" do bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do - return [TypeConstraint (DictTy $ sink target) (DictTy resultTy)] + return [TypeConstraint (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] ListE <$> inferMixedArgs "dict" expls bsConstrained emptyMixedArgs emptyMixedArgs :: MixedArgs (CAtom n) @@ -2121,34 +2247,33 @@ typeErrAsSearchFailure cont = cont `catchErr` \err@(Err errTy _ _) -> do _ -> throwErr err synthDictForData :: forall i n. DictType n -> InfererM i n (SynthAtom n) -synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of +synthDictForData dictTy@(DataDictType ty) = case ty of -- TODO Deduplicate vs CheckType.checkDataLike - -- The "Var" case is different - TyVar _ -> synthDictFromGiven dictTy - TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l >> recurBinder (Abs b r) >> success - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType nt - recur ty' >> success - TC con -> case con of - BaseType _ -> success - ProdType as -> mapM_ recur as >> success - SumType cs -> mapM_ recur cs >> success - RefType _ _ -> success - HeapType -> success - _ -> notData - _ -> notData + -- The "Stuck" case is different + StuckTy _ _ -> synthDictFromGiven dictTy + TyCon con -> case con of + TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success + DepPairTy (DepPairType _ b@(_:>l) r) -> do + recur l >> recurBinder (Abs b r) >> success + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType nt + recur ty' >> success + BaseType _ -> success + ProdType as -> mapM_ recur as >> success + SumType cs -> mapM_ recur cs >> success + RefType _ _ -> success + HeapType -> success + _ -> notData where - recur ty' = synthDictForData $ DictType "Data" dName [Type ty'] + recur ty' = synthDictForData $ DataDictType ty' recurBinder :: Abs CBinder CType n -> InfererM i n (SynthAtom n) recurBinder (Abs b body) = withFreshBinderInf noHint Explicit (binderType b) \b' -> do body' <- applyRename (b@>binderName b') body - ans <- synthDictForData $ DictType "Data" (sink dName) [Type body'] + ans <- synthDictForData $ DataDictType (toType body') return $ ignoreHoistFailure $ hoist b' ans notData = empty - success = return $ DictCon $ DataData (DictTy dictTy) ty + success = return $ toAtom $ DataData ty synthDictForData dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy instance GenericE Givens where @@ -2182,7 +2307,7 @@ buildBlockInfWithRecon cont = do asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n)) asFFIFunType ty = return do - Pi piTy <- return ty + TyCon (Pi piTy) <- return ty impTy <- checkFFIFunTypeM piTy return (impTy, piTy) @@ -2250,7 +2375,7 @@ instance SinkableE SigmaAtom instance SubstE AtomSubstVal SigmaAtom where substE env (SigmaAtom sn x) = SigmaAtom sn $ substE env x substE env (SigmaUVar sn ty uvar) = case uvar of - UAtomVar v -> substE env $ SigmaAtom (Just sn) $ Var (AtomVar v ty) + UAtomVar v -> substE env $ SigmaAtom (Just sn) $ toAtom (AtomVar v ty) UTyConVar v -> SigmaUVar sn ty' $ UTyConVar $ substE env v UDataConVar v -> SigmaUVar sn ty' $ UDataConVar $ substE env v UPunVar v -> SigmaUVar sn ty' $ UPunVar $ substE env v diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 1a271a34f..da14eeb94 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -87,7 +87,7 @@ inlineDeclsSubst = \case -- See NoteSecretsSubtlety if presInfo == UsedOnce then do let substVal = case expr' of - Atom (Var name') -> Rename $ atomVarName name' + Atom (Stuck _ (Var name')) -> Rename $ atomVarName name' _ -> SubstVal (DoneEx expr') extendSubst (b @> substVal) $ inlineDeclsSubst rest else do @@ -108,7 +108,7 @@ inlineDeclsSubst = \case resolveWorkConservation NoInlineLet _ = NoInline -- Quick hack to always unconditionally inline renames, until we get -- a better story about measuring the sizes of atoms and expressions. - resolveWorkConservation (OccInfoPure _) (Atom (Var _)) = UsedOnce + resolveWorkConservation (OccInfoPure _) (Atom (Stuck _ (Var _))) = UsedOnce resolveWorkConservation (OccInfoPure (UsageInfo s (ixDepth, d))) expr | d <= One = case ixDepthExpr expr >= ixDepth of True -> if s <= One then UsedOnce else UsedMulti @@ -193,7 +193,7 @@ preInlineUnconditionally = \case -- instead of emitting the binding. data Context (from::E) (to::E) (o::S) where Stop :: Context e e o - TabAppCtx :: [SAtom i] -> Subst InlineSubstVal i o + TabAppCtx :: SAtom i -> Subst InlineSubstVal i o -> Context SExpr e o -> Context SExpr e o CaseCtx :: [SAlt i] -> SType i -> EffectRow SimpIR i -> Subst InlineSubstVal i o @@ -218,9 +218,9 @@ instance Emits o => Visitor (InlineM i o) SimpIR i o where inlineExpr :: Emits o => Context SExpr e o -> SExpr i -> InlineM i o (e o) inlineExpr ctx = \case Atom atom -> inlineAtom ctx atom - TabApp _ tbl ixs -> do + TabApp _ tbl ix -> do s <- getSubst - inlineAtom (TabAppCtx ixs s ctx) tbl + inlineAtom (TabAppCtx ix s ctx) tbl Case scrut alts (EffTy effs resultTy) -> do s <- getSubst inlineAtom (CaseCtx alts resultTy effs s ctx) scrut @@ -231,11 +231,22 @@ inlineExpr ctx = \case inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o) inlineAtom ctx = \case - Stuck (StuckVar name) -> inlineName ctx name - Stuck (StuckProject _ i x) -> do - ans <- proj i =<< inline Stop (Stuck x) + Stuck _ stuck -> inlineStuck ctx stuck + Con con -> (toExpr <$> visitGeneric con) >>= reconstruct ctx + +inlineStuck :: Emits o => Context SExpr e o -> SStuck i -> InlineM i o (e o) +inlineStuck ctx = \case + Var name -> inlineName ctx name + StuckProject i x -> do + ans <- proj i =<< emitExprToAtom =<< inlineStuck Stop x reconstruct ctx $ Atom ans - atom -> (Atom <$> visitAtomPartial atom) >>= reconstruct ctx + StuckTabApp _ _ -> error "not implemented" + PtrVar t p -> do + s <- mkStuck =<< (PtrVar t <$> substM p) + reconstruct ctx (toExpr s) + RepValAtom repVal -> do + s <- mkStuck =<< (RepValAtom <$> visitGeneric repVal) + reconstruct ctx (toExpr s) inlineName :: Emits o => Context SExpr e o -> SAtomVar i -> InlineM i o (e o) inlineName ctx name = @@ -250,7 +261,7 @@ inlineName ctx name = -- (expr', presInfo) | inline presInfo expr' ctx -> inline -- no info -> do not inline (as now) v <- toAtomVar name' - reconstruct ctx (Atom $ Var v) + reconstruct ctx (toExpr v) SubstVal (DoneEx expr) -> dropSubst $ inlineExpr ctx expr SubstVal (SuspEx expr s') -> withSubst s' $ inlineExpr ctx expr @@ -261,7 +272,7 @@ instance Inlinable SAtom where inline ctx a = inlineAtom (EmitToAtomCtx ctx) a instance Inlinable SType where - inline ctx ty = visitTypePartial ty >>= reconstruct ctx + inline ctx (TyCon ty) = (TyCon <$> visitGeneric ty) >>= reconstruct ctx instance Inlinable SLam where inline ctx (LamExpr bs body) = do @@ -291,7 +302,7 @@ instance Inlinable (PiType SimpIR) where reconstruct :: Emits o => Context e1 e2 o -> e1 o -> InlineM i o (e2 o) reconstruct ctx e = case ctx of Stop -> return e - TabAppCtx ixs s ctx' -> withSubst s $ reconstructTabApp ctx' e ixs + TabAppCtx ix s ctx' -> withSubst s $ reconstructTabApp ctx' e ix CaseCtx alts resultTy effs s ctx' -> withSubst s $ reconstructCase ctx' e alts resultTy effs EmitToAtomCtx ctx' -> emitExprToAtom e >>= reconstruct ctx' @@ -299,24 +310,17 @@ reconstruct ctx e = case ctx of {-# INLINE reconstruct #-} reconstructTabApp :: Emits o - => Context SExpr e o -> SExpr o -> [SAtom i] -> InlineM i o (e o) -reconstructTabApp ctx expr [] = do - reconstruct ctx expr -reconstructTabApp ctx expr ixs = - case fromNaryForExpr (length ixs) expr of - Just (bsCount, LamExpr bs body) -> do - -- See NoteReconstructTabAppDecisions - let (ixsPref, ixsRest) = splitAt bsCount ixs - ixsPref' <- mapM (inline $ EmitToNameCtx Stop) ixsPref - let ixsPref'' = [v | AtomVar v _ <- ixsPref'] - s <- getSubst - let moreSubst = bs @@> map Rename ixsPref'' - dropSubst $ extendSubst moreSubst do - inlineExpr (TabAppCtx ixsRest s ctx) body - Nothing -> do - array' <- emitExprToAtom expr - ixs' <- mapM (inline Stop) ixs - reconstruct ctx =<< mkTabApp array' ixs' + => Context SExpr e o -> SExpr o -> SAtom i -> InlineM i o (e o) +reconstructTabApp ctx expr i = case expr of + PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> do + -- See NoteReconstructTabAppDecisions + AtomVar i' _ <- inline (EmitToNameCtx Stop) i + dropSubst $ extendSubst (b@>Rename i') do + inlineExpr ctx body + _ -> do + array' <- emitExprToAtom expr + i' <- inline Stop i + reconstruct ctx =<< mkTabApp array' i' reconstructCase :: Emits o => Context SExpr e o -> SExpr o -> [SAlt i] -> SType i -> EffectRow SimpIR i @@ -340,12 +344,13 @@ reconstructCase ctx scrutExpr alts resultTy effs = -- context `ctx` into the selected alternative if the optimization fires, -- but leave it around the whole reconstructed `Case` if it doesn't. scrut <- emitExprToAtom scrutExpr - case trySelectBranch scrut of - Just (i, val) -> do + case scrut of + Con con -> do + SumCon _ i val <- return con Abs b body <- return $ alts !! i extendSubst (b @> (SubstVal $ DoneEx $ Atom val)) do inlineExpr ctx body - Nothing -> do + Stuck _ _ -> do alts' <- mapM visitAlt alts resultTy' <- inline Stop resultTy effs' <- inline Stop effs diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index e2e183955..e7b942f6e 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -66,10 +66,10 @@ simplifyJTy JArrayName{shape, dtype} = go shape $ simplifyDType dtype where simplifyDType :: DType -> Type r n simplifyDType = \case - F64 -> BaseTy $ P.Scalar P.Float64Type - F32 -> BaseTy $ P.Scalar P.Float32Type - I64 -> BaseTy $ P.Scalar P.Int64Type - I32 -> BaseTy $ P.Scalar P.Int32Type + F64 -> TyCon $ BaseType $ P.Scalar P.Float64Type + F32 -> TyCon $ BaseType $ P.Scalar P.Float32Type + I64 -> TyCon $ BaseType $ P.Scalar P.Int64Type + I32 -> TyCon $ BaseType $ P.Scalar P.Int32Type simplifyEqns :: Emits o => Nest JEqn i i' -> JaxSimpM i' o a -> JaxSimpM i o a simplifyEqns eqn cont = do @@ -104,7 +104,7 @@ simplifyAtom = \case SubstVal x -> return (x, ty) Rename nm' -> do nm'' <- toAtomVar nm' - return (Var nm'', ty) + return (toAtom nm'', ty) -- TODO In Jax, literals can presumably include (large) arrays. How should we -- represent them here? JLiteral (JLit {..}) -> return (Con (Lit (P.Float32Lit 0.0)), ty) @@ -124,5 +124,5 @@ unaryExpandRank op arg JArrayName{shape} = go arg shape where go arg' = \case [] -> emitExprToAtom $ PrimOp (UnOp op arg') (DimSize sz:rest) -> buildFor noHint P.Fwd (litFinIxTy sz) \i -> do - ixed <- mkTabApp (sink arg') [Var i] >>= emitExprToAtom + ixed <- mkTabApp (sink arg') (toAtom i) >>= emitExprToAtom go ixed rest diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 426615649..3f1347bad 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -8,7 +8,6 @@ module Linearize (linearize, linearizeTopLam) where import Control.Category ((>>>)) import Control.Monad.Reader -import Data.Foldable (toList) import Data.Functor import Data.List (elemIndex) import Data.Maybe (catMaybes, isJust) @@ -83,7 +82,7 @@ extendActivePrimalss vs = local \primals -> primals { activeVars = activeVars primals ++ vs } getTangentArg :: Int -> TangentM o (Atom SimpIR o) -getTangentArg idx = asks \(TangentArgs vs) -> Var $ vs !! idx +getTangentArg idx = asks \(TangentArgs vs) -> toAtom $ vs !! idx extendTangentArgs :: SAtomVar n -> TangentM n a -> TangentM n a extendTangentArgs v m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ [v]) m @@ -190,17 +189,17 @@ getTangentArgTys topVs = go mempty topVs where -- like this, but there's nothing to prevent users writing programs that -- sling around heap variables by themselves. We should try to do something -- better... - TC HeapType -> do - withFreshBinder (getNameHint v) (TC HeapType) \hb -> do + TyCon HeapType -> do + withFreshBinder (getNameHint v) (TyCon HeapType) \hb -> do let newHeapMap = sink heapMap <> eMapSingleton (sink (atomVarName v)) (binderVar hb) Abs bs UnitE <- go newHeapMap $ sinkList vs return $ EmptyAbs $ Nest hb bs - RefTy (Var h) referentTy -> do + RefTy (Stuck _ (Var h)) referentTy -> do case lookupEMap heapMap (atomVarName h) of Nothing -> error "shouldn't happen?" Just h' -> do tt <- tangentType referentTy - let refTy = RefTy (Var h') tt + let refTy = RefTy (toAtom h') tt withFreshBinder (getNameHint v) refTy \refb -> do Abs bs UnitE <- go (sink heapMap) $ sinkList vs return $ EmptyAbs $ Nest refb bs @@ -300,12 +299,12 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do let primalFun = LamExpr bs' body' ObligateRecon ty (Abs bsRecon (LamExpr bsTangent tangentBody)) <- return linLamAbs tangentFun <- withFreshBinder "residuals" ty \bResidual -> do - xs <- unpackTelescope bsRecon $ Var $ binderVar bResidual + xs <- unpackTelescope bsRecon $ toAtom $ binderVar bResidual Abs bsTangent' UnitE <- applySubst (bsRecon @@> map SubstVal xs) (Abs bsTangent UnitE) - tangentTy <- ProdTy <$> typesFromNonDepBinderNest bsTangent' + tangentTy <- TyCon <$> ProdType <$> typesFromNonDepBinderNest bsTangent' withFreshBinder "t" tangentTy \bTangent -> do tangentBody' <- buildBlock do - ts <- getUnpacked $ Var $ sink $ binderVar bTangent + ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts emitExpr =<< applySubst substFrag tangentBody @@ -325,21 +324,21 @@ linearizeLambdaApp (UnaryLamExpr b body) x = do linearizeLambdaApp _ _ = error "not implemented" linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom -linearizeAtom atom = case atom of - Con con -> linearizePrimCon con - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> emitZeroT - Stuck (StuckVar v) -> do +linearizeAtom (Con con) = linearizePrimCon con +linearizeAtom atom@(Stuck _ stuck) = case stuck of + PtrVar _ _ -> emitZeroT + Var v -> do v' <- renameM v activePrimalIdx v' >>= \case - Nothing -> withZeroT $ return (Var v') - Just idx -> return $ WithTangent (Var v') $ getTangentArg idx + Nothing -> withZeroT $ return (toAtom v') + Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes - Stuck (StuckProject ty i x) -> linearizeExpr $ Project ty i (Stuck x) - Stuck (StuckTabApp t f xs) -> linearizeExpr $ TabApp t (Stuck f) xs + StuckProject _ _ -> undefined + StuckTabApp _ _ -> undefined RepValAtom _ -> emitZeroT where emitZeroT = withZeroT $ renameM atom + linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 linearizeDecls Empty cont = cont -- TODO: as an optimization, don't bother extending the tangent args if the @@ -385,14 +384,12 @@ linearizeExpr expr = case expr of (ans, residuals) <- fromPair =<< naryTopApp fPrimal xs' return $ WithTangent ans do ts' <- forM (catMaybes ts) \(WithTangent UnitE t) -> t - naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals, ProdVal ts']) + naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals, Con $ ProdCon ts']) where unitLike :: e n -> UnitE n unitLike _ = UnitE - TabApp _ x idxs -> do - zipLin (linearizeAtom x) (pureLin $ ListE $ toList idxs) `bindLin` - \(PairE x' (ListE idxs')) -> naryTabApp x' idxs' - PrimOp op -> linearizeOp op + TabApp _ x i -> zipLin (linearizeAtom x) (pureLin i) `bindLin` \(PairE x' i') -> tabApp x' i' + PrimOp op -> linearizeOp op Case e alts (EffTy effs resultTy) -> do e' <- renameM e effs' <- renameM effs @@ -408,7 +405,7 @@ linearizeExpr expr = case expr of let tys = recons <&> \(ObligateRecon t _) -> t alts'' <- forM (enumerate alts') \(i, alt) -> do injectAltResult tys i alt - let fullResultTy = PairTy resultTy' $ SumTy tys + let fullResultTy = PairTy resultTy' $ TyCon $ SumType tys result <- emitExpr $ Case e' alts'' (EffTy effs' fullResultTy) (primal, residualss) <- fromPair result resultTangentType <- tangentType resultTy' @@ -434,15 +431,15 @@ linearizeOp op = case op of Hof (TypedHof _ e) -> linearizeHof e DAMOp _ -> error "shouldn't occur here" RefOp ref m -> case m of - MAsk -> linearizeAtom ref `bindLin` \ref' -> liftM Var $ emit $ PrimOp $ RefOp ref' MAsk + MAsk -> linearizeAtom ref `bindLin` \ref' -> liftM toAtom $ emit $ PrimOp $ RefOp ref' MAsk MExtend monoid x -> do -- TODO: check that we're dealing with a +/0 monoid monoid' <- renameM monoid zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM Var $ emit $ PrimOp $ RefOp ref' $ MExtend (sink monoid') x' - MGet -> linearizeAtom ref `bindLin` \ref' -> liftM Var $ emit $ PrimOp $ RefOp ref' MGet + liftM toAtom $ emit $ PrimOp $ RefOp ref' $ MExtend (sink monoid') x' + MGet -> linearizeAtom ref `bindLin` \ref' -> liftM toAtom $ emit $ PrimOp $ RefOp ref' MGet MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM Var $ emit $ PrimOp $ RefOp ref' $ MPut x' + liftM toAtom $ emit $ PrimOp $ RefOp ref' $ MPut x' IndexRef _ i -> do zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> emitExpr =<< mkIndexRef ref' i' @@ -454,7 +451,7 @@ linearizeOp op = case op of MiscOp miscOp -> linearizeMiscOp miscOp VectorOp _ -> error "not implemented" where - emitZeroT = withZeroT $ liftM Var $ emit =<< renameM (PrimOp op) + emitZeroT = withZeroT $ liftM toAtom $ emit =<< renameM (PrimOp op) la = linearizeAtom linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom @@ -489,7 +486,7 @@ linearizeMiscOp op = case op of ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" where - emitZeroT = withZeroT $ liftM Var $ emit =<< renameM (PrimOp $ MiscOp op) + emitZeroT = withZeroT $ liftM toAtom $ emit =<< renameM (PrimOp $ MiscOp op) la = linearizeAtom linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom @@ -566,22 +563,23 @@ linearizeBinOp op x' y' = do referToPrimal :: (Builder SimpIR m, Emits l, DExt n l) => SAtom n -> m l (SAtom l) referToPrimal x = do case x of - Var v -> lookupEnv (atomVarName $ sink v) >>= \case + Stuck _ (Var v) -> lookupEnv (atomVarName $ sink v) >>= \case AtomNameBinding (LetBound (DeclBinding PlainLet (Atom atom))) -> referToPrimal atom - AtomNameBinding (LetBound (DeclBinding PlainLet (TabApp _ tab is))) -> do + AtomNameBinding (LetBound (DeclBinding PlainLet (TabApp _ tab i))) -> do tab' <- referToPrimal tab - is' <- mapM referToPrimal is - emitExpr =<< mkTabApp tab' is' + i' <- referToPrimal i + emitExpr =<< mkTabApp tab' i' _ -> sinkM x _ -> sinkM x linearizePrimCon :: Emits o => Con SimpIR i -> LinM i o SAtom SAtom linearizePrimCon con = case con of Lit _ -> emitZeroT - ProdCon xs -> fmapLin (ProdVal . fromComposeE) $ seqLin (fmap linearizeAtom xs) + ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs) SumCon _ _ _ -> notImplemented HeapVal -> emitZeroT + DepPair _ _ _ -> notImplemented where emitZeroT = withZeroT $ renameM $ Con con linearizeHof :: Emits o => Hof SimpIR i -> LinM i o SAtom SAtom @@ -605,7 +603,7 @@ linearizeHof hof = case hof of Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs) withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@> Rename (atomVarName i')) do - residuals' <- tabApp (sink primalsAux) (Var i') >>= getSnd >>= unpackTelescope bs + residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs extendSubst (bs @@> (SubstVal <$> residuals')) $ applyLinLam linLam' RunReader r lam -> do @@ -658,14 +656,14 @@ linearizeHof hof = case hof of linearizeEffectFun :: RWS -> SLam i -> PrimalM i o (SLam o, LinLamAbs o) linearizeEffectFun rws (BinaryLamExpr hB refB body) = do - withFreshBinder noHint (TC HeapType) \h -> do + withFreshBinder noHint (TyCon HeapType) \h -> do bTy <- extendSubst (hB@>binderName h) $ renameM $ binderType refB withFreshBinder noHint bTy \b -> do let ref = binderVar b hVar <- sinkM $ binderVar h (body', linLam) <- extendActiveSubst hB hVar $ extendActiveSubst refB ref $ -- TODO: maybe we should check whether we need to extend the active effects - extendActiveEffs (RWSEffect rws (Var hVar)) do + extendActiveEffs (RWSEffect rws (toAtom hVar)) do linearizeExprDefunc body -- TODO: this assumes that references aren't returned. Our type system -- ensures that such references can never be *used* once the effect runner diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 99fd67ef7..6acf8eac5 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -69,7 +69,7 @@ lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftE EffTy _ resultTy <- instantiate (sink piTy) xs let resultDestTy = RawRefTy resultTy withFreshBinder "ans" resultDestTy \destBinder -> do - let dest = Var $ binderVar destBinder + let dest = toAtom $ binderVar destBinder LamExpr (bs' >>> UnaryNest destBinder) <$> buildBlock do lowerExpr (Just (sink dest)) body' $> UnitVal False -> LamExpr bs' <$> buildBlock (lowerExpr Nothing body') @@ -104,17 +104,17 @@ lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do case isSingletonType ansTy of True -> do body' <- buildUnaryLamExpr noHint (PairTy ty' UnitTy) \b' -> do - (i, _) <- fromPair $ Var b' + (i, _) <- fromPair $ toAtom b' extendSubst (ib @> SubstVal i) $ lowerExpr Nothing body $> UnitVal void $ emitSeq dir ixTy' UnitVal body' fromJust <$> singletonTypeVal ansTy False -> do - initDest <- ProdVal . (:[]) <$> case maybeDest of + initDest <- Con . ProdCon . (:[]) <$> case maybeDest of Just d -> return d Nothing -> emitExpr $ AllocDest ansTy let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do - (i, destProd) <- fromPair $ Var b' + (i, destProd) <- fromPair $ toAtom b' dest <- proj 0 destProd idest <- emitExpr =<< mkIndexRef dest i extendSubst (ib @> SubstVal i) $ lowerExpr (Just idest) body $> UnitVal @@ -124,12 +124,12 @@ lowerFor _ _ _ _ _ = error "expected a unary lambda expression" lowerTabCon :: Emits o => OptDest o -> SType i -> [SAtom i] -> LowerM i o (SAtom o) lowerTabCon maybeDest tabTy elems = do - TabPi tabTy' <- substM tabTy + TyCon (TabPi tabTy') <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ AllocDest $ TabPi tabTy' + Nothing -> emitExpr $ AllocDest $ TyCon $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do - buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord + buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ toAtom $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used -- linearly, and to force reads of the `Freeze dest'` result not to be -- reordered in front of the writes. @@ -141,7 +141,7 @@ lowerTabCon maybeDest tabTy elems = do i <- dropSubst $ extendSubst (bord@>SubstVal (IdxRepVal (fromIntegral ord))) $ lowerExpr Nothing ufoBlock carried_dest <- buildRememberDest "dest" incoming_dest \local_dest -> do - idest <- indexRef (Var local_dest) (sink i) + idest <- indexRef (toAtom local_dest) (sink i) place idest =<< visitAtom e return UnitVal go carried_dest rest @@ -163,7 +163,7 @@ lowerCase maybeDest scrut alts resultTy = do buildAbs (getNameHint b) ty' \b' -> extendSubst (b @> Rename (atomVarName b')) $ buildBlock do - lowerExpr (Just (Var $ sink $ local_dest)) body $> UnitVal + lowerExpr (Just (toAtom $ sink $ local_dest)) body $> UnitVal void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr return UnitVal emitExpr $ Freeze dest' @@ -202,7 +202,7 @@ lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- XXX: When adding more cases, be careful about potentially repeated vars in the output! decomposeDest :: Emits o => Dest o -> SExpr i' -> LowerM i o (Maybe (DestAssignment i' o)) decomposeDest dest = \case - Atom (Stuck (StuckVar v)) -> + Atom (Stuck _ (Var v)) -> return $ Just $ singletonNameMapE (atomVarName v) $ LiftE dest _ -> return Nothing @@ -243,7 +243,7 @@ lowerExpr dest expr = case expr of -- But we have to emit explicit writes, for all the vars that are not defined in decls! forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do x <- case s ! n of - Rename v -> Var <$> toAtomVar v + Rename v -> toAtom <$> toAtomVar v SubstVal a -> return a place d x withSubst s' (substM result) >>= emitExpr diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index d6c6f8a9d..bb14ca55c 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -317,7 +317,7 @@ runFallibleT1 m = {-# INLINE runFallibleT1 #-} instance Monad1 m => MonadFail (FallibleT1 m n) where - fail s = throw MonadFailErr s + fail s = throw SearchFailure s {-# INLINE fail #-} instance Monad1 m => Fallible (FallibleT1 m n) where diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 711374df1..fcf04cdf2 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -194,17 +194,18 @@ summaryExpr = \case summary :: SAtom n -> OCCM n (IxExpr n) summary atom = case atom of - Var v -> ixExpr $ atomVarName v - Con c -> constructor c - _ -> unknown atom + Stuck _ stuck -> case stuck of + Var v -> ixExpr $ atomVarName v + _ -> unknown atom + Con c -> case c of + -- TODO Represent the actual literal value? + Lit _ -> return $ Deterministic [] + ProdCon elts -> Product <$> mapM summary elts + SumCon _ tag payload -> Inject tag <$> summary payload + HeapVal -> invalid "HeapVal" + DepPair _ _ _ -> error "not implemented" where invalid tag = error $ "Unexpected indexing by " ++ tag - constructor = \case - -- TODO Represent the actual literal value? - Lit _ -> return $ Deterministic [] - ProdCon elts -> Product <$> mapM summary elts - SumCon _ tag payload -> Inject tag <$> summary payload - HeapVal -> invalid "HeapVal" unknown :: HoistableE e => e n -> OCCM n (IxExpr n) unknown _ = return IxAll @@ -245,24 +246,25 @@ class HasOCC (e::E) where instance HasOCC SAtom where occ a = \case - Stuck e -> Stuck <$> occ a e - atom -> runOCCMVisitor a $ visitAtomPartial atom + Stuck t e -> Stuck <$> occ a t <*> occ a e + Con con -> liftM Con $ runOCCMVisitor a $ visitGeneric con instance HasOCC SStuck where occ a = \case - StuckVar (AtomVar n ty) -> do + Var (AtomVar n ty) -> do modify (<> FV (singletonNameMapE n $ AccessInfo One a)) ty' <- occTy ty - return $ StuckVar (AtomVar n ty') - StuckProject t i x -> StuckProject <$> occ a t <*> pure i <*> occ a x - StuckTabApp t array ixs -> do - t' <- occTy t - (a', ixs') <- occIdxs a ixs + return $ Var (AtomVar n ty') + StuckProject i x -> StuckProject <$> pure i <*> occ a x + StuckTabApp array ixs -> do + (a', ixs') <- occIdx a ixs array' <- occ a' array - return $ StuckTabApp t' array' ixs' + return $ StuckTabApp array' ixs' + PtrVar t p -> return $ PtrVar t p + RepValAtom x -> return $ RepValAtom x instance HasOCC SType where - occ a ty = runOCCMVisitor a $ visitTypePartial ty + occ a (TyCon con) = liftM TyCon $ runOCCMVisitor a $ visitGeneric con -- TODO What, actually, is the right thing to do for type annotations? Do we -- want a rule like "we never inline into type annotations", or such? For @@ -363,11 +365,11 @@ instance HasOCC SExpr where effTy' <- occ a effTy Abs decls' ans' <- occNest a (Abs decls ans) return $ Block effTy' (Abs decls' ans') - TabApp t array ixs -> do + TabApp t array ix -> do t' <- occTy t - (a', ixs') <- occIdxs a ixs + (a', ix') <- occIdx a ix array' <- occ a' array - return $ TabApp t' array' ixs' + return $ TabApp t' array' ix' Case scrut alts (EffTy effs ty) -> do scrut' <- occ accessOnce scrut scrutIx <- summary scrut @@ -382,12 +384,10 @@ instance HasOCC SExpr where PrimOp . RefOp ref' <$> occ a op expr -> occGeneric a expr -occIdxs :: Access n -> [SAtom n] -> OCCM n (Access n, [SAtom n]) -occIdxs acc [] = return (acc, []) -occIdxs acc (ix:ixs) = do - (acc', ixs') <- occIdxs acc ixs - (summ, ix') <- occurrenceAndSummary ix - return (location summ acc', ix':ixs') +occIdx :: Access n -> SAtom n -> OCCM n (Access n, SAtom n) +occIdx acc ix = do + (summ, ix') <- occurrenceAndSummary ix + return (location summ acc, ix') -- Arguments: Usage of the return value, summary of the scrutinee, the -- alternative itself. diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 3c187d58c..dd1d0aaae 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -41,7 +41,7 @@ optimize = dceTop -- Clean up user code peepholeOp :: PrimOp SimpIR o -> EnvReaderM o (SExpr o) peepholeOp op = case op of - MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case foldCast sTy l of + MiscOp (CastOp (TyCon (BaseType (Scalar sTy))) (Con (Lit l))) -> return $ case foldCast sTy l of Just l' -> lit l' Nothing -> noop -- TODO: Support more unary and binary ops. @@ -69,9 +69,9 @@ peepholeOp op = case op of BinOp BAnd (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) -> return $ lit $ Word8Lit $ lv .&. rv MiscOp (ToEnum ty (Con (Lit (Word8Lit tag)))) -> case ty of - SumTy cases -> return $ Atom $ SumVal cases (fromIntegral tag) UnitVal + TyCon (SumType cases) -> return $ toExpr $ SumCon cases (fromIntegral tag) UnitVal _ -> error "Ill typed ToEnum?" - MiscOp (SumTag (SumVal _ tag _)) -> return $ lit $ Word8Lit $ fromIntegral tag + MiscOp (SumTag (Con (SumCon _ tag _))) -> return $ lit $ Word8Lit $ fromIntegral tag _ -> return noop where noop = PrimOp op @@ -187,7 +187,7 @@ foldCast sTy l = case sTy of peepholeExpr :: SExpr o -> EnvReaderM o (SExpr o) peepholeExpr expr = case expr of PrimOp op -> peepholeOp op - TabApp _ (Var (AtomVar t _)) [IdxRepVal ord] -> + TabApp _ (Stuck _ (Var (AtomVar t _))) (IdxRepVal ord) -> lookupAtomName t <&> \case LetBound (DeclBinding ann (TabCon Nothing tabTy elems)) | ann /= NoInlineLet && isFinTabTy tabTy-> @@ -202,7 +202,7 @@ peepholeExpr expr = case expr of -- Think, partial evaluation of threefry. _ -> return expr where isFinTabTy = \case - TabPi (TabPiType (IxDictRawFin _) _ _) -> True + TyCon (TabPi (TabPiType (DictCon (IxRawFin _)) _ _)) -> True _ -> False -- === Loop unrolling === @@ -242,7 +242,7 @@ ulExpr :: Emits o => SExpr i -> ULM i o (SAtom o) ulExpr expr = case expr of PrimOp (Hof (TypedHof _ (For Fwd ixTy body))) -> case ixTypeDict ixTy of - IxDictRawFin (IdxRepVal n) -> do + DictCon (IxRawFin (IdxRepVal n)) -> do (body', bodyCost) <- withLocalAccounting $ visitLamEmits body -- We add n (in the form of (... + 1) * n) for the cost of the TabCon reconstructing the result. case (bodyCost + 1) * (fromIntegral n) <= unrollBlowupThreshold of @@ -253,7 +253,7 @@ ulExpr expr = case expr of inc $ fromIntegral n -- To account for the TabCon we emit below getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do - let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy + let tabTy = toType $ TabPiType (DictCon $ IxRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy emitExpr $ TabCon Nothing tabTy vals _ -> error "Expected `for` body to have a Pi type" _ -> error "Expected `for` body to be a lambda expression" @@ -306,7 +306,7 @@ hoistLoopInvariant lam = liftLamExpr lam hoistLoopInvariantExpr licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o) licmExpr = \case - PrimOp (DAMOp (Seq _ dir ix (ProdVal dests) (LamExpr (UnaryNest b) body))) -> do + PrimOp (DAMOp (Seq _ dir ix (Con (ProdCon dests)) (LamExpr (UnaryNest b) body))) -> do ix' <- substM ix dests' <- mapM visitAtom dests let numCarriesOriginal = length dests' @@ -319,15 +319,15 @@ licmExpr = \case PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody extraDests' <- mapM toAtomVar extraDests -- Append the destinations of hoisted Allocs as loop carried values. - let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') + let dests'' = Con $ ProdCon $ dests' ++ (toAtom <$> extraDests') let carryTy = getType dests'' let lbTy = case ix' of IxType ixTy _ -> PairTy ixTy carryTy extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab body' <- withFreshBinder noHint lbTy \lb' -> do - (oldIx, allCarries) <- fromPairReduced $ Var $ binderVar lb' + (oldIx, allCarries) <- fromPairReduced $ toAtom $ binderVar lb' (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpackedReduced allCarries - let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) + let oldLoopBinderVal = Con $ ProdCon [oldIx, Con $ ProdCon oldCarries] let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal block <- mkBlock =<< applySubst s bodyAbs return $ UnaryLamExpr lb' block @@ -419,11 +419,13 @@ instance Color c => HasDCE (Name c) where dce n = modify (<> FV (freeVarsE n)) $> n instance HasDCE SAtom where - dce = \case - Stuck e -> modify (<> FV (freeVarsE e)) $> Stuck e - atom -> visitAtomPartial atom + dce atom = case atom of + Stuck _ _ -> modify (<> FV (freeVarsE atom)) $> atom + Con con -> Con <$> visitGeneric con + +instance HasDCE SType where + dce (TyCon e) = TyCon <$> visitGeneric e -instance HasDCE SType where dce = visitTypePartial instance HasDCE (PiType SimpIR) where dce (PiType bs effTy) = do dceBinders bs effTy \bs' effTy' -> PiType bs' <$> dce effTy' diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 968193a00..241894483 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -161,7 +161,7 @@ instance IRRep r => PrettyPrec (Expr r n) where Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) - TabApp _ f xs -> atPrec AppPrec $ pApp f <> "." <> dotted (toList xs) + TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es PrimOp op -> prettyPrec op @@ -209,21 +209,24 @@ instance IRRep r => PrettyPrec (LamExpr r n) where instance IRRep r => Pretty (IxType r n) where pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict -instance Pretty (DictCon n) where - pretty d = case d of - InstanceDict _ name args -> "Instance" <+> p name <+> p args - IxFin _ n -> "Ix (Fin" <+> p n <> ")" - DataData _ a -> "Data " <+> p a +instance IRRep r => Pretty (Dict r n) where + pretty = \case + DictCon con -> pretty con + StuckDict _ stuck -> pretty stuck -instance IRRep r => Pretty (IxDict r n) where +instance IRRep r => Pretty (DictCon r n) where pretty = \case - IxDictAtom x -> p x - IxDictRawFin n -> "Ix (RawFin " <> p n <> ")" - IxDictSpecialized _ d xs -> p d <+> p xs + InstanceDict _ name args -> "Instance" <+> p name <+> p args + IxFin n -> "Ix (Fin" <+> p n <> ")" + DataData a -> "Data " <+> p a + IxRawFin n -> "Ix (RawFin " <> p n <> ")" + IxSpecialized d xs -> p d <+> p xs instance Pretty (DictType n) where - pretty (DictType classSourceName _ params) = - p classSourceName <+> spaced params + pretty = \case + DictType classSourceName _ params -> p classSourceName <+> spaced params + IxDictType ty -> "Ix" <+> p ty + DataDictType ty -> "Data" <+> p ty instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (DepPairType r n) where @@ -236,49 +239,32 @@ instance Pretty (CoreLamExpr n) where instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Atom r n) where prettyPrec atom = case atom of - Stuck e -> prettyPrec e - Lam lam -> atPrec LowestPrec $ p lam - DepPair x y _ -> atPrec ArgPrec $ align $ group $ - parens $ p x <+> ",>" <+> p y Con e -> prettyPrec e - Eff e -> atPrec ArgPrec $ p e - PtrVar _ v -> atPrec ArgPrec $ p v - DictCon d -> atPrec LowestPrec $ p d - RepValAtom x -> atPrec LowestPrec $ pretty x - NewtypeCon con x -> prettyPrecNewtype con x - SimpInCore x -> prettyPrec x - TypeAsAtom ty -> prettyPrec ty + Stuck _ e -> prettyPrec e instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Type r n) where prettyPrec = \case - Pi piType -> atPrec LowestPrec $ align $ p piType - TabPi piType -> atPrec LowestPrec $ align $ p piType - DepPairTy ty -> prettyPrec ty - TC e -> prettyPrec e - DictTy t -> atPrec LowestPrec $ p t - NewtypeTyCon con -> prettyPrec con - StuckTy e -> prettyPrec e + TyCon e -> prettyPrec e + StuckTy _ e -> prettyPrec e instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (Stuck r n) where prettyPrec = \case - StuckVar v -> atPrec ArgPrec $ p v - StuckProject _ i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v - StuckTabApp _ f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs - StuckUnwrap _ v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v - InstantiatedGiven _ v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) - SuperclassProj _ d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i - -instance Pretty (SimpInCore n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (SimpInCore n) where - prettyPrec = \case + Var v -> atPrec ArgPrec $ p v + StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs + StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v + InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) + SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i + PtrVar _ v -> atPrec ArgPrec $ p v + RepValAtom x -> atPrec LowestPrec $ pretty x + ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts LiftSimp ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" LiftSimpFun ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" - ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts - TabLam _ _ -> atPrec AppPrec $ "tablam" + TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam -instance IRRep r => Pretty (RepVal r n) where +instance Pretty (RepVal n) where pretty (RepVal ty tree) = " p tree <+> ":" <+> p ty <> ">" instance Pretty a => Pretty (Tree a) where @@ -326,14 +312,9 @@ withExplParens (Inferred _ (Synth _)) x = brackets x instance IRRep r => Pretty (TabPiType r n) where pretty (TabPiType dict (b:>ty) body) = let prettyBody = case body of - Pi subpi -> pretty subpi + TyCon (Pi subpi) -> pretty subpi _ -> pLowest body - prettyBinder = case dict of - IxDictRawFin n -> if binderName b `isFreeIn` body - then parens $ p b <> ":" <> prettyTy - else prettyTy - where prettyTy = "RawFin" <+> p n - _ -> prettyBinderHelper (b:>ty) body + prettyBinder = prettyBinderHelper (b:>ty) body in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) -- A helper to let us turn dict printing on and off. We mostly want it off to @@ -442,7 +423,7 @@ instance Pretty (DataConDef n) where p name <+> ":" <+> p repTy instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName methodNames _ _ params superclasses methodTys) = + pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = "Class:" <+> pretty classSourceName <+> pretty methodNames <> indented ( line <> "parameter binders:" <+> pretty params <> @@ -682,9 +663,6 @@ instance PrettyPrec (UPat' n l) where spaced :: (Foldable f, Pretty a) => f a -> Doc ann spaced xs = hsep $ map p $ toList xs -dotted :: (Foldable f, Pretty a) => f a -> Doc ann -dotted xs = fold $ punctuate "." $ map p $ toList xs - commaSep :: (Foldable f, Pretty a) => f a -> Doc ann commaSep xs = fold $ punctuate "," $ map p $ toList xs @@ -820,8 +798,8 @@ instance PrettyPrec ScalarBaseType where Word32Type -> "Word32" Word64Type -> "Word64" -instance IRRep r => Pretty (TC r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (TC r n) where +instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (TyCon r n) where prettyPrec con = case con of BaseType b -> prettyPrec b ProdType [] -> atPrec ArgPrec $ "()" @@ -832,6 +810,11 @@ instance IRRep r => PrettyPrec (TC r n) where RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a TypeKind -> atPrec ArgPrec "Type" HeapType -> atPrec ArgPrec "Heap" + Pi piType -> atPrec LowestPrec $ align $ p piType + TabPi piType -> atPrec LowestPrec $ align $ p piType + DepPairTy ty -> prettyPrec ty + DictTy t -> atPrec LowestPrec $ p t + NewtypeTyCon con' -> prettyPrec con' prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann prettyPrecNewtype con x = case (con, x) of @@ -866,6 +849,13 @@ instance IRRep r => PrettyPrec (Con r n) where SumCon _ tag payload -> atPrec ArgPrec $ "(" <> p tag <> "|" <+> pApp payload <+> "|)" HeapVal -> atPrec ArgPrec "HeapValue" + Lam lam -> atPrec LowestPrec $ p lam + DepPair x y _ -> atPrec ArgPrec $ align $ group $ + parens $ p x <+> ",>" <+> p y + Eff e -> atPrec ArgPrec $ p e + DictConAtom d -> atPrec LowestPrec $ p d + NewtypeCon con x -> prettyPrecNewtype con x + TyConAtom ty -> prettyPrec ty instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (PrimOp r n) where @@ -932,12 +922,9 @@ instance IRRep r => PrettyPrec (Hof r n) where instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec instance IRRep r => PrettyPrec (DAMOp r n) where prettyPrec op = atPrec LowestPrec case op of - Seq _ ann d c lamExpr -> case lamExpr of + Seq _ ann _ c lamExpr -> case lamExpr of UnaryLamExpr b body -> do - let rawFinPretty = case d of - IxType _ (IxDictRawFin n) -> parens $ "RawFin" <+> p n - _ -> mempty - "seq" <+> rawFinPretty <+> pApp ann <+> pApp c <+> prettyLam (p b <> ".") body + "seq" <+> pApp ann <+> pApp c <+> prettyLam (p b <> ".") body _ -> p (show op) -- shouldn't happen, but crashing pretty printers make debugging hard RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y Place r v -> pApp r <+> "r:=" <+> pApp v diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 424ec295e..5210014cf 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -9,6 +9,7 @@ module QueryType (module QueryType, module QueryTypePure, toAtomVar) where import Control.Category ((>>>)) import Control.Monad import Data.List (elemIndex) +import Data.Maybe (fromJust) import Data.Functor ((<&>)) import Types.Primitives @@ -35,8 +36,8 @@ sourceNameType v = do caseAltsBinderTys :: (EnvReader m, IRRep r) => Type r n -> m n [Type r n] caseAltsBinderTys ty = case ty of - SumTy types -> return types - NewtypeTyCon t -> case t of + TyCon (SumType types) -> return types -- need this case? + TyCon (NewtypeTyCon t) -> case t of UserADTType _ defName params -> do def <- lookupTyCon defName ~(ADTCons cons) <- instantiateTyConDef def params @@ -55,16 +56,13 @@ piTypeWithoutDest (PiType bsRefB _) = PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here _ -> error "expected trailing dest binder" -typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfTabApp t [] = return t -typeOfTabApp (TabPi tabTy) (i:rest) = do - resultTy <- instantiate tabTy [i] - typeOfTabApp resultTy rest +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> Atom r n -> m n (Type r n) +typeOfTabApp (TyCon (TabPi tabTy)) i = instantiate tabTy [i] typeOfTabApp ty _ = error $ "expected a table type. Got: " ++ pprint ty -typeOfApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) +typeOfApplyMethod :: EnvReader m => CDict n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) typeOfApplyMethod d i args = do - ty <- Pi <$> getMethodType d i + ty <- toType <$> getMethodType d i appEffTy ty args typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) @@ -73,33 +71,33 @@ typeOfTopApp f xs = do instantiate piTy xs typeOfIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Type r n -> Atom r n -> m n (Type r n) -typeOfIndexRef (TC (RefType h s)) i = do - TabPi tabPi <- return s +typeOfIndexRef (TyCon (RefType h s)) i = do + TyCon (TabPi tabPi) <- return s eltTy <- instantiate tabPi [i] - return $ TC $ RefType h eltTy + return $ toType $ RefType h eltTy typeOfIndexRef _ _ = error "expected a ref type" typeOfProjRef :: EnvReader m => Type r n -> Projection -> m n (Type r n) -typeOfProjRef (TC (RefType h s)) p = do - TC . RefType h <$> case p of +typeOfProjRef (TyCon (RefType h s)) p = do + toType . RefType h <$> case p of ProjectProduct i -> do - ~(ProdTy tys) <- return s + ~(TyCon (ProdType tys)) <- return s return $ tys !! i UnwrapNewtype -> do case s of - NewtypeTyCon tc -> snd <$> unwrapNewtypeType tc + TyCon (NewtypeTyCon tc) -> snd <$> unwrapNewtypeType tc _ -> error "expected a newtype" typeOfProjRef _ _ = error "expected a reference" appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (Pi piTy) xs = instantiate piTy xs +appEffTy (TyCon (Pi piTy)) xs = instantiate piTy xs appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do +partialAppType (TyCon (Pi (CorePiType appExpl expls bs effTy))) xs = do (_, expls2) <- return $ splitAt (length xs) expls PairB bs1 bs2 <- return $ splitNestAt (length xs) bs - instantiate (Abs bs1 (Pi $ CorePiType appExpl expls2 bs2 effTy)) xs + instantiate (Abs bs1 (toType $ CorePiType appExpl expls2 bs2 effTy)) xs partialAppType _ _ = error "expected a pi type" effTyOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (EffTy r n) @@ -114,7 +112,7 @@ typeOfHof = \case Linearize f _ -> getLamExprType f >>= \case PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do let b' = ignoreHoistFailure $ hoist binder b - let fLinTy = Pi $ nonDepPiType [a] Pure b' + let fLinTy = toType $ nonDepPiType [a] Pure b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" Transpose f _ -> getLamExprType f >>= \case @@ -151,7 +149,7 @@ deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleto getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do - ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className + ClassDef _ _ methodNames _ _ _ _ _ <- lookupClassDef className case elemIndex methodSourceName methodNames of Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className Just i -> return i @@ -164,39 +162,52 @@ getUVarType = \case UDataConVar v -> getDataConNameType v UPunVar v -> getStructDataConType v UClassVar v -> do - ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v - return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind + ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef v + return $ toType $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind UMethodVar v -> getMethodNameType v getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case MethodBinding className i -> do - ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className + ClassDef _ _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do - let params = Var <$> bindersVars paramBs' - dictTy <- DictTy <$> dictType (sink className) params + let params = toAtom <$> bindersVars paramBs' + dictTy <- toType <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do - scDicts <- getSuperclassDicts (Var $ binderVar dictB) + scDicts <- getSuperclassDicts (toDict $ binderVar dictB) CorePiType appExpl methodExpls methodBs effTy <- instantiate (sink absPiTy) scDicts let paramExpls = paramNames <&> \name -> Inferred name Unify let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls - return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy - -getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) -getMethodType dict i = liftEnvReaderM $ withSubstReaderT do - ~(DictTy (DictType _ className params)) <- return $ getType dict - superclassDicts <- getSuperclassDicts dict - classDef <- lookupClassDef className - withInstantiated classDef params \ab -> do - withInstantiated ab superclassDicts \(ListE methodTys) -> - substM $ methodTys !! i + return $ toType $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy + +getMethodType :: EnvReader m => CDict n -> Int -> m n (CorePiType n) +getMethodType dict i = do + ~(TyCon (DictTy dictTy)) <- return $ getType dict + case dictTy of + DictType _ className params -> liftEnvReaderM $ withSubstReaderT do + superclassDicts <- getSuperclassDicts dict + classDef <- lookupClassDef className + withInstantiated classDef params \ab -> do + withInstantiated ab superclassDicts \(ListE methodTys) -> + substM $ methodTys !! i + IxDictType ixTy -> liftEnvReaderM case i of + 0 -> mkCorePiType [] NatTy -- size' : () -> Nat + 1 -> mkCorePiType [ixTy] NatTy -- ordinal : (n) -> Nat + 2 -> mkCorePiType [NatTy] ixTy -- unsafe_from_ordinal : (Nat) -> n + _ -> error "Ix only has three methods" + DataDictType _ -> error "Data class has no methods" + +mkCorePiType :: EnvReader m => [CType n] -> CType n -> m n (CorePiType n) +mkCorePiType argTys resultTy = liftEnvReaderM $ withFreshBinders argTys \bs _ -> do + expls <- return $ nestToList (const Explicit) bs + return $ CorePiType ExplicitApp expls bs (EffTy Pure (sink resultTy)) getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) getTyConNameType v = do TyConDef _ expls bs _ <- lookupTyCon v case bs of Empty -> return TyKind - _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind + _ -> return $ toType $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do @@ -208,9 +219,9 @@ getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do refreshAbs ab \dataBs UnitE -> do let appExpl = case dataBs of Empty -> ImplicitApp _ -> ExplicitApp - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) + let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) let dataExpls = nestToList (const $ Explicit) dataBs - return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) + return $ toType $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do @@ -218,10 +229,10 @@ getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do buildDataConType tyConDef \expls paramBs' paramVs params -> do withInstantiatedNames tyConDef paramVs \(StructFields fields) -> do fieldTys <- forM fields \(_, t) -> renameM t - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) params + let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) params Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy let dataExpls = nestToList (const Explicit) dataBs - return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') + return $ toType $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') buildDataConType :: (EnvReader m, EnvExtender m) @@ -236,44 +247,29 @@ buildDataConType (TyConDef _ roleExpls bs _) cont = do refreshAbs (Abs bs UnitE) \bs' UnitE -> do let vs = nestToNames bs' vs' <- mapM toAtomVar vs - cont expls' bs' vs $ TyConParams expls (Var <$> vs') + cont expls' bs' vs $ TyConParams expls (toAtom <$> vs') makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) makeTyConParams tc params = do TyConDef _ expls _ _ <- lookupTyCon tc return $ TyConParams (map snd expls) params -getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getDataClassName = lookupSourceMap "Data" >>= \case - Nothing -> throw CompilerErr $ "Data interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - -dataDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -dataDictType ty = do - dataClassName <- getDataClassName - dictType dataClassName [Type ty] - -getIxClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getIxClassName = lookupSourceMap "Ix" >>= \case - Nothing -> throw CompilerErr $ "Ix interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) dictType className params = do - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className - return $ DictType sourceName className params - -ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -ixDictType ty = do - ixClassName <- getIxClassName - dictType ixClassName [Type ty] + ClassDef sourceName builtinName _ _ _ _ _ _ <- lookupClassDef className + return case builtinName of + Just Ix -> IxDictType singleTyParam + Just Data -> DataDictType singleTyParam + Nothing -> DictType sourceName className params + where singleTyParam = case params of + [p] -> fromJust $ toMaybeType p + _ -> error "not a single type param" makePreludeMaybeTy :: EnvReader m => CType n -> m n (CType n) makePreludeMaybeTy ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - return $ TypeCon "Maybe" tyConName $ TyConParams [Explicit] [Type ty] + let params = TyConParams [Explicit] [toAtom ty] + return $ toType $ UserADTType "Maybe" tyConName params -- === computing effects === @@ -285,7 +281,7 @@ rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow rwsFunEffects rws f = getLamExprType f >>= \case PiType (BinaryNest h ref) et -> do let effs' = ignoreHoistFailure $ hoist ref (etEff et) - let hVal = Var $ AtomVar (binderName h) (TC HeapType) + let hVal = toAtom $ AtomVar (binderName h) (TyCon HeapType) let effs'' = deleteEff (RWSEffect rws hVal) effs' return $ ignoreHoistFailure $ hoist h effs'' _ -> error "Expected a binary function type" @@ -305,19 +301,22 @@ getTypeRWSAction f = getLamExprType f >>= \case _ -> error "expected a ref" _ -> error "expected a pi type" -getSuperclassDicts :: EnvReader m => CAtom n -> m n ([CAtom n]) +getSuperclassDicts :: EnvReader m => CDict n -> m n ([CAtom n]) getSuperclassDicts dict = do case getType dict of - DictTy dTy -> do + TyCon (DictTy dTy) -> do ts <- getSuperclassTys dTy forM (enumerate ts) \(i, _) -> reduceSuperclassProj i dict _ -> error "expected a dict type" getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] -getSuperclassTys (DictType _ className params) = do - ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className - forM [0 .. nestLength superclasses - 1] \i -> do - instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params +getSuperclassTys = \case + DictType _ className params -> do + ClassDef _ _ _ _ _ bs superclasses _ <- lookupClassDef className + forM [0 .. nestLength superclasses - 1] \i -> do + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params + DataDictType _ -> return [] + IxDictType ty -> return [toType $ DataDictType ty] getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) getTypeTopFun f = lookupTopFun f >>= \case @@ -336,10 +335,10 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where [] -> return $ PiType Empty (EffTy (OneEffect IOEffect) resultTy) where resultTy = case resultTys of [] -> UnitTy - [t] -> BaseTy t - [t1, t2] -> PairTy (BaseTy t1) (BaseTy t2) + [t] -> toType $ BaseType t + [t1, t2] -> TyCon (ProdType [toType $ BaseType t1, toType $ BaseType t2]) _ -> error $ "Not a valid FFI return type: " ++ pprint resultTys - t:ts -> withFreshBinder noHint (BaseTy t) \b -> do + t:ts -> withFreshBinder noHint (toType $ BaseType t) \b -> do PiType bs effTy <- go ts return $ PiType (Nest b bs) effTy @@ -354,24 +353,25 @@ isData ty = do checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () checkDataLike ty = case ty of - TyVar _ -> notData - TabPi (TabPiType _ b eltTy) -> do - renameBinders b \_ -> - checkDataLike eltTy - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l - renameBinders b \_ -> checkDataLike r - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType =<< renameM nt - dropSubst $ recur ty' - TC con -> case con of + StuckTy _ _ -> notData + TyCon con -> case con of + TabPi (TabPiType _ b eltTy) -> do + renameBinders b \_ -> + checkDataLike eltTy + DepPairTy (DepPairType _ b@(_:>l) r) -> do + recur l + renameBinders b \_ -> checkDataLike r + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType =<< renameM nt + dropSubst $ recur ty' BaseType _ -> return () ProdType as -> mapM_ recur as SumType cs -> mapM_ recur cs RefType _ _ -> return () HeapType -> return () - _ -> notData - _ -> notData + TypeKind -> notData + DictTy _ -> notData + Pi _ -> notData where recur = checkDataLike notData = throw TypeErr $ pprint ty diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 258bbb9b3..f21a066eb 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -17,6 +17,9 @@ class HasType (r::IR) (e::E) | e -> r where class HasEffects (e::E) (r::IR) | e -> r where getEffects :: e n -> EffectRow r n +getTyCon :: HasType SimpIR e => e n -> TyCon SimpIR n +getTyCon e = con where TyCon con = getType e + isPure :: (IRRep r, HasEffects e r) => e n -> Bool isPure e = case getEffects e of Pure -> True @@ -32,8 +35,8 @@ instance IRRep r => HasType r (AtomBinding r) where SolverBound (SkolemBound ty) -> ty SolverBound (DictBound ty) -> ty NoinlineFun ty _ -> ty - TopDataBound (RepVal ty _) -> ty - FFIFunBound piTy _ -> Pi piTy + TopDataBound e -> getType e + FFIFunBound piTy _ -> TyCon $ Pi piTy litType :: LitVal -> BaseType litType v = case v of @@ -69,66 +72,48 @@ instance IRRep r => HasType r (AtomVar r) where {-# INLINE getType #-} instance IRRep r => HasType r (Atom r) where - getType atom = case atom of - Stuck e -> getType e - Lam (CoreLamExpr piTy _) -> Pi piTy - DepPair _ _ ty -> DepPairTy ty - Con con -> getType con - Eff _ -> EffKind - PtrVar t _ -> PtrTy t - DictCon d -> getType d - NewtypeCon con _ -> getNewtypeType con - RepValAtom (RepVal ty _) -> ty - SimpInCore x -> getType x - TypeAsAtom ty -> getType ty - -instance HasType CoreIR DictCon where getType = \case - InstanceDict t _ _ -> t - IxFin t _ -> t - DataData t _ -> t + Stuck t _ -> t + Con e -> getType e -instance IRRep r => HasType r (Type r) where +instance HasType CoreIR (Dict CoreIR) where getType = \case - NewtypeTyCon con -> getType con - Pi _ -> TyKind - TabPi _ -> TyKind - DepPairTy _ -> TyKind - TC _ -> TyKind - DictTy _ -> TyKind - StuckTy e -> getType e - -instance IRRep r => HasType r (Stuck r) where + StuckDict t _ -> t + DictCon e -> getType e + +instance HasType CoreIR (DictCon CoreIR) where getType = \case - StuckVar (AtomVar _ t) -> t - StuckProject t _ _ -> t - StuckTabApp t _ _ -> t - StuckUnwrap t _ -> t - InstantiatedGiven t _ _ -> t - SuperclassProj t _ _ -> t - -instance HasType CoreIR SimpInCore where + InstanceDict t _ _ -> t + DataData t -> toType $ DataDictType t + IxFin n -> toType $ IxDictType (FinTy n) + IxRawFin _ -> toType $ IxDictType IdxRepTy + +instance HasType CoreIR CType where getType = \case - LiftSimp t _ -> t - LiftSimpFun piTy _ -> Pi $ piTy - TabLam t _ -> TabPi $ t - ACase _ _ t -> t + TyCon _ -> TyKind + StuckTy t _ -> t instance HasType CoreIR NewtypeTyCon where getType _ = TyKind getNewtypeType :: NewtypeCon n -> CType n getNewtypeType con = case con of - NatCon -> NewtypeTyCon Nat - FinCon n -> NewtypeTyCon $ Fin n - UserADTData sn d params -> NewtypeTyCon $ UserADTType sn d params + NatCon -> TyCon $ NewtypeTyCon Nat + FinCon n -> TyCon $ NewtypeTyCon $ Fin n + UserADTData sn d xs -> TyCon $ NewtypeTyCon $ UserADTType sn d xs instance IRRep r => HasType r (Con r) where getType = \case - Lit l -> BaseTy $ litType l - ProdCon xs -> ProdTy $ map getType xs - SumCon tys _ _ -> SumTy tys - HeapVal -> TC HeapType + Lit l -> toType $ BaseType $ litType l + ProdCon xs -> toType $ ProdType $ map getType xs + SumCon tys _ _ -> toType $ SumType tys + HeapVal -> toType HeapType + Lam (CoreLamExpr piTy _) -> toType $ Pi piTy + DepPair _ _ ty -> toType $ DepPairTy ty + Eff _ -> EffKind + DictConAtom d -> getType d + NewtypeCon con _ -> getNewtypeType con + TyConAtom _ -> TyKind getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n getSuperclassType _ Empty = error "bad index" @@ -150,6 +135,9 @@ instance IRRep r => HasType r (Expr r) where Project t _ _ -> t Unwrap t _ -> t +instance HasType SimpIR RepVal where + getType (RepVal ty _) = ty + instance IRRep r => HasType r (DAMOp r) where getType = \case AllocDest ty -> RawRefTy ty @@ -162,15 +150,15 @@ instance IRRep r => HasType r (DAMOp r) where instance IRRep r => HasType r (PrimOp r) where getType primOp = case primOp of - BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x - UnOp op x -> TC $ BaseType $ typeUnOp op $ getTypeBaseType x + BinOp op x _ -> TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x + UnOp op x -> TyCon $ BaseType $ typeUnOp op $ getTypeBaseType x Hof (TypedHof (EffTy _ ty) _) -> ty MemOp op -> getType op MiscOp op -> getType op VectorOp op -> getType op DAMOp op -> getType op RefOp ref m -> case getType ref of - TC (RefType _ s) -> case m of + TyCon (RefType _ s) -> case m of MGet -> s MPut _ -> UnitTy MAsk -> s @@ -181,7 +169,7 @@ instance IRRep r => HasType r (PrimOp r) where getTypeBaseType :: (IRRep r, HasType r e) => e n -> BaseType getTypeBaseType e = case getType e of - TC (BaseType b) -> b + TyCon (BaseType b) -> b ty -> error $ "Expected a base type. Got: " ++ show ty instance IRRep r => HasType r (MemOp r) where @@ -191,7 +179,7 @@ instance IRRep r => HasType r (MemOp r) where PtrOffset arr _ -> getType arr PtrLoad ptr -> do let PtrTy (_, t) = getType ptr - BaseTy t + toType $ BaseType t PtrStore _ _ -> UnitTy instance IRRep r => HasType r (VectorOp r) where @@ -200,7 +188,7 @@ instance IRRep r => HasType r (VectorOp r) where VectorIota vty -> vty VectorIdx _ _ vty -> vty VectorSubref ref _ vty -> case getType ref of - TC (RefType h _) -> TC $ RefType h vty + TyCon (RefType h _) -> TyCon $ RefType h vty ty -> error $ "Not a reference type: " ++ show ty instance IRRep r => HasType r (MiscOp r) where @@ -214,20 +202,20 @@ instance IRRep r => HasType r (MiscOp r) where GarbageVal t -> t SumTag _ -> TagRepTy ToEnum t _ -> t - OutputStream -> BaseTy $ hostPtrTy $ Scalar Word8Type + OutputStream -> toType $ BaseType $ hostPtrTy $ Scalar Word8Type where hostPtrTy ty = PtrType (CPU, ty) ShowAny _ -> rawStrType -- TODO: constrain `ShowAny` to have `HasCore r` - ShowScalar _ -> PairTy IdxRepTy $ rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy + ShowScalar _ -> toType $ ProdType [IdxRepTy, rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy] rawStrType :: IRRep r => Type r n rawStrType = case newName "n" of Abs b v -> do - let tabTy = rawFinTabType (Var $ AtomVar v IdxRepTy) CharRepTy - DepPairTy $ DepPairType ExplicitDepPair (b:>IdxRepTy) tabTy + let tabTy = rawFinTabType (toAtom $ AtomVar v IdxRepTy) CharRepTy + TyCon $ DepPairTy $ DepPairType ExplicitDepPair (b:>IdxRepTy) tabTy -- `n` argument is IdxRepVal, not Nat rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n -rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy +rawFinTabType n eltTy = IxType IdxRepTy (DictCon (IxRawFin n)) ==> eltTy tabIxType :: TabPiType r n -> IxType r n tabIxType (TabPiType d (_:>t) _) = IxType t d @@ -255,21 +243,13 @@ coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f (==>) :: IRRep r => IxType r n -> Type r n -> Type r n -a ==> b = TabPi $ nonDepTabPiType a b +a ==> b = TyCon $ TabPi $ nonDepTabPiType a b -litFinIxTy :: Int -> IxType r n +litFinIxTy :: Int -> IxType SimpIR n litFinIxTy n = finIxTy $ IdxRepVal $ fromIntegral n -finIxTy :: Atom r n -> IxType r n -finIxTy n = IxType IdxRepTy (IxDictRawFin n) - -ixTyFromDict :: IRRep r => IxDict r n -> IxType r n -ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of - IxDictAtom dict -> case getType dict of - DictTy (DictType "Ix" _ [Type iTy]) -> iTy - _ -> error $ "Not an Ix dict: " ++ show dict - IxDictRawFin _ -> IdxRepTy - IxDictSpecialized n _ _ -> n +finIxTy :: Atom SimpIR n -> IxType SimpIR n +finIxTy n = IxType IdxRepTy (DictCon (IxRawFin n)) -- === querying effects implementation === @@ -316,7 +296,7 @@ instance IRRep r => HasEffects (PrimOp r) r where ShowAny _ -> Pure ShowScalar _ -> Pure RefOp ref m -> case getType ref of - TC (RefType h _) -> case m of + TyCon (RefType h _) -> case m of MGet -> OneEffect (RWSEffect State h) MPut _ -> OneEffect (RWSEffect State h) MAsk -> OneEffect (RWSEffect Reader h) diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 593a3fc9c..f073ea0c9 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -58,26 +58,29 @@ emitCharLit c = emitChar $ charRepVal c showAnyRec :: forall n. Emits n => CAtom n -> Print n showAnyRec atom = case getType atom of - -- hack to print chars nicely. TODO: make `Char` a newtype - TC t -> case t of - BaseType bt -> case bt of - Vector _ _ -> error "not implemented" - PtrType _ -> printTypeOnly "pointer" - Scalar _ -> do - (n, tab) <- fromPair =<< emitExpr (PrimOp $ MiscOp $ ShowScalar atom) - logicalTabTy <- finTabTyCore (NewtypeCon NatCon n) CharRepTy - tab' <- emitExpr $ PrimOp $ MiscOp $ UnsafeCoerce logicalTabTy tab - emitCharTab tab' - -- TODO: we could do better than this but it's not urgent because raw sum types - -- aren't user-facing. - SumType _ -> printAsConstant - RefType _ _ -> printTypeOnly "reference" - HeapType -> printAsConstant - ProdType _ -> do - xs <- getUnpacked atom - parens $ sepBy ", " $ map rec xs - -- TODO: traverse the type and print out data components - TypeKind -> printAsConstant + TyCon con -> showAnyTyCon con atom + StuckTy _ e -> error $ "unexpected stuck type expression: " ++ pprint e + +showAnyTyCon :: forall n. Emits n => TyCon CoreIR n -> CAtom n -> Print n +showAnyTyCon tyCon atom = case tyCon of + BaseType bt -> case bt of + Vector _ _ -> error "not implemented" + PtrType _ -> printTypeOnly "pointer" + Scalar _ -> do + (n, tab) <- fromPair =<< emitExpr (ShowScalar atom) + logicalTabTy <- finTabTyCore (Con $ NewtypeCon NatCon n) CharRepTy + tab' <- emitExpr $ UnsafeCoerce logicalTabTy tab + emitCharTab tab' + -- TODO: we could do better than this but it's not urgent because raw sum types + -- aren't user-facing. + SumType _ -> printAsConstant + RefType _ _ -> printTypeOnly "reference" + HeapType -> printAsConstant + ProdType _ -> do + xs <- getUnpacked atom + parens $ sepBy ", " $ map rec xs + -- TODO: traverse the type and print out data components + TypeKind -> printAsConstant Pi _ -> printTypeOnly "function" TabPi _ -> brackets $ forEachTabElt atom \iOrd x -> do isFirst <- ieq iOrd (NatVal 0) @@ -88,11 +91,11 @@ showAnyRec atom = case getType atom of Nat -> do n <- unwrapNewtype atom -- Cast to Int so that it prints in decimal instead of hex - let intTy = TC (BaseType (Scalar Int64Type)) - emitExpr (PrimOp $ MiscOp $ CastOp intTy n) >>= rec + let intTy = toType $ BaseType (Scalar Int64Type) + emitExpr (CastOp intTy n) >>= rec EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype - UserADTType "List" _ (TyConParams [Explicit] [Type Word8Ty]) -> do + UserADTType "List" _ (TyConParams [Explicit] [Con (TyConAtom (BaseType (Scalar (Word8Type))))]) -> do charTab <- applyProjections [ProjectProduct 1, UnwrapNewtype] atom emitCharLit '"' emitCharTab charTab @@ -127,7 +130,6 @@ showAnyRec atom = case getType atom of -- Done well, this could let you inspect the results of dictionary synthesis -- and maybe even debug synthesis failures. DictTy _ -> printAsConstant - StuckTy e -> error $ "unexpected stuck type expression: " ++ pprint e where rec :: Emits n' => CAtom n' -> Print n' rec = showAnyRec @@ -161,18 +163,18 @@ withBuffer => (forall l . (Emits l, DExt n l) => CAtom l -> BuilderM CoreIR l ()) -> BuilderM CoreIR n (CAtom n) withBuffer cont = do - lam <- withFreshBinder "h" (TC HeapType) \h -> do - bufTy <- bufferTy (Var $ binderVar h) + lam <- withFreshBinder "h" (TyCon HeapType) \h -> do + bufTy <- bufferTy (toAtom $ binderVar h) withFreshBinder "buf" bufTy \b -> do - let eff = OneEffect (RWSEffect State (Var $ sink $ binderVar h)) + let eff = OneEffect (RWSEffect State (toAtom $ sink $ binderVar h)) body <- buildBlock do - cont $ sink $ Var $ binderVar b + cont $ sink $ toAtom $ binderVar b return UnitVal let binders = BinaryNest h b let expls = [Inferred Nothing Unify, Explicit] let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy let lam = LamExpr (BinaryNest h b) body - return $ Lam $ CoreLamExpr piTy lam + return $ toAtom $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] bufferTy :: EnvReader m => CAtom n -> m n (CType n) @@ -184,7 +186,7 @@ bufferTy h = do extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do RefTy h _ <- return $ getType buf - TabPi t <- return $ getType tab + TyCon (TabPi t) <- return $ getType tab n <- applyIxMethodCore Size (tabIxType t) [] void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] @@ -200,15 +202,13 @@ stringLitAsCharTab s = do emitExpr $ TabCon Nothing t (map charRepVal s) finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) -finTabTyCore n eltTy = do - d <- DictCon <$> mkIxFin n - return $ IxType (FinTy n) (IxDictAtom d) ==> eltTy +finTabTyCore n eltTy = return $ IxType (FinTy n) (DictCon $ IxFin n) ==> eltTy getPreludeFunction :: EnvReader m => String -> m n (CAtom n) getPreludeFunction sourceName = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of - UAtomVar v -> Var <$> toAtomVar v + UAtomVar v -> toAtom <$> toAtomVar v _ -> notfound Nothing -> notfound where notfound = error $ "Function not defined: " ++ sourceName @@ -218,14 +218,14 @@ applyPreludeFunction name args = do f <- getPreludeFunction name naryApp f args -strType :: EnvReader m => m n (CType n) -strType = constructPreludeType "List" $ TyConParams [Explicit] [Type CharRepTy] +strType :: forall n m. EnvReader m => m n (CType n) +strType = constructPreludeType "List" $ TyConParams [Explicit] [toAtom (CharRepTy :: CType n)] constructPreludeType :: EnvReader m => String -> TyConParams n -> m n (CType n) constructPreludeType sourceName params = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of - UTyConVar v -> return $ TypeCon sourceName v params + UTyConVar v -> return $ toType $ UserADTType sourceName v params _ -> notfound Nothing -> notfound where notfound = error $ "Type constructor not defined: " ++ sourceName @@ -236,10 +236,10 @@ forEachTabElt -> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ()) -> m n () forEachTabElt tab cont = do - TabPi t <- return $ getType tab + TyCon (TabPi t) <- return $ getType tab let ixTy = tabIxType t void $ buildFor "i" Fwd ixTy \i -> do - x <- tabApp (sink tab) (Var i) - i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i] + x <- tabApp (sink tab) (toAtom i) + i' <- applyIxMethodCore Ordinal (sink ixTy) [toAtom i] cont i' x return $ UnitVal diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 0f71998f4..e98d977fa 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -10,7 +10,6 @@ module Simplify ( simplifyTopBlock, simplifyTopFunction, ReconstructAtom (..), applyReconTop, linearizeTopFun, SimplifiedTopLam (..)) where -import Control.Applicative import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader @@ -68,211 +67,206 @@ tryAsDataAtom atom = do isData ty >>= \case False -> return Nothing True -> Just <$> do - repAtom <- go atom + repAtom <- dropSubst $ toDataAtom atom return (repAtom, ty) - where - go :: Emits n => CAtom n -> SimplifyM i n (SAtom n) - go = \case - Stuck e -> case e of - StuckVar v -> lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> go x - _ -> error "Shouldn't have irreducible top names left" - StuckUnwrap _ x -> go (Stuck x) - -- TODO: do we need to think about a case like `fst (1, \x.x)`, where - -- the projection is data but the argument isn't? - StuckProject _ i x -> reduceProj i =<< go (Stuck x) - _ -> notData - Con con -> Con <$> case con of - Lit v -> return $ Lit v - ProdCon xs -> ProdCon <$> mapM go xs - SumCon tys tag x -> SumCon <$> mapM getRepType tys <*> pure tag <*> go x - HeapVal -> return HeapVal - PtrVar t v -> return $ PtrVar t v - DepPair x y ty -> do - DepPairTy ty' <- getRepType $ DepPairTy ty - DepPair <$> go x <*> go y <*> pure ty' - NewtypeCon _ x -> go x - SimpInCore x -> case x of - LiftSimp _ x' -> return x' - LiftSimpFun _ _ -> notData - TabLam _ tabLam -> forceTabLam tabLam - ACase scrut alts resultTy -> forceACase scrut alts resultTy - Lam _ -> notData - DictCon _ -> notData - Eff _ -> notData - TypeAsAtom _ -> notData - where - notData = error $ "Not runtime-representable data: " ++ pprint atom data WithSubst (e::E) (o::S) where WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o +type ACase = SStuck `PairE` ListE (Abs SBinder CAtom) `PairE` CType + data ConcreteCAtom (n::S) = - CCCon (WithSubst CAtom n) -- can't be Stuck or SimpInCore - | CCSimpInCore (SimpInCore n) -- can't be ACase + CCCon (WithSubst (Con CoreIR) n) + | CCLiftSimp (CType n) (Stuck SimpIR n) + | CCFun (ConcreteCFun n) + | CCTabLam (WithSubst TabLamExpr n) + | CCACase (WithSubst ACase n) + +data ConcreteCFun (n::S) = + CCLiftSimpFun (CorePiType n) (LamExpr SimpIR n) | CCNoInlineFun (CAtomVar n) (CType n) (CAtom n) | CCFFIFun (CorePiType n) (TopFunName n) --- Yields to the continuation a term with a concrete CoreIR constructor, --- or LiftSimpFun, liftSimp, or TabLam. -forceConstructor - :: Emits o - => CAtom i - -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) - -> SimplifyM i o (CAtom o) -forceConstructor atom cont = withDistinct case atom of - Stuck stuck -> forceStuck stuck cont - SimpInCore lifted -> case lifted of - ACase e alts resultTy -> do - e' <- substM e - resultTy' <- substM resultTy - defuncCase e' resultTy' \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - forceConstructor body cont - _ -> do - lifted' <- substM lifted - cont $ CCSimpInCore lifted' - _ -> do - Distinct <- getDistinct +forceConstructor :: CAtom i -> SimplifyM i o (ConcreteCAtom o) +forceConstructor atom = withDistinct case atom of + Stuck _ stuck -> forceStuck stuck + Con con -> do subst <- getSubst - cont $ CCCon $ WithSubst subst atom + return $ CCCon $ WithSubst subst con -forceStuck - :: Emits o - => CStuck i - -> (forall o' i'. (DExt o o', Emits o') => ConcreteCAtom o'-> SimplifyM i' o' (CAtom o')) - -> SimplifyM i o (CAtom o) -forceStuck stuck cont = withDistinct case stuck of - StuckVar v -> lookupSubstM (atomVarName v) >>= \case - SubstVal x -> dropSubst $ forceConstructor x cont +forceStuck :: forall i o . CStuck i -> SimplifyM i o (ConcreteCAtom o) +forceStuck stuck = withDistinct case stuck of + Var v -> lookupSubstM (atomVarName v) >>= \case + SubstVal x -> dropSubst $ forceConstructor x Rename v' -> lookupAtomName v' >>= \case - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x cont + LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x NoinlineFun t f -> do v'' <- toAtomVar v' - cont $ CCNoInlineFun v'' t f - FFIFunBound t f -> cont $ CCFFIFun t f + return $ CCFun $ CCNoInlineFun v'' t f + FFIFunBound t f -> return $ CCFun $ CCFFIFun t f _ -> error "shouldn't have other CVars left" - -- TODO: figure out how to de-dup these cases with their Expr counterpart - StuckProject _ i x -> do - ty <- substM $ getType stuck - forceStuck x \case - CCSimpInCore (LiftSimp _ x') -> do - x'' <- proj i x' - cont $ CCSimpInCore $ LiftSimp (sink ty) x'' - CCCon (WithSubst s con) -> withSubst s case con of - ProdVal xs -> forceConstructor (xs!!i) cont - DepPair l r _ -> forceConstructor ([l, r]!!i) cont - _ -> error "Can't project stuck term" - _ -> error "Can't project stuck term" - StuckTabApp _ f xs -> do - ty <- substM $ getType stuck - xs' <- forM xs \x -> toDataAtomIgnoreRecon =<< substM x - forceStuck f \case - CCSimpInCore (LiftSimp _ f') -> do - result <- naryTabApp f' (sink<$>xs') - cont $ CCSimpInCore $ LiftSimp (sink ty) result - _ -> error "not a table" -- what about table lambda? - StuckUnwrap _ x -> forceStuck x \case + LiftSimp _ x -> do + -- the subst should be rename-only for `x`. We should make subst IR-specific + s <- getSubst + let s' = newSubst \v -> case s ! v of + SubstVal _ -> error "subst should be rename-only for SimpIR vars" -- TODO: make subst IR-specific + Rename v' -> v' + x' <- runSubstReaderT s' $ renameM x + returnLifted x' + -- We "thunk" ACase rather than forcing it because different use-cases require different ways to force it + ACase e alts resultTy -> do + subst <- getSubst + return $ CCACase $ WithSubst subst $ e `PairE` ListE alts `PairE` resultTy + TabLam e -> do + subst <- getSubst + return $ CCTabLam $ WithSubst subst e + StuckProject i x -> forceStuck x >>= \case + CCLiftSimp _ x' -> returnLifted $ StuckProject i x' CCCon (WithSubst s con) -> withSubst s case con of - NewtypeCon _ x' -> forceConstructor x' cont - _ -> error "can't unwrap stuck term" - _ -> error "can't unwrap stuck term" - InstantiatedGiven _ _ _ -> error "shouldn't have this left" - SuperclassProj _ _ _ -> error "shouldn't have this left" - -forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) -forceTabLam (PairE ixTy (Abs b ab)) = - buildFor (getNameHint b) Fwd ixTy \v -> do - result <- applyRename (b@>(atomVarName v)) ab >>= emitDecls - toDataAtomIgnoreRecon result - -type NaryTabLamExpr = Abs (Nest SBinder) (Abs (Nest SDecl) CAtom) - -fromNaryTabLam :: Int -> CAtom n -> Maybe (Int, NaryTabLamExpr n) -fromNaryTabLam maxDepth | maxDepth <= 0 = error "expected positive number of args" -fromNaryTabLam maxDepth = \case - SimpInCore (TabLam _ (PairE _ (Abs b body))) -> - extend <|> (Just $ (1, Abs (Nest b Empty) body)) - where - extend = case body of - Abs Empty lam | maxDepth > 1 -> do - (d, Abs (Nest b2 bs2) body2) <- fromNaryTabLam (maxDepth - 1) lam - return $ (d + 1, Abs (Nest b (Nest b2 bs2)) body2) - _ -> Nothing - _ -> Nothing - -forceACase :: Emits n => SAtom n -> [Abs SBinder CAtom n] -> CType n -> SimplifyM i n (SAtom n) -forceACase scrut alts resultTy = do - resultTy' <- getRepType resultTy - buildCase scrut resultTy' \i arg -> do - Abs b result <- return $ alts !! i - applySubst (b@>SubstVal arg) result >>= toDataAtomIgnoreRecon + ProdCon xs -> forceConstructor (xs!!i) + DepPair l r _ -> forceConstructor ([l, r]!!i) + _ -> error "not a product" + CCACase x' -> pushUnderACase x' \x'' -> reduceProj i x'' + CCFun _ -> error "not a product" + CCTabLam _ -> error "not a product" + StuckTabApp f x -> forceStuck f >>= \case + CCLiftSimp _ f' -> do + x' <- toDataAtom x + returnLifted $ StuckTabApp f' x' + CCTabLam (WithSubst s (PairE _ (Abs b body))) -> do + x' <- toDataAtom x + result <- withSubst s $ extendSubst (b@>SubstVal x') $ substM body + dropSubst $ forceConstructor result + CCACase f' -> pushUnderACase f' \f'' -> reduceTabApp f'' =<< substM x + CCCon _ -> error "not a table" + CCFun _ -> error "not a table" + StuckUnwrap x -> forceStuck x >>= \case + CCCon (WithSubst s con) -> case con of + NewtypeCon _ x' -> withSubst s $ forceConstructor x' + _ -> error "not a newtype" + CCLiftSimp _ x' -> returnLifted x' + CCACase x' -> pushUnderACase x' \x'' -> reduceUnwrap x'' + CCFun _ -> error "not a newtype" + CCTabLam _ -> error "not a newtype" + InstantiatedGiven _ _ -> error "shouldn't have this left" + SuperclassProj _ _ -> error "shouldn't have this left" + PtrVar ty p -> do + p' <- substM p + returnLifted $ PtrVar ty p' + LiftSimpFun t f -> CCFun <$> (CCLiftSimpFun <$> substM t <*> substM f) + where + returnLifted :: SStuck o -> SimplifyM i o (ConcreteCAtom o) + returnLifted s = do + resultTy <- getType <$> substMStuck stuck + return $ CCLiftSimp resultTy s + +pushUnderACase + :: WithSubst ACase o + -> (forall o'. DExt o o' => CAtom o' -> SimplifyM i o' (CAtom o')) + -> SimplifyM i o (ConcreteCAtom o) +pushUnderACase _ _ = undefined +-- pushUnderACase (WithSubst s (scrut `PairE` ListE alts `PairE` resultTy)) cont = undefined +-- TODO: make a buildACase to use here and elsewhere in Simplify. Maybe in CheapReduce too? + + +forceACase + :: Emits o => WithSubst ACase o + -> (forall o'. (Emits o', DExt o o') => ConcreteCAtom o' -> SimplifyM i o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceACase (WithSubst subst (scrut `PairE` ListE alts `PairE` resultTy)) cont = do + resultTy' <- withSubst subst $ substM resultTy + scrut' <- withSubst subst $ substMStuck scrut + defuncCase scrut' resultTy' \i x -> do + Abs b body <- return $ alts !! i + body' <- withSubst (sink subst) $ extendSubst (b@>SubstVal x) $ forceConstructor body + cont body' tryGetRepType :: Type CoreIR n -> SimplifyM i n (Maybe (SType n)) tryGetRepType t = isData t >>= \case False -> return Nothing - True -> Just <$> getRepType t - -getRepType :: Type CoreIR n -> SimplifyM i n (SType n) -getRepType ty = go ty where - go :: Type CoreIR n -> SimplifyM i n (SType n) - go = \case - TC con -> TC <$> case con of - BaseType b -> return $ BaseType b - ProdType ts -> ProdType <$> mapM go ts - SumType ts -> SumType <$> mapM go ts - RefType h a -> RefType <$> toDataAtomIgnoreReconAssumeNoDecls h <*> go a - TypeKind -> error $ notDataType - HeapType -> return $ HeapType - DepPairTy (DepPairType expl b@(_:>l) r) -> do - l' <- go l - withFreshBinder (getNameHint b) l' \b' -> do - x <- liftSimpAtom (sink l) (Var $ binderVar b') - r' <- go =<< applySubst (b@>SubstVal x) r - return $ DepPairTy $ DepPairType expl b' r' - TabPi tabTy -> do - let ixTy = tabIxType tabTy - IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint tabTy) t' \b' -> do - x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< instantiate (sink tabTy) [x] - return $ TabPi $ TabPiType d' b' bodyTy' - NewtypeTyCon con -> do - (_, ty') <- unwrapNewtypeType con - go ty' - Pi _ -> error notDataType - DictTy _ -> error notDataType - StuckTy _ -> error "Shouldn't have stuck expressions in CoreIR IR with SimpIR builder names" - where notDataType = "Not a type of runtime-representable data: " ++ pprint ty - -toDataAtom :: Emits n => CAtom n -> SimplifyM i n (SAtom n, Type CoreIR n) -toDataAtom x = tryAsDataAtom x >>= \case - Just x' -> return x' - Nothing -> error $ "Not a data atom: " ++ pprint x - -simplifyDataAtom :: Emits o => CAtom i -> SimplifyM i o (SAtom o) -simplifyDataAtom x = toDataAtomIgnoreRecon =<< simplifyAtom x - -toDataAtomIgnoreRecon :: Emits n => CAtom n -> SimplifyM i n (SAtom n) -toDataAtomIgnoreRecon x = fst <$> toDataAtom x - -toDataAtomIgnoreReconAssumeNoDecls :: CAtom n -> SimplifyM i n (SAtom n) -toDataAtomIgnoreReconAssumeNoDecls x = do - Abs decls result <- buildScoped $ fst <$> toDataAtom (sink x) + True -> Just <$> dropSubst (getRepType t) + +getRepType :: Type CoreIR i -> SimplifyM i o (SType o) +getRepType (StuckTy _ stuck) = + substMStuck stuck >>= \case + Stuck _ _ -> error "shouldn't have stuck CType after substitution" + Con (TyConAtom tyCon) -> dropSubst $ getRepType (TyCon tyCon) + Con _ -> error "not a type" +getRepType (TyCon con) = case con of + BaseType b -> return $ toType $ BaseType b + ProdType ts -> toType . ProdType <$> mapM getRepType ts + SumType ts -> toType . SumType <$> mapM getRepType ts + RefType h a -> toType <$> (RefType <$> toDataAtomAssumeNoDecls h <*> getRepType a) + HeapType -> return $ toType HeapType + DepPairTy (DepPairType expl b r) -> do + withSimplifiedBinder b \b' -> do + r' <- getRepType r + return $ toType $ DepPairType expl b' r' + TabPi (TabPiType ixDict b r) -> do + ixDict' <- simplifyIxDict ixDict + withSimplifiedBinder b \b' -> do + r' <- getRepType r + return $ toType $ TabPi $ TabPiType ixDict' b' r' + NewtypeTyCon con' -> do + (_, ty') <- unwrapNewtypeType =<< substM con' + dropSubst $ getRepType ty' + Pi _ -> error notDataType + DictTy _ -> error notDataType + TypeKind -> error notDataType + where notDataType = "Not a type of runtime-representable data" + +toDataAtom :: CAtom i -> SimplifyM i o (SAtom o) +toDataAtom (Con con) = case con of + Lit v -> return $ toAtom $ Lit v + ProdCon xs -> toAtom . ProdCon <$> mapM rec xs + SumCon tys tag x -> toAtom <$> (SumCon <$> mapM getRepType tys <*> pure tag <*> rec x) + HeapVal -> return $ toAtom HeapVal + DepPair x y ty -> do + TyCon (DepPairTy ty') <- getRepType $ TyCon $ DepPairTy ty + toAtom <$> (DepPair <$> rec x <*> rec y <*> pure ty') + NewtypeCon _ x -> rec x + Lam _ -> notData + DictConAtom _ -> notData + Eff _ -> notData + TyConAtom _ -> notData + where + rec = toDataAtom + notData = error $ "Not runtime-representable data" +toDataAtom (Stuck _ stuck) = forceStuck stuck >>= \case + CCCon (WithSubst s con) -> withSubst s $ toDataAtom (Con con) + CCLiftSimp _ e -> mkStuck e + CCFun _ -> notData + CCACase _ -> notData -- TODO: make sure we observe this invariant" + CCTabLam _ -> notData -- TODO: make sure we observe this invariant" + where notData = error $ "Not runtime-representable data" + +toDataAtomAssumeNoDecls :: CAtom i -> SimplifyM i o (SAtom o) +toDataAtomAssumeNoDecls x = do + Abs decls result <- buildScoped $ toDataAtom x case decls of Empty -> return result _ -> error "unexpected decls" +withSimplifiedBinder + :: CBinder i i' + -> (forall o'. DExt o o' => Binder SimpIR o o' -> SimplifyM i' o' a) + -> SimplifyM i o a +withSimplifiedBinder (b:>ty) cont = do + tySimp <- getRepType ty + tyCore <- substM ty + withFreshBinder (getNameHint b) tySimp \b' -> do + x <- liftSimpAtom (sink tyCore) (toAtom $ binderVar b') + extendSubst (b@>SubstVal x) $ cont b' + withSimplifiedBinders :: Nest (Binder CoreIR) o any -> (forall o'. DExt o o' => Nest (Binder SimpIR) o o' -> [CAtom o'] -> SimplifyM i o' a) -> SimplifyM i o a withSimplifiedBinders Empty cont = getDistinct >>= \Distinct -> cont Empty [] withSimplifiedBinders (Nest (bCore:>ty) bsCore) cont = do - simpTy <- getRepType ty + simpTy <- dropSubst $ getRepType ty withFreshBinder (getNameHint bCore) simpTy \bSimp -> do - x <- liftSimpAtom (sink ty) (Var $ binderVar bSimp) + x <- liftSimpAtom (sink ty) (toAtom $ binderVar bSimp) -- TODO: carry a substitution instead of doing N^2 work like this Abs bsCore' UnitE <- applySubst (bCore@>SubstVal x) (EmptyAbs bsCore) withSimplifiedBinders bsCore' \bsSimp xs -> @@ -388,22 +382,24 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of Block _ (Abs decls body) -> simplifyDecls decls $ simplifyExpr body App (EffTy _ ty) f xs -> do ty' <- substM ty + f' <- forceConstructor f xs' <- mapM simplifyAtom xs - simplifyApp ty' f xs' - TabApp _ f xs -> do - xs' <- mapM simplifyAtom xs - simplifyTabApp f xs' + simplifyApp ty' f' xs' + TabApp _ f x -> withDistinct do + x' <- simplifyAtom x + f' <- forceConstructor f + simplifyTabApp f' x' Atom x -> simplifyAtom x PrimOp op -> simplifyOp op ApplyMethod (EffTy _ ty) dict i xs -> do ty' <- substM ty xs' <- mapM simplifyAtom xs - dict' <- simplifyAtom dict + Just dict' <- toMaybeDict <$> simplifyAtom dict applyDictMethod ty' dict' i xs' TabCon _ ty xs -> do ty' <- substM ty - tySimp <- getRepType ty' - xs' <- forM xs \x -> simplifyDataAtom x + tySimp <- getRepType ty + xs' <- forM xs \x -> toDataAtom x liftSimpAtom ty' =<< emitExpr (TabCon Nothing tySimp xs') Case scrut alts (EffTy _ resultTy) -> do scrut' <- simplifyAtom scrut @@ -427,17 +423,17 @@ requireReduced expr = reduceExpr expr >>= \case simplifyRefOp :: Emits o => RefOp CoreIR i -> SAtom o -> SimplifyM i o (SAtom o) simplifyRefOp op ref = case op of MExtend (BaseMonoid em cb) x -> do - em' <- simplifyDataAtom em - x' <- simplifyDataAtom x + em' <- toDataAtom em + x' <- toDataAtom x (cb', CoerceReconAbs) <- simplifyLam cb emitRefOp $ MExtend (BaseMonoid em' cb') x' MGet -> emitExpr $ RefOp ref MGet MPut x -> do - x' <- simplifyDataAtom x + x' <- toDataAtom x emitRefOp $ MPut x' MAsk -> emitRefOp MAsk IndexRef _ x -> do - x' <- simplifyDataAtom x + x' <- toDataAtom x emitExpr =<< mkIndexRef ref x' ProjRef _ (ProjectProduct i) -> emitExpr =<< mkProjRef ref (ProjectProduct i) ProjRef _ UnwrapNewtype -> return ref @@ -455,50 +451,41 @@ defuncCaseCore scrut resultTy cont = do let xCoreTy = altBinderTys !! i x' <- liftSimpAtom (sink xCoreTy) x cont i x' - -- TODO: we should use forceConstructor here - Nothing -> case trySelectBranch scrut of - Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg - Nothing -> go scrut where - go = \case - SimpInCore (ACase scrutSimp alts _) -> do - defuncCase scrutSimp resultTy \i x -> do - Abs altb altAtom <- return $ alts !! i - altAtom' <- applySubst (altb @> SubstVal x) altAtom - cont i altAtom' - NewtypeCon con scrut' | isSumCon con -> go scrut' - _ -> nope - nope = error $ "Don't know how to scrutinize non-data " ++ pprint scrut + Nothing -> case scrut of + Con (SumCon _ i arg) -> getDistinct >>= \Distinct -> cont i arg + _ -> error $ "Don't know how to scrutinize non-data " ++ pprint scrut defuncCase :: Emits o => Atom SimpIR o -> Type CoreIR o -> (forall o'. (Emits o', DExt o o') => Int -> SAtom o' -> SimplifyM i o' (CAtom o')) -> SimplifyM i o (CAtom o) defuncCase scrut resultTy cont = do - case trySelectBranch scrut of - Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg - Nothing -> do - scrutTy <- return $ getType scrut - altBinderTys <- caseAltsBinderTys scrutTy + case scrut of + Con (SumCon _ i arg) -> getDistinct >>= \Distinct -> cont i arg + Con _ -> error "scrutinee must be a sum type" + Stuck _ _ -> do + altBinderTys <- caseAltsBinderTys (getType scrut) tryGetRepType resultTy >>= \case Just resultTyData -> do alts' <- forM (enumerate altBinderTys) \(i, bTy) -> do - buildAbs noHint bTy \x -> do - buildBlock $ cont i (sink $ Var x) >>= toDataAtomIgnoreRecon + buildAbs noHint bTy \x -> buildBlock do + ans <- cont i (toAtom $ sink x) + dropSubst $ toDataAtom ans caseExpr <- mkCase scrut resultTyData alts' emitExpr caseExpr >>= liftSimpAtom resultTy Nothing -> do split <- splitDataComponents resultTy (alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do simplifyAlt split bTy $ cont i - let closureSumTy = SumTy closureTys + let closureSumTy = TyCon $ SumType closureTys let newNonDataTy = nonDataTy split alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts'' caseResult <- emitExpr $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> - buildAbs noHint ty \v -> applyRecon (sink recon) (Var v) - let nonDataVal = SimpInCore $ ACase sumVal reconAlts newNonDataTy + buildAbs noHint ty \v -> applyRecon (sink recon) (toAtom v) + nonDataVal <- mkACase sumVal reconAlts newNonDataTy Distinct <- getDistinct fromSplit split dataVal nonDataVal @@ -509,7 +496,7 @@ simplifyAlt -> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o) simplifyAlt split ty cont = do withFreshBinder noHint ty \b -> do - ab <- buildScoped $ cont $ sink $ Var $ binderVar b + ab <- buildScoped $ cont $ sink $ toAtom $ binderVar b (body, recon) <- refreshAbs ab \decls result -> do let locals = toScopeFrag b >>> toScopeFrag decls -- TODO: this might be too cautious. The type only needs to @@ -524,34 +511,34 @@ simplifyAlt split ty cont = do let nonDataType' = ignoreHoistFailure $ hoist b nonDataType return (Abs b body', nonDataType', recon) -simplifyApp :: forall i o. Emits o - => CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyApp resultTy f xs = forceConstructor f \f' -> do - xs' <- mapM sinkM xs - case f' of - CCCon (WithSubst s (Lam (CoreLamExpr _ lam))) -> - withSubst s $ withInstantiated lam xs' \body -> - simplifyExpr body - CCSimpInCore (LiftSimpFun _ lam) -> do - xs'' <- mapM toDataAtomIgnoreRecon xs' - result <- instantiate lam xs'' >>= emitExpr - liftSimpAtom (sink resultTy) result - CCNoInlineFun v _ _ -> simplifyTopFunApp v xs' - CCFFIFun _ f'' -> do - xs'' <- mapM toDataAtomIgnoreRecon xs' - liftSimpAtom (sink resultTy) =<< naryTopApp f'' xs'' +simplifyApp :: Emits o => CType o -> ConcreteCAtom o -> [CAtom o] -> SimplifyM i o (CAtom o) +simplifyApp resultTy f xs = case f of + CCCon (WithSubst s con) -> case con of + Lam (CoreLamExpr _ lam) -> withSubst s $ withInstantiated lam xs \body -> simplifyExpr body _ -> error "not a function" + CCFun ccFun -> case ccFun of + CCLiftSimpFun _ lam -> do + xs' <- dropSubst $ mapM toDataAtom xs + result <- instantiate lam xs' >>= emitExpr + liftSimpAtom resultTy result + CCNoInlineFun v _ _ -> simplifyTopFunApp v xs + CCFFIFun _ f' -> do + xs' <- dropSubst $ mapM toDataAtom xs + liftSimpAtom resultTy =<< naryTopApp f' xs' + CCACase aCase -> forceACase aCase \f' -> simplifyApp (sink resultTy) f' (sink <$> xs) + CCTabLam _ -> error "not a function" + CCLiftSimp _ _ -> error "not a function" simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n) simplifyTopFunApp fName xs = do - fTy@(Pi piTy) <- return $ getType fName + fTy@(TyCon (Pi piTy)) <- return $ getType fName resultTy <- typeOfApp fTy xs isData resultTy >>= \case True -> do (xsGeneralized, runtimeArgs) <- generalizeArgs piTy xs let spec = AppSpecialization fName xsGeneralized Just specializedFunction <- getSpecializedFunction spec >>= emitHoistedEnv - runtimeArgs' <- mapM toDataAtomIgnoreRecon runtimeArgs + runtimeArgs' <- dropSubst $ mapM toDataAtom runtimeArgs liftSimpAtom resultTy =<< naryTopApp specializedFunction runtimeArgs' False -> -- TODO: we should probably just fall back to inlining in this case, @@ -583,40 +570,38 @@ specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do ListE staticArgs' <- applyRename (bs@@>(atomVarName <$> runtimeArgs)) staticArgs naryApp f' staticArgs' -simplifyTabApp :: forall i o. Emits o - => CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyTabApp f [] = simplifyAtom f -simplifyTabApp f xs = forceConstructor f \case - CCSimpInCore sic@(TabLam _ _) -> do - case fromNaryTabLam (length xs) (SimpInCore sic) of - Just (bsCount, ab) -> do - (xsPref, xsRest) <- splitAt bsCount <$> mapM sinkM xs - xsPref' <- mapM toDataAtomIgnoreRecon xsPref - block' <- instantiate ab xsPref' - atom <- emitDecls block' - dropSubst $ simplifyTabApp atom xsRest - Nothing -> error "should never happen" - CCSimpInCore (LiftSimp fTy f') -> do - resultTy <- typeOfTabApp fTy (sink<$>xs) - xs' <- mapM (toDataAtomIgnoreRecon . sink) xs - liftSimpAtom resultTy =<< naryTabApp f' xs' +simplifyTabApp ::Emits o => ConcreteCAtom o -> CAtom o -> SimplifyM i o (CAtom o) +simplifyTabApp f x = case f of + CCLiftSimp fTy f' -> do + f'' <- mkStuck f' + resultTy <- typeOfTabApp fTy x + x' <- dropSubst $ toDataAtom x + liftSimpAtom resultTy =<< tabApp f'' x' + CCACase aCase -> forceACase aCase \f' -> simplifyTabApp f' (sink x) + CCTabLam (WithSubst s (PairE _ (Abs b ab))) -> do + x' <- dropSubst $ toDataAtom x + withSubst s $ extendSubst (b@>(SubstVal x')) $ substM ab _ -> error "not a table" -simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) -simplifyIxType (IxType t ixDict) = do - t' <- getRepType t - IxType t' <$> case ixDict of - IxDictAtom (DictCon (IxFin _ n)) -> do - n' <- toDataAtomIgnoreReconAssumeNoDecls n - return $ IxDictRawFin n' - IxDictAtom d -> do - (dictAbs, params) <- generalizeIxDict d - params' <- mapM toDataAtomIgnoreReconAssumeNoDecls params - sdName <- requireIxDictCache dictAbs - return $ IxDictSpecialized t' sdName params' - IxDictRawFin n -> do - n' <- toDataAtomIgnoreReconAssumeNoDecls n - return $ IxDictRawFin n' +simplifyIxDict :: Dict CoreIR i -> SimplifyM i o (SDict o) +simplifyIxDict (StuckDict _ stuck) = forceStuck stuck >>= \case + CCCon (WithSubst s con) -> case con of + DictConAtom con' -> withSubst s $ simplifyIxDict (DictCon con') + _ -> error "not a dict" + CCLiftSimp _ _ -> error "not a dict" + CCFun _ -> error "not a dict" + CCTabLam _ -> error "not a dict" + CCACase _ -> error "not implemented" -- TODO: consider what to do about this +simplifyIxDict (DictCon con) = case con of + IxFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n + IxRawFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n + InstanceDict _ _ _ -> do + d <- DictCon <$> substM con + (dictAbs, params) <- generalizeIxDict d + params' <- dropSubst $ mapM toDataAtomAssumeNoDecls params + sdName <- requireIxDictCache dictAbs + return $ DictCon $ IxSpecialized sdName params' + DataData _ -> error "not an Ix dict" requireIxDictCache :: (HoistingTopBuilder TopEnvFrag m) => AbsDict n -> m n (Name SpecializedDictNameC n) @@ -642,7 +627,7 @@ simplifyDictMethod absDict@(Abs bs dict) method = do lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict - emitExpr =<< mkApplyMethod dict' (fromEnum method) (Var <$> methodArgs) + emitExpr =<< mkApplyMethod dict' (fromEnum method) (toAtom <$> methodArgs) simplifyTopFunction lamExpr ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) @@ -660,14 +645,9 @@ simplifyLam :: LamExpr CoreIR i -> SimplifyM i o (LamExpr SimpIR o, Abs (Nest (AtomNameBinder SimpIR)) ReconstructAtom o) simplifyLam (LamExpr bsTop body) = case bsTop of - Nest (b:>ty) bs -> do - ty' <- substM ty - tySimp <- getRepType ty' - withFreshBinder (getNameHint b) tySimp \b''@(b':>_) -> do - x <- liftSimpAtom (sink ty') (Var $ binderVar b'') - extendSubst (b@>SubstVal x) do - (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body - return (LamExpr (Nest (b':>tySimp) bs') body', Abs (Nest b' bsRecon) recon) + Nest b bs -> withSimplifiedBinder b \b'@(b'':>_) -> do + (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body + return (LamExpr (Nest b' bs') body', Abs (Nest b'' bsRecon) recon) Empty -> do SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body return (LamExpr Empty body', Abs Empty recon) @@ -681,25 +661,25 @@ data SplitDataNonData n = SplitDataNonData -- bijection between that type and a (data, non-data) pair type. splitDataComponents :: Type CoreIR n -> SimplifyM i n (SplitDataNonData n) splitDataComponents = \case - ProdTy tys -> do + TyCon (ProdType tys) -> do splits <- mapM splitDataComponents tys return $ SplitDataNonData - { dataTy = ProdTy $ map dataTy splits - , nonDataTy = ProdTy $ map nonDataTy splits + { dataTy = TyCon $ ProdType $ map dataTy splits + , nonDataTy = TyCon $ ProdType $ map nonDataTy splits , toSplit = \xProd -> do xs <- getUnpackedReduced xProd (ys, zs) <- unzip <$> forM (zip xs splits) \(x, split) -> toSplit split x - return (ProdVal ys, ProdVal zs) + return (Con $ ProdCon ys, Con $ ProdCon zs) , fromSplit = \xsProd ysProd -> do xs <- getUnpackedReduced xsProd ys <- getUnpackedReduced ysProd zs <- forM (zip (zip xs ys) splits) \((x, y), split) -> fromSplit split x y - return $ ProdVal zs } + return $ Con $ ProdCon zs } ty -> tryGetRepType ty >>= \case Just repTy -> return $ SplitDataNonData { dataTy = repTy , nonDataTy = UnitTy - , toSplit = \x -> (,UnitVal) <$> toDataAtomIgnoreReconAssumeNoDecls x + , toSplit = \x -> (,UnitVal) <$> (dropSubst $ toDataAtomAssumeNoDecls x) , fromSplit = \x _ -> liftSimpAtom (sink ty) x } Nothing -> return $ SplitDataNonData { dataTy = UnitTy @@ -738,11 +718,11 @@ simplifyOp op = case op of MemOp op' -> simplifyGenericOp op' VectorOp op' -> simplifyGenericOp op' RefOp ref eff -> do - ref' <- simplifyDataAtom ref + ref' <- toDataAtom ref liftResult =<< simplifyRefOp eff ref' BinOp binop x' y' -> do - x <- simplifyDataAtom x' - y <- simplifyDataAtom y' + x <- toDataAtom x' + y <- toDataAtom y' liftResult =<< case binop of ISub -> isub x y IAdd -> iadd x y @@ -752,13 +732,13 @@ simplifyOp op = case op of ICmp Equal -> ieq x y _ -> emitExpr $ BinOp binop x y UnOp unOp x' -> do - x <- simplifyDataAtom x' + x <- toDataAtom x' liftResult =<< emitExpr (UnOp unOp x) MiscOp op' -> case op' of Select c' x' y' -> do - c <- simplifyDataAtom c' - x <- simplifyDataAtom x' - y <- simplifyDataAtom y' + c <- toDataAtom c' + x <- toDataAtom x' + y <- toDataAtom y' liftResult =<< select c x y ShowAny x' -> do x <- simplifyAtom x' @@ -776,10 +756,7 @@ simplifyGenericOp -> SimplifyM i o (CAtom o) simplifyGenericOp op = do ty <- substM $ getType op - op' <- traverseOp op - (substM >=> getRepType) - (simplifyAtom >=> toDataAtomIgnoreRecon) - (error "shouldn't have lambda left") + op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left") result <- liftEnvReaderM (peepholeExpr $ toExpr op') >>= emitExprToAtom liftSimpAtom ty result {-# INLINE simplifyGenericOp #-} @@ -787,7 +764,7 @@ simplifyGenericOp op = do pattern CoerceReconAbs :: Abs (Nest b) ReconstructAtom n pattern CoerceReconAbs <- Abs _ (CoerceRecon _) -applyDictMethod :: Emits o => CType o -> CAtom o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) +applyDictMethod :: Emits o => CType o -> CDict o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) applyDictMethod resultTy d i methodArgs = case d of DictCon (InstanceDict _ instanceName instanceArgs) -> dropSubst do instanceArgs' <- mapM simplifyAtom instanceArgs @@ -795,8 +772,9 @@ applyDictMethod resultTy d i methodArgs = case d of withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do let InstanceBody _ methods = body let method = methods !! i - simplifyApp resultTy method methodArgs - DictCon (IxFin _ n) -> applyIxFinMethod (toEnum i) n methodArgs + method' <- forceConstructor method + simplifyApp resultTy method' methodArgs + DictCon (IxFin n) -> applyIxFinMethod (toEnum i) n methodArgs d' -> error $ "Not a simplified dict: " ++ pprint d' where applyIxFinMethod :: EnvReader m => IxMethod -> CAtom n -> [CAtom n] -> m n (CAtom n) @@ -804,31 +782,29 @@ applyDictMethod resultTy d i methodArgs = case d of case (method, args) of (Size, []) -> return n -- result : Nat (Ordinal, [ix]) -> reduceUnwrap ix -- result : Nat - (UnsafeFromOrdinal, [ix]) -> return $ NewtypeCon (FinCon n) ix + (UnsafeFromOrdinal, [ix]) -> return $ toAtom $ NewtypeCon (FinCon n) ix _ -> error "bad ix args" simplifyHof :: Emits o => CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o) simplifyHof resultTy = \case - For d ixTypeCore' lam -> do + For d (IxType ixTy ixDict) lam -> do (lam', Abs (UnaryNest bIx) recon) <- simplifyLam lam - ixTypeCore <- substM ixTypeCore' - ixTypeSimp <- simplifyIxType ixTypeCore - ans <- emitHof $ For d ixTypeSimp lam' + ixTy' <- getRepType ixTy + ixDict' <- simplifyIxDict ixDict + ans <- emitHof $ For d (IxType ixTy' ixDict') lam' case recon of CoerceRecon _ -> liftSimpAtom resultTy ans LamRecon (Abs bsClosure reconResult) -> do - TabPi resultTabTy <- return resultTy - liftM (SimpInCore . TabLam resultTabTy) $ - PairE ixTypeSimp <$> buildAbs noHint (ixTypeType ixTypeSimp) \i -> buildScoped do - i' <- sinkM i - xs <- unpackTelescope bsClosure =<< tabApp (sink ans) (Var i') - applySubst (bIx@>Rename (atomVarName i') <.> bsClosure @@> map SubstVal xs) reconResult + ab <- buildAbs noHint ixTy' \i -> do + xs <- unpackTelescope bsClosure =<< reduceTabApp (sink ans) (toAtom i) + applySubst (bIx@>Rename (atomVarName i) <.> bsClosure @@> map SubstVal xs) reconResult + mkStuck $ TabLam $ IxType ixTy' ixDict' `PairE` ab While body -> do SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body result <- emitHof $ While body' liftSimpAtom resultTy result RunReader r lam -> do - r' <- simplifyDataAtom r + r' <- toDataAtom r (lam', Abs b recon) <- simplifyLam lam ans <- emitHof $ RunReader r' lam' let recon' = ignoreHoistFailure $ hoist b recon @@ -836,7 +812,7 @@ simplifyHof resultTy = \case RunWriter Nothing (BaseMonoid e combine) lam -> do LamExpr (BinaryNest h (_:>RefTy _ wTy)) _ <- return lam wTy' <- substM $ ignoreHoistFailure $ hoist h wTy - e' <- simplifyDataAtom e + e' <- toDataAtom e (combine', CoerceReconAbs) <- simplifyLam combine (lam', Abs b recon) <- simplifyLam lam (ans, w) <- fromPair =<< emitHof (RunWriter Nothing (BaseMonoid e' combine') lam') @@ -846,7 +822,8 @@ simplifyHof resultTy = \case return $ PairVal ans' w' RunWriter _ _ _ -> error "Shouldn't see a RunWriter with a dest in Simplify" RunState Nothing s lam -> do - (s', sTy) <- toDataAtom =<< simplifyAtom s + s' <- toDataAtom s + sTy <- substM $ getType s (lam', Abs b recon) <- simplifyLam lam resultPair <- emitHof $ RunState Nothing s' lam' (ans, sOut) <- fromPair resultPair @@ -864,7 +841,7 @@ simplifyHof resultTy = \case ans <- emitHof $ RunInit body' applyRecon recon ans Linearize lam x -> do - x' <- simplifyDataAtom x + x' <- toDataAtom x -- XXX: we're ignoring the result type here, which only makes sense if we're -- dealing with functions on simple types. (lam', recon) <- simplifyLam lam @@ -876,7 +853,7 @@ simplifyHof resultTy = \case return $ PairVal result' linFun' Transpose lam x -> do (lam', CoerceReconAbs) <- simplifyLam lam - x' <- simplifyDataAtom x + x' <- toDataAtom x result <- transpose lam' x' liftSimpAtom resultTy result CatchException _ body-> do @@ -892,19 +869,26 @@ simplifyHof resultTy = \case -- takes an internal SimpIR Maybe to a CoreIR "prelude Maybe" fmapMaybe - :: (EnvReader m, EnvExtender m) - => SAtom n -> (forall l. DExt n l => SAtom l -> m l (CAtom l)) - -> m n (CAtom n) + :: SAtom n -> (forall l. DExt n l => SAtom l -> SimplifyM i l (CAtom l)) + -> SimplifyM i n (CAtom n) fmapMaybe scrut f = do ~(MaybeTy justTy) <- return $ getType scrut (justAlt, resultJustTy) <- withFreshBinder noHint justTy \b -> do - result <- f (Var $ binderVar b) + result <- f (toAtom $ binderVar b) resultTy <- return $ ignoreHoistFailure $ hoist b (getType result) result' <- preludeJustVal result return (Abs b result', resultTy) nothingAlt <- buildAbs noHint UnitTy \_ -> preludeNothingVal $ sink resultJustTy resultMaybeTy <- makePreludeMaybeTy resultJustTy - return $ SimpInCore $ ACase scrut [nothingAlt, justAlt] resultMaybeTy + mkACase scrut [nothingAlt, justAlt] resultMaybeTy + +mkACase :: SAtom n -> [Abs SBinder CAtom n] -> CType n -> SimplifyM i n (CAtom n) +mkACase scrut alts resultTy = case scrut of + Con (SumCon _ i arg) -> do + Abs b body <- return $ alts !! i + applySubst (b@>SubstVal arg) body + Con _ -> error "not a sum type" + Stuck _ scrut' -> mkStuck $ ACase scrut' alts resultTy -- This is wrong! The correct implementation is below. And yet there's some -- compensatory bug somewhere that means that the wrong answer works and the @@ -918,53 +902,41 @@ preludeJustVal x = return x preludeNothingVal :: EnvReader m => CType n -> m n (CAtom n) preludeNothingVal ty = do con <- preludeMaybeNewtypeCon ty - return $ NewtypeCon con (NothingAtom ty) + return $ Con $ NewtypeCon con (NothingAtom ty) preludeMaybeNewtypeCon :: EnvReader m => CType n -> m n (NewtypeCon n) preludeMaybeNewtypeCon ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" TyConDef sn _ _ _ <- lookupTyCon tyConName - let params = TyConParams [Explicit] [Type ty] + let params = TyConParams [Explicit] [toAtom ty] return $ UserADTData sn tyConName params liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) -liftSimpAtom ty simpAtom = case simpAtom of - Stuck _ -> justLift - RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? - _ -> do - (cons , ty') <- unwrapLeadingNewtypesType ty - atom <- case (ty', simpAtom) of - (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v - (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs - (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do - x1' <- rec t1 x1 - t2' <- applySubst (b@>SubstVal x1') t2 - x2' <- rec t2' x2 - return $ DepPair x1' x2' dpt - _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' - return $ wrapNewtypesData cons atom +liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type" +liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of + Stuck _ stuck -> return $ Stuck ty $ LiftSimp ty stuck + Con con -> Con <$> case (tyCon, con) of + (NewtypeTyCon newtypeCon, _) -> do + (dataCon, repTy) <- unwrapNewtypeType newtypeCon + cAtom <- rec repTy (Con con) + return $ NewtypeCon dataCon cAtom + (BaseType _ , Lit v) -> return $ Lit v + (ProdType tys, ProdCon xs) -> ProdCon <$> zipWithM rec tys xs + (SumType tys, SumCon _ i x) -> SumCon tys i <$> rec (tys!!i) x + (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do + x1' <- rec t1 x1 + t2' <- applySubst (b@>SubstVal x1') t2 + x2' <- rec t2' x2 + return $ DepPair x1' x2' dpt + _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty where rec = liftSimpAtom - justLift = return $ SimpInCore $ LiftSimp ty simpAtom {-# INLINE liftSimpAtom #-} -unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) -unwrapLeadingNewtypesType = \case - NewtypeTyCon tyCon -> do - (dataCon, ty) <- unwrapNewtypeType tyCon - (dataCons, ty') <- unwrapLeadingNewtypesType ty - return (dataCon:dataCons, ty') - ty -> return ([], ty) - liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) -liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f +liftSimpFun (TyCon (Pi piTy)) f = mkStuck $ LiftSimpFun piTy f liftSimpFun _ _ = error "not a pi type" -wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n -wrapNewtypesData [] x = x -wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x - -- === simplifying custom linearizations === linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) @@ -1014,18 +986,20 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do Abs runtimeBs' <$> buildScoped do ListE staticArgs' <- instantiate (sink $ Abs runtimeBs staticArgs) (sink <$> runtimeArgs) fCustom' <- sinkM fCustom + -- TODO: give a HasType instance to ConcreteCAtom resultTy <- typeOfApp (getType fCustom') staticArgs' - pairResult <- dropSubst $ simplifyApp resultTy fCustom' staticArgs' + fCustom'' <- dropSubst $ forceConstructor fCustom' + pairResult <- dropSubst $ simplifyApp resultTy fCustom'' staticArgs' (primalResult, fLin) <- fromPairReduced pairResult - primalResult' <- toDataAtomIgnoreRecon primalResult + primalResult' <- dropSubst $ toDataAtom primalResult let explicitPrimalArgs = drop nImplicit staticArgs' allTangentTys <- forM explicitPrimalArgs \primalArg -> do - tangentType =<< getRepType (getType primalArg) + tangentType =<< dropSubst (getRepType (getType primalArg)) let actives' = drop (length actives - nExplicit) actives activeTangentTys <- catMaybes <$> forM (zip allTangentTys actives') \(t, active) -> return case active of True -> Just t; False -> Nothing - fLin' <- buildUnaryLamExpr "t" (ProdTy activeTangentTys) \activeTangentArg -> do - activeTangentArgs <- getUnpacked $ Var activeTangentArg + fLin' <- buildUnaryLamExpr "t" (toType $ ProdType activeTangentTys) \activeTangentArg -> do + activeTangentArgs <- getUnpacked $ toAtom activeTangentArg ListE allTangentTys' <- sinkM $ ListE allTangentTys tangentArgs <- buildTangentArgs zeros (zip allTangentTys' actives') activeTangentArgs -- TODO: we're throwing away core type information here. Once we @@ -1034,12 +1008,13 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do -- a custom linearization defined for a function on ADTs will -- not work. fLin' <- sinkM fLin - Pi (CorePiType _ _ bs _) <- return $ getType fLin' + TyCon (Pi (CorePiType _ _ bs _)) <- return $ getType fLin' let tangentCoreTys = fromNonDepNest bs tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs resultTyTangent <- typeOfApp (getType fLin') tangentArgs' - tangentResult <- dropSubst $ simplifyApp resultTyTangent fLin' tangentArgs' - toDataAtomIgnoreRecon tangentResult + fLin'' <- dropSubst $ forceConstructor fLin' + tangentResult <- dropSubst $ simplifyApp resultTyTangent fLin'' tangentArgs' + dropSubst $ toDataAtom tangentResult return $ PairE primalResult' fLin' PairE primalFun tangentFun <- defuncLinearized linearized primalFun' <- asTopLam primalFun @@ -1082,7 +1057,7 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') primalFun <- LamExpr bs <$> mkBlock declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do - LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) + LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (toAtom residuals) applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitExpr let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun @@ -1160,7 +1135,7 @@ exceptToMaybeExpr expr = case expr of False -> do v <- emit expr' let ty = getType v - return $ JustAtom ty (Var v) + return $ JustAtom ty (toAtom v) hasExceptions :: SExpr n -> Bool hasExceptions expr = case getEffects expr of diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 038362357..0983908ad 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -287,7 +287,7 @@ toAtomVar v = do lookupAtomSubst :: (IRRep r, SubstReader AtomSubstVal m, EnvReader2 m) => AtomName r i -> m i o (Atom r o) lookupAtomSubst v = do lookupSubstM v >>= \case - Rename v' -> Var <$> toAtomVar v' + Rename v' -> toAtom <$> toAtomVar v' SubstVal x -> return x atomSubstM :: (AtomSubstReader v m, EnvReader2 m, SinkableE e, SubstE AtomSubstVal e) diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 87f4c4a19..8511559f5 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -309,7 +309,7 @@ evalSourceBlock' mname block = case sbContents block of impl <- case expr of WithSrcE _ (UVar _) -> renameSourceNamesUExpr expr >>= \case - WithSrcE _ (UVar (InternalName _ _ (UAtomVar v))) -> Var <$> toAtomVar v + WithSrcE _ (UVar (InternalName _ _ (UAtomVar v))) -> toAtom <$> toAtomVar v _ -> error "Expected a variable" _ -> evalUExpr expr fType <- getType <$> toAtomVar fname' @@ -794,7 +794,7 @@ getBenchRequirement block = case sbLogLevel block of getDexString :: (MonadIO1 m, EnvReader m, Fallible1 m) => Val CoreIR n -> m n String getDexString val = do -- TODO: use a `ByteString` instead of `String` - SimpInCore (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val + Stuck _ (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val Branch [Leaf (IIdxRepVal n), Leaf (IPtrVar ptrName _)] <- return tree PtrBinding (CPU, Scalar Word8Type) (PtrLitVal ptr) <- lookupEnv ptrName liftIO $ peekCStringLen (castPtr ptr, fromIntegral n) @@ -923,7 +923,7 @@ instance Generic TopStateEx where getLinearizationType :: SymbolicZeros -> CType n -> EnvReaderT Except n (Int, Int, CType n) getLinearizationType zeros = \case - Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy)) -> do + TyCon (Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy))) -> do (numIs, numEs) <- getNumImplicits expls refreshAbs (Abs bs resultTy) \bs' resultTy' -> do PairB _ bsE <- return $ splitNestAt numIs bs' @@ -936,9 +936,9 @@ getLinearizationType zeros = \case resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint resultTy' - let tanFunTy = Pi $ nonDepPiType argTanTys Pure resultTanTy + let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) - return (numIs, numEs, Pi fullTy) + return (numIs, numEs, toType $ Pi fullTy) _ -> throw TypeErr $ "Can't define a custom linearization for implicit or impure functions" where getNumImplicits :: Fallible m => [Explicitness] -> m (Int, Int) diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 75e14ec7c..a296087fb 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -47,7 +47,7 @@ transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do outTy' <- applyRename (bsNonlin''@@> nestToNames bsNonlin') outTy withFreshBinder "ct" outTy' \bCT -> do - let ct = Var $ binderVar bCT + let ct = toAtom $ binderVar bCT body' <- buildBlock do inTy <- substNonlin $ binderType bLin withAccumulator inTy \refSubstVal -> @@ -133,7 +133,7 @@ withAccumulator ty cont = do Nothing -> do baseMonoid <- tangentBaseMonoidFor ty getSnd =<< emitRunWriter noHint ty baseMonoid \_ ref -> - cont (LinRef $ Var ref) >> return UnitVal + cont (LinRef $ toAtom ref) >> return UnitVal Just val -> do -- If the accumulator's type is inhabited by just one value, we -- don't need any actual accumulation, and can just return that @@ -202,26 +202,21 @@ transposeExpr expr ct = case expr of transposeAtom xLin ct' -- TODO: Instead, should we handle table application like nonlinear -- expressions, where we just project the reference? - TabApp _ x is -> do - is' <- mapM substNonlin is + TabApp _ x i -> do + i' <- substNonlin i case x of - Stuck (StuckVar v) -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "shouldn't happen" - LinRef ref -> do - refProj <- naryIndexRef ref (toList is') - emitCTToRef refProj ct - LinTrivial -> return () - Stuck (StuckProject _ _ _) -> undefined - -- ProjectElt _ i' x' -> do - -- let (idxs, v) = asNaryProj i' x' - -- lookupSubstM (atomVarName v) >>= \case - -- RenameNonlin _ -> error "an error, probably" - -- LinRef ref -> do - -- ref' <- getNaryProjRef (toList idxs) ref - -- refProj <- naryIndexRef ref' (toList is') - -- emitCTToRef refProj ct - -- LinTrivial -> return () + Stuck _ stuck -> case stuck of + Var v -> do + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> error "shouldn't happen" + LinRef ref -> do + refProj <- indexRef ref i' + emitCTToRef refProj ct + LinTrivial -> return () + StuckProject _ _ -> undefined + StuckTabApp _ _ -> undefined + PtrVar _ _ -> error "not tangent" + RepValAtom _ -> error "not tangent" _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do @@ -303,26 +298,26 @@ transposeMiscOp op _ = case op of transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o () transposeAtom atom ct = case atom of - Con con -> transposeCon con ct - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> notTangent - Stuck (StuckVar v) -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> - -- XXX: we seem to need this case, but it feels like it should be an error! - return () - LinRef ref -> emitCTToRef ref ct - LinTrivial -> return () - Stuck (StuckProject _ _ _) -> error "not implemented" - Stuck (StuckTabApp _ _ _) -> error "not implemented" - -- let (idxs, v) = asNaryProj i' x' - -- lookupSubstM (atomVarName v) >>= \case - -- RenameNonlin _ -> error "an error, probably" - -- LinRef ref -> do - -- ref' <- applyProjectionsRef (toList idxs) ref - -- emitCTToRef ref' ct - -- LinTrivial -> return () - RepValAtom _ -> error "not implemented" + Con con -> transposeCon con ct + Stuck _ stuck -> case stuck of + PtrVar _ _ -> notTangent + Var v -> do + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> + -- XXX: we seem to need this case, but it feels like it should be an error! + return () + LinRef ref -> emitCTToRef ref ct + LinTrivial -> return () + StuckProject _ _ -> error "not implemented" + StuckTabApp _ _ -> error "not implemented" + -- let (idxs, v) = asNaryProj i' x' + -- lookupSubstM (atomVarName v) >>= \case + -- RenameNonlin _ -> error "an error, probably" + -- LinRef ref -> do + -- ref' <- applyProjectionsRef (toList idxs) ref + -- emitCTToRef ref' ct + -- LinTrivial -> return () + RepValAtom _ -> error "not implemented" where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o () @@ -331,7 +326,7 @@ transposeHof hof ct = case hof of UnaryLamExpr b body <- return lam ixTy <- substNonlin ixTy' void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do - ctElt <- tabApp (sink ct) (Var i) + ctElt <- tabApp (sink ct) (toAtom i) extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal RunState Nothing s (BinaryLamExpr hB refB body) -> do @@ -368,6 +363,7 @@ transposeCon con ct = case con of ProdCon xs -> forM_ (enumerate xs) \(i, x) -> proj i ct >>= transposeAtom x SumCon _ _ _ -> notImplemented HeapVal -> notTangent + DepPair _ _ _ -> notImplemented where notTangent = error $ "Not a tangent atom: " ++ pprint (Con con) notImplemented :: HasCallStack => a diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 95380bbca..06f2cbc86 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -48,76 +48,91 @@ import Types.Imp -- === core IR === data Atom (r::IR) (n::S) where - Con :: Con r n -> Atom r n - Stuck :: Stuck r n -> Atom r n - PtrVar :: PtrType -> PtrName n -> Atom r n - DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Atom r n - -- === CoreIR only === - Lam :: CoreLamExpr n -> Atom CoreIR n - Eff :: EffectRow CoreIR n -> Atom CoreIR n - DictCon :: DictCon n -> Atom CoreIR n - NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Atom CoreIR n - TypeAsAtom :: Type CoreIR n -> Atom CoreIR n - -- === Shims between IRs === - SimpInCore :: SimpInCore n -> Atom CoreIR n - RepValAtom :: RepVal SimpIR n -> Atom SimpIR n + Con :: Con r n -> Atom r n + Stuck :: Type r n -> Stuck r n -> Atom r n + deriving (Show, Generic) data Type (r::IR) (n::S) where - TC :: TC r n -> Type r n - TabPi :: TabPiType r n -> Type r n - DepPairTy :: DepPairType r n -> Type r n - StuckTy :: Stuck CoreIR n -> Type CoreIR n - DictTy :: DictType n -> Type CoreIR n - Pi :: CorePiType n -> Type CoreIR n - NewtypeTyCon :: NewtypeTyCon n -> Type CoreIR n + TyCon :: TyCon r n -> Type r n + StuckTy :: CType n -> CStuck n -> Type CoreIR n -data Stuck (r::IR) (n::S) where - StuckVar :: AtomVar r n -> Stuck r n - StuckProject :: Type r n -> Int -> Stuck r n -> Stuck r n - StuckTabApp :: Type r n -> Stuck r n -> [Atom r n] -> Stuck r n - StuckUnwrap :: CType n -> CStuck n -> Stuck CoreIR n - InstantiatedGiven :: CType n -> CStuck n -> [CAtom n] -> Stuck CoreIR n - SuperclassProj :: CType n -> Int -> CStuck n -> Stuck CoreIR n +data Dict (r::IR) (n::S) where + DictCon :: DictCon r n -> Dict r n + StuckDict :: CType n -> CStuck n -> Dict CoreIR n -pattern Var :: AtomVar r n -> Atom r n -pattern Var v = Stuck (StuckVar v) +data Con (r::IR) (n::S) where + Lit :: LitVal -> Con r n + ProdCon :: [Atom r n] -> Con r n + SumCon :: [Type r n] -> Int -> Atom r n -> Con r n -- type, tag, payload + HeapVal :: Con r n + DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Con r n + Lam :: CoreLamExpr n -> Con CoreIR n + Eff :: EffectRow CoreIR n -> Con CoreIR n + NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Con CoreIR n + DictConAtom :: DictCon CoreIR n -> Con CoreIR n + TyConAtom :: TyCon CoreIR n -> Con CoreIR n -pattern TyVar :: AtomVar CoreIR n -> Type CoreIR n -pattern TyVar v = StuckTy (StuckVar v) +data Stuck (r::IR) (n::S) where + Var :: AtomVar r n -> Stuck r n + StuckProject :: Int -> Stuck r n -> Stuck r n + StuckTabApp :: Stuck r n -> Atom r n -> Stuck r n + PtrVar :: PtrType -> PtrName n -> Stuck r n + RepValAtom :: RepVal n -> Stuck SimpIR n + StuckUnwrap :: CStuck n -> Stuck CoreIR n + InstantiatedGiven :: CStuck n -> [CAtom n] -> Stuck CoreIR n + SuperclassProj :: Int -> CStuck n -> Stuck CoreIR n + LiftSimp :: CType n -> Stuck SimpIR n -> Stuck CoreIR n + LiftSimpFun :: CorePiType n -> LamExpr SimpIR n -> Stuck CoreIR n + -- TabLam and ACase are just defunctionalization tools. The result type + -- in both cases should *not* be `Data`. + TabLam :: TabLamExpr n -> Stuck CoreIR n + ACase :: SStuck n -> [Abs SBinder CAtom n] -> CType n -> Stuck CoreIR n + +data TyCon (r::IR) (n::S) where + BaseType :: BaseType -> TyCon r n + ProdType :: [Type r n] -> TyCon r n + SumType :: [Type r n] -> TyCon r n + RefType :: Atom r n -> Type r n -> TyCon r n + HeapType :: TyCon r n + TabPi :: TabPiType r n -> TyCon r n + DepPairTy :: DepPairType r n -> TyCon r n + TypeKind :: TyCon CoreIR n + DictTy :: DictType n -> TyCon CoreIR n + Pi :: CorePiType n -> TyCon CoreIR n + NewtypeTyCon :: NewtypeTyCon n -> TyCon CoreIR n data AtomVar (r::IR) (n::S) = AtomVar { atomVarName :: AtomName r n , atomVarType :: Type r n } deriving (Show, Generic) -type TabLamExpr = PairE (IxType SimpIR) (Abs (Binder SimpIR) (Abs (Nest SDecl) CAtom)) -data SimpInCore (n::S) = - LiftSimp (CType n) (SAtom n) - | LiftSimpFun (CorePiType n) (LamExpr SimpIR n) - | TabLam (TabPiType CoreIR n) (TabLamExpr n) - | ACase (SAtom n) [Abs SBinder CAtom n] (CType n) - deriving (Show, Generic) - -deriving instance IRRep r => Show (Atom r n) -deriving instance IRRep r => Show (Type r n) -deriving instance IRRep r => Show (Stuck r n) - -deriving via WrapE (Atom r) n instance IRRep r => Generic (Atom r n) -deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) -deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) - +deriving instance IRRep r => Show (DictCon r n) +deriving instance IRRep r => Show (Dict r n) +deriving instance IRRep r => Show (Con r n) +deriving instance IRRep r => Show (TyCon r n) +deriving instance IRRep r => Show (Type r n) +deriving instance IRRep r => Show (Stuck r n) + +deriving via WrapE (DictCon r) n instance IRRep r => Generic (DictCon r n) +deriving via WrapE (Dict r) n instance IRRep r => Generic (Dict r n) +deriving via WrapE (Con r) n instance IRRep r => Generic (Con r n) +deriving via WrapE (TyCon r) n instance IRRep r => Generic (TyCon r n) +deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) +deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) + +-- TODO: factor out the EffTy and maybe merge with PrimOp data Expr r n where Block :: EffTy r n -> Block r n -> Expr r n - TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n - TabApp :: Type r n -> Atom r n -> [Atom r n] -> Expr r n - Case :: Atom r n -> [Alt r n] -> EffTy r n -> Expr r n - Atom :: Atom r n -> Expr r n - TabCon :: Maybe (WhenCore r Dict n) -> Type r n -> [Atom r n] -> Expr r n - PrimOp :: PrimOp r n -> Expr r n - Project :: Type r n -> Int -> Atom r n -> Expr r n - App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n - Unwrap :: CType n -> CAtom n -> Expr CoreIR n - ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n + TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n + TabApp :: Type r n -> Atom r n -> Atom r n -> Expr r n + Case :: Atom r n -> [Alt r n] -> EffTy r n -> Expr r n + Atom :: Atom r n -> Expr r n + TabCon :: Maybe (WhenCore r (Dict CoreIR) n) -> Type r n -> [Atom r n] -> Expr r n + PrimOp :: PrimOp r n -> Expr r n + Project :: Type r n -> Int -> Atom r n -> Expr r n + App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n + Unwrap :: CType n -> CAtom n -> Expr CoreIR n + ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n deriving instance IRRep r => Show (Expr r n) deriving via WrapE (Expr r) n instance IRRep r => Generic (Expr r n) @@ -198,20 +213,10 @@ data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) data LamExpr (r::IR) (n::S) where LamExpr :: Nest (Binder r) n l -> Expr r l -> LamExpr r n -data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) +data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) deriving (Show, Generic) -data IxDict r n where - IxDictAtom :: Atom CoreIR n -> IxDict CoreIR n - -- TODO: make these two only available in SimpIR (currently we can't do that - -- because we need CoreIR to be a superset of SimpIR) - -- IxDictRawFin is used post-simplification. It behaves like `Fin`, but - -- it's parameterized by a newtyped-stripped `IxRepVal` instead of `Nat`, and - -- it describes indices of type `IxRepVal`. - IxDictRawFin :: Atom r n -> IxDict r n - IxDictSpecialized :: SType n -> SpecDictName n -> [SAtom n] -> IxDict SimpIR n - -deriving instance IRRep r => Show (IxDict r n) -deriving via WrapE (IxDict r) n instance IRRep r => Generic (IxDict r n) +type TabLamExpr = PairE (IxType SimpIR) (Abs SBinder CAtom) +type IxDict = Dict data IxMethod = Size | Ordinal | UnsafeFromOrdinal deriving (Show, Generic, Enum, Bounded, Eq) @@ -235,7 +240,6 @@ data DepPairType (r::IR) (n::S) where type Val = Atom type Kind = Type -type Dict = Atom CoreIR -- A nest where the annotation of a binder cannot depend on the binders -- introduced before it. You can think of it as introducing a bunch of @@ -277,7 +281,7 @@ instance ToBindersAbs TyConDef DataConDefs CoreIR where toAbs (TyConDef _ _ bs body) = Abs bs body instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where - toAbs (ClassDef _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) + toAbs (ClassDef _ _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) instance ToBindersAbs (TopLam r) (Expr r) r where toAbs (TopLam _ _ lam) = toAbs lam @@ -330,22 +334,6 @@ traverseOp op fType fAtom fLam = do -- === Various ops === -data TC (r::IR) (n::S) where - BaseType :: BaseType -> TC r n - ProdType :: [Type r n] -> TC r n - SumType :: [Type r n] -> TC r n - RefType :: Atom r n -> Type r n -> TC r n - TypeKind :: TC r n -- TODO: `HasCore r` constraint - HeapType :: TC r n - deriving (Show, Generic) - -data Con (r::IR) (n::S) where - Lit :: LitVal -> Con r n - ProdCon :: [Atom r n] -> Con r n - SumCon :: [Type r n] -> Int -> Atom r n -> Con r n -- type, tag, payload - HeapVal :: Con r n - deriving (Show, Generic) - data PrimOp (r::IR) (n::S) where UnOp :: P.UnOp -> Atom r n -> PrimOp r n BinOp :: P.BinOp -> Atom r n -> Atom r n -> PrimOp r n @@ -438,6 +426,7 @@ data RefOp r n = type CAtom = Atom CoreIR type CType = Type CoreIR +type CDict = Dict CoreIR type CStuck = Stuck CoreIR type CBinder = Binder CoreIR type CExpr = Expr CoreIR @@ -450,6 +439,7 @@ type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR +type SDict = Dict SimpIR type SStuck = Stuck SimpIR type SExpr = Expr SimpIR type SBlock = Block SimpIR @@ -459,7 +449,6 @@ type SDecls = Decls SimpIR type SAtomName = AtomName SimpIR type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR -type SRepVal = RepVal SimpIR type SLam = LamExpr SimpIR type STopLam = TopLam SimpIR @@ -479,9 +468,6 @@ data NewtypeTyCon (n::S) = | UserADTType SourceName (TyConName n) (TyConParams n) deriving (Show, Generic) -pattern TypeCon :: SourceName -> TyConName n -> TyConParams n -> CType n -pattern TypeCon s d xs = NewtypeTyCon (UserADTType s d xs) - isSumCon :: NewtypeCon n -> Bool isSumCon = \case UserADTData _ _ _ -> True @@ -494,6 +480,7 @@ type RoleExpl = (ParamRole, Explicitness) data ClassDef (n::S) where ClassDef :: SourceName -- name of class + -> Maybe BuiltinClassName -> [SourceName] -- method source names -> [Maybe SourceName] -- parameter source names -> [RoleExpl] -- parameter info @@ -502,6 +489,8 @@ data ClassDef (n::S) where -> [CorePiType n3] -- method types -> ClassDef n1 +data BuiltinClassName = Data | Ix deriving (Show, Generic, Eq) + data InstanceDef (n::S) where InstanceDef :: ClassName n1 @@ -517,17 +506,23 @@ data InstanceBody (n::S) = [CAtom n] -- method definitions deriving (Show, Generic) -data DictType (n::S) = DictType SourceName (ClassName n) [CAtom n] - deriving (Show, Generic) - -data DictCon (n::S) = - InstanceDict (CType n) (InstanceName n) [CAtom n] - -- Special case for `Ix (Fin n)` (TODO: a more general mechanism for built-in classes and instances) - | IxFin (CType n) (CAtom n) - -- Special case for `Data ` - | DataData (CType n) (CType n) +data DictType (n::S) = + DictType SourceName (ClassName n) [CAtom n] + | IxDictType (CType n) + | DataDictType (CType n) deriving (Show, Generic) +data DictCon (r::IR) (n::S) where + InstanceDict :: CType n -> InstanceName n -> [CAtom n] -> DictCon CoreIR n + -- Special case for `Data ` + DataData :: CType n -> DictCon CoreIR n + IxFin :: CAtom n -> DictCon CoreIR n + -- IxRawFin is like `Fin`, but it's parameterized by a newtyped-stripped + -- `IxRepVal` instead of `Nat`, and it describes indices of type `IxRepVal`. + -- TODO: make is SimpIR-only + IxRawFin :: Atom r n -> DictCon r n + IxSpecialized :: SpecDictName n -> [SAtom n] -> DictCon SimpIR n + -- TODO: Use an IntMap newtype CustomRules (n::S) = CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } @@ -539,7 +534,7 @@ data AtomRules (n::S) = -- === Runtime representations === -data RepVal (r::IR) (n::S) = RepVal (Type r n) (Tree (IExpr n)) +data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) deriving (Show, Generic) -- === envs and modules === @@ -630,7 +625,8 @@ data TopEnvUpdate n = -- TODO: we could add a lot more structure for querying by dict type, caching, etc. data SynthCandidates n = SynthCandidates - { instanceDicts :: M.Map (ClassName n) [InstanceName n] } + { instanceDicts :: M.Map (ClassName n) [InstanceName n] + , ixInstances :: [InstanceName n] } deriving (Show, Generic) emptyImportStatus :: ImportStatus n @@ -795,7 +791,7 @@ newtype TopFunLowerings (n::S) = TopFunLowerings data AtomBinding (r::IR) (n::S) where LetBound :: DeclBinding r n -> AtomBinding r n MiscBound :: Type r n -> AtomBinding r n - TopDataBound :: RepVal SimpIR n -> AtomBinding SimpIR n + TopDataBound :: RepVal n -> AtomBinding SimpIR n SolverBound :: SolverBinding n -> AtomBinding CoreIR n NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n @@ -919,7 +915,7 @@ instance IRRep r => Store (Effect r n) type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e -type AbsDict = Abstracted CoreIR Dict +type AbsDict = Abstracted CoreIR (Dict CoreIR) data SpecializedDictDef n = SpecializedDict @@ -993,67 +989,90 @@ instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where class ToAtom (e::E) (r::IR) | e -> r where toAtom :: e n -> Atom r n -instance ToAtom (Atom r) r where - toAtom = id - +instance ToAtom (Atom r) r where toAtom = id +instance ToAtom (Con r) r where toAtom = Con +instance ToAtom (TyCon CoreIR) CoreIR where toAtom = Con . TyConAtom +instance ToAtom (DictCon CoreIR) CoreIR where toAtom = Con . DictConAtom +instance ToAtom (EffectRow CoreIR) CoreIR where toAtom = Con . Eff +instance ToAtom CoreLamExpr CoreIR where toAtom = Con . Lam +instance ToAtom DictType CoreIR where toAtom = Con . TyConAtom . DictTy +instance ToAtom NewtypeTyCon CoreIR where toAtom = Con . TyConAtom . NewtypeTyCon instance ToAtom (AtomVar r) r where - toAtom = Var - -instance ToAtom (Con r) r where - toAtom = Con - + toAtom (AtomVar v ty) = Stuck ty (Var (AtomVar v ty)) +instance ToAtom RepVal SimpIR where + toAtom (RepVal ty tree) = Stuck ty $ RepValAtom $ RepVal ty tree instance ToAtom (Type CoreIR) CoreIR where - toAtom = TypeAsAtom + toAtom = \case + TyCon con -> Con $ TyConAtom con + StuckTy t s -> Stuck t s +instance ToAtom (Dict CoreIR) CoreIR where + toAtom = \case + DictCon d -> Con $ DictConAtom d + StuckDict t s -> Stuck t s + +-- This can help avoid ambiguous `r` parameter with ToAtom +toAtomR :: ToAtom (e r) r => e r n -> Atom r n +toAtomR = toAtom + +-- === ToType === + +class ToType (e::E) (r::IR) | e -> r where + toType :: e n -> Type r n + +instance ToType (Type r) r where toType = id +instance ToType (TyCon r) r where toType = TyCon +instance ToType (TabPiType r) r where toType = TyCon . TabPi +instance ToType (DepPairType r) r where toType = TyCon . DepPairTy +instance ToType CorePiType CoreIR where toType = TyCon . Pi +instance ToType DictType CoreIR where toType = TyCon . DictTy +instance ToType NewtypeTyCon CoreIR where toType = TyCon . NewtypeTyCon + +toMaybeType :: CAtom n -> Maybe (CType n) +toMaybeType = \case + Stuck t s -> Just $ StuckTy t s + Con (TyConAtom t) -> Just $ TyCon t + _ -> Nothing + +-- === ToDict === + +class ToDict (e::E) (r::IR) | e -> r where + toDict :: e n -> Dict r n + +instance ToDict (Dict r) r where toDict = id +instance ToDict (DictCon r) r where toDict = DictCon +instance ToDict CAtomVar CoreIR where + toDict (AtomVar v ty) = StuckDict ty (Var (AtomVar v ty)) + +toMaybeDict :: CAtom n -> Maybe (CDict n) +toMaybeDict = \case + Stuck t s -> Just $ StuckDict t s + Con (DictConAtom d) -> Just $ DictCon d + _ -> Nothing -- === ToExpr === class ToExpr (e::E) (r::IR) | e -> r where toExpr :: e n -> Expr r n -instance ToExpr (Expr r) r where - toExpr = id - -instance ToExpr (Atom r) r where - toExpr = Atom - -instance ToExpr (AtomVar r) r where - toExpr = toExpr . toAtom - -instance ToExpr (PrimOp r) r where - toExpr = PrimOp - -instance ToExpr (MiscOp r) r where - toExpr = PrimOp . MiscOp - -instance ToExpr (MemOp r) r where - toExpr = PrimOp . MemOp - -instance ToExpr (VectorOp r) r where - toExpr = PrimOp . VectorOp - -instance ToExpr (TypedHof r) r where - toExpr = PrimOp . Hof - -instance ToExpr (DAMOp SimpIR) SimpIR where - toExpr = PrimOp . DAMOp +instance ToExpr (Expr r) r where toExpr = id +instance ToExpr (Atom r) r where toExpr = Atom +instance ToExpr (Con r) r where toExpr = Atom . Con +instance ToExpr (AtomVar r) r where toExpr = toExpr . toAtom +instance ToExpr (PrimOp r) r where toExpr = PrimOp +instance ToExpr (MiscOp r) r where toExpr = PrimOp . MiscOp +instance ToExpr (MemOp r) r where toExpr = PrimOp . MemOp +instance ToExpr (VectorOp r) r where toExpr = PrimOp . VectorOp +instance ToExpr (TypedHof r) r where toExpr = PrimOp . Hof +instance ToExpr (DAMOp SimpIR) SimpIR where toExpr = PrimOp . DAMOp -- === Pattern synonyms === --- XXX: only use this pattern when you're actually expecting a type. If it's --- a Var, it doesn't check whether it's a type. -pattern Type :: CType n -> CAtom n -pattern Type t <- ((\case Stuck e -> Just (StuckTy e) - TypeAsAtom t -> Just t - _ -> Nothing) -> Just t) - where Type (StuckTy e) = Stuck e - Type t = TypeAsAtom t - pattern IdxRepScalarBaseTy :: ScalarBaseType pattern IdxRepScalarBaseTy = Word32Type -- Type used to represent indices and sizes at run-time pattern IdxRepTy :: Type r n -pattern IdxRepTy = TC (BaseType (Scalar Word32Type)) +pattern IdxRepTy = TyCon (BaseType (Scalar Word32Type)) pattern IdxRepVal :: Word32 -> Atom r n pattern IdxRepVal x = Con (Lit (Word32Lit x)) @@ -1066,7 +1085,7 @@ pattern IIdxRepTy = Scalar Word32Type -- Type used to represent sum type tags at run-time pattern TagRepTy :: Type r n -pattern TagRepTy = TC (BaseType (Scalar Word8Type)) +pattern TagRepTy = TyCon (BaseType (Scalar Word8Type)) pattern TagRepVal :: Word8 -> Atom r n pattern TagRepVal x = Con (Lit (Word8Lit x)) @@ -1078,64 +1097,52 @@ charRepVal :: Char -> Atom r n charRepVal c = Con (Lit (Word8Lit (fromIntegral $ fromEnum c))) pattern Word8Ty :: Type r n -pattern Word8Ty = TC (BaseType (Scalar Word8Type)) - -pattern ProdTy :: [Type r n] -> Type r n -pattern ProdTy tys = TC (ProdType tys) - -pattern ProdVal :: [Atom r n] -> Atom r n -pattern ProdVal xs = Con (ProdCon xs) - -pattern SumTy :: [Type r n] -> Type r n -pattern SumTy cs = TC (SumType cs) - -pattern SumVal :: [Type r n] -> Int -> Atom r n -> Atom r n -pattern SumVal tys tag payload = Con (SumCon tys tag payload) +pattern Word8Ty = TyCon (BaseType (Scalar Word8Type)) pattern PairVal :: Atom r n -> Atom r n -> Atom r n pattern PairVal x y = Con (ProdCon [x, y]) pattern PairTy :: Type r n -> Type r n -> Type r n -pattern PairTy x y = TC (ProdType [x, y]) +pattern PairTy x y = TyCon (ProdType [x, y]) pattern UnitVal :: Atom r n pattern UnitVal = Con (ProdCon []) pattern UnitTy :: Type r n -pattern UnitTy = TC (ProdType []) +pattern UnitTy = TyCon (ProdType []) pattern BaseTy :: BaseType -> Type r n -pattern BaseTy b = TC (BaseType b) +pattern BaseTy b = TyCon (BaseType b) pattern PtrTy :: PtrType -> Type r n -pattern PtrTy ty = BaseTy (PtrType ty) +pattern PtrTy ty = TyCon (BaseType (PtrType ty)) pattern RefTy :: Atom r n -> Type r n -> Type r n -pattern RefTy r a = TC (RefType r a) +pattern RefTy r a = TyCon (RefType r a) pattern RawRefTy :: Type r n -> Type r n -pattern RawRefTy a = TC (RefType (Con HeapVal) a) +pattern RawRefTy a = TyCon (RefType (Con HeapVal) a) pattern TabTy :: IxDict r n -> Binder r n l -> Type r l -> Type r n -pattern TabTy d b body = TabPi (TabPiType d b body) +pattern TabTy d b body = TyCon (TabPi (TabPiType d b body)) pattern FinTy :: Atom CoreIR n -> Type CoreIR n -pattern FinTy n = NewtypeTyCon (Fin n) +pattern FinTy n = TyCon (NewtypeTyCon (Fin n)) pattern NatTy :: Type CoreIR n -pattern NatTy = NewtypeTyCon Nat +pattern NatTy = TyCon (NewtypeTyCon Nat) pattern NatVal :: Word32 -> Atom CoreIR n -pattern NatVal n = NewtypeCon NatCon (IdxRepVal n) +pattern NatVal n = Con (NewtypeCon NatCon (IdxRepVal n)) -pattern TyKind :: Kind r n -pattern TyKind = TC TypeKind +pattern TyKind :: Kind CoreIR n +pattern TyKind = TyCon TypeKind pattern EffKind :: Kind CoreIR n -pattern EffKind = NewtypeTyCon EffectRowKind +pattern EffKind = TyCon (NewtypeTyCon EffectRowKind) pattern FinConst :: Word32 -> Type CoreIR n -pattern FinConst n = NewtypeTyCon (Fin (NatVal n)) +pattern FinConst n = TyCon (NewtypeTyCon (Fin (NatVal n))) pattern NullaryLamExpr :: Expr r n -> LamExpr r n pattern NullaryLamExpr body = LamExpr Empty body @@ -1147,13 +1154,13 @@ pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Expr r l2 -> LamExpr pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body pattern MaybeTy :: Type r n -> Type r n -pattern MaybeTy a = SumTy [UnitTy, a] +pattern MaybeTy a = TyCon (SumType [UnitTy, a]) pattern NothingAtom :: Type r n -> Atom r n -pattern NothingAtom a = SumVal [UnitTy, a] 0 UnitVal +pattern NothingAtom a = Con (SumCon [UnitTy, a] 0 UnitVal) pattern JustAtom :: Type r n -> Atom r n -> Atom r n -pattern JustAtom a x = SumVal [UnitTy, a] 1 x +pattern JustAtom a x = Con (SumCon [UnitTy, a] 1 x) pattern BoolTy :: Type r n pattern BoolTy = Word8Ty @@ -1175,16 +1182,16 @@ instance HoistableE AtomRules instance AlphaEqE AtomRules instance RenameE AtomRules -instance IRRep r => GenericE (RepVal r) where - type RepE (RepVal r) = PairE (Type r) (ComposeE Tree IExpr) +instance GenericE RepVal where + type RepE RepVal= PairE SType (ComposeE Tree IExpr) fromE (RepVal ty tree) = ty `PairE` ComposeE tree toE (ty `PairE` ComposeE tree) = RepVal ty tree -instance IRRep r => SinkableE (RepVal r) -instance IRRep r => RenameE (RepVal r) -instance IRRep r => HoistableE (RepVal r) -instance IRRep r => AlphaHashableE (RepVal r) -instance IRRep r => AlphaEqE (RepVal r) +instance SinkableE RepVal +instance RenameE RepVal +instance HoistableE RepVal +instance AlphaHashableE RepVal +instance AlphaEqE RepVal instance GenericE CustomRules where type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) @@ -1462,88 +1469,15 @@ instance IRRep r => RenameE (RefOp r) instance IRRep r => AlphaEqE (RefOp r) instance IRRep r => AlphaHashableE (RefOp r) -instance GenericE SimpInCore where - type RepE SimpInCore = EitherE4 - {- LiftSimp -} (CType `PairE` SAtom) - {- LiftSimpFun -} (CorePiType `PairE` LamExpr SimpIR) - {- TabLam -} (TabPiType CoreIR `PairE` TabLamExpr) - {- ACase -} (SAtom `PairE` ListE (Abs SBinder CAtom) `PairE` CType) +instance IRRep r => GenericE (Atom r) where + type RepE (Atom r) = EitherE (PairE (Type r) (Stuck r)) (Con r) fromE = \case - LiftSimp ty x -> Case0 $ ty `PairE` x - LiftSimpFun ty x -> Case1 $ ty `PairE` x - TabLam ty lam -> Case2 $ ty `PairE` lam - ACase scrut alts resultTy -> Case3 $ scrut `PairE` ListE alts `PairE` resultTy + Stuck t x -> LeftE (PairE t x) + Con x -> RightE x {-# INLINE fromE #-} - toE = \case - Case0 (ty `PairE` x) -> LiftSimp ty x - Case1 (ty `PairE` x) -> LiftSimpFun ty x - Case2 (ty `PairE` lam) -> TabLam ty lam - Case3 (x `PairE` ListE alts `PairE` ty) -> ACase x alts ty - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE SimpInCore -instance HoistableE SimpInCore -instance RenameE SimpInCore -instance AlphaEqE SimpInCore -instance AlphaHashableE SimpInCore - -instance IRRep r => GenericE (Atom r) where - -- As tempting as it might be to reorder cases here, the current permutation - -- was chosen as to make GHC inliner confident enough to simplify through - -- toE/fromE entirely. If you wish to modify the order, please consult the - -- GHC Core dump to make sure you haven't regressed this optimization. - type RepE (Atom r) = EitherE3 - (EitherE3 - {- Stuck -} (Stuck r) - {- Lam -} (WhenCore r CoreLamExpr) - {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r) - ) (EitherE3 - {- DictCon -} (WhenCore r DictCon) - {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) - {- Con -} (Con r) - ) (EitherE5 - {- Eff -} ( WhenCore r (EffectRow r)) - {- PtrVar -} (LiftE PtrType `PairE` PtrName) - {- RepValAtom -} ( WhenSimp r (RepVal r)) - {- SimpInCore -} ( WhenCore r SimpInCore) - {- TypeAsAtom -} ( WhenCore r (Type CoreIR)) - ) - - fromE atom = case atom of - Stuck x -> Case0 (Case0 x) - Lam lamExpr -> Case0 (Case1 (WhenIRE lamExpr)) - DepPair l r ty -> Case0 (Case2 $ l `PairE` r `PairE` ty) - DictCon d -> Case1 $ Case0 $ WhenIRE d - NewtypeCon c x -> Case1 $ Case1 $ WhenIRE (c `PairE` x) - Con con -> Case1 $ Case2 con - Eff effs -> Case2 $ Case0 $ WhenIRE effs - PtrVar t v -> Case2 $ Case1 $ LiftE t `PairE` v - RepValAtom rv -> Case2 $ Case2 $ WhenIRE $ rv - SimpInCore x -> Case2 $ Case3 $ WhenIRE x - TypeAsAtom t -> Case2 $ Case4 $ WhenIRE t - {-# INLINE fromE #-} - - toE atom = case atom of - Case0 val -> case val of - Case0 e -> Stuck e - Case1 (WhenIRE (lamExpr)) -> Lam lamExpr - Case2 (l `PairE` r `PairE` ty) -> DepPair l r ty - _ -> error "impossible" - Case1 val -> case val of - Case0 (WhenIRE d) -> DictCon d - Case1 (WhenIRE (c `PairE` x)) -> NewtypeCon c x - Case2 con -> Con con - _ -> error "impossible" - Case2 val -> case val of - Case0 (WhenIRE effs) -> Eff effs - Case1 (LiftE t `PairE` v) -> PtrVar t v - Case2 (WhenIRE rv) -> RepValAtom rv - Case3 (WhenIRE x) -> SimpInCore x - Case4 (WhenIRE t) -> TypeAsAtom t - _ -> error "impossible" - _ -> error "impossible" + LeftE (PairE t x) -> Stuck t x + RightE x -> Con x {-# INLINE toE #-} instance IRRep r => SinkableE (Atom r) @@ -1553,29 +1487,55 @@ instance IRRep r => AlphaHashableE (Atom r) instance IRRep r => RenameE (Atom r) instance IRRep r => GenericE (Stuck r) where - type RepE (Stuck r) = EitherE6 - {- StuckVar -} (AtomVar r) - {- StuckProject -} (Type r `PairE` LiftE Int `PairE` Stuck r) - {- StuckTabApp -} (Type r `PairE` Stuck r `PairE` ListE (Atom r)) - {- StuckUnwrap -} (WhenCore r (CType `PairE` CStuck)) - {- InstantiatedGiven -} (WhenCore r (CType `PairE` CStuck `PairE` ListE CAtom)) - {- SuperclassProj -} (WhenCore r (CType `PairE` LiftE Int `PairE` CStuck)) + type RepE (Stuck r) = EitherE2 + (EitherE6 + {- Var -} (AtomVar r) + {- StuckProject -} (LiftE Int `PairE` Stuck r) + {- StuckTabApp -} (Stuck r `PairE` Atom r) + {- StuckUnwrap -} (WhenCore r (CStuck)) + {- InstantiatedGiven -} (WhenCore r (CStuck `PairE` ListE CAtom)) + {- SuperclassProj -} (WhenCore r (LiftE Int `PairE` CStuck)) + ) (EitherE6 + {- PtrVar -} (LiftE PtrType `PairE` PtrName) + {- RepValAtom -} (WhenSimp r RepVal) + {- LiftSimp -} (WhenCore r (CType `PairE` SStuck)) + {- LiftSimpFun -} (WhenCore r (CorePiType `PairE` LamExpr SimpIR)) + {- TabLam -} (WhenCore r TabLamExpr) + {- ACase -} (WhenCore r (SStuck `PairE` ListE (Abs SBinder CAtom) `PairE` CType)) + ) + fromE = \case - StuckVar v -> Case0 v - StuckProject t i e -> Case1 $ t `PairE` LiftE i `PairE` e - StuckTabApp t f x -> Case2 $ t `PairE` f `PairE` ListE x - StuckUnwrap t e -> Case3 $ WhenIRE $ t `PairE` e - InstantiatedGiven t e xs -> Case4 $ WhenIRE $ t `PairE` e `PairE` ListE xs - SuperclassProj t i e -> Case5 $ WhenIRE $ t `PairE` LiftE i `PairE` e + Var v -> Case0 $ Case0 v + StuckProject i e -> Case0 $ Case1 $ LiftE i `PairE` e + StuckTabApp f x -> Case0 $ Case2 $ f `PairE` x + StuckUnwrap e -> Case0 $ Case3 $ WhenIRE $ e + InstantiatedGiven e xs -> Case0 $ Case4 $ WhenIRE $ e `PairE` ListE xs + SuperclassProj i e -> Case0 $ Case5 $ WhenIRE $ LiftE i `PairE` e + PtrVar t p -> Case1 $ Case0 $ LiftE t `PairE` p + RepValAtom r -> Case1 $ Case1 $ WhenIRE r + LiftSimp t x -> Case1 $ Case2 $ WhenIRE $ t `PairE` x + LiftSimpFun t lam -> Case1 $ Case3 $ WhenIRE $ t `PairE` lam + TabLam lam -> Case1 $ Case4 $ WhenIRE lam + ACase s alts ty -> Case1 $ Case5 $ WhenIRE $ s `PairE` ListE alts `PairE` ty {-# INLINE fromE #-} toE = \case - Case0 v -> StuckVar v - Case1 (t `PairE` LiftE i `PairE` e) -> StuckProject t i e - Case2 (t `PairE` f `PairE` ListE x) -> StuckTabApp t f x - Case3 (WhenIRE (t `PairE` e)) -> StuckUnwrap t e - Case4 (WhenIRE (t `PairE` e `PairE` ListE xs)) -> InstantiatedGiven t e xs - Case5 (WhenIRE (t `PairE` LiftE i `PairE` e)) -> SuperclassProj t i e + Case0 con -> case con of + Case0 v -> Var v + Case1 (LiftE i `PairE` e) -> StuckProject i e + Case2 (f `PairE` x) -> StuckTabApp f x + Case3 (WhenIRE e) -> StuckUnwrap e + Case4 (WhenIRE (e `PairE` ListE xs)) -> InstantiatedGiven e xs + Case5 (WhenIRE (LiftE i `PairE` e)) -> SuperclassProj i e + _ -> error "impossible" + Case1 con -> case con of + Case0 (LiftE t `PairE` p) -> PtrVar t p + Case1 (WhenIRE r) -> RepValAtom r + Case2 (WhenIRE (t `PairE` x)) -> LiftSimp t x + Case3 (WhenIRE (t `PairE` lam)) -> LiftSimpFun t lam + Case4 (WhenIRE lam) -> TabLam lam + Case5 (WhenIRE (s `PairE` ListE alts `PairE` ty)) -> ACase s alts ty + _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1601,7 +1561,6 @@ instance Eq (AtomVar r n) where instance IRRep r => SinkableE (AtomVar r) instance IRRep r => HoistableE (AtomVar r) - -- We ignore the type annotation because it should be determined by the var instance IRRep r => AlphaEqE (AtomVar r) where alphaEqE (AtomVar v _) (AtomVar v' _) = alphaEqE v v' @@ -1613,34 +1572,14 @@ instance IRRep r => AlphaHashableE (AtomVar r) where instance IRRep r => RenameE (AtomVar r) instance IRRep r => GenericE (Type r) where - type RepE (Type r) = EitherE7 - {- StuckTy -} (WhenCore r CStuck) - {- Pi -} (WhenCore r CorePiType) - {- TabPi -} (TabPiType r) - {- DepPairTy -} (DepPairType r) - {- DictTy -} (WhenCore r DictType) - {- NewtypeTyCon -} (WhenCore r NewtypeTyCon) - {- TC -} (TC r) - + type RepE (Type r) = EitherE (WhenCore r (PairE (Type r) (Stuck r))) (TyCon r) fromE = \case - StuckTy e -> Case0 $ WhenIRE e - Pi t -> Case1 $ WhenIRE t - TabPi t -> Case2 t - DepPairTy t -> Case3 t - DictTy d -> Case4 $ WhenIRE d - NewtypeTyCon t -> Case5 $ WhenIRE t - TC con -> Case6 $ con + StuckTy t x -> LeftE (WhenIRE (PairE t x)) + TyCon x -> RightE x {-# INLINE fromE #-} - toE = \case - Case0 (WhenIRE e) -> StuckTy e - Case1 (WhenIRE t) -> Pi t - Case2 t -> TabPi t - Case3 t -> DepPairTy t - Case4 (WhenIRE d) -> DictTy d - Case5 (WhenIRE t) -> NewtypeTyCon t - Case6 con -> TC con - _ -> error "impossible" + LeftE (WhenIRE (PairE t x)) -> StuckTy t x + RightE x -> TyCon x {-# INLINE toE #-} instance IRRep r => SinkableE (Type r) @@ -1653,21 +1592,21 @@ instance IRRep r => GenericE (Expr r) where type RepE (Expr r) = EitherE2 ( EitherE6 {- App -} (WhenCore r (EffTy r `PairE` Atom r `PairE` ListE (Atom r))) - {- TabApp -} (Type r `PairE` Atom r `PairE` ListE (Atom r)) + {- TabApp -} (Type r `PairE` Atom r `PairE` Atom r) {- Case -} (Atom r `PairE` ListE (Alt r) `PairE` EffTy r) {- Atom -} (Atom r) {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) {- Block -} (EffTy r `PairE` Block r) ) ( EitherE5 - {- TabCon -} (MaybeE (WhenCore r Dict) `PairE` Type r `PairE` ListE (Atom r)) + {- TabCon -} (MaybeE (WhenCore r (Dict CoreIR)) `PairE` Type r `PairE` ListE (Atom r)) {- PrimOp -} (PrimOp r) {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r))) {- Project -} (Type r `PairE` LiftE Int `PairE` Atom r) {- Unwrap -} (WhenCore r (CType `PairE` CAtom))) fromE = \case App et f xs -> Case0 $ Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) - TabApp t f xs -> Case0 $ Case1 (t `PairE` f `PairE` ListE xs) + TabApp t f x -> Case0 $ Case1 (t `PairE` f `PairE` x) Case e alts effTy -> Case0 $ Case2 (e `PairE` ListE alts `PairE` effTy) Atom x -> Case0 $ Case3 (x) TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) @@ -1681,7 +1620,7 @@ instance IRRep r => GenericE (Expr r) where toE = \case Case0 case0 -> case case0 of Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> App et f xs - Case1 (t `PairE` f `PairE` ListE xs) -> TabApp t f xs + Case1 (t `PairE` f `PairE` x) -> TabApp t f x Case2 (e `PairE` ListE alts `PairE` effTy) -> Case e alts effTy Case3 (x) -> Atom x Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> TopApp et f xs @@ -1846,27 +1785,49 @@ instance IRRep r => AlphaEqE (MiscOp r) instance IRRep r => AlphaHashableE (MiscOp r) instance IRRep r => RenameE (MiscOp r) -instance GenericOp Con where - type OpConst Con r = Either LitVal P.Con - fromOp = \case - Lit l -> GenericOpRep (Left l) [] [] [] - ProdCon xs -> GenericOpRep (Right P.ProdCon) [] xs [] - SumCon tys i x -> GenericOpRep (Right (P.SumCon i)) tys [x] [] - HeapVal -> GenericOpRep (Right P.HeapVal) [] [] [] - {-# INLINE fromOp #-} - - toOp = \case - GenericOpRep (Left l) [] [] [] -> Just $ Lit l - GenericOpRep (Right P.ProdCon) [] xs [] -> Just $ ProdCon xs - GenericOpRep (Right (P.SumCon i)) tys [x] [] -> Just $ SumCon tys i x - GenericOpRep (Right P.HeapVal) [] [] [] -> Just $ HeapVal - _ -> Nothing - {-# INLINE toOp #-} - instance IRRep r => GenericE (Con r) where - type RepE (Con r) = GenericOpRep (OpConst Con r) r - fromE = fromEGenericOpRep - toE = toEGenericOpRep + type RepE (Con r) = EitherE2 + (EitherE5 + {- Lit -} (LiftE LitVal) + {- ProdCon -} (ListE (Atom r)) + {- SumCon -} (ListE (Type r) `PairE` LiftE Int `PairE` Atom r) + {- HeapVal -} UnitE + {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r)) + (WhenCore r (EitherE5 + {- Lam -} CoreLamExpr + {- Eff -} (EffectRow CoreIR) + {- NewtypeCon -} (NewtypeCon `PairE` CAtom) + {- DictConAtom -} (DictCon CoreIR) + {- TyConAtom -} (TyCon CoreIR))) + fromE = \case + Lit l -> Case0 $ Case0 $ LiftE l + ProdCon xs -> Case0 $ Case1 $ ListE xs + SumCon ts i x -> Case0 $ Case2 $ ListE ts `PairE` LiftE i `PairE` x + HeapVal -> Case0 $ Case3 $ UnitE + DepPair x y t -> Case0 $ Case4 $ x `PairE` y `PairE` t + Lam lam -> Case1 $ WhenIRE $ Case0 lam + Eff effs -> Case1 $ WhenIRE $ Case1 effs + NewtypeCon con x -> Case1 $ WhenIRE $ Case2 $ con `PairE` x + DictConAtom con -> Case1 $ WhenIRE $ Case3 con + TyConAtom tc -> Case1 $ WhenIRE $ Case4 tc + {-# INLINE fromE #-} + toE = \case + Case0 con -> case con of + Case0 (LiftE l) -> Lit l + Case1 (ListE xs) -> ProdCon xs + Case2 (ListE ts `PairE` LiftE i `PairE` x) -> SumCon ts i x + Case3 UnitE -> HeapVal + Case4 (x `PairE` y `PairE` t) -> DepPair x y t + _ -> error "impossible" + Case1 (WhenIRE con) -> case con of + Case0 lam -> Lam lam + Case1 effs -> Eff effs + Case2 (con' `PairE` x) -> NewtypeCon con' x + Case3 con' -> DictConAtom con' + Case4 tc -> TyConAtom tc + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} instance IRRep r => SinkableE (Con r) instance IRRep r => HoistableE (Con r) @@ -1874,36 +1835,61 @@ instance IRRep r => AlphaEqE (Con r) instance IRRep r => AlphaHashableE (Con r) instance IRRep r => RenameE (Con r) -instance GenericOp TC where - type OpConst TC r = Either BaseType P.TC - fromOp = \case - BaseType b -> GenericOpRep (Left b) [] [] [] - ProdType ts -> GenericOpRep (Right P.ProdType) ts [] [] - SumType ts -> GenericOpRep (Right P.SumType) ts [] [] - RefType h t -> GenericOpRep (Right P.RefType) [t] [h] [] - TypeKind -> GenericOpRep (Right P.TypeKind) [] [] [] - HeapType -> GenericOpRep (Right P.HeapType) [] [] [] - {-# INLINE fromOp #-} - - toOp = \case - GenericOpRep (Left b) [] [] [] -> Just (BaseType b) - GenericOpRep (Right P.ProdType) ts [] [] -> Just (ProdType ts) - GenericOpRep (Right P.SumType) ts [] [] -> Just (SumType ts) - GenericOpRep (Right P.RefType) [t] [h] [] -> Just (RefType h t) - GenericOpRep (Right P.TypeKind) [] [] [] -> Just TypeKind - GenericOpRep (Right P.HeapType) [] [] [] -> Just HeapType - GenericOpRep _ _ _ _ -> Nothing - {-# INLINE toOp #-} +instance IRRep r => GenericE (TyCon r) where + type RepE (TyCon r) = EitherE3 + (EitherE4 + {- BaseType -} (LiftE BaseType) + {- ProdType -} (ListE (Type r)) + {- SumType -} (ListE (Type r)) + {- RefType -} (Atom r `PairE` Type r)) + (EitherE4 + {- HeapType -} UnitE + {- TabPi -} (TabPiType r) + {- DepPairTy -} (DepPairType r) + {- TypeKind -} (WhenCore r UnitE)) + (EitherE3 + {- DictTy -} (WhenCore r DictType) + {- Pi -} (WhenCore r CorePiType) + {- NewtypeTyCon -} (WhenCore r NewtypeTyCon)) + fromE = \case + BaseType b -> Case0 (Case0 (LiftE b)) + ProdType ts -> Case0 (Case1 (ListE ts)) + SumType ts -> Case0 (Case2 (ListE ts)) + RefType h t -> Case0 (Case3 (h `PairE` t)) + HeapType -> Case1 (Case0 UnitE) + TabPi t -> Case1 (Case1 t) + DepPairTy t -> Case1 (Case2 t) + TypeKind -> Case1 (Case3 (WhenIRE UnitE)) + DictTy t -> Case2 (Case0 (WhenIRE t)) + Pi t -> Case2 (Case1 (WhenIRE t)) + NewtypeTyCon t -> Case2 (Case2 (WhenIRE t)) + {-# INLINE fromE #-} + toE = \case + Case0 c -> case c of + Case0 (LiftE b ) -> BaseType b + Case1 (ListE ts) -> ProdType ts + Case2 (ListE ts) -> SumType ts + Case3 (h `PairE` t) -> RefType h t + _ -> error "impossible" + Case1 c -> case c of + Case0 UnitE -> HeapType + Case1 t -> TabPi t + Case2 t -> DepPairTy t + Case3 (WhenIRE UnitE) -> TypeKind + _ -> error "impossible" + Case2 c -> case c of + Case0 (WhenIRE t) -> DictTy t + Case1 (WhenIRE t) -> Pi t + Case2 (WhenIRE t) -> NewtypeTyCon t + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} -instance IRRep r => GenericE (TC r) where - type RepE (TC r) = GenericOpRep (OpConst TC r) r - fromE = fromEGenericOpRep - toE = toEGenericOpRep -instance IRRep r => SinkableE (TC r) -instance IRRep r => HoistableE (TC r) -instance IRRep r => AlphaEqE (TC r) -instance IRRep r => AlphaHashableE (TC r) -instance IRRep r => RenameE (TC r) +instance IRRep r => SinkableE (TyCon r) +instance IRRep r => HoistableE (TyCon r) +instance IRRep r => AlphaEqE (TyCon r) +instance IRRep r => AlphaHashableE (TyCon r) +instance IRRep r => RenameE (TyCon r) instance IRRep r => GenericB (NonDepNest r ann) where type RepB (NonDepNest r ann) = (LiftB (ListE ann)) `PairB` Nest (AtomNameBinder r) @@ -1922,13 +1908,13 @@ deriving instance (Show (ann n)) => IRRep r => Show (NonDepNest r ann n l) instance GenericE ClassDef where type RepE ClassDef = - LiftE (SourceName, [SourceName], [Maybe SourceName], [RoleExpl]) + LiftE (SourceName, Maybe BuiltinClassName, [SourceName], [Maybe SourceName], [RoleExpl]) `PairE` Abs (Nest CBinder) (Abs (Nest CBinder) (ListE CorePiType)) - fromE (ClassDef name names paramNames roleExpls b scs tys) = - LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) + fromE (ClassDef name builtin names paramNames roleExpls b scs tys) = + LiftE (name, builtin, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) {-# INLINE fromE #-} - toE (LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = - ClassDef name names paramNames roleExpls b scs tys + toE (LiftE (name, builtin, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = + ClassDef name builtin names paramNames roleExpls b scs tys {-# INLINE toE #-} instance SinkableE ClassDef @@ -1939,7 +1925,7 @@ instance RenameE ClassDef deriving instance Show (ClassDef n) deriving via WrapE ClassDef n instance Generic (ClassDef n) instance HasSourceName (ClassDef n) where - getSourceName = \case ClassDef name _ _ _ _ _ _ -> name + getSourceName = \case ClassDef name _ _ _ _ _ _ _ -> name instance GenericE InstanceDef where type RepE InstanceDef = @@ -1969,11 +1955,19 @@ instance AlphaHashableE InstanceBody instance RenameE InstanceBody instance GenericE DictType where - type RepE DictType = LiftE SourceName `PairE` ClassName `PairE` ListE CAtom - fromE (DictType sourceName className params) = - LiftE sourceName `PairE` className `PairE` ListE params - toE (LiftE sourceName `PairE` className `PairE` ListE params) = - DictType sourceName className params + type RepE DictType = EitherE3 + {- DictType -} (LiftE SourceName `PairE` ClassName `PairE` ListE CAtom) + {- IxDictType -} CType + {- DataDictType -} CType + fromE = \case + DictType sourceName className params -> Case0 $ LiftE sourceName `PairE` className `PairE` ListE params + IxDictType ty -> Case1 ty + DataDictType ty -> Case2 ty + toE = \case + Case0 (LiftE sourceName `PairE` className `PairE` ListE params) -> DictType sourceName className params + Case1 ty -> IxDictType ty + Case2 ty -> DataDictType ty + _ -> error "impossible" instance SinkableE DictType instance HoistableE DictType @@ -1981,26 +1975,49 @@ instance AlphaEqE DictType instance AlphaHashableE DictType instance RenameE DictType -instance GenericE DictCon where - type RepE DictCon = EitherE3 - {- InstanceDict -} (CType `PairE` PairE InstanceName (ListE CAtom)) - {- IxFin -} (CType `PairE` CAtom) - {- DataData -} (CType `PairE` CType) - fromE d = case d of - InstanceDict t v args -> Case0 $ t `PairE` PairE v (ListE args) - IxFin t x -> Case1 $ t `PairE` x - DataData t ty -> Case2 $ t `PairE` ty - toE d = case d of - Case0 (t `PairE` (PairE v (ListE args))) -> InstanceDict t v args - Case1 (t `PairE` x) -> IxFin t x - Case2 (t `PairE` ty) -> DataData t ty +instance IRRep r => GenericE (Dict r) where + type RepE (Dict r) = EitherE (WhenCore r (PairE (Type r) (Stuck r))) (DictCon r) + fromE = \case + StuckDict t d -> LeftE (WhenIRE (PairE t d)) + DictCon d -> RightE d + {-# INLINE fromE #-} + toE = \case + LeftE (WhenIRE (PairE t d)) -> StuckDict t d + RightE d -> DictCon d + {-# INLINE toE #-} + +instance IRRep r => SinkableE (Dict r) +instance IRRep r => HoistableE (Dict r) +instance IRRep r => AlphaEqE (Dict r) +instance IRRep r => AlphaHashableE (Dict r) +instance IRRep r => RenameE (Dict r) + +instance IRRep r => GenericE (DictCon r) where + type RepE (DictCon r) = EitherE5 + {- InstanceDict -} (WhenCore r (CType `PairE` PairE InstanceName (ListE CAtom))) + {- IxFin -} (WhenCore r CAtom) + {- DataData -} (WhenCore r CType) + {- IxRawFin -} (Atom r) + {- IxSpecialized -} (WhenSimp r (SpecDictName `PairE` ListE SAtom)) + fromE = \case + InstanceDict t v args -> Case0 $ WhenIRE $ t `PairE` PairE v (ListE args) + IxFin x -> Case1 $ WhenIRE $ x + DataData ty -> Case2 $ WhenIRE $ ty + IxRawFin n -> Case3 $ n + IxSpecialized d xs -> Case4 $ WhenIRE $ d `PairE` ListE xs + toE = \case + Case0 (WhenIRE (t `PairE` (PairE v (ListE args)))) -> InstanceDict t v args + Case1 (WhenIRE x) -> IxFin x + Case2 (WhenIRE ty) -> DataData ty + Case3 n -> IxRawFin n + Case4 (WhenIRE (d `PairE` ListE xs)) -> IxSpecialized d xs _ -> error "impossible" -instance SinkableE DictCon -instance HoistableE DictCon -instance AlphaEqE DictCon -instance AlphaHashableE DictCon -instance RenameE DictCon +instance IRRep r => SinkableE (DictCon r) +instance IRRep r => HoistableE (DictCon r) +instance IRRep r => AlphaEqE (DictCon r) +instance IRRep r => AlphaHashableE (DictCon r) +instance IRRep r => RenameE (DictCon r) instance GenericE Cache where type RepE Cache = @@ -2068,8 +2085,6 @@ instance HoistableE CoreLamExpr instance AlphaEqE CoreLamExpr instance AlphaHashableE CoreLamExpr instance RenameE CoreLamExpr -deriving instance Show (CoreLamExpr n) -deriving via WrapE CoreLamExpr n instance Generic (CoreLamExpr n) instance GenericE CorePiType where type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) @@ -2086,30 +2101,6 @@ instance RenameE CorePiType deriving instance Show (CorePiType n) deriving via WrapE CorePiType n instance Generic (CorePiType n) -instance IRRep r => GenericE (IxDict r) where - type RepE (IxDict r) = - EitherE3 - (WhenCore r (Atom r)) - (Atom r) - (WhenSimp r (Type r `PairE` SpecDictName `PairE` ListE (Atom r))) - fromE = \case - IxDictAtom x -> Case0 $ WhenIRE x - IxDictRawFin n -> Case1 $ n - IxDictSpecialized t d xs -> Case2 $ WhenIRE $ t `PairE` d `PairE` ListE xs - {-# INLINE fromE #-} - toE = \case - Case0 (WhenIRE x) -> IxDictAtom x - Case1 (n) -> IxDictRawFin n - Case2 (WhenIRE (t `PairE` d `PairE` ListE xs)) -> IxDictSpecialized t d xs - _ -> error "impossible" - {-# INLINE toE #-} - -instance IRRep r => SinkableE (IxDict r) -instance IRRep r => HoistableE (IxDict r) -instance IRRep r => RenameE (IxDict r) -instance IRRep r => AlphaEqE (IxDict r) -instance IRRep r => AlphaHashableE (IxDict r) - instance IRRep r => GenericE (IxType r) where type RepE (IxType r) = PairE (Type r) (IxDict r) fromE (IxType ty d) = PairE ty d @@ -2183,11 +2174,12 @@ deriving via WrapE (DepPairType r) n instance IRRep r => Generic (DepPairType r instance GenericE SynthCandidates where type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) - fromE (SynthCandidates ys) = ListE ys' - where ys' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList ys) + `PairE` ListE InstanceName + fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys + where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) {-# INLINE fromE #-} - toE (ListE ys) = SynthCandidates ys' - where ys' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) ys + toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys + where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs {-# INLINE toE #-} instance SinkableE SynthCandidates @@ -2204,7 +2196,7 @@ instance IRRep r => GenericE (AtomBinding r) where (WhenCore r SolverBinding) -- SolverBound ) (EitherE3 (WhenCore r (PairE CType CAtom)) -- NoinlineFun - (WhenSimp r (RepVal SimpIR)) -- TopDataBound + (WhenSimp r RepVal) -- TopDataBound (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound ) @@ -2520,10 +2512,11 @@ instance IRRep r => BindsOneName (Decl r) (AtomNameC r) where {-# INLINE binderName #-} instance Semigroup (SynthCandidates n) where - SynthCandidates xs <> SynthCandidates xs' = SynthCandidates (M.unionWith (<>) xs xs') + SynthCandidates xs ys <> SynthCandidates xs' ys' = + SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') instance Monoid (SynthCandidates n) where - mempty = SynthCandidates mempty + mempty = SynthCandidates mempty mempty instance GenericB EnvFrag where type RepB EnvFrag = RecSubstFrag Binding @@ -2787,22 +2780,22 @@ instance Monoid (LoadedObjects n) where instance Hashable InfVarDesc instance Hashable IxMethod instance Hashable ParamRole +instance Hashable BuiltinClassName instance Hashable a => Hashable (EvalStatus a) instance IRRep r => Store (MiscOp r n) instance IRRep r => Store (VectorOp r n) instance IRRep r => Store (MemOp r n) -instance IRRep r => Store (TC r n) +instance IRRep r => Store (TyCon r n) instance IRRep r => Store (Con r n) instance IRRep r => Store (PrimOp r n) -instance IRRep r => Store (RepVal r n) +instance Store (RepVal n) instance IRRep r => Store (Type r n) instance IRRep r => Store (EffTy r n) instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) instance IRRep r => Store (AtomVar r n) instance IRRep r => Store (Expr r n) -instance Store (SimpInCore n) instance Store (SolverBinding n) instance IRRep r => Store (AtomBinding r n) instance Store (SpecializationSpec n) @@ -2821,11 +2814,12 @@ instance Store (CoreLamExpr n) instance IRRep r => Store (TabPiType r n) instance IRRep r => Store (DepPairType r n) instance Store (AtomRules n) +instance Store BuiltinClassName instance Store (ClassDef n) instance Store (InstanceDef n) instance Store (InstanceBody n) instance Store (DictType n) -instance Store (DictCon n) +instance IRRep r => Store (DictCon r n) instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (EffectOpType n) @@ -2845,12 +2839,12 @@ instance Store InfVarDesc instance Store IxMethod instance Store ParamRole instance Store (SpecializedDictDef n) +instance IRRep r => Store (Dict r n) instance IRRep r => Store (TypedHof r n) instance IRRep r => Store (Hof r n) instance IRRep r => Store (RefOp r n) instance IRRep r => Store (BaseMonoid r n) instance IRRep r => Store (DAMOp r n) -instance IRRep r => Store (IxDict r n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 3be5058b1..2405da77d 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -178,7 +178,7 @@ vectorizeLoopsExpr expr = do let vn = n `div` loopWidth body' <- vectorizeSeq loopWidth ixty body dest' <- renameM dest - emitExpr =<< mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body') + emitExpr =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') else renameM expr >>= emitExpr) `catchErr` \err -> do let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr @@ -256,7 +256,7 @@ isAdditionMonoid monoid = do BaseMonoid { baseEmpty = (Con (Lit l)) , baseCombine = BinaryLamExpr (b1:>_) (b2:>_) body } <- Just monoid unless (_isZeroLit l) Nothing - PrimOp (BinOp op (Var b1') (Var b2')) <- return body + PrimOp (BinOp op (Stuck _ (Var b1')) (Stuck _ (Var b2'))) <- return body unless (op `elem` [P.IAdd, P.FAdd]) Nothing case (binderName b1, atomVarName b1', binderName b2, atomVarName b2') of -- Checking the raw names here because (i) I don't know how to convince the @@ -300,7 +300,7 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where safe :: Effect SimpIR i -> TopVectorizeM i o Bool safe InitEffect = return True safe (RWSEffect Reader _) = return True - safe (RWSEffect Writer (Var h)) = do + safe (RWSEffect Writer (Stuck _ (Var h))) = do h' <- renameM $ atomVarName h commuteMap <- ask case lookupNameMapE h' commuteMap of @@ -313,15 +313,15 @@ vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do newLoopTy <- case ty of - ProdTy [_ixType, ref] -> do + TyCon (ProdType [_ixType, ref]) -> do ref' <- renameM ref - return $ ProdTy [IdxRepTy, ref'] + return $ TyCon $ ProdType [IdxRepTy, ref'] _ -> error "Unexpected seq binder type" ixty' <- renameM ixty liftVectorizeM loopWidth $ buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do -- The per-tile loop iterates on `Fin` - (viOrd, dest) <- fromPair $ Var ci + (viOrd, dest) <- fromPair $ toAtom ci iOrd <- imul viOrd $ IdxRepVal loopWidth -- TODO: It would be nice to cancel this UnsafeFromOrdinal with the -- Ordinal that will be taken later when indexing, but that should @@ -362,13 +362,13 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of LamExpr Empty <$> buildBlock (do vectorizeExpr body >>= \case (VVal _ ans) -> return ans - (VRename v) -> Var <$> toAtomVar v) + (VRename v) -> toAtom <$> toAtomVar v) (Nest (b:>ty) rest, (stab:stabs)) -> do ty' <- vectorizeType ty ty'' <- promoteTypeByStability ty' stab withFreshBinder (getNameHint b) ty'' \b' -> do var <- toAtomVar $ binderName b' - extendSubst (b @> VVal stab (Var var)) do + extendSubst (b @> VVal stab (toAtom var)) do LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs return $ LamExpr (Nest b' rest') body' _ -> error "Zip error" @@ -377,7 +377,7 @@ vectorizeExpr :: Emits o => SExpr i -> VectorizeM i o (VAtom o) vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do case expr of Block _ block -> vectorizeBlock block - TabApp _ tbl [ix] -> do + TabApp _ tbl ix -> do VVal Uniform tbl' <- vectorizeAtom tbl VVal Contiguous ix' <- vectorizeAtom ix case getType tbl' of @@ -442,7 +442,7 @@ vectorizeRefOp ref' op = VVal Uniform ref <- vectorizeAtom ref' VVal Contiguous i <- vectorizeAtom i' case getType ref of - TC (RefType _ (TabTy _ tb a)) -> do + TyCon (RefType _ (TabTy _ tb a)) -> do vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" @@ -525,23 +525,29 @@ vectorizeType t = do vectorizeAtom :: SAtom i -> VectorizeM i o (VAtom o) vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do case atom of - Stuck e -> case e of - StuckVar v -> lookupSubstM (atomVarName v) >>= \case - VRename v' -> VVal Uniform . Var <$> toAtomVar v' - v' -> return v' - StuckProject _ i x -> do - VVal vv x' <- vectorizeAtom (Stuck x) - ov <- case vv of - ProdStability sbs -> return $ sbs !! i - _ -> throwVectErr "Invalid projection" - x'' <- reduceProj i x' - return $ VVal ov x'' - -- TODO: think about this case - StuckTabApp _ _ _ -> throwVectErr $ "Cannot vectorize atom: " ++ pprint atom - Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l - _ -> do - subst <- getSubst - VVal Uniform <$> fmapNamesM (uniformSubst subst) atom + Stuck _ e -> vectorizeStuck e + Con con -> case con of + Lit l -> return $ VVal Uniform $ Con $ Lit l + _ -> do + subst <- getSubst + VVal Uniform <$> fmapNamesM (uniformSubst subst) atom + +vectorizeStuck :: SStuck i -> VectorizeM i o (VAtom o) +vectorizeStuck = \case + Var v -> lookupSubstM (atomVarName v) >>= \case + VRename v' -> VVal Uniform . toAtom <$> toAtomVar v' + v' -> return v' + StuckProject i x -> do + VVal vv x' <- vectorizeStuck x + ov <- case vv of + ProdStability sbs -> return $ sbs !! i + _ -> throwVectErr "Invalid projection" + x'' <- reduceProj i x' + return $ VVal ov x'' + -- TODO: think about this case + StuckTabApp _ _ -> throwVectErr $ "Cannot vectorize atom" + PtrVar _ _ -> throwVectErr $ "Cannot vectorize atom" + RepValAtom _ -> throwVectErr $ "Cannot vectorize atom" uniformSubst :: Color c => Subst VSubstValC i o -> Name c i -> AtomSubstVal c o uniformSubst subst n = case subst ! n of @@ -577,7 +583,7 @@ ensureVarying (VVal s val) = case s of _ -> throwVectErr "Not implemented" ProdStability _ -> throwVectErr "Not implemented" ensureVarying (VRename v) = do - x <- Var <$> toAtomVar v + x <- toAtom <$> toAtomVar v ensureVarying (VVal Uniform x) promoteTypeByStability :: SType o -> Stability -> VectorizeM i o (SType o) @@ -586,7 +592,7 @@ promoteTypeByStability ty = \case Contiguous -> return ty Varying -> getVectorType ty ProdStability stabs -> case ty of - ProdTy elts -> ProdTy <$> zipWithZ promoteTypeByStability elts stabs + TyCon (ProdType elts) -> TyCon <$> ProdType <$> zipWithZ promoteTypeByStability elts stabs _ -> throw ZipErr "Type and stability" -- === computing byte widths === @@ -618,7 +624,7 @@ instance ExprVisitorNoEmits (CalcWidthM i o) SimpIR i o where typeByteWidth :: SType n -> Word32 typeByteWidth ty = case ty of - TC (BaseType bt) -> case bt of + BaseTy bt -> case bt of -- Currently only support vectorization of scalar types (cf. `getVectorType` above): Scalar _ -> fromInteger . toInteger $ sizeOf bt _ -> maxWord32 From ac0f6c8f4f726eca58037d25a654adc6c419bd81 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 8 Nov 2023 14:02:24 -0500 Subject: [PATCH 08/41] Fix a couple of bugs --- lib/prelude.dx | 3 +++ src/lib/CheapReduction.hs | 8 +++++--- src/lib/ConcreteSyntax.hs | 13 +++++++++---- src/lib/Inference.hs | 7 ++++--- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index 243e09e03..badf5e39e 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -32,6 +32,7 @@ interface Data(a:Type) '### Casting +@inline def internal_cast(x:from) -> to given (from:Type, to:Type) = %cast(to, x) @@ -59,6 +60,7 @@ Nat = %Nat() NatRep = Word32 def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x) +@inline def rep_to_nat(x : NatRep) -> Nat = %NatCon(x) def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast @@ -72,6 +74,7 @@ def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat +@inline def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 8fc403f03..4feafff29 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -171,14 +171,16 @@ queryStuckType = \case Var v -> return $ getType v StuckProject i s -> projType i =<< mkStuck s StuckTabApp f x -> do - f' <- mkStuck f - typeOfTabApp (getType f') x + fTy <- queryStuckType f + typeOfTabApp fTy x PtrVar t _ -> return $ PtrTy t RepValAtom repVal -> return $ getType repVal StuckUnwrap s -> queryStuckType s >>= \case TyCon (NewtypeTyCon con) -> snd <$> unwrapNewtypeType con _ -> error "not a newtype" - InstantiatedGiven _ _ -> undefined + InstantiatedGiven f xs -> do + fTy <- queryStuckType f + typeOfApp fTy xs SuperclassProj i s -> superclassProjType i =<< queryStuckType s LiftSimp t _ -> return t LiftSimpFun t _ -> return $ TyCon $ Pi t diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 32eabb071..2bae7ea7a 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -262,7 +262,7 @@ structDef = do funDefLetWithAnn :: Parser (LetAnn, CDef) funDefLetWithAnn = do - ann <- noInline <|> return PlainLet + ann <- topLetAnn <|> return PlainLet def <- funDefLet return (ann, def) @@ -306,12 +306,17 @@ topLetOrExpr = withSrc topLet >>= \case topLet :: Parser CTopDecl' topLet = do - lAnn <- noInline <|> return PlainLet + lAnn <- topLetAnn <|> return PlainLet decl <- cDecl return $ CSDecl lAnn decl -noInline :: Parser LetAnn -noInline = (char '@' >> string "noinline" $> NoInlineLet) <* nextLine +topLetAnn :: Parser LetAnn +topLetAnn = do + void $ char '@' + ann <- (string "inline" $> InlineLet) + <|> (string "noinline" $> NoInlineLet) + nextLine + return ann onePerLine :: Parser a -> Parser [a] onePerLine p = liftM (:[]) p diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index eae234af6..cf0422e04 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -628,6 +628,7 @@ withUDecl (WithSrcB src d) cont = addSrcContext src case d of considerInlineAnn :: LetAnn -> CType n -> LetAnn considerInlineAnn PlainLet TyKind = InlineLet +considerInlineAnn PlainLet (TyCon (Pi (CorePiType _ _ _ (EffTy Pure TyKind)))) = InlineLet considerInlineAnn ann _ = ann applyFromLiteralMethod @@ -1970,8 +1971,8 @@ isSkolemName v = lookupEnv v >>= \case _ -> return False {-# INLINE isSkolemName #-} -renameForPrinting :: (EnvReader m, HasNamesE e) - => e n -> m n (Abs (Nest (AtomNameBinder CoreIR)) e n) +renameForPrinting :: EnvReader m + => (PairE CAtom CAtom n) -> m n (Abs (Nest (AtomNameBinder CoreIR)) (PairE CAtom CAtom) n) renameForPrinting e = do infVars <- filterM isSolverName $ freeAtomVarsList e let ab = abstractFreeVarsNoAnn infVars e @@ -2202,7 +2203,7 @@ synthDictFromGiven targetTy = do SynthDictType givenDictTy -> do guard =<< alphaEq targetTy givenDictTy return given - SynthPiType givenPiTy -> do + SynthPiType givenPiTy -> typeErrAsSearchFailure do args <- instantiateSynthArgs targetTy givenPiTy reduceInstantiateGiven given args From edb39df4d029676cfb7760258137cbc8a9967511 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 8 Nov 2023 15:29:06 -0500 Subject: [PATCH 09/41] Fill in some missing cases in cheap reduction --- src/lib/CheapReduction.hs | 56 ++++++++++++++++++++++++++++++++++----- src/lib/Simplify.hs | 37 +++----------------------- src/lib/Types/Core.hs | 2 +- 3 files changed, 54 insertions(+), 41 deletions(-) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 4feafff29..c70815282 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -15,7 +15,8 @@ module CheapReduction , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst , repValAtom, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp - , reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck) + , reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck + , liftSimpAtom, reduceACase) where import Control.Applicative @@ -61,6 +62,10 @@ reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x {-# INLINE reduceProj #-} +reduceACase :: EnvReader m => SAtom n -> [Abs SBinder CAtom n] -> CType n -> m n (CAtom n) +reduceACase scrut alts resultTy = liftM fromJust $ liftReducerM $ reduceACaseM scrut alts resultTy +{-# INLINE reduceACase #-} + reduceUnwrap :: EnvReader m => CAtom n -> m n (CAtom n) reduceUnwrap x = liftM fromJust $ liftReducerM $ reduceUnwrapM x {-# INLINE reduceUnwrap #-} @@ -131,6 +136,14 @@ reduceApp f xs = do Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body _ -> empty +reduceACaseM :: SAtom n -> [Abs SBinder CAtom n] -> CType n -> ReducerM i n (CAtom n) +reduceACaseM scrut alts resultTy = case scrut of + Con (SumCon _ i arg) -> do + Abs b body <- return $ alts !! i + applySubst (b@>SubstVal arg) body + Con _ -> error "not a sum type" + Stuck _ scrut' -> mkStuck $ ACase scrut' alts resultTy + reduceProjM :: IRRep r => Int -> Atom r o -> ReducerM i o (Atom r o) reduceProjM i x = case x of Con con -> case con of @@ -183,10 +196,10 @@ queryStuckType = \case typeOfApp fTy xs SuperclassProj i s -> superclassProjType i =<< queryStuckType s LiftSimp t _ -> return t - LiftSimpFun t _ -> return $ TyCon $ Pi t + LiftSimpFun t _ -> return $ toType t -- TabLam and ACase are just defunctionalization tools. The result type -- in both cases should *not* be `Data`. - TabLam _ -> undefined + TabLam (PairE t _) -> return $ toType t ACase _ _ resultTy -> return resultTy projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n) @@ -657,10 +670,39 @@ reduceStuck = \case reduceSuperclassProjM superclassIx child' PtrVar ptrTy ptr -> mkStuck =<< PtrVar ptrTy <$> substM ptr RepValAtom repVal -> mkStuck =<< RepValAtom <$> substM repVal - LiftSimp _ _ -> undefined - LiftSimpFun _ _ -> undefined - TabLam _ -> undefined - ACase _ _ _ -> undefined + LiftSimp t s -> do + t' <- substM t + s' <- reduceStuck s + liftSimpAtom t' s' + LiftSimpFun t f -> mkStuck =<< (LiftSimpFun <$> substM t <*> substM f) + TabLam lam -> mkStuck =<< (TabLam <$> substM lam) + ACase scrut alts resultTy -> do + scrut' <- reduceStuck scrut + resultTy' <- substM resultTy + alts' <- mapM substM alts + reduceACaseM scrut' alts' resultTy' + +liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) +liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type" +liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of + Stuck _ stuck -> return $ Stuck ty $ LiftSimp ty stuck + Con con -> Con <$> case (tyCon, con) of + (NewtypeTyCon newtypeCon, _) -> do + (dataCon, repTy) <- unwrapNewtypeType newtypeCon + cAtom <- rec repTy (Con con) + return $ NewtypeCon dataCon cAtom + (BaseType _ , Lit v) -> return $ Lit v + (ProdType tys, ProdCon xs) -> ProdCon <$> zipWithM rec tys xs + (SumType tys, SumCon _ i x) -> SumCon tys i <$> rec (tys!!i) x + (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do + x1' <- rec t1 x1 + t2' <- applySubst (b@>SubstVal x1') t2 + x2' <- rec t2' x2 + return $ DepPair x1' x2' dpt + _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty + where + rec = liftSimpAtom +{-# INLINE liftSimpAtom #-} instance IRRep r => SubstE AtomSubstVal (EffectRow r) where substE env (EffectRow effs tailVar) = do diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index e98d977fa..239641ede 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -485,7 +485,7 @@ defuncCase scrut resultTy cont = do (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> buildAbs noHint ty \v -> applyRecon (sink recon) (toAtom v) - nonDataVal <- mkACase sumVal reconAlts newNonDataTy + nonDataVal <- reduceACase sumVal reconAlts newNonDataTy Distinct <- getDistinct fromSplit split dataVal nonDataVal @@ -798,7 +798,8 @@ simplifyHof resultTy = \case ab <- buildAbs noHint ixTy' \i -> do xs <- unpackTelescope bsClosure =<< reduceTabApp (sink ans) (toAtom i) applySubst (bIx@>Rename (atomVarName i) <.> bsClosure @@> map SubstVal xs) reconResult - mkStuck $ TabLam $ IxType ixTy' ixDict' `PairE` ab + TyCon (TabPi resultTy') <- return resultTy + mkStuck $ TabLam $ resultTy' `PairE` ab While body -> do SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body result <- emitHof $ While body' @@ -880,15 +881,7 @@ fmapMaybe scrut f = do return (Abs b result', resultTy) nothingAlt <- buildAbs noHint UnitTy \_ -> preludeNothingVal $ sink resultJustTy resultMaybeTy <- makePreludeMaybeTy resultJustTy - mkACase scrut [nothingAlt, justAlt] resultMaybeTy - -mkACase :: SAtom n -> [Abs SBinder CAtom n] -> CType n -> SimplifyM i n (CAtom n) -mkACase scrut alts resultTy = case scrut of - Con (SumCon _ i arg) -> do - Abs b body <- return $ alts !! i - applySubst (b@>SubstVal arg) body - Con _ -> error "not a sum type" - Stuck _ scrut' -> mkStuck $ ACase scrut' alts resultTy + reduceACase scrut [nothingAlt, justAlt] resultMaybeTy -- This is wrong! The correct implementation is below. And yet there's some -- compensatory bug somewhere that means that the wrong answer works and the @@ -911,28 +904,6 @@ preludeMaybeNewtypeCon ty = do let params = TyConParams [Explicit] [toAtom ty] return $ UserADTData sn tyConName params -liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) -liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type" -liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of - Stuck _ stuck -> return $ Stuck ty $ LiftSimp ty stuck - Con con -> Con <$> case (tyCon, con) of - (NewtypeTyCon newtypeCon, _) -> do - (dataCon, repTy) <- unwrapNewtypeType newtypeCon - cAtom <- rec repTy (Con con) - return $ NewtypeCon dataCon cAtom - (BaseType _ , Lit v) -> return $ Lit v - (ProdType tys, ProdCon xs) -> ProdCon <$> zipWithM rec tys xs - (SumType tys, SumCon _ i x) -> SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do - x1' <- rec t1 x1 - t2' <- applySubst (b@>SubstVal x1') t2 - x2' <- rec t2' x2 - return $ DepPair x1' x2' dpt - _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty - where - rec = liftSimpAtom -{-# INLINE liftSimpAtom #-} - liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) liftSimpFun (TyCon (Pi piTy)) f = mkStuck $ LiftSimpFun piTy f liftSimpFun _ _ = error "not a pi type" diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 06f2cbc86..2ff148a21 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -215,7 +215,7 @@ data LamExpr (r::IR) (n::S) where data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) deriving (Show, Generic) -type TabLamExpr = PairE (IxType SimpIR) (Abs SBinder CAtom) +type TabLamExpr = PairE (TabPiType CoreIR) (Abs SBinder CAtom) type IxDict = Dict data IxMethod = Size | Ordinal | UnsafeFromOrdinal From be618938ea4234c05e4bd98bf4192555e2e3d8f4 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 9 Nov 2023 10:15:25 -0500 Subject: [PATCH 10/41] Rename `emitExpr` to `emit` --- src/lib/Algebra.hs | 2 +- src/lib/Builder.hs | 101 ++++++++++++++++++---------------------- src/lib/Imp.hs | 2 +- src/lib/Inference.hs | 51 ++++++++++---------- src/lib/Inline.hs | 14 +++--- src/lib/JAX/ToSimp.hs | 4 +- src/lib/Linearize.hs | 42 ++++++++--------- src/lib/Lower.hs | 24 +++++----- src/lib/Optimize.hs | 7 ++- src/lib/RuntimePrint.hs | 8 ++-- src/lib/Simplify.hs | 30 ++++++------ src/lib/TopLevel.hs | 3 +- src/lib/Transpose.hs | 8 ++-- src/lib/Vectorize.hs | 46 +++++++++--------- 14 files changed, 165 insertions(+), 177 deletions(-) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index bf8462a83..eaebecaaf 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -59,7 +59,7 @@ sumUsingPolys lim (Abs i body) = do "Algebraic simplification failed to model index computations:\n" ++ "Trying to sum from 0 to " ++ pprint lim ++ " - 1, \\" ++ pprint i' ++ "." ++ pprint body' - limName <- emit (Atom lim) + limName <- emitToVar (Atom lim) emitPolynomial $ sum (LeftE (atomVarName limName)) sumAbs mul :: Polynomial n-> Polynomial n -> Polynomial n diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 7a654911a..e38792bc9 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -66,10 +66,6 @@ emitDecl _ _ (Atom (Stuck _ (Var n))) = return n emitDecl hint ann expr = rawEmitDecl hint ann expr {-# INLINE emitDecl #-} -emit :: (Builder r m, Emits n) => Expr r n -> m n (AtomVar r n) -emit expr = emitDecl noHint PlainLet expr -{-# INLINE emit #-} - emitInline :: (Builder r m, Emits n) => Atom r n -> m n (AtomVar r n) emitInline atom = emitDecl noHint InlineLet $ Atom atom {-# INLINE emitInline #-} @@ -78,21 +74,21 @@ emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n emitHinted hint expr = emitDecl hint PlainLet expr {-# INLINE emitHinted #-} -emitExpr :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) -emitExpr e = case toExpr e of +emit :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emit e = case toExpr e of Atom x -> return x - Block _ block -> emitDecls block >>= emitExpr - expr -> toAtom <$> emit expr -{-# INLINE emitExpr #-} + Block _ block -> emitDecls block >>= emit + expr -> toAtom <$> emitToVar expr +{-# INLINE emit #-} emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n) emitToVar e = case toExpr e of Atom (Stuck _ (Var v)) -> return v - expr -> emit expr + expr -> emitDecl noHint PlainLet expr {-# INLINE emitToVar #-} emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) -emitHof hof = mkTypedHof hof >>= emitExpr +emitHof hof = mkTypedHof hof >>= emit mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) mkTypedHof hof = do @@ -100,7 +96,7 @@ mkTypedHof hof = do return $ TypedHof effTy hof emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) -emitUnOp op x = emitExpr $ UnOp op x +emitUnOp op x = emit $ UnOp op x {-# INLINE emitUnOp #-} emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) @@ -115,11 +111,6 @@ emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do AtomVar v _ <- emitDecl (getNameHint b) ann expr' extendSubst (b @> v) $ emitDecls' rest e -emitExprToAtom :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) -emitExprToAtom (Atom atom) = return atom -emitExprToAtom expr = toAtom <$> emit expr -{-# INLINE emitExprToAtom #-} - buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m) => (forall l. (Emits l, DExt n l) => m l (e l)) -> m n (e n) @@ -739,7 +730,7 @@ buildCaseAlts scrut indexedAltBody = do injectAltResult :: EnvReader m => [SType n] -> Int -> Alt SimpIR n -> m n (Alt SimpIR n) injectAltResult sumTys con (Abs b body) = liftBuilder do buildAlt (binderType b) \v -> do - originalResult <- emitExpr =<< applySubst (b@>SubstVal (toAtom v)) body + originalResult <- emit =<< applySubst (b@>SubstVal (toAtom v)) body (dataResult, nonDataResult) <- fromPairReduced originalResult return $ toAtom $ ProdCon [dataResult, Con $ SumCon (sinkList sumTys) con nonDataResult] @@ -768,7 +759,7 @@ buildCase :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n -> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l)) -> m n (Atom r n) -buildCase s r b = emitExprToAtom =<< buildCase' s r b +buildCase s r b = emit =<< buildCase' s r b buildEffLam :: ScopableBuilder r m @@ -845,7 +836,7 @@ emitSeq :: (Emits n, ScopableBuilder SimpIR m) -> m n (Atom SimpIR n) emitSeq d t x f = do op <- mkSeq d t x f - emitExpr $ PrimOp $ DAMOp op + emit $ PrimOp $ DAMOp op mkSeq :: EnvReader m => Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n @@ -862,7 +853,7 @@ buildRememberDest hint dest cont = do ty <- return $ getType dest doit <- buildUnaryLamExpr hint ty cont effs <- functionEffs doit - emitExpr $ PrimOp $ DAMOp $ RememberDest effs dest doit + emit $ PrimOp $ DAMOp $ RememberDest effs dest doit -- === vector space (ish) type class === @@ -918,7 +909,7 @@ tangentBaseMonoidFor ty = do addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do case getTyCon x of - BaseType (Scalar _) -> emitExpr $ BinOp FAdd x y + BaseType (Scalar _) -> emit $ BinOp FAdd x y ProdType _ -> do xs <- getUnpacked x ys <- getUnpacked y @@ -956,63 +947,63 @@ fLitLike x t = case getTyCon t of _ -> error "Expected a floating point scalar" neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -neg x = emitExpr $ UnOp FNeg x +neg x = emit $ UnOp FNeg x add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -add x y = emitExpr $ BinOp FAdd x y +add x y = emit $ BinOp FAdd x y -- TODO: Implement constant folding for fixed-width integer types as well! iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) iadd (Con (Lit l)) y | getIntLit l == 0 = return y iadd x (Con (Lit l)) | getIntLit l == 0 = return x iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y -iadd x y = emitExpr $ BinOp IAdd x y +iadd x y = emit $ BinOp IAdd x y mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -mul x y = emitExpr $ BinOp FMul x y +mul x y = emit $ BinOp FMul x y imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) imul (Con (Lit l)) y | getIntLit l == 1 = return y imul x (Con (Lit l)) | getIntLit l == 1 = return x imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y -imul x y = emitExpr $ BinOp IMul x y +imul x y = emit $ BinOp IMul x y sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emitExpr $ BinOp FSub x y +sub x y = emit $ BinOp FSub x y isub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) isub x (Con (Lit l)) | getIntLit l == 0 = return x isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y -isub x y = emitExpr $ BinOp ISub x y +isub x y = emit $ BinOp ISub x y select :: (Builder r m, Emits n) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n) select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y -select p x y = emitExpr $ MiscOp $ Select p x y +select p x y = emit $ MiscOp $ Select p x y div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -div' x y = emitExpr $ BinOp FDiv x y +div' x y = emit $ BinOp FDiv x y idiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) idiv x (Con (Lit l)) | getIntLit l == 1 = return x idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y -idiv x y = emitExpr $ BinOp IDiv x y +idiv x y = emit $ BinOp IDiv x y irem :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -irem x y = emitExpr $ BinOp IRem x y +irem x y = emit $ BinOp IRem x y fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -fpow x y = emitExpr $ BinOp FPow x y +fpow x y = emit $ BinOp FPow x y flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -flog x = emitExpr $ UnOp Log x +flog x = emit $ UnOp Log x ilt :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y -ilt x y = emitExpr $ BinOp (ICmp Less) x y +ilt x y = emit $ BinOp (ICmp Less) x y ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y -ieq x y = emitExpr $ BinOp (ICmp Equal) x y +ieq x y = emit $ BinOp (ICmp Equal) x y fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) fromPair pair = do @@ -1026,7 +1017,7 @@ applyProjectionsRef [] ref = return ref applyProjectionsRef (i:is) ref = getProjRef i =<< applyProjectionsRef is ref getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n) -getProjRef i r = emitExpr =<< mkProjRef r i +getProjRef i r = emit =<< mkProjRef r i -- XXX: getUnpacked must reduce its argument to enforce the invariant that -- ProjectElt atoms are always fully reduced (to avoid type errors between two @@ -1048,7 +1039,7 @@ unwrapNewtype (Con (NewtypeCon _ x)) = return x unwrapNewtype x = case getType x of TyCon (NewtypeTyCon con) -> do (_, ty) <- unwrapNewtypeType con - emitExpr $ Unwrap ty x + emit $ Unwrap ty x _ -> error "not a newtype" {-# INLINE unwrapNewtype #-} @@ -1061,7 +1052,7 @@ proj i = \case _ -> error "not a product" x -> do ty <- projType i x - emitExpr $ Project ty i x + emit $ Project ty i x getFst :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) getFst = proj 0 @@ -1166,21 +1157,21 @@ mkCatchException body = do return $ CatchException resultTy body app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) -app x i = mkApp x [i] >>= emitExpr +app x i = mkApp x [i] >>= emit naryApp :: (CBuilder m, Emits n) => CAtom n -> [CAtom n] -> m n (CAtom n) naryApp = naryAppHinted noHint {-# INLINE naryApp #-} naryTopApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n) -naryTopApp f xs = emitExpr =<< mkTopApp f xs +naryTopApp f xs = emit =<< mkTopApp f xs {-# INLINE naryTopApp #-} naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n) naryTopAppInlined f xs = do TopFunBinding f' <- lookupEnv f case f' of - DexTopFun _ lam _ -> instantiate lam xs >>= emitExpr + DexTopFun _ lam _ -> instantiate lam xs >>= emit _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} @@ -1189,29 +1180,29 @@ naryAppHinted :: (CBuilder m, Emits n) naryAppHinted hint f xs = toAtom <$> (mkApp f xs >>= emitHinted hint) tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -tabApp x i = mkTabApp x i >>= emitExpr +tabApp x i = mkTabApp x i >>= emit naryTabApp :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) naryTabApp f [] = return f naryTabApp f (x:xs) = do - ans <- mkTabApp f x >>= emitExpr + ans <- mkTabApp f x >>= emit naryTabApp ans xs {-# INLINE naryTabApp #-} indexRef :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -indexRef ref i = emitExpr =<< mkIndexRef ref i +indexRef ref i = emit =<< mkIndexRef ref i naryIndexRef :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) naryIndexRef ref is = foldM indexRef ref is ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ptrOffset x (IdxRepVal 0) = return x -ptrOffset x i = emitExpr $ MemOp $ PtrOffset x i +ptrOffset x i = emit $ MemOp $ PtrOffset x i {-# INLINE ptrOffset #-} unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) unsafePtrLoad x = do - body <- liftEmitBuilder $ buildBlock $ emitExpr . MemOp . PtrLoad =<< sinkM x + body <- liftEmitBuilder $ buildBlock $ emit . MemOp . PtrLoad =<< sinkM x emitHof $ RunIO body mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n) @@ -1236,7 +1227,7 @@ applyIxMethod (DictCon dict) method args = case dict of IxSpecialized d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs - instantiate (fs !! fromEnum method) (params ++ args) >>= emitExpr + instantiate (fs !! fromEnum method) (params ++ args) >>= emit unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n) unsafeFromOrdinal (IxType _ dict) i = applyIxMethod dict UnsafeFromOrdinal [i] @@ -1251,7 +1242,7 @@ indexSetSize (IxType _ dict) = applyIxMethod dict Size [] applyIxMethodCore :: (CBuilder m, Emits n) => IxMethod -> IxType CoreIR n -> [CAtom n] -> m n (CAtom n) applyIxMethodCore method (IxType _ dict) args = - emitExpr =<< mkApplyMethod dict (fromEnum method) args + emit =<< mkApplyMethod dict (fromEnum method) args -- === pseudo-prelude === @@ -1267,7 +1258,7 @@ emitIf :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Atom r n) emitIf predicate resultTy trueCase falseCase = do - predicate' <- emitExpr $ ToEnum (TyCon (SumType [UnitTy, UnitTy])) predicate + predicate' <- emit $ ToEnum (TyCon (SumType [UnitTy, UnitTy])) predicate buildCase predicate' resultTy \i _ -> case i of 0 -> falseCase @@ -1291,7 +1282,7 @@ fromJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) fromJustE x = liftEmitBuilder do MaybeTy a <- return $ getType x emitMaybeCase x a - (emitExpr $ MiscOp $ ThrowError $ sink a) + (emit $ MiscOp $ ThrowError $ sink a) (return) -- Maybe a -> Bool @@ -1307,12 +1298,12 @@ reduceE monoid xs = liftEmitBuilder do getSnd =<< emitRunWriter noHint a monoid \_ ref -> buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do x <- tabApp (sink xs) (toAtom i) - emitExpr $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x + emit $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n) andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $ buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y -> - emitExpr $ BinOp BAnd (sink $ toAtom x) (toAtom y) + emit $ BinOp BAnd (sink $ toAtom x) (toAtom y) -- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b) mapE :: (Emits n, ScopableBuilder SimpIR m) @@ -1363,7 +1354,7 @@ runMaybeWhile body = do emitWhile do ans <- body emitMaybeCase ans Word8Ty - (emit (toExpr $ RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom) + (emit (RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom) (return) return UnitVal emitIf hadError (MaybeTy UnitTy) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index e99e527a4..7ab6c865c 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -353,7 +353,7 @@ toImpRefOp refDest' m = do True -> do BinaryLamExpr xb yb body <- return bc body' <- applySubst (xb @> SubstVal x <.> yb @> SubstVal y) body - ans <- liftBuilderImp $ emitExpr (sink body') + ans <- liftBuilderImp $ emit (sink body') storeAtom accDest ans False -> case accTy of TyCon (TabPi t) -> do diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index cf0422e04..c8e9681e2 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -640,7 +640,7 @@ applyFromLiteralMethod resultTy methodName litVal = MethodBinding className _ <- lookupEnv methodName' dictTy <- toType <$> dictType className [toAtom resultTy] Just d <- toMaybeDict <$> trySynthTerm dictTy Full - emitExpr =<< mkApplyMethod d 0 [litVal] + emit =<< mkApplyMethod d 0 [litVal] -- atom that requires instantiation to become a rho type data SigmaAtom n = @@ -787,11 +787,11 @@ inlineTypeAliases v = do _ -> toAtom <$> toAtomVar v applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) -applySigmaAtom (SigmaAtom _ f) args = emitExprWithEffects =<< mkApp f args +applySigmaAtom (SigmaAtom _ f) args = emitWithEffects =<< mkApp f args applySigmaAtom (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do f'' <- inlineTypeAliases f' - emitExprWithEffects =<< mkApp f'' args + emitWithEffects =<< mkApp f'' args UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls @@ -820,9 +820,9 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args - emitExprWithEffects =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' + emitWithEffects =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' applySigmaAtom (SigmaPartialApp _ f prevArgs) args = - emitExprWithEffects =<< mkApp f (prevArgs ++ args) + emitWithEffects =<< mkApp f (prevArgs ++ args) splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do @@ -862,10 +862,10 @@ applyDataCon tc conIx topArgs = do where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty -emitExprWithEffects :: Emits o => CExpr o -> InfererM i o (CAtom o) -emitExprWithEffects expr = do +emitWithEffects :: Emits o => CExpr o -> InfererM i o (CAtom o) +emitWithEffects expr = do addEffects $ getEffects expr - emitExpr expr + emit expr checkExplicitArity :: [Explicitness] -> [a] -> InfererM i o () checkExplicitArity expls args = do @@ -1016,7 +1016,7 @@ inferPrimArg x = do TyKind -> reduceExpr xBlock >>= \case Just reduced -> return reduced _ -> throw CompilerErr "Type args to primops must be reducible" - _ -> emitExpr xBlock + _ -> emit xBlock matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case @@ -1035,15 +1035,15 @@ matchPrimApp = \case P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs P.HeapVal -> \case ~[] -> return $ toAtom HeapVal P.SumCon _ -> error "not supported" - UMiscOp op -> \x -> emitExpr =<< MiscOp <$> matchGenericOp op x - UMemOp op -> \x -> emitExpr =<< MemOp <$> matchGenericOp op x - UBinOp op -> \case ~[x, y] -> emitExpr $ BinOp op x y - UUnOp op -> \case ~[x] -> emitExpr $ UnOp op x - UMAsk -> \case ~[r] -> emitExpr $ RefOp r MAsk - UMGet -> \case ~[r] -> emitExpr $ RefOp r MGet - UMPut -> \case ~[r, x] -> emitExpr $ RefOp r $ MPut x + UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x + UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x + UBinOp op -> \case ~[x, y] -> emit $ BinOp op x y + UUnOp op -> \case ~[x] -> emit $ UnOp op x + UMAsk -> \case ~[r] -> emit $ RefOp r MAsk + UMGet -> \case ~[r] -> emit $ RefOp r MGet + UMPut -> \case ~[r, x] -> emit $ RefOp r $ MPut x UIndexRef -> \case ~[r, i] -> indexRef r i - UApplyMethod i -> \case ~(d:args) -> emitExpr =<< mkApplyMethod (fromJust $ toMaybeDict d) i args + UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x URunReader -> \case ~[x, f] -> do f' <- lam2 f; emitHof $ RunReader x f' @@ -1051,7 +1051,7 @@ matchPrimApp = \case UWhile -> \case ~[f] -> do f' <- lam0 f; emitHof $ While f' URunIO -> \case ~[f] -> do f' <- lam0 f; emitHof $ RunIO f' UCatchException-> \case ~[f] -> do f' <- lam0 f; emitHof =<< mkCatchException f' - UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emitExpr $ RefOp r $ MExtend (BaseMonoid z f') x + UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emit $ RefOp r $ MExtend (BaseMonoid z f') x URunWriter -> \args -> do [idVal, combiner, f] <- return args combiner' <- lam2 combiner @@ -1131,8 +1131,8 @@ buildNthOrderedAlt alts _ resultTy i v = do case lookup i [(idx, alt) | IndexedAlt idx alt <- alts] of Nothing -> do resultTy' <- sinkM resultTy - emitExpr $ ThrowError resultTy' - Just alt -> applyAbs alt (SubstVal v) >>= emitExpr + emit $ ThrowError resultTy' + Just alt -> applyAbs alt (SubstVal v) >>= emit buildMonomorphicCase :: (Emits n, ScopableBuilder CoreIR m) @@ -1159,7 +1159,7 @@ buildSortedCase scrut alts resultTy = do [_] -> do let [IndexedAlt _ alt] = alts scrut' <- unwrapNewtype scrut - emitExpr =<< applyAbs alt (SubstVal scrut') + emit =<< applyAbs alt (SubstVal scrut') _ -> do scrut' <- unwrapNewtype scrut liftEmitBuilder $ buildMonomorphicCase alts scrut' resultTy @@ -1526,8 +1526,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do buildBlock do args <- forM idxs \projs -> do - ans <- applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) - emit $ Atom ans + emitToVar =<< applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) bindLetPats ps args $ cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" @@ -1596,7 +1595,7 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of TyCon (TabPi (TabPiType _ (_:>FinConst n') _)) | n == n' -> return () ty -> throw TypeErr $ "Expected a Fin " ++ show n ++ " table type but got: " ++ pprint ty xs <- forM [0 .. n - 1] \i -> do - emit =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) + emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont checkUType :: UType i -> InfererM i o (CType o) @@ -1621,7 +1620,7 @@ inferTabCon xs = do xs' <- forM xs \x -> topDown elemTy x let dTy = toType $ DataDictType elemTy Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full - emitExpr $ TabCon (Just $ WhenIRE dataDict) tabTy xs' + emit $ TabCon (Just $ WhenIRE dataDict) tabTy xs' checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> [UExpr i] -> InfererM i o (CAtom o) checkTabCon tabTy@(TabPiType _ b elemTy) xs = do @@ -1640,7 +1639,7 @@ checkTabCon tabTy@(TabPiType _ b elemTy) xs = do let dTy = toType $ DataDictType elemTy' return $ toType $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full - emitExpr $ TabCon (Just $ WhenIRE dataDict) (TyCon (TabPi tabTy)) xs' + emit $ TabCon (Just $ WhenIRE dataDict) (TyCon (TabPi tabTy)) xs' addEffects :: EffectRow CoreIR o -> InfererM i o () addEffects Pure = return () diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index da14eeb94..9ba40faea 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -238,7 +238,7 @@ inlineStuck :: Emits o => Context SExpr e o -> SStuck i -> InlineM i o (e o) inlineStuck ctx = \case Var name -> inlineName ctx name StuckProject i x -> do - ans <- proj i =<< emitExprToAtom =<< inlineStuck Stop x + ans <- proj i =<< emit =<< inlineStuck Stop x reconstruct ctx $ Atom ans StuckTabApp _ _ -> error "not implemented" PtrVar t p -> do @@ -305,8 +305,8 @@ reconstruct ctx e = case ctx of TabAppCtx ix s ctx' -> withSubst s $ reconstructTabApp ctx' e ix CaseCtx alts resultTy effs s ctx' -> withSubst s $ reconstructCase ctx' e alts resultTy effs - EmitToAtomCtx ctx' -> emitExprToAtom e >>= reconstruct ctx' - EmitToNameCtx ctx' -> emit (Atom e) >>= reconstruct ctx' + EmitToAtomCtx ctx' -> emit e >>= reconstruct ctx' + EmitToNameCtx ctx' -> emitToVar e >>= reconstruct ctx' {-# INLINE reconstruct #-} reconstructTabApp :: Emits o @@ -318,7 +318,7 @@ reconstructTabApp ctx expr i = case expr of dropSubst $ extendSubst (b@>Rename i') do inlineExpr ctx body _ -> do - array' <- emitExprToAtom expr + array' <- emit expr i' <- inline Stop i reconstruct ctx =<< mkTabApp array' i' @@ -333,17 +333,17 @@ reconstructCase ctx scrutExpr alts resultTy effs = -- of the arms of the outer case resultTy' <- inline Stop resultTy reconstruct ctx =<< (buildCase' sscrut resultTy' \i val -> do - ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emitExpr + ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emit buildCase ans (sink resultTy') \j jval -> do Abs b body <- return $ alts !! j extendSubst (b @> (SubstVal $ DoneEx $ Atom jval)) do - inlineExpr Stop body >>= emitExprToAtom) + inlineExpr Stop body >>= emit) _ -> do -- Attempt case-of-known-constructor optimization -- I can't use `buildCase` here because I want to propagate the incoming -- context `ctx` into the selected alternative if the optimization fires, -- but leave it around the whole reconstructed `Case` if it doesn't. - scrut <- emitExprToAtom scrutExpr + scrut <- emit scrutExpr case scrut of Con con -> do SumCon _ i val <- return con diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index e7b942f6e..cdf25d73b 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -122,7 +122,7 @@ unaryExpandRank :: forall i o. Emits o unaryExpandRank op arg JArrayName{shape} = go arg shape where go :: Emits l => SAtom l -> [DimSizeName] -> JaxSimpM i l (SAtom l) go arg' = \case - [] -> emitExprToAtom $ PrimOp (UnOp op arg') + [] -> emit $ PrimOp (UnOp op arg') (DimSize sz:rest) -> buildFor noHint P.Fwd (litFinIxTy sz) \i -> do - ixed <- mkTabApp (sink arg') (toAtom i) >>= emitExprToAtom + ixed <- mkTabApp (sink arg') (toAtom i) >>= emit go ixed rest diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 3f1347bad..1f0054671 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -278,7 +278,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom applyLinLam (LamExpr bs body) = do TangentArgs args <- liftSubstReaderT $ getTangentArgs extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do - substM body >>= emitExpr + substM body >>= emit -- === actual linearization passs === @@ -307,7 +307,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts - emitExpr =<< applySubst substFrag tangentBody + emit =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun @@ -316,7 +316,7 @@ linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing -- reify the tangent builder as a lambda linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam o) linearizeLambdaApp (UnaryLamExpr b body) x = do - vp <- emit $ Atom x + vp <- emitToVar x extendActiveSubst b vp do WithTangent primalResult tangentAction <- linearizeExpr body tanFun <- tangentFunAsLambda tangentAction @@ -348,7 +348,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do expr' <- renameM expr isTrivialForAD expr' >>= \case True -> do - v <- emit expr' + v <- emitToVar expr' extendSubst (b@>atomVarName v) $ linearizeDecls rest cont False -> do WithTangent p tf <- linearizeExpr expr @@ -406,7 +406,7 @@ linearizeExpr expr = case expr of alts'' <- forM (enumerate alts') \(i, alt) -> do injectAltResult tys i alt let fullResultTy = PairTy resultTy' $ TyCon $ SumType tys - result <- emitExpr $ Case e' alts'' (EffTy effs' fullResultTy) + result <- emit $ Case e' alts'' (EffTy effs' fullResultTy) (primal, residualss) <- fromPair result resultTangentType <- tangentType resultTy' return $ WithTangent primal do @@ -418,7 +418,7 @@ linearizeExpr expr = case expr of TabCon _ ty xs -> do ty' <- renameM ty seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> - emitExpr $ TabCon Nothing (sink ty') xs' + emit $ TabCon Nothing (sink ty') xs' Project _ i x -> do WithTangent x' tx <- linearizeAtom x xi <- proj i x' @@ -431,19 +431,19 @@ linearizeOp op = case op of Hof (TypedHof _ e) -> linearizeHof e DAMOp _ -> error "shouldn't occur here" RefOp ref m -> case m of - MAsk -> linearizeAtom ref `bindLin` \ref' -> liftM toAtom $ emit $ PrimOp $ RefOp ref' MAsk + MAsk -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MAsk MExtend monoid x -> do -- TODO: check that we're dealing with a +/0 monoid monoid' <- renameM monoid zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM toAtom $ emit $ PrimOp $ RefOp ref' $ MExtend (sink monoid') x' - MGet -> linearizeAtom ref `bindLin` \ref' -> liftM toAtom $ emit $ PrimOp $ RefOp ref' MGet + emit $ RefOp ref' $ MExtend (sink monoid') x' + MGet -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MGet MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM toAtom $ emit $ PrimOp $ RefOp ref' $ MPut x' + emit $ RefOp ref' $ MPut x' IndexRef _ i -> do zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> - emitExpr =<< mkIndexRef ref' i' - ProjRef _ i -> la ref `bindLin` \ref' -> emitExpr =<< mkProjRef ref' i + emit =<< mkIndexRef ref' i' + ProjRef _ i -> la ref `bindLin` \ref' -> emit =<< mkProjRef ref' i UnOp uop x -> linearizeUnOp uop x BinOp bop x y -> linearizeBinOp bop x y -- XXX: This assumes that pointers are always constants @@ -451,7 +451,7 @@ linearizeOp op = case op of MiscOp miscOp -> linearizeMiscOp miscOp VectorOp _ -> error "not implemented" where - emitZeroT = withZeroT $ liftM toAtom $ emit =<< renameM (PrimOp op) + emitZeroT = withZeroT $ emit =<< renameM (PrimOp op) la = linearizeAtom linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom @@ -459,7 +459,7 @@ linearizeMiscOp op = case op of SumTag _ -> emitZeroT ToEnum _ _ -> emitZeroT Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin` - \(p' `PairE` t' `PairE` f') -> emitExpr $ MiscOp $ Select p' t' f' + \(p' `PairE` t' `PairE` f') -> emit $ MiscOp $ Select p' t' f' CastOp t v -> do vt <- getType <$> renameM v t' <- renameM t @@ -468,14 +468,14 @@ linearizeMiscOp op = case op of ((&&) <$> (vtTangentType `alphaEq` vt) <*> (tTangentType `alphaEq` t')) >>= \case True -> do - linearizeAtom v `bindLin` \v' -> emitExpr $ MiscOp $ CastOp (sink t') v' + linearizeAtom v `bindLin` \v' -> emit $ MiscOp $ CastOp (sink t') v' False -> do WithTangent x xt <- linearizeAtom v yt <- case (vtTangentType, tTangentType) of (_ , UnitTy) -> return $ UnitVal (UnitTy, tt ) -> zeroAt tt _ -> error "Expected at least one side of the CastOp to have a trivial tangent type" - y <- emitExpr $ MiscOp $ CastOp t' x + y <- emit $ MiscOp $ CastOp t' x return $ WithTangent y do xt >> return (sink yt) BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented @@ -486,13 +486,13 @@ linearizeMiscOp op = case op of ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" where - emitZeroT = withZeroT $ liftM toAtom $ emit =<< renameM (PrimOp $ MiscOp op) + emitZeroT = withZeroT $ emit =<< renameM (PrimOp $ MiscOp op) la = linearizeAtom linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom linearizeUnOp op x' = do WithTangent x tx <- linearizeAtom x' - let emitZeroT = withZeroT $ emitExpr $ UnOp op x + let emitZeroT = withZeroT $ emit $ UnOp op x case op of Exp -> do y <- emitUnOp Exp x @@ -523,7 +523,7 @@ linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom linearizeBinOp op x' y' = do WithTangent x tx <- linearizeAtom x' WithTangent y ty <- linearizeAtom y' - let emitZeroT = withZeroT $ emitExpr $ BinOp op x y + let emitZeroT = withZeroT $ emit $ BinOp op x y case op of IAdd -> emitZeroT ISub -> emitZeroT @@ -541,7 +541,7 @@ linearizeBinOp op x' y' = do ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty) (bindM2 mul (referToPrimal y) (referToPrimal y)) sub tx' ty' - FPow -> withT (emitExpr $ BinOp FPow x y) do + FPow -> withT (emit $ BinOp FPow x y) do px <- referToPrimal x py <- referToPrimal y c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px @@ -569,7 +569,7 @@ referToPrimal x = do AtomNameBinding (LetBound (DeclBinding PlainLet (TabApp _ tab i))) -> do tab' <- referToPrimal tab i' <- referToPrimal i - emitExpr =<< mkTabApp tab' i' + emit =<< mkTabApp tab' i' _ -> sinkM x _ -> sinkM x diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 6acf8eac5..cf28b0667 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -111,15 +111,15 @@ lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do False -> do initDest <- Con . ProdCon . (:[]) <$> case maybeDest of Just d -> return d - Nothing -> emitExpr $ AllocDest ansTy + Nothing -> emit $ AllocDest ansTy let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do (i, destProd) <- fromPair $ toAtom b' dest <- proj 0 destProd - idest <- emitExpr =<< mkIndexRef dest i + idest <- emit =<< mkIndexRef dest i extendSubst (ib @> SubstVal i) $ lowerExpr (Just idest) body $> UnitVal ans <- emitSeq dir ixTy' initDest body' >>= proj 0 - emitExpr $ Freeze ans + emit $ Freeze ans lowerFor _ _ _ _ _ = error "expected a unary lambda expression" lowerTabCon :: Emits o => OptDest o -> SType i -> [SAtom i] -> LowerM i o (SAtom o) @@ -127,7 +127,7 @@ lowerTabCon maybeDest tabTy elems = do TyCon (TabPi tabTy') <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ AllocDest $ TyCon $ TabPi tabTy' + Nothing -> emit $ AllocDest $ TyCon $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ toAtom $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used @@ -146,7 +146,7 @@ lowerTabCon maybeDest tabTy elems = do return UnitVal go carried_dest rest dest' <- go dest (enumerate elems) - emitExpr $ Freeze dest' + emit $ Freeze dest' lowerCase :: Emits o => OptDest o -> SAtom i -> [Alt SimpIR i] -> SType i @@ -155,7 +155,7 @@ lowerCase maybeDest scrut alts resultTy = do resultTy' <- substM resultTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ AllocDest resultTy' + Nothing -> emit $ AllocDest resultTy' scrut' <- visitAtom scrut dest' <- buildRememberDest "case_dest" dest \local_dest -> do alts' <- forM alts \(Abs (b:>ty) body) -> do @@ -164,9 +164,9 @@ lowerCase maybeDest scrut alts resultTy = do extendSubst (b @> Rename (atomVarName b')) $ buildBlock do lowerExpr (Just (toAtom $ sink $ local_dest)) body $> UnitVal - void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr + void $ mkCase (sink scrut') UnitTy alts' >>= emit return UnitVal - emitExpr $ Freeze dest' + emit $ Freeze dest' -- Destination-passing traversals -- @@ -246,7 +246,7 @@ lowerExpr dest expr = case expr of Rename v -> toAtom <$> toAtomVar v SubstVal a -> return a place d x - withSubst s' (substM result) >>= emitExpr + withSubst s' (substM result) >>= emit TabCon Nothing ty els -> lowerTabCon dest ty els PrimOp (Hof (TypedHof (EffTy _ ansTy) (For dir ixDict body))) -> do ansTy' <- substM ansTy @@ -263,13 +263,13 @@ lowerExpr dest expr = case expr of emitHof $ RunState ref' s' body' -- this case is important because this pass changes effects PrimOp (Hof (TypedHof _ hof)) -> do - hof' <- emitExpr =<< (visitGeneric hof >>= mkTypedHof) + hof' <- emit =<< (visitGeneric hof >>= mkTypedHof) placeGeneric hof' Case e alts (EffTy _ ty) -> lowerCase dest e alts ty _ -> generic where generic :: LowerM i o (SAtom o) - generic = visitGeneric expr >>= emitExpr >>= placeGeneric + generic = visitGeneric expr >>= emit >>= placeGeneric placeGeneric :: SAtom o -> LowerM i o (SAtom o) placeGeneric e = do @@ -301,7 +301,7 @@ lowerExpr dest expr = case expr of return $ Just (Just bd, Just rd) place :: Emits o => Dest o -> SAtom o -> LowerM i o () -place d x = void $ emitExpr $ Place d x +place d x = void $ emit $ Place d x -- === Extensions to the name system === diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index dd1d0aaae..2bc0452ab 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -254,7 +254,7 @@ ulExpr expr = case expr of getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do let tabTy = toType $ TabPiType (DictCon $ IxRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy - emitExpr $ TabCon Nothing tabTy vals + emit $ TabCon Nothing tabTy vals _ -> error "Expected `for` body to have a Pi type" _ -> error "Expected `for` body to be a lambda expression" False -> do @@ -268,8 +268,7 @@ ulExpr expr = case expr of _ -> nothingSpecial where inc i = modify \(ULS n) -> ULS (n + i) - nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr) - >>= emitExprToAtom + nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr) >>= emit unrollBlowupThreshold = 12 withLocalAccounting m = do oldCost <- get @@ -344,7 +343,7 @@ licmExpr = \case block <- mkBlock =<< applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' - expr -> visitGeneric expr >>= emitExpr + expr -> visitGeneric expr >>= emit seqLICM :: RNest SDecl n1 n2 -- hoisted decls -> [SAtomName n2] -- hoisted dests diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index f073ea0c9..63fc869fd 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -67,9 +67,9 @@ showAnyTyCon tyCon atom = case tyCon of Vector _ _ -> error "not implemented" PtrType _ -> printTypeOnly "pointer" Scalar _ -> do - (n, tab) <- fromPair =<< emitExpr (ShowScalar atom) + (n, tab) <- fromPair =<< emit (ShowScalar atom) logicalTabTy <- finTabTyCore (Con $ NewtypeCon NatCon n) CharRepTy - tab' <- emitExpr $ UnsafeCoerce logicalTabTy tab + tab' <- emit $ UnsafeCoerce logicalTabTy tab emitCharTab tab' -- TODO: we could do better than this but it's not urgent because raw sum types -- aren't user-facing. @@ -92,7 +92,7 @@ showAnyTyCon tyCon atom = case tyCon of n <- unwrapNewtype atom -- Cast to Int so that it prints in decimal instead of hex let intTy = toType $ BaseType (Scalar Int64Type) - emitExpr (CastOp intTy n) >>= rec + emit (CastOp intTy n) >>= rec EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype UserADTType "List" _ (TyConParams [Explicit] [Con (TyConAtom (BaseType (Scalar (Word8Type))))]) -> do @@ -199,7 +199,7 @@ pushBuffer buf x = do stringLitAsCharTab :: (Emits n, CBuilder m) => String -> m n (CAtom n) stringLitAsCharTab s = do t <- finTabTyCore (NatVal $ fromIntegral $ length s) CharRepTy - emitExpr $ TabCon Nothing t (map charRepVal s) + emit $ TabCon Nothing t (map charRepVal s) finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) finTabTyCore n eltTy = return $ IxType (FinTy n) (DictCon $ IxFin n) ==> eltTy diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 239641ede..53a5fe7f6 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -400,7 +400,7 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of ty' <- substM ty tySimp <- getRepType ty xs' <- forM xs \x -> toDataAtom x - liftSimpAtom ty' =<< emitExpr (TabCon Nothing tySimp xs') + liftSimpAtom ty' =<< emit (TabCon Nothing tySimp xs') Case scrut alts (EffTy _ resultTy) -> do scrut' <- simplifyAtom scrut resultTy' <- substM resultTy @@ -427,17 +427,17 @@ simplifyRefOp op ref = case op of x' <- toDataAtom x (cb', CoerceReconAbs) <- simplifyLam cb emitRefOp $ MExtend (BaseMonoid em' cb') x' - MGet -> emitExpr $ RefOp ref MGet + MGet -> emit $ RefOp ref MGet MPut x -> do x' <- toDataAtom x emitRefOp $ MPut x' MAsk -> emitRefOp MAsk IndexRef _ x -> do x' <- toDataAtom x - emitExpr =<< mkIndexRef ref x' - ProjRef _ (ProjectProduct i) -> emitExpr =<< mkProjRef ref (ProjectProduct i) + emit =<< mkIndexRef ref x' + ProjRef _ (ProjectProduct i) -> emit =<< mkProjRef ref (ProjectProduct i) ProjRef _ UnwrapNewtype -> return ref - where emitRefOp op' = emitExpr $ RefOp ref op' + where emitRefOp op' = emit $ RefOp ref op' defuncCaseCore :: Emits o => Atom CoreIR o -> Type CoreIR o @@ -472,7 +472,7 @@ defuncCase scrut resultTy cont = do ans <- cont i (toAtom $ sink x) dropSubst $ toDataAtom ans caseExpr <- mkCase scrut resultTyData alts' - emitExpr caseExpr >>= liftSimpAtom resultTy + emit caseExpr >>= liftSimpAtom resultTy Nothing -> do split <- splitDataComponents resultTy (alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do @@ -481,7 +481,7 @@ defuncCase scrut resultTy cont = do let newNonDataTy = nonDataTy split alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts'' - caseResult <- emitExpr $ caseExpr + caseResult <- emit $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> buildAbs noHint ty \v -> applyRecon (sink recon) (toAtom v) @@ -519,7 +519,7 @@ simplifyApp resultTy f xs = case f of CCFun ccFun -> case ccFun of CCLiftSimpFun _ lam -> do xs' <- dropSubst $ mapM toDataAtom xs - result <- instantiate lam xs' >>= emitExpr + result <- instantiate lam xs' >>= emit liftSimpAtom resultTy result CCNoInlineFun v _ _ -> simplifyTopFunApp v xs CCFFIFun _ f' -> do @@ -627,7 +627,7 @@ simplifyDictMethod absDict@(Abs bs dict) method = do lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict - emitExpr =<< mkApplyMethod dict' (fromEnum method) (toAtom <$> methodArgs) + emit =<< mkApplyMethod dict' (fromEnum method) (toAtom <$> methodArgs) simplifyTopFunction lamExpr ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) @@ -730,10 +730,10 @@ simplifyOp op = case op of IDiv -> idiv x y ICmp Less -> ilt x y ICmp Equal -> ieq x y - _ -> emitExpr $ BinOp binop x y + _ -> emit $ BinOp binop x y UnOp unOp x' -> do x <- toDataAtom x' - liftResult =<< emitExpr (UnOp unOp x) + liftResult =<< emit (UnOp unOp x) MiscOp op' -> case op' of Select c' x' y' -> do c <- toDataAtom c' @@ -757,7 +757,7 @@ simplifyGenericOp simplifyGenericOp op = do ty <- substM $ getType op op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left") - result <- liftEnvReaderM (peepholeExpr $ toExpr op') >>= emitExprToAtom + result <- liftEnvReaderM (peepholeExpr $ toExpr op') >>= emit liftSimpAtom ty result {-# INLINE simplifyGenericOp #-} @@ -861,7 +861,7 @@ simplifyHof resultTy = \case SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ exceptToMaybeExpr body' - result <- emitExpr block + result <- emit block case recon of CoerceRecon ty -> do maybeTy <- makePreludeMaybeTy ty @@ -1029,7 +1029,7 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do primalFun <- LamExpr bs <$> mkBlock declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (toAtom residuals) - applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitExpr + applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emit let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun @@ -1106,7 +1106,7 @@ exceptToMaybeExpr expr = case expr of False -> do v <- emit expr' let ty = getType v - return $ JustAtom ty (toAtom v) + return $ JustAtom ty v hasExceptions :: SExpr n -> Bool hasExceptions expr = case getEffects expr of diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 8511559f5..2f4970c3d 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -639,8 +639,7 @@ compileTopLevelFun cc fSimp = do printCodegen :: (Topper m, Mut n) => CAtom n -> m n String printCodegen x = do - block <- liftBuilder $ buildBlock do - emitExpr $ PrimOp $ MiscOp $ ShowAny $ sink x + block <- liftBuilder $ buildBlock $ emit $ ShowAny $ sink x (topBlock, _) <- asTopBlock block getDexString =<< evalBlock topBlock diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index a296087fb..e312de43b 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -145,7 +145,7 @@ withAccumulator ty cont = do emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) - void $ emitExpr $ RefOp ref $ MExtend baseMonoid ct + void $ emit $ RefOp ref $ MExtend baseMonoid ct getLinRegions :: TransposeM i o [SAtomVar o] getLinRegions = asks fromListE @@ -166,7 +166,7 @@ transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct = transposeWithDecls rest result (sink ct) transposeExpr expr ctExpr Just nonlinExpr -> do - v <- emit nonlinExpr + v <- emitToVar nonlinExpr extendSubst (b @> RenameNonlin (atomVarName v)) $ transposeWithDecls rest result ct @@ -226,7 +226,7 @@ transposeExpr expr ct = case expr of False -> do e' <- substNonlin e void $ buildCase e' UnitTy \i v -> do - v' <- emit (Atom v) + v' <- emitToVar v Abs b body <- return $ alts !! i extendSubst (b @> RenameNonlin (atomVarName v')) do transposeExpr body (sink ct) @@ -244,7 +244,7 @@ transposeOp op ct = case op of DAMOp _ -> error "unreachable" -- TODO: rule out statically RefOp refArg m -> do refArg' <- substNonlin refArg - let emitEff = emitExpr . RefOp refArg' + let emitEff = emit . RefOp refArg' case m of MAsk -> do baseMonoid <- tangentBaseMonoidFor (getType ct) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 2405da77d..189fe03dc 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -178,8 +178,8 @@ vectorizeLoopsExpr expr = do let vn = n `div` loopWidth body' <- vectorizeSeq loopWidth ixty body dest' <- renameM dest - emitExpr =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') - else renameM expr >>= emitExpr) + emit =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') + else renameM expr >>= emit) `catchErr` \err -> do let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr ctx = mempty { messageCtx = [msg] } @@ -194,7 +194,7 @@ vectorizeLoopsExpr expr = do extendRenamer (hb' @> atomVarName hb) do extendRenamer (refb' @> atomVarName refb) do vectorizeLoopsExpr body - emitExpr =<< mkTypedHof (RunReader item' lam) + emit =<< mkTypedHof (RunReader item' lam) PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do dest' <- renameM dest @@ -206,8 +206,8 @@ vectorizeLoopsExpr expr = do extendRenamer (refb' @> atomVarName refb) do extendCommuteMap (atomVarName hb) commutativity do vectorizeLoopsExpr body - emitExpr =<< mkTypedHof (RunWriter (Just dest') monoid' lam) - _ -> renameM expr >>= emitExpr + emit =<< mkTypedHof (RunWriter (Just dest') monoid' lam) + _ -> renameM expr >>= emit where recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do @@ -215,7 +215,7 @@ vectorizeLoopsExpr expr = do ixty' <- renameM ixty dest' <- renameM dest body' <- vectorizeLoopsLamExpr body - emitExpr $ Seq effs' dir ixty' dest' body' + emit $ Seq effs' dir ixty' dest' body' recurSeq _ = error "Impossible" simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) @@ -385,7 +385,7 @@ vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Varying <$> emitExpr (VectorIdx tbl' ix' vty) + VVal Varying <$> emit (VectorIdx tbl' ix' vty) tblTy -> do throwVectErr $ "bad type: " ++ pprint tblTy ++ "\ntbl' : " ++ pprint tbl' Atom atom -> vectorizeAtom atom @@ -405,11 +405,11 @@ vectorizeDAMOp op = VVal vref ref <- vectorizeAtom ref' sval@(VVal vval val) <- vectorizeAtom val' VVal Uniform <$> case (vref, vval) of - (Uniform , Uniform ) -> emitExpr $ Place ref val + (Uniform , Uniform ) -> emit $ Place ref val (Uniform , _ ) -> throwVectErr "Write conflict? This should never happen!" (Varying , _ ) -> throwVectErr "Vector scatter not implemented" - (Contiguous, Varying ) -> emitExpr $ Place ref val - (Contiguous, Contiguous) -> emitExpr . Place ref =<< ensureVarying sval + (Contiguous, Varying ) -> emit $ Place ref val + (Contiguous, Contiguous) -> emit . Place ref =<< ensureVarying sval _ -> throwVectErr "Not implemented yet" _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op @@ -420,7 +420,7 @@ vectorizeRefOp ref' op = -- TODO A contiguous reference becomes a vector load producing a varying -- result. VVal Uniform ref <- vectorizeAtom ref' - VVal Uniform <$> emitExpr (RefOp ref MAsk) + VVal Uniform <$> emit (RefOp ref MAsk) MExtend basemonoid' x' -> do VVal refStab ref <- vectorizeAtom ref' VVal xStab x <- vectorizeAtom x' @@ -437,7 +437,7 @@ vectorizeRefOp ref' op = Contiguous -> do vectorizeBaseMonoid basemonoid' Varying xStab s -> throwVectErr $ "Cannot vectorize reference with loop-varying stability " ++ show s - VVal Uniform <$> emitExpr (RefOp ref $ MExtend basemonoid x) + VVal Uniform <$> emit (RefOp ref $ MExtend basemonoid x) IndexRef _ i' -> do VVal Uniform ref <- vectorizeAtom ref' VVal Contiguous i <- vectorizeAtom i' @@ -446,7 +446,7 @@ vectorizeRefOp ref' op = vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Contiguous <$> emitExpr (VectorSubref ref i vty) + VVal Contiguous <$> emit (VectorSubref ref i vty) refTy -> do throwVectErr do "bad type: " ++ pprint refTy ++ "\nref' : " ++ pprint ref' @@ -472,7 +472,7 @@ vectorizePrimOp op = case op of sx@(VVal vx x) <- vectorizeAtom arg let v = case vx of Uniform -> Uniform; _ -> Varying x' <- if vx /= v then ensureVarying sx else return x - VVal v <$> emitExpr (UnOp opk x') + VVal v <$> emit (UnOp opk x') BinOp opk arg1 arg2 -> do sx@(VVal vx x) <- vectorizeAtom arg1 sy@(VVal vy y) <- vectorizeAtom arg2 @@ -483,7 +483,7 @@ vectorizePrimOp op = case op of _ -> Varying x' <- if v == Varying then ensureVarying sx else return x y' <- if v == Varying then ensureVarying sy else return y - VVal v <$> emitExpr (BinOp opk x' y') + VVal v <$> emit (BinOp opk x' y') MiscOp (CastOp tyArg arg) -> do ty <- vectorizeType tyArg VVal vx x <- vectorizeAtom arg @@ -492,19 +492,19 @@ vectorizePrimOp op = case op of Varying -> getVectorType ty Contiguous -> return ty ProdStability _ -> throwVectErr "Unexpected cast of product type" - VVal vx <$> emitExpr (CastOp ty' x) + VVal vx <$> emit (CastOp ty' x) DAMOp op' -> vectorizeDAMOp op' RefOp ref op' -> vectorizeRefOp ref op' MemOp (PtrOffset arg1 arg2) -> do VVal Uniform ptr <- vectorizeAtom arg1 VVal Contiguous off <- vectorizeAtom arg2 - VVal Contiguous <$> emitExpr (PtrOffset ptr off) + VVal Contiguous <$> emit (PtrOffset ptr off) MemOp (PtrLoad arg) -> do VVal Contiguous ptr <- vectorizeAtom arg BaseTy (PtrType (addrSpace, a)) <- return $ getType ptr BaseTy av <- getVectorType $ BaseTy a - ptr' <- emitExpr $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr - VVal Varying <$> emitExpr (PtrLoad ptr') + ptr' <- emit $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr + VVal Varying <$> emit (PtrLoad ptr') -- Vectorizing IO might not always be safe! Here, we depend on vectorizeOp -- being picky about the IO-inducing ops it supports, and expect it to -- complain about FFI calls and the like. @@ -570,16 +570,16 @@ ensureVarying (VVal s val) = case s of Varying -> return val Uniform -> do vty <- getVectorType $ getType val - emitExpr $ VectorBroadcast val vty + emit $ VectorBroadcast val vty -- Note that the implementation of this case will depend on val's type. Contiguous -> do let ty = getType val vty <- getVectorType ty case ty of BaseTy (Scalar sbt) -> do - bval <- emitExpr $ VectorBroadcast val vty - iota <- emitExpr $ VectorIota vty - emitExpr $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota + bval <- emit $ VectorBroadcast val vty + iota <- emit $ VectorIota vty + emit $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota _ -> throwVectErr "Not implemented" ProdStability _ -> throwVectErr "Not implemented" ensureVarying (VRename v) = do From d10cfc591dfe5d04d97d56e72dcde6c0969fe1de Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 9 Nov 2023 14:16:07 -0500 Subject: [PATCH 11/41] Put peephole optimizations in one place and use them from Builder. The reason to do this now is that I want to make AD linearity explicit for correctness reasons. I started doing that but realized I was going to need linear version of each of the helper functions in Builder `add`, `mul`. This cuts down on that boilerplate and it's a good idea anyway. --- dex.cabal | 1 + src/lib/Algebra.hs | 13 +- src/lib/Builder.hs | 145 +++++-------------- src/lib/Inference.hs | 5 +- src/lib/Inline.hs | 4 +- src/lib/Linearize.hs | 10 +- src/lib/Optimize.hs | 177 +---------------------- src/lib/PeepholeOptimize.hs | 276 ++++++++++++++++++++++++++++++++++++ src/lib/RuntimePrint.hs | 2 +- src/lib/Simplify.hs | 36 ++--- 10 files changed, 346 insertions(+), 323 deletions(-) create mode 100644 src/lib/PeepholeOptimize.hs diff --git a/dex.cabal b/dex.cabal index 1b6b7b771..e3737f4e3 100644 --- a/dex.cabal +++ b/dex.cabal @@ -77,6 +77,7 @@ library , Occurrence , OccAnalysis , Optimize + , PeepholeOptimize , PPrint , RawName , Runtime diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index eaebecaaf..b3d6d2502 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -50,7 +50,7 @@ newtype Polynomial (n::S) = -- us compute sums in closed form. This tries to compute -- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`. sumUsingPolys :: Emits n - => Atom SimpIR n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (Atom SimpIR n) + => SAtom n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (SAtom n) sumUsingPolys lim (Abs i body) = do sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do exprAsPoly body' >>= \case @@ -138,7 +138,7 @@ type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR exprAsPoly :: (EnvExtender m, EnvReader m) => SExpr n -> m n (Maybe (Polynomial n)) exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPolyRec expr -atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o) +atomAsPoly :: SAtom i -> BlockTraverserM i o (Polynomial o) atomAsPoly = \case Stuck _ (Var v) -> atomVarAsPoly v Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v' @@ -190,7 +190,7 @@ blockAsPoly (Abs decls result) = case decls of -- coefficients. This is why we have to find the least common multiples and do the -- accumulation over numbers multiplied by that LCM. We essentially do fixed point -- fractional math here. -emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (Atom SimpIR n) +emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (SAtom n) emitPolynomial (Polynomial p) = do let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p monoAtoms <- flip traverse (toList p) $ \(m, c) -> do @@ -204,7 +204,7 @@ emitPolynomial (Polynomial p) = do -- because it might be causing overflows due to all arithmetic being shifted. asAtom = IdxRepVal . fromInteger -emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (Atom SimpIR n) +emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (SAtom n) emitMonomial (Monomial m) = do varAtoms <- forM (toList m) \(v, e) -> case v of LeftE v' -> do @@ -215,9 +215,12 @@ emitMonomial (Monomial m) = do ipow atom e foldM imul (IdxRepVal 1) varAtoms -ipow :: Emits n => Atom SimpIR n -> Int -> BuilderM SimpIR n (Atom SimpIR n) +ipow :: Emits n => SAtom n -> Int -> BuilderM SimpIR n (SAtom n) ipow x i = foldM imul (IdxRepVal 1) (replicate i x) +idiv :: Emits n => SAtom n -> SAtom n -> BuilderM SimpIR n (SAtom n) +idiv = undefined + -- === instances === instance GenericE Monomial where diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index e38792bc9..b65b1414d 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -29,12 +29,13 @@ import IRVariants import MTL1 import Subst import Name +import PeepholeOptimize import QueryType import Types.Core import Types.Imp import Types.Primitives import Types.Source -import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...)) +import Util (enumerate, transitiveClosureM, bindM2, toSnocList) -- === Ordinary (local) builder class === @@ -66,50 +67,31 @@ emitDecl _ _ (Atom (Stuck _ (Var n))) = return n emitDecl hint ann expr = rawEmitDecl hint ann expr {-# INLINE emitDecl #-} -emitInline :: (Builder r m, Emits n) => Atom r n -> m n (AtomVar r n) -emitInline atom = emitDecl noHint InlineLet $ Atom atom -{-# INLINE emitInline #-} - -emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n) -emitHinted hint expr = emitDecl hint PlainLet expr -{-# INLINE emitHinted #-} - emit :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) emit e = case toExpr e of Atom x -> return x Block _ block -> emitDecls block >>= emit - expr -> toAtom <$> emitToVar expr + expr -> do + v <- emitDecl noHint PlainLet $ peepholeExpr expr + return $ toAtom v {-# INLINE emit #-} emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n) -emitToVar e = case toExpr e of - Atom (Stuck _ (Var v)) -> return v - expr -> emitDecl noHint PlainLet expr +emitToVar expr = emit expr >>= \case + Stuck _ (Var v) -> return v + atom -> emitDecl noHint PlainLet (toExpr atom) {-# INLINE emitToVar #-} -emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) -emitHof hof = mkTypedHof hof >>= emit - -mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) -mkTypedHof hof = do - effTy <- effTyOfHof hof - return $ TypedHof effTy hof - -emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) -emitUnOp op x = emit $ UnOp op x -{-# INLINE emitUnOp #-} - emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) => WithDecls r e n -> m n (e n) -emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result - -emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e) - => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) -emitDecls' Empty e = renameM e -emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do - expr' <- renameM expr - AtomVar v _ <- emitDecl (getNameHint b) ann expr' - extendSubst (b @> v) $ emitDecls' rest e +emitDecls (Abs decls result) = runSubstReaderT idSubst $ go decls result where + go :: (Builder r m, Emits o, RenameE e, SinkableE e) + => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) + go Empty e = renameM e + go (Nest (Let b (DeclBinding ann expr)) rest) e = do + expr' <- renameM expr + AtomVar v _ <- emitDecl (getNameHint b) ann expr' + extendSubst (b @> v) $ go rest e buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m) => (forall l. (Emits l, DExt n l) => m l (e l)) @@ -775,6 +757,14 @@ buildEffLam hint ty body = do body' <- buildBlock $ body (sink hVar) $ sink ref return $ LamExpr (BinaryNest h b) body' +emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHof hof = mkTypedHof hof >>= emit + +mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) +mkTypedHof hof = do + effTy <- effTyOfHof hof + return $ TypedHof effTy hof + buildForAnn :: (Emits n, ScopableBuilder r m) => NameHint -> ForAnn -> IxType r n @@ -940,70 +930,38 @@ symbolicTangentNonZero val = do -- === builder versions of common local ops === -fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n) -fLitLike x t = case getTyCon t of - BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x - BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x - _ -> error "Expected a floating point scalar" - neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) neg x = emit $ UnOp FNeg x add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) add x y = emit $ BinOp FAdd x y --- TODO: Implement constant folding for fixed-width integer types as well! -iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -iadd (Con (Lit l)) y | getIntLit l == 0 = return y -iadd x (Con (Lit l)) | getIntLit l == 0 = return x -iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y -iadd x y = emit $ BinOp IAdd x y - mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) mul x y = emit $ BinOp FMul x y +iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +iadd x y = emit $ BinOp IAdd x y + imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -imul (Con (Lit l)) y | getIntLit l == 1 = return y -imul x (Con (Lit l)) | getIntLit l == 1 = return x -imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y imul x y = emit $ BinOp IMul x y -sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emit $ BinOp FSub x y - -isub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -isub x (Con (Lit l)) | getIntLit l == 0 = return x -isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y -isub x y = emit $ BinOp ISub x y - -select :: (Builder r m, Emits n) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n) -select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y -select p x y = emit $ MiscOp $ Select p x y - div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) div' x y = emit $ BinOp FDiv x y -idiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -idiv x (Con (Lit l)) | getIntLit l == 1 = return x -idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y -idiv x y = emit $ BinOp IDiv x y - -irem :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -irem x y = emit $ BinOp IRem x y - fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) fpow x y = emit $ BinOp FPow x y +sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +sub x y = emit $ BinOp FSub x y + flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) flog x = emit $ UnOp Log x -ilt :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y -ilt x y = emit $ BinOp (ICmp Less) x y - -ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y -ieq x y = emit $ BinOp (ICmp Equal) x y +fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n) +fLitLike x t = case getTyCon t of + BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x + BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x + _ -> error "Expected a floating point scalar" fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) fromPair pair = do @@ -1160,7 +1118,7 @@ app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) app x i = mkApp x [i] >>= emit naryApp :: (CBuilder m, Emits n) => CAtom n -> [CAtom n] -> m n (CAtom n) -naryApp = naryAppHinted noHint +naryApp f xs= mkApp f xs >>= emit {-# INLINE naryApp #-} naryTopApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n) @@ -1175,10 +1133,6 @@ naryTopAppInlined f xs = do _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} -naryAppHinted :: (CBuilder m, Emits n) - => NameHint -> CAtom n -> [CAtom n] -> m n (CAtom n) -naryAppHinted hint f xs = toAtom <$> (mkApp f xs >>= emitHinted hint) - tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) tabApp x i = mkTabApp x i >>= emit @@ -1581,30 +1535,3 @@ visitDeclsEmits (Nest (Let b (DeclBinding _ expr)) decls) cont = do x <- visitExprEmits expr extendSubst (b@>SubstVal x) do visitDeclsEmits decls cont - --- === Helpers for function evaluation over fixed-width types === - -applyIntBinOp' :: (forall a. (Eq a, Ord a, Num a, Integral a) - => (a -> Atom r n) -> a -> a -> Atom r n) -> Atom r n -> Atom r n -> Atom r n -applyIntBinOp' f x y = case (x, y) of - (Con (Lit (Int64Lit xv)), Con (Lit (Int64Lit yv))) -> f (Con . Lit . Int64Lit) xv yv - (Con (Lit (Int32Lit xv)), Con (Lit (Int32Lit yv))) -> f (Con . Lit . Int32Lit) xv yv - (Con (Lit (Word8Lit xv)), Con (Lit (Word8Lit yv))) -> f (Con . Lit . Word8Lit) xv yv - (Con (Lit (Word32Lit xv)), Con (Lit (Word32Lit yv))) -> f (Con . Lit . Word32Lit) xv yv - (Con (Lit (Word64Lit xv)), Con (Lit (Word64Lit yv))) -> f (Con . Lit . Word64Lit) xv yv - _ -> error "Expected integer atoms" - -applyIntBinOp :: (forall a. (Num a, Integral a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n -applyIntBinOp f x y = applyIntBinOp' (\w -> w ... f) x y - -applyIntCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> Atom r n -> Atom r n -> Atom r n -applyIntCmpOp f x y = applyIntBinOp' (\_ -> (Con . Lit . Word8Lit . fromIntegral . fromEnum) ... f) x y - -applyFloatBinOp :: (forall a. (Num a, Fractional a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n -applyFloatBinOp f x y = case (x, y) of - (Con (Lit (Float64Lit xv)), Con (Lit (Float64Lit yv))) -> Con $ Lit $ Float64Lit $ f xv yv - (Con (Lit (Float32Lit xv)), Con (Lit (Float32Lit yv))) -> Con $ Lit $ Float32Lit $ f xv yv - _ -> error "Expected float atoms" - -_applyFloatUnOp :: (forall a. (Num a, Fractional a) => a -> a) -> Atom r n -> Atom r n -_applyFloatUnOp f x = applyFloatBinOp (\_ -> f) (error "shouldn't be needed") x diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index c8e9681e2..b9253952a 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -122,7 +122,7 @@ inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case d _ -> do PairE block recon <- liftInfererM $ buildBlockInfWithRecon do val <- checkMaybeAnnExpr tyAnn rhs - v <- emitHinted (getNameHint p) $ Atom val + v <- emitDecl (getNameHint p) PlainLet $ Atom val bindLetPat p v do renameM result (topBlock, _) <- asTopBlock block @@ -1597,6 +1597,9 @@ bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of xs <- forM [0 .. n - 1] \i -> do emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont + where + emitInline :: Emits n => CAtom n -> InfererM i n (AtomVar CoreIR n) + emitInline atom = emitDecl noHint InlineLet $ Atom atom checkUType :: UType i -> InfererM i o (CType o) checkUType t = do diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 9ba40faea..4104948c0 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -14,7 +14,7 @@ import IRVariants import Name import Subst import Occurrence hiding (Var) -import Optimize +import PeepholeOptimize import Types.Core import Types.Primitives @@ -80,7 +80,7 @@ inlineDeclsSubst = \case s <- getSubst extendSubst (b @> SubstVal (SuspEx expr s)) $ inlineDeclsSubst rest else do - expr' <- inlineExpr Stop expr >>= (liftEnvReaderM . peepholeExpr) + expr' <- peepholeExpr <$> inlineExpr Stop expr -- If the inliner starts moving effectful expressions, it may become -- necessary to query the effects of the new expression here. let presInfo = resolveWorkConservation ann expr' diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 1f0054671..d4ca0417e 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -495,18 +495,18 @@ linearizeUnOp op x' = do let emitZeroT = withZeroT $ emit $ UnOp op x case op of Exp -> do - y <- emitUnOp Exp x + y <- emit $ UnOp Exp x return $ WithTangent y (bindM2 mul tx (sinkM y)) Exp2 -> notImplemented - Log -> withT (emitUnOp Log x) $ (tx >>= (`div'` sink x)) + Log -> withT (emit $ UnOp Log x) $ (tx >>= (`div'` sink x)) Log2 -> notImplemented Log10 -> notImplemented Log1p -> notImplemented - Sin -> withT (emitUnOp Sin x) $ bindM2 mul tx (emitUnOp Cos (sink x)) - Cos -> withT (emitUnOp Cos x) $ bindM2 mul tx (neg =<< emitUnOp Sin (sink x)) + Sin -> withT (emit $ UnOp Sin x) $ bindM2 mul tx (emit $ UnOp Cos (sink x)) + Cos -> withT (emit $ UnOp Cos x) $ bindM2 mul tx (neg =<< emit (UnOp Sin (sink x))) Tan -> notImplemented Sqrt -> do - y <- emitUnOp Sqrt x + y <- emit $ UnOp Sqrt x return $ WithTangent y do denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y) bindM2 div' tx (pure denominator) diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 2bc0452ab..b1714fd62 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -7,16 +7,11 @@ {-# LANGUAGE UndecidableInstances #-} module Optimize - ( optimize, peepholeOp, peepholeExpr, hoistLoopInvariant, dceTop, foldCast ) where + ( optimize, hoistLoopInvariant, dceTop) where import Data.Functor -import Data.Word -import Data.Bits -import Data.Bits.Floating -import Data.List import Control.Monad import Control.Monad.State.Strict -import GHC.Float import Types.Core import Types.Primitives @@ -37,174 +32,6 @@ optimize = dceTop -- Clean up user code >=> dceTop -- Clean up peephole-optimized code after unrolling >=> hoistLoopInvariant --- === Peephole optimizations === - -peepholeOp :: PrimOp SimpIR o -> EnvReaderM o (SExpr o) -peepholeOp op = case op of - MiscOp (CastOp (TyCon (BaseType (Scalar sTy))) (Con (Lit l))) -> return $ case foldCast sTy l of - Just l' -> lit l' - Nothing -> noop - -- TODO: Support more unary and binary ops. - BinOp IAdd l r -> return $ case (l, r) of - -- TODO: Shortcut when either side is zero. - (Con (Lit ll), Con (Lit rl)) -> case (ll, rl) of - (Word32Lit lv, Word32Lit lr) -> lit $ Word32Lit $ lv + lr - _ -> noop - _ -> noop - BinOp (ICmp cop) (Con (Lit ll)) (Con (Lit rl)) -> - return $ lit $ Word8Lit $ fromIntegral $ fromEnum $ case (ll, rl) of - (Int32Lit lv, Int32Lit rv) -> cmp cop lv rv - (Int64Lit lv, Int64Lit rv) -> cmp cop lv rv - (Word8Lit lv, Word8Lit rv) -> cmp cop lv rv - (Word32Lit lv, Word32Lit rv) -> cmp cop lv rv - (Word64Lit lv, Word64Lit rv) -> cmp cop lv rv - _ -> error "Ill typed ICmp?" - BinOp (FCmp cop) (Con (Lit ll)) (Con (Lit rl)) -> - return $ lit $ Word8Lit $ fromIntegral $ fromEnum $ case (ll, rl) of - (Float32Lit lv, Float32Lit rv) -> cmp cop lv rv - (Float64Lit lv, Float64Lit rv) -> cmp cop lv rv - _ -> error "Ill typed FCmp?" - BinOp BOr (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) -> - return $ lit $ Word8Lit $ lv .|. rv - BinOp BAnd (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) -> - return $ lit $ Word8Lit $ lv .&. rv - MiscOp (ToEnum ty (Con (Lit (Word8Lit tag)))) -> case ty of - TyCon (SumType cases) -> return $ toExpr $ SumCon cases (fromIntegral tag) UnitVal - _ -> error "Ill typed ToEnum?" - MiscOp (SumTag (Con (SumCon _ tag _))) -> return $ lit $ Word8Lit $ fromIntegral tag - _ -> return noop - where - noop = PrimOp op - lit = Atom . Con . Lit - - cmp :: Ord a => CmpOp -> a -> a -> Bool - cmp = \case - Less -> (<) - Greater -> (>) - Equal -> (==) - LessEqual -> (<=) - GreaterEqual -> (>=) - -foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal -foldCast sTy l = case sTy of - -- TODO: Check that the casts relating to floating-point agree with the - -- runtime behavior. The runtime is given by the `ICastOp` case in - -- ImpToLLVM.hs. We should make sure that the Haskell functions here - -- produce bitwise identical results to those instructions, by adjusting - -- either this or that as called for. - -- TODO: Also implement casts that may have unrepresentable results, i.e., - -- casting floating-point numbers to smaller floating-point numbers or to - -- fixed-point. Both of these necessarily have a much smaller dynamic range. - Int32Type -> case l of - Int32Lit _ -> Just l - Int64Lit i -> Just $ Int32Lit $ fromIntegral i - Word8Lit i -> Just $ Int32Lit $ fromIntegral i - Word32Lit i -> Just $ Int32Lit $ fromIntegral i - Word64Lit i -> Just $ Int32Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Int64Type -> case l of - Int32Lit i -> Just $ Int64Lit $ fromIntegral i - Int64Lit _ -> Just l - Word8Lit i -> Just $ Int64Lit $ fromIntegral i - Word32Lit i -> Just $ Int64Lit $ fromIntegral i - Word64Lit i -> Just $ Int64Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word8Type -> case l of - Int32Lit i -> Just $ Word8Lit $ fromIntegral i - Int64Lit i -> Just $ Word8Lit $ fromIntegral i - Word8Lit _ -> Just l - Word32Lit i -> Just $ Word8Lit $ fromIntegral i - Word64Lit i -> Just $ Word8Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word32Type -> case l of - Int32Lit i -> Just $ Word32Lit $ fromIntegral i - Int64Lit i -> Just $ Word32Lit $ fromIntegral i - Word8Lit i -> Just $ Word32Lit $ fromIntegral i - Word32Lit _ -> Just l - Word64Lit i -> Just $ Word32Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word64Type -> case l of - Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) - Int64Lit i -> Just $ Word64Lit $ fromIntegral i - Word8Lit i -> Just $ Word64Lit $ fromIntegral i - Word32Lit i -> Just $ Word64Lit $ fromIntegral i - Word64Lit _ -> Just l - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Float32Type -> case l of - Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Word8Lit i -> Just $ Float32Lit $ fromIntegral i - Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Float32Lit _ -> Just l - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Float64Type -> case l of - Int32Lit i -> Just $ Float64Lit $ fromIntegral i - Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i - Word8Lit i -> Just $ Float64Lit $ fromIntegral i - Word32Lit i -> Just $ Float64Lit $ fromIntegral i - Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i - Float32Lit f -> Just $ Float64Lit $ float2Double f - Float64Lit _ -> Just l - PtrLit _ _ -> Nothing - where - -- When casting an integer type to a floating-point type of lower precision - -- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds - -- toward zero, instead of rounding to nearest even like everybody else. - -- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231. - -- - -- We patch this by manually checking the two adjacent floats to the - -- candidate answer, and using one of those if the reverse cast is closer - -- to the original input. - -- - -- This rounds to nearest. We round to nearest *even* by considering the - -- candidates in decreasing order of the number of trailing zeros they - -- exhibit when cast back to the original integer type. - fixUlp :: forall a b w. (Num a, Integral a, FiniteBits a, RealFrac b, FloatingBits b w) - => a -> b -> b - fixUlp orig candidate = res where - res = closest $ sortBy moreLowBits [candidate, candidatem1, candidatep1] - candidatem1 = nextDown candidate - candidatep1 = nextUp candidate - closest = minimumBy (\ca cb -> err ca `compare` err cb) - err cand = absdiff orig (round cand) - absdiff a b = if a >= b then a - b else b - a - moreLowBits a b = - compare (0 - countTrailingZeros (round @b @a a)) - (0 - countTrailingZeros (round @b @a b)) - -peepholeExpr :: SExpr o -> EnvReaderM o (SExpr o) -peepholeExpr expr = case expr of - PrimOp op -> peepholeOp op - TabApp _ (Stuck _ (Var (AtomVar t _))) (IdxRepVal ord) -> - lookupAtomName t <&> \case - LetBound (DeclBinding ann (TabCon Nothing tabTy elems)) - | ann /= NoInlineLet && isFinTabTy tabTy-> - -- It is not safe to assume that this index can always be simplified! - -- For example, it might be coming from an unsafe_from_ordinal that is - -- under a case branch that would be dead for all invalid indices. - if 0 <= ord && fromIntegral ord < length elems - then Atom $ elems !! fromIntegral ord - else expr - _ -> expr - -- TODO: Apply a function to literals when it has a cheap body? - -- Think, partial evaluation of threefry. - _ -> return expr - where isFinTabTy = \case - TyCon (TabPi (TabPiType (DictCon (IxRawFin _)) _ _)) -> True - _ -> False - -- === Loop unrolling === unrollLoops :: EnvReader m => STopLam n -> m n (STopLam n) @@ -268,7 +95,7 @@ ulExpr expr = case expr of _ -> nothingSpecial where inc i = modify \(ULS n) -> ULS (n + i) - nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr) >>= emit + nothingSpecial = inc 1 >> visitGeneric expr >>= emit unrollBlowupThreshold = 12 withLocalAccounting m = do oldCost <- get diff --git a/src/lib/PeepholeOptimize.hs b/src/lib/PeepholeOptimize.hs new file mode 100644 index 000000000..8ec599acd --- /dev/null +++ b/src/lib/PeepholeOptimize.hs @@ -0,0 +1,276 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module PeepholeOptimize (PeepholeOpt (..), peepholeExpr) where + +import Data.Word +import Data.Bits +import Data.List +import Data.Bits.Floating +import GHC.Float + +import Types.Core +import Types.Primitives +import Name +import IRVariants +import qualified Types.OpNames as P + +peepholeExpr :: Expr r n -> Expr r n +peepholeExpr e = case peephole e of + Just x -> Atom x + Nothing -> e +{-# INLINE peepholeExpr #-} + +-- === Peephole optimization = undefined + +-- These are context-free (no env!) optimizations of expressions and ops that +-- are worth doing unconditionally. Builder calls this automatically in `emit`. + +class ToExpr e r => PeepholeOpt (e::E) (r::IR) | e -> r where + peephole :: e n -> Maybe (Atom r n) + +instance PeepholeOpt (Expr r) r where + peephole = \case + Atom x -> Just x + PrimOp op -> peephole op + Project _ i x -> case x of + Con con -> Just case con of + ProdCon xs -> xs !! i + DepPair l _ _ | i == 0 -> l + DepPair _ r _ | i == 1 -> r + _ -> error "not a product" + Stuck _ _ -> Nothing + Unwrap _ x -> case x of + Con con -> Just case con of + NewtypeCon _ x' -> x' + _ -> error "not a newtype" + Stuck _ _ -> Nothing + App _ _ _ -> Nothing + TabApp _ _ _ -> Nothing + Case _ _ _ -> Nothing + TopApp _ _ _ -> Nothing + Block _ _ -> Nothing + TabCon _ _ _ -> Nothing + ApplyMethod _ _ _ _ -> Nothing + {-# INLINE peephole #-} + +instance PeepholeOpt (PrimOp r) r where + peephole = \case + MiscOp op -> peephole op + BinOp op l r -> peepholeBinOp op l r + _ -> Nothing + {-# INLINE peephole #-} + +peepholeBinOp :: P.BinOp -> Atom r n -> Atom r n -> Maybe (Atom r n) +peepholeBinOp op x y = case op of + IAdd -> case (x, y) of + (Con (Lit x'), y') | getIntLit x' == 0 -> Just y' + (x', Con (Lit y')) | getIntLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (+) x' y' + _ -> Nothing + ISub -> case (x, y) of + (x', Con (Lit y')) | getIntLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (-) x' y' + _ -> Nothing + IMul -> case (x, y) of + (Con (Lit x'), y') | getIntLit x' == 1 -> Just y' + (x', Con (Lit y')) | getIntLit y' == 1 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (*) x' y' + _ -> Nothing + IDiv -> case (x, y) of + (x', Con (Lit y')) | getIntLit y' == 1 -> Just x' + _ -> Nothing + ICmp cop -> case (x, y) of + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntCmpOp (cmp cop) x' y' + _ -> Nothing + FAdd -> case (x, y) of + (Con (Lit x'), y') | getFloatLit x' == 0 -> Just y' + (x', Con (Lit y')) | getFloatLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (+) x' y' + _ -> Nothing + FSub -> case (x, y) of + (x', Con (Lit y')) | getFloatLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (-) x' y' + _ -> Nothing + FMul -> case (x, y) of + (Con (Lit x'), y') | getFloatLit x' == 1 -> Just y' + (x', Con (Lit y')) | getFloatLit y' == 1 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (*) x' y' + _ -> Nothing + FDiv -> case (x, y) of + (x', Con (Lit y')) | getFloatLit y' == 1 -> Just x' + _ -> Nothing + FCmp cop -> case (x, y) of + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatCmpOp (cmp cop) x' y' + _ -> Nothing + BOr -> case (x, y) of + (Con (Lit (Word8Lit x')), Con (Lit (Word8Lit y'))) -> Just $ Con $ Lit $ Word8Lit $ x' .|. y' + _ -> Nothing + BAnd -> case (x, y) of + (Con (Lit (Word8Lit lv)), Con (Lit (Word8Lit rv))) -> Just $ Con $ Lit $ Word8Lit $ lv .&. rv + _ -> Nothing + BXor -> Nothing -- TODO + BShL -> Nothing -- TODO + BShR -> Nothing -- TODO + IRem -> Nothing -- TODO + FPow -> Nothing -- TODO +{-# INLINE peepholeBinOp #-} + +instance PeepholeOpt (MiscOp r) r where + peephole = \case + CastOp (TyCon (BaseType (Scalar sTy))) (Con (Lit l)) -> case foldCast sTy l of + Just l' -> Just $ Con $ Lit l' + Nothing -> Nothing + ToEnum ty (Con (Lit (Word8Lit tag))) -> case ty of + TyCon (SumType cases) -> Just $ Con $ SumCon cases (fromIntegral tag) UnitVal + _ -> error "Ill typed ToEnum" + SumTag (Con (SumCon _ tag _)) -> Just $ Con $ Lit $ Word8Lit $ fromIntegral tag + Select p x y -> case p of + Con (Lit (Word8Lit p')) -> Just if p' /= 0 then x else y + _ -> Nothing + _ -> Nothing + +foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal +foldCast sTy l = case sTy of + -- TODO: Check that the casts relating to floating-point agree with the + -- runtime behavior. The runtime is given by the `ICastOp` case in + -- ImpToLLVM.hs. We should make sure that the Haskell functions here + -- produce bitwise identical results to those instructions, by adjusting + -- either this or that as called for. + -- TODO: Also implement casts that may have unrepresentable results, i.e., + -- casting floating-point numbers to smaller floating-point numbers or to + -- fixed-point. Both of these necessarily have a much smaller dynamic range. + Int32Type -> case l of + Int32Lit _ -> Just l + Int64Lit i -> Just $ Int32Lit $ fromIntegral i + Word8Lit i -> Just $ Int32Lit $ fromIntegral i + Word32Lit i -> Just $ Int32Lit $ fromIntegral i + Word64Lit i -> Just $ Int32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Int64Type -> case l of + Int32Lit i -> Just $ Int64Lit $ fromIntegral i + Int64Lit _ -> Just l + Word8Lit i -> Just $ Int64Lit $ fromIntegral i + Word32Lit i -> Just $ Int64Lit $ fromIntegral i + Word64Lit i -> Just $ Int64Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word8Type -> case l of + Int32Lit i -> Just $ Word8Lit $ fromIntegral i + Int64Lit i -> Just $ Word8Lit $ fromIntegral i + Word8Lit _ -> Just l + Word32Lit i -> Just $ Word8Lit $ fromIntegral i + Word64Lit i -> Just $ Word8Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word32Type -> case l of + Int32Lit i -> Just $ Word32Lit $ fromIntegral i + Int64Lit i -> Just $ Word32Lit $ fromIntegral i + Word8Lit i -> Just $ Word32Lit $ fromIntegral i + Word32Lit _ -> Just l + Word64Lit i -> Just $ Word32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word64Type -> case l of + Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) + Int64Lit i -> Just $ Word64Lit $ fromIntegral i + Word8Lit i -> Just $ Word64Lit $ fromIntegral i + Word32Lit i -> Just $ Word64Lit $ fromIntegral i + Word64Lit _ -> Just l + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float32Type -> case l of + Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float32Lit $ fromIntegral i + Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Float32Lit _ -> Just l + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float64Type -> case l of + Int32Lit i -> Just $ Float64Lit $ fromIntegral i + Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float64Lit $ fromIntegral i + Word32Lit i -> Just $ Float64Lit $ fromIntegral i + Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Float32Lit f -> Just $ Float64Lit $ float2Double f + Float64Lit _ -> Just l + PtrLit _ _ -> Nothing + where + -- When casting an integer type to a floating-point type of lower precision + -- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds + -- toward zero, instead of rounding to nearest even like everybody else. + -- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231. + -- + -- We patch this by manually checking the two adjacent floats to the + -- candidate answer, and using one of those if the reverse cast is closer + -- to the original input. + -- + -- This rounds to nearest. We round to nearest *even* by considering the + -- candidates in decreasing order of the number of trailing zeros they + -- exhibit when cast back to the original integer type. + fixUlp :: forall a b w. (Num a, Integral a, FiniteBits a, RealFrac b, FloatingBits b w) + => a -> b -> b + fixUlp orig candidate = res where + res = closest $ sortBy moreLowBits [candidate, candidatem1, candidatep1] + candidatem1 = nextDown candidate + candidatep1 = nextUp candidate + closest = minimumBy (\ca cb -> err ca `compare` err cb) + err cand = absdiff orig (round cand) + absdiff a b = if a >= b then a - b else b - a + moreLowBits a b = + compare (0 - countTrailingZeros (round @b @a a)) + (0 - countTrailingZeros (round @b @a b)) + +-- === Helpers for function evaluation over fixed-width types === + +applyIntBinOp :: (forall a. (Num a, Integral a) => a -> a -> a) -> LitVal -> LitVal -> LitVal +applyIntBinOp f x y = case (x, y) of + (Int64Lit x', Int64Lit y') -> Int64Lit $ f x' y' + (Int32Lit x', Int32Lit y') -> Int32Lit $ f x' y' + (Word8Lit x', Word8Lit y') -> Word8Lit $ f x' y' + (Word32Lit x', Word32Lit y') -> Word32Lit $ f x' y' + (Word64Lit x', Word64Lit y') -> Word64Lit $ f x' y' + _ -> error "Expected integer atoms" + +applyIntCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> LitVal -> LitVal -> LitVal +applyIntCmpOp f x y = boolLit case (x, y) of + (Int64Lit x', Int64Lit y') -> f x' y' + (Int32Lit x', Int32Lit y') -> f x' y' + (Word8Lit x', Word8Lit y') -> f x' y' + (Word32Lit x', Word32Lit y') -> f x' y' + (Word64Lit x', Word64Lit y') -> f x' y' + _ -> error "Expected integer atoms" + +applyFloatBinOp :: (forall a. (Num a, Fractional a) => a -> a -> a) -> LitVal -> LitVal -> LitVal +applyFloatBinOp f x y = case (x, y) of + (Float64Lit x', Float64Lit y') -> Float64Lit $ f x' y' + (Float32Lit x', Float32Lit y') -> Float32Lit $ f x' y' + _ -> error "Expected float atoms" + +applyFloatCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> LitVal -> LitVal -> LitVal +applyFloatCmpOp f x y = boolLit case (x, y) of + (Float64Lit x', Float64Lit y') -> f x' y' + (Float32Lit x', Float32Lit y') -> f x' y' + _ -> error "Expected float atoms" + +boolLit :: Bool -> LitVal +boolLit x = Word8Lit $ fromIntegral $ fromEnum x + +cmp :: Ord a => CmpOp -> a -> a -> Bool +cmp = \case + Less -> (<) + Greater -> (>) + Equal -> (==) + LessEqual -> (<=) + GreaterEqual -> (>=) diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 63fc869fd..af42c0c2f 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -83,7 +83,7 @@ showAnyTyCon tyCon atom = case tyCon of TypeKind -> printAsConstant Pi _ -> printTypeOnly "function" TabPi _ -> brackets $ forEachTabElt atom \iOrd x -> do - isFirst <- ieq iOrd (NatVal 0) + isFirst <- emit $ BinOp (ICmp Equal) iOrd (NatVal 0) void $ emitIf isFirst UnitTy (return UnitVal) (emitLit ", " >> return UnitVal) rec x NewtypeTyCon tc -> case tc of diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 53a5fe7f6..9997df23c 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -26,7 +26,6 @@ import IRVariants import Linearize import Name import Subst -import Optimize (peepholeExpr) import QueryType import RuntimePrint import Transpose @@ -720,29 +719,17 @@ simplifyOp op = case op of RefOp ref eff -> do ref' <- toDataAtom ref liftResult =<< simplifyRefOp eff ref' - BinOp binop x' y' -> do - x <- toDataAtom x' - y <- toDataAtom y' - liftResult =<< case binop of - ISub -> isub x y - IAdd -> iadd x y - IMul -> imul x y - IDiv -> idiv x y - ICmp Less -> ilt x y - ICmp Equal -> ieq x y - _ -> emit $ BinOp binop x y - UnOp unOp x' -> do - x <- toDataAtom x' - liftResult =<< emit (UnOp unOp x) + BinOp binop x y -> do + x' <- toDataAtom x + y' <- toDataAtom y + liftResult =<< emit (BinOp binop x' y') + UnOp unOp x -> do + x' <- toDataAtom x + liftResult =<< emit (UnOp unOp x') MiscOp op' -> case op' of - Select c' x' y' -> do - c <- toDataAtom c' - x <- toDataAtom x' - y <- toDataAtom y' - liftResult =<< select c x y - ShowAny x' -> do - x <- simplifyAtom x' - dropSubst $ showAny x >>= simplifyExpr + ShowAny x -> do + x' <- simplifyAtom x + dropSubst $ showAny x' >>= simplifyExpr _ -> simplifyGenericOp op' where liftResult x = do @@ -757,8 +744,7 @@ simplifyGenericOp simplifyGenericOp op = do ty <- substM $ getType op op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left") - result <- liftEnvReaderM (peepholeExpr $ toExpr op') >>= emit - liftSimpAtom ty result + liftSimpAtom ty =<< emit op' {-# INLINE simplifyGenericOp #-} pattern CoerceReconAbs :: Abs (Nest b) ReconstructAtom n From 9672aa3afe288ba110a6c37616fc047515e09f0b Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 9 Nov 2023 16:50:55 -0500 Subject: [PATCH 12/41] Fix autodiff using explicit linearity annotations and handle projections efficiently. --- src/lib/Algebra.hs | 2 +- src/lib/Builder.hs | 44 ++--- src/lib/Inline.hs | 3 + src/lib/Linearize.hs | 321 ++++++++++++++++++------------------ src/lib/PPrint.hs | 1 + src/lib/Transpose.hs | 194 ++++++++-------------- src/lib/Types/Primitives.hs | 1 + 7 files changed, 257 insertions(+), 309 deletions(-) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index b3d6d2502..5ecc05f76 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -18,7 +18,7 @@ import Data.Text.Prettyprint.Doc import Data.List (intersperse) import Data.Tuple (swap) -import Builder hiding (sub, add, mul) +import Builder import Core import CheapReduction import Err diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index b65b1414d..0bb8e8fec 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -847,6 +847,12 @@ buildRememberDest hint dest cont = do -- === vector space (ish) type class === +emitLin :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emitLin e = case toExpr e of + Atom x -> return x + expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr +{-# INLINE emitLin #-} + zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) @@ -930,14 +936,17 @@ symbolicTangentNonZero val = do -- === builder versions of common local ops === -neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -neg x = emit $ UnOp FNeg x +fadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fadd x y = emit $ BinOp FAdd x y -add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -add x y = emit $ BinOp FAdd x y +fsub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fsub x y = emit $ BinOp FSub x y -mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -mul x y = emit $ BinOp FMul x y +fmul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fmul x y = emit $ BinOp FMul x y + +fdiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fdiv x y = emit $ BinOp FDiv x y iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) iadd x y = emit $ BinOp IAdd x y @@ -945,22 +954,10 @@ iadd x y = emit $ BinOp IAdd x y imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) imul x y = emit $ BinOp IMul x y -div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -div' x y = emit $ BinOp FDiv x y - -fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -fpow x y = emit $ BinOp FPow x y - -sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emit $ BinOp FSub x y - -flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -flog x = emit $ UnOp Log x - -fLitLike :: (SBuilder m, Emits n) => Double -> SAtom n -> m n (SAtom n) +fLitLike :: Double -> SAtom n -> SAtom n fLitLike x t = case getTyCon t of - BaseType (Scalar Float64Type) -> return $ toAtom $ Lit $ Float64Lit x - BaseType (Scalar Float32Type) -> return $ toAtom $ Lit $ Float32Lit $ realToFrac x + BaseType (Scalar Float64Type) -> toAtom $ Lit $ Float64Lit x + BaseType (Scalar Float32Type) -> toAtom $ Lit $ Float32Lit $ realToFrac x _ -> error "Expected a floating point scalar" fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) @@ -1085,6 +1082,11 @@ mkTabApp xs ixs = do ty <- typeOfTabApp (getType xs) ixs return $ TabApp ty xs ixs +mkProject :: (EnvReader m, IRRep r) => Int -> Atom r n -> m n (Expr r n) +mkProject i x = do + ty <- projType i x + return $ Project ty i x + mkTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (SExpr n) mkTopApp f xs = do resultTy <- typeOfTopApp f xs diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 4104948c0..f3ada792d 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -98,12 +98,14 @@ inlineDeclsSubst = \case inlineDeclsSubst rest where dropOccInfo PlainLet = PlainLet + dropOccInfo LinearLet = LinearLet dropOccInfo InlineLet = InlineLet dropOccInfo NoInlineLet = NoInlineLet dropOccInfo (OccInfoPure _) = PlainLet dropOccInfo (OccInfoImpure _) = PlainLet resolveWorkConservation PlainLet _ = NoInline -- No occurrence info, assume the worst + resolveWorkConservation LinearLet _ = NoInline resolveWorkConservation InlineLet _ = NoInline resolveWorkConservation NoInlineLet _ = NoInline -- Quick hack to always unconditionally inline renames, until we get @@ -176,6 +178,7 @@ preInlineUnconditionally = \case PlainLet -> False -- "Missing occurrence annotation" InlineLet -> True NoInlineLet -> False + LinearLet -> False OccInfoPure (UsageInfo s (0, d)) | s <= One && d <= One -> True OccInfoPure _ -> False OccInfoImpure _ -> False diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index d4ca0417e..870548364 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -16,6 +16,7 @@ import GHC.Stack import Builder import Core +import CheapReduction import Imp import IRVariants import MTL1 @@ -26,7 +27,7 @@ import PPrint import QueryType import Types.Core import Types.Primitives -import Util (bindM2, enumerate) +import Util (enumerate) -- === linearization monad === @@ -93,48 +94,39 @@ extendTangentArgss vs' m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ vs') getTangentArgs :: TangentM o (TangentArgs o) getTangentArgs = ask -bindLin - :: Emits o - => LinM i o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> LinM i o e' e' -bindLin m f = do - result <- m - withBoth result f - -withBoth - :: Emits o +emitBoth + :: (Emits o, ToExpr e' SimpIR) => WithTangent o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> PrimalM i o (WithTangent o e' e') -withBoth (WithTangent x tx) f = do + -> (forall o' m. (DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) + -> LinM i o SAtom SAtom +emitBoth (WithTangent x tx) f = do Distinct <- getDistinct - y <- f x - return $ WithTangent y do - tx >>= f + x' <- emit =<< f x + return $ WithTangent x' do + tx' <- tx + emitLin =<< f tx' -_withTangentComputation - :: Emits o - => WithTangent o e1 e2 - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e2 o' -> m o' (e2' o')) - -> PrimalM i o (WithTangent o e1 e2') -_withTangentComputation (WithTangent x tx) f = do - Distinct <- getDistinct - return $ WithTangent x do - tx >>= f +emitZeroT :: (Emits o, HasNamesE e', ToExpr e' SimpIR) => e' i -> LinM i o SAtom SAtom +emitZeroT e = do + x <- emit =<< renameM e + return $ WithTangent x (zeroLikeT x) + +zeroLikeT :: (DExt o o', Emits o', HasType SimpIR e) => e o -> TangentM o' (SAtom o') +zeroLikeT x = do + ty <- sinkM $ getType x + zeroAt =<< tangentType ty fmapLin :: Emits o => (forall o'. e o' -> e' o') -> LinM i o e e -> LinM i o e' e' -fmapLin f m = m `bindLin` (pure . f) +fmapLin f m = do + WithTangent ans tx <- m + return $ WithTangent (f ans) (f <$> tx) -zipLin :: LinM i o e1 e1 -> LinM i o e2 e2 -> LinM i o (PairE e1 e2) (PairE e1 e2) -zipLin m1 m2 = do - WithTangent x1 t1 <- m1 - WithTangent x2 t2 <- m2 - return $ WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 +zipLin :: WithTangent o e1 e1 -> WithTangent o e2 e2 -> WithTangent o (PairE e1 e2) (PairE e1 e2) +zipLin (WithTangent x1 t1) (WithTangent x2 t2) = WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 seqLin :: Traversable f @@ -325,19 +317,28 @@ linearizeLambdaApp _ _ = error "not implemented" linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom linearizeAtom (Con con) = linearizePrimCon con -linearizeAtom atom@(Stuck _ stuck) = case stuck of - PtrVar _ _ -> emitZeroT +linearizeAtom (Stuck _ stuck) = linearizeStuck stuck + +linearizeStuck :: Emits o => Stuck SimpIR i -> LinM i o SAtom SAtom +linearizeStuck stuck = case stuck of Var v -> do v' <- renameM v activePrimalIdx v' >>= \case - Nothing -> withZeroT $ return (toAtom v') + Nothing -> zero Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx - -- TODO: buildScoped and reduce the results so we keep expression in non-ANF for type checking purposes - StuckProject _ _ -> undefined - StuckTabApp _ _ -> undefined - RepValAtom _ -> emitZeroT - where emitZeroT = withZeroT $ renameM atom - + PtrVar _ _ -> zero + RepValAtom _ -> zero + -- TODO: de-dup with the Expr versions of these + StuckProject i x -> do + x' <- linearizeStuck x + emitBoth x' \x'' -> mkProject i x'' + StuckTabApp x i -> do + pt <- zipLin <$> linearizeStuck x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' + where + zero = do + atom <- mkStuck =<< renameM stuck + return $ WithTangent atom (zeroLikeT atom) linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 linearizeDecls Empty cont = cont @@ -388,7 +389,6 @@ linearizeExpr expr = case expr of where unitLike :: e n -> UnitE n unitLike _ = UnitE - TabApp _ x i -> zipLin (linearizeAtom x) (pureLin i) `bindLin` \(PairE x' i') -> tabApp x' i' PrimOp op -> linearizeOp op Case e alts (EffTy effs resultTy) -> do e' <- renameM e @@ -417,49 +417,54 @@ linearizeExpr expr = case expr of applyLinLam linLam TabCon _ ty xs -> do ty' <- renameM ty - seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> - emit $ TabCon Nothing (sink ty') xs' + pt <- seqLin (map linearizeAtom xs) + emitBoth pt \(ComposeE xs') -> return $ TabCon Nothing (sink ty') xs' + TabApp _ x i -> do + pt <- zipLin <$> linearizeAtom x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' Project _ i x -> do - WithTangent x' tx <- linearizeAtom x - xi <- proj i x' - return $ WithTangent xi do - t <- tx - proj i t + x' <- linearizeAtom x + emitBoth x' \x'' -> mkProject i x'' linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of Hof (TypedHof _ e) -> linearizeHof e DAMOp _ -> error "shouldn't occur here" - RefOp ref m -> case m of - MAsk -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MAsk - MExtend monoid x -> do - -- TODO: check that we're dealing with a +/0 monoid - monoid' <- renameM monoid - zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - emit $ RefOp ref' $ MExtend (sink monoid') x' - MGet -> linearizeAtom ref `bindLin` \ref' -> emit $ RefOp ref' MGet - MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - emit $ RefOp ref' $ MPut x' - IndexRef _ i -> do - zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> - emit =<< mkIndexRef ref' i' - ProjRef _ i -> la ref `bindLin` \ref' -> emit =<< mkProjRef ref' i + RefOp ref m -> do + ref' <- linearizeAtom ref + case m of + MAsk -> emitBoth ref' \ref'' -> return $ RefOp ref'' MAsk + MExtend monoid x -> do + -- TODO: check that we're dealing with a +/0 monoid + monoid' <- renameM monoid + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> + return $ RefOp ref'' $ MExtend (sink monoid') x'' + MGet -> emitBoth ref' \ref'' -> return $ RefOp ref'' MGet + MPut x -> do + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> return $ RefOp ref'' $ MPut x'' + IndexRef _ i -> do + i' <- pureLin i + emitBoth (zipLin ref' i') \(PairE ref'' i'') -> mkIndexRef ref'' i'' + ProjRef _ i -> emitBoth ref' \ref'' -> mkProjRef ref'' i UnOp uop x -> linearizeUnOp uop x BinOp bop x y -> linearizeBinOp bop x y -- XXX: This assumes that pointers are always constants - MemOp _ -> emitZeroT + MemOp _ -> emitZeroT op MiscOp miscOp -> linearizeMiscOp miscOp VectorOp _ -> error "not implemented" - where - emitZeroT = withZeroT $ emit =<< renameM (PrimOp op) - la = linearizeAtom linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom linearizeMiscOp op = case op of - SumTag _ -> emitZeroT - ToEnum _ _ -> emitZeroT - Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin` - \(p' `PairE` t' `PairE` f') -> emit $ MiscOp $ Select p' t' f' + SumTag _ -> zero + ToEnum _ _ -> zero + Select p t f -> do + p' <- pureLin p + t' <- linearizeAtom t + f' <- linearizeAtom f + emitBoth (p' `zipLin` t' `zipLin` f') + \(p'' `PairE` t'' `PairE` f'') -> return $ Select p'' t'' f'' CastOp t v -> do vt <- getType <$> renameM v t' <- renameM t @@ -468,92 +473,105 @@ linearizeMiscOp op = case op of ((&&) <$> (vtTangentType `alphaEq` vt) <*> (tTangentType `alphaEq` t')) >>= \case True -> do - linearizeAtom v `bindLin` \v' -> emit $ MiscOp $ CastOp (sink t') v' + v' <- linearizeAtom v + emitBoth v' \v'' -> return $ CastOp (sink t') v'' False -> do WithTangent x xt <- linearizeAtom v yt <- case (vtTangentType, tTangentType) of (_ , UnitTy) -> return $ UnitVal (UnitTy, tt ) -> zeroAt tt _ -> error "Expected at least one side of the CastOp to have a trivial tangent type" - y <- emit $ MiscOp $ CastOp t' x + y <- emit $ CastOp t' x return $ WithTangent y do xt >> return (sink yt) BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented ThrowException _ -> notImplemented - ThrowError _ -> emitZeroT - OutputStream -> emitZeroT + ThrowError _ -> zero + OutputStream -> zero ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - emitZeroT = withZeroT $ emit =<< renameM (PrimOp $ MiscOp op) - la = linearizeAtom + where zero = emitZeroT op linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom -linearizeUnOp op x' = do - WithTangent x tx <- linearizeAtom x' - let emitZeroT = withZeroT $ emit $ UnOp op x - case op of - Exp -> do - y <- emit $ UnOp Exp x - return $ WithTangent y (bindM2 mul tx (sinkM y)) - Exp2 -> notImplemented - Log -> withT (emit $ UnOp Log x) $ (tx >>= (`div'` sink x)) - Log2 -> notImplemented - Log10 -> notImplemented - Log1p -> notImplemented - Sin -> withT (emit $ UnOp Sin x) $ bindM2 mul tx (emit $ UnOp Cos (sink x)) - Cos -> withT (emit $ UnOp Cos x) $ bindM2 mul tx (neg =<< emit (UnOp Sin (sink x))) - Tan -> notImplemented - Sqrt -> do - y <- emit $ UnOp Sqrt x - return $ WithTangent y do - denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y) - bindM2 div' tx (pure denominator) - Floor -> emitZeroT - Ceil -> emitZeroT - Round -> emitZeroT - LGamma -> notImplemented - Erf -> notImplemented - Erfc -> notImplemented - FNeg -> withT (neg x) (neg =<< tx) - BNot -> emitZeroT +linearizeUnOp op x'' = do + WithTangent x' tx' <- linearizeAtom x'' + ans' <- emit $ UnOp op x' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- sinkM x' + tx <- tx' + let zero = zeroLikeT ans + case op of + Exp -> emitLin $ BinOp FMul tx ans + Exp2 -> notImplemented + Log -> emitLin $ BinOp FDiv tx x + Log2 -> notImplemented + Log10 -> notImplemented + Log1p -> notImplemented + Sin -> do + c <- emit $ UnOp Cos x + emitLin $ BinOp FMul tx c + Cos -> do + c <- emit =<< (UnOp FNeg <$> emit (UnOp Sin x)) + emitLin $ BinOp FMul tx c + Tan -> notImplemented + Sqrt -> do + denominator <- fmul (2 `fLitLike` ans) ans + emitLin $ BinOp FDiv tx denominator + Floor -> zero + Ceil -> zero + Round -> zero + LGamma -> notImplemented + Erf -> notImplemented + Erfc -> notImplemented + FNeg -> emitLin $ UnOp FNeg tx + BNot -> zero linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom -linearizeBinOp op x' y' = do - WithTangent x tx <- linearizeAtom x' - WithTangent y ty <- linearizeAtom y' - let emitZeroT = withZeroT $ emit $ BinOp op x y - case op of - IAdd -> emitZeroT - ISub -> emitZeroT - IMul -> emitZeroT - IDiv -> emitZeroT - IRem -> emitZeroT - ICmp _ -> emitZeroT - FAdd -> withT (add x y) (bindM2 add tx ty) - FSub -> withT (sub x y) (bindM2 sub tx ty) - FMul -> withT (mul x y) - (bindM2 add (bindM2 mul (referToPrimal x) ty) - (bindM2 mul tx (referToPrimal y))) - FDiv -> withT (div' x y) do - tx' <- bindM2 div' tx (referToPrimal y) - ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty) - (bindM2 mul (referToPrimal y) (referToPrimal y)) - sub tx' ty' - FPow -> withT (emit $ BinOp FPow x y) do - px <- referToPrimal x - py <- referToPrimal y - c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px - tx' <- bindM2 mul tx (return py) - ty' <- bindM2 mul (bindM2 mul (return px) ty) (flog px) - mul c =<< add tx' ty' - FCmp _ -> emitZeroT - BAnd -> emitZeroT - BOr -> emitZeroT - BXor -> emitZeroT - BShL -> emitZeroT - BShR -> emitZeroT +linearizeBinOp op x'' y'' = do + WithTangent x' tx' <- linearizeAtom x'' + WithTangent y' ty' <- linearizeAtom y'' + ans' <- emit $ BinOp op x' y' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- referToPrimal x' + y <- referToPrimal y' + tx <- tx' + ty <- ty' + let zero = zeroLikeT ans + case op of + IAdd -> zero + ISub -> zero + IMul -> zero + IDiv -> zero + IRem -> zero + ICmp _ -> zero + FAdd -> emitLin $ BinOp FAdd tx ty + FSub -> emitLin $ BinOp FSub tx ty + FMul -> do + t1 <- emitLin $ BinOp FMul ty x + t2 <- emitLin $ BinOp FMul tx y + emitLin $ BinOp FAdd t1 t2 + FDiv -> do + t1 <- emitLin $ BinOp FDiv tx y + xyy <- fdiv x =<< fmul y y + t2 <- emitLin $ BinOp FMul ty xyy + emitLin $ BinOp FSub t1 t2 + FPow -> do + ym1 <- fsub y (1.0 `fLitLike` y) + xpowym1 <- emit $ BinOp FPow x ym1 + xlogx <- fmul x =<< emit (UnOp Log x) + t1 <- emitLin $ BinOp FMul tx y + t2 <- emitLin $ BinOp FMul ty xlogx + t12 <- emitLin $ BinOp FAdd t1 t2 + emitLin $ BinOp FMul xpowym1 t12 + FCmp _ -> zero + BAnd -> zero + BOr -> zero + BXor -> zero + BShL -> zero + BShR -> zero -- This has the same type as `sinkM` and falls back thereto, but recomputes -- indexing a primal array in the tangent to avoid materializing intermediate @@ -575,12 +593,12 @@ referToPrimal x = do linearizePrimCon :: Emits o => Con SimpIR i -> LinM i o SAtom SAtom linearizePrimCon con = case con of - Lit _ -> emitZeroT + Lit _ -> zero ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs) SumCon _ _ _ -> notImplemented - HeapVal -> emitZeroT + HeapVal -> zero DepPair _ _ _ -> notImplemented - where emitZeroT = withZeroT $ renameM $ Con con + where zero = emitZeroT con linearizeHof :: Emits o => Hof SimpIR i -> LinM i o SAtom SAtom linearizeHof hof = case hof of @@ -672,21 +690,6 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do return (BinaryLamExpr h b body', linLam') linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" -withT :: PrimalM i o (e1 o) - -> (forall o'. (Emits o', DExt o o') => TangentM o' (e2 o')) - -> PrimalM i o (WithTangent o e1 e2) -withT p t = do - p' <- p - return $ WithTangent p' t - -withZeroT :: PrimalM i o (Atom SimpIR o) - -> PrimalM i o (WithTangent o SAtom SAtom) -withZeroT p = do - p' <- p - return $ WithTangent p' do - pTy <- return $ getType $ sink p' - zeroAt =<< tangentType pTy - notImplemented :: HasCallStack => a notImplemented = error "Not implemented" diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 241894483..af1c48f1e 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -970,6 +970,7 @@ instance Pretty LetAnn where PlainLet -> "" InlineLet -> "%inline" NoInlineLet -> "%noinline" + LinearLet -> "%linear" OccInfoPure u -> p u <> line OccInfoImpure u -> p u <> ", impure" <> line diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index e312de43b..302ca9e4f 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -9,8 +9,6 @@ module Transpose (transpose, transposeTopFun) where import Data.Foldable import Data.Functor import Control.Category ((>>>)) -import Control.Monad.Reader -import qualified Data.Set as S import GHC.Stack import Builder @@ -18,7 +16,6 @@ import Core import Err import Imp import IRVariants -import MTL1 import Name import Subst import QueryType @@ -37,7 +34,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do {-# SCC transpose #-} runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a -runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont +runTransposeM cont = runSubstReaderT idSubst $ cont transposeTopFun :: (MonadFail1 m, EnvReader m) @@ -73,20 +70,15 @@ unpackLinearLamExpr lam@(LamExpr bs body) = do -- === transposition monad === +type AtomTransposeSubstVal = TransposeSubstVal (AtomNameC SimpIR) data TransposeSubstVal c n where RenameNonlin :: Name c n -> TransposeSubstVal c n -- accumulator references corresponding to non-ref linear variables - LinRef :: SAtom n -> TransposeSubstVal (AtomNameC SimpIR) n + LinRef :: SAtom n -> AtomTransposeSubstVal n -- as an optimization, we don't make references for trivial vector spaces - LinTrivial :: TransposeSubstVal (AtomNameC SimpIR) n + LinTrivial :: AtomTransposeSubstVal n -type LinRegions = ListE SAtomVar - -type TransposeM a = SubstReaderT TransposeSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a - -type TransposeM' a = SubstReaderT AtomSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a +type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a -- TODO: it might make sense to replace substNonlin/isLin -- with a single `trySubtNonlin :: e i -> Maybe (e o)`. @@ -99,30 +91,6 @@ substNonlin e = do RenameNonlin v' -> v' _ -> error "not a nonlinear expression") e --- TODO: Can we generalize onNonLin to accept SubstReaderT Name instead of --- SubstReaderT AtomSubstVal? For that to work, we need another combinator, --- that lifts a SubstReader AtomSubstVal into a SubstReader Name, because --- effectsSubstE is currently typed as SubstReader AtomSubstVal. --- Then we can presumably recode substNonlin as `onNonLin substM`. We may --- be able to do that anyway, except we will then need to restrict the type --- of substNonlin to require `SubstE AtomSubstVal e`; but that may be fine. -onNonLin :: HasCallStack - => TransposeM' i o a -> TransposeM i o a -onNonLin cont = do - subst <- getSubst - let subst' = newSubst (\v -> case subst ! v of - RenameNonlin v' -> Rename v' - _ -> error "not a nonlinear expression") - liftSubstReaderT $ runSubstReaderT subst' cont - -isLin :: HoistableE e => e i -> TransposeM i o Bool -isLin e = do - substVals <- mapM lookupSubstM $ freeAtomVarsList @SimpIR e - return $ flip any substVals \case - LinTrivial -> True - LinRef _ -> True - RenameNonlin _ -> False - withAccumulator :: Emits o => SType o @@ -147,43 +115,42 @@ emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) void $ emit $ RefOp ref $ MExtend baseMonoid ct -getLinRegions :: TransposeM i o [SAtomVar o] -getLinRegions = asks fromListE - -extendLinRegions :: SAtomVar o -> TransposeM i o a -> TransposeM i o a -extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont - -- === actual pass === -transposeWithDecls :: Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () +transposeWithDecls :: forall i i' o. Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () transposeWithDecls Empty atom ct = transposeExpr atom ct -transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct = - substExprIfNonlin expr >>= \case - Nothing -> do - ty' <- substNonlin $ getType expr - ctExpr <- withAccumulator ty' \refSubstVal -> - extendSubst (b @> refSubstVal) $ - transposeWithDecls rest result (sink ct) - transposeExpr expr ctExpr - Just nonlinExpr -> do - v <- emitToVar nonlinExpr - extendSubst (b @> RenameNonlin (atomVarName v)) $ - transposeWithDecls rest result ct - -substExprIfNonlin :: SExpr i -> TransposeM i o (Maybe (SExpr o)) -substExprIfNonlin expr = - isLin expr >>= \case - True -> return Nothing - False -> do - onNonLin (substM $ getEffects expr) >>= isLinEff >>= \case - True -> return Nothing - False -> Just <$> substNonlin expr +transposeWithDecls (Nest (Let b (DeclBinding ann expr)) rest) result ct = case ann of + LinearLet -> do + ty' <- substNonlin $ getType expr + case expr of + Project _ i x -> do + continue =<< projectLinearRef x \ref -> emitLin =<< mkProjRef ref (ProjectProduct i) + TabApp _ x i -> do + continue =<< projectLinearRef x \ref -> do + i' <- substNonlin i + emitLin =<< mkIndexRef ref i' + _ -> do + ctExpr <- withAccumulator ty' \refSubstVal -> continue refSubstVal + transposeExpr expr ctExpr + _ -> do + v <- substNonlin expr >>= emitToVar + continue $ RenameNonlin (atomVarName v) + where + continue :: forall o'. (Emits o', Ext o o') => AtomTransposeSubstVal o' -> TransposeM i o' () + continue substVal = do + ct' <- sinkM ct + extendSubst (b @> substVal) $ transposeWithDecls rest result ct' -isLinEff :: EffectRow SimpIR o -> TransposeM i o Bool -isLinEff effs@(EffectRow _ NoTail) = do - regions <- fmap atomVarName <$> getLinRegions - let effRegions = freeAtomVarsList effs - return $ not $ null $ S.fromList effRegions `S.intersection` S.fromList regions +projectLinearRef + :: Emits o + => SAtom i -> (SAtom o -> TransposeM i o (SAtom o)) + -> TransposeM i o (AtomTransposeSubstVal o) +projectLinearRef x f = do + Stuck _ (Var v) <- return x + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> error "nonlinear" + LinRef ref -> LinRef <$> f ref + LinTrivial -> return LinTrivial getTransposedTopFun :: EnvReader m => TopFunName n -> m n (Maybe (TopFunName n)) getTransposedTopFun f = do @@ -200,44 +167,23 @@ transposeExpr expr ct = case expr of xsNonlin' <- mapM substNonlin xsNonlin ct' <- naryTopApp fT (xsNonlin' ++ [ct]) transposeAtom xLin ct' - -- TODO: Instead, should we handle table application like nonlinear - -- expressions, where we just project the reference? - TabApp _ x i -> do - i' <- substNonlin i - case x of - Stuck _ stuck -> case stuck of - Var v -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "shouldn't happen" - LinRef ref -> do - refProj <- indexRef ref i' - emitCTToRef refProj ct - LinTrivial -> return () - StuckProject _ _ -> undefined - StuckTabApp _ _ -> undefined - PtrVar _ _ -> error "not tangent" - RepValAtom _ -> error "not tangent" - _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do - linearScrutinee <- isLin e - case linearScrutinee of - True -> notImplemented - False -> do - e' <- substNonlin e - void $ buildCase e' UnitTy \i v -> do - v' <- emitToVar v - Abs b body <- return $ alts !! i - extendSubst (b @> RenameNonlin (atomVarName v')) do - transposeExpr body (sink ct) - return UnitVal + e' <- substNonlin e + void $ buildCase e' UnitTy \i v -> do + v' <- emitToVar v + Abs b body <- return $ alts !! i + extendSubst (b @> RenameNonlin (atomVarName v')) do + transposeExpr body (sink ct) + return UnitVal TabCon _ ty es -> do TabTy d b _ <- return ty idxTy <- substNonlin $ IxType (binderType b) d forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e - Project _ _ _ -> undefined + TabApp _ _ _ -> error "should have been handled by reference projection" + Project _ _ _ -> error "should have been handled by reference projection" transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o () transposeOp op ct = case op of @@ -262,18 +208,21 @@ transposeOp op ct = case op of ProjRef _ _ -> notImplemented Hof (TypedHof _ hof) -> transposeHof hof ct MiscOp miscOp -> transposeMiscOp miscOp ct - UnOp FNeg x -> transposeAtom x =<< neg ct + UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct) UnOp _ _ -> notLinear BinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct - BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< neg ct) + BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< (emitLin $ UnOp FNeg ct)) + -- XXX: linear argument to FMul is always first BinOp FMul x y -> do - xLin <- isLin x - if xLin - then transposeAtom x =<< mul ct =<< substNonlin y - else transposeAtom y =<< mul ct =<< substNonlin x - BinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y + y' <- substNonlin y + tx <- emitLin $ BinOp FMul ct y' + transposeAtom x tx + BinOp FDiv x y -> do + y' <- substNonlin y + tx <- emitLin $ BinOp FDiv ct y' + transposeAtom x tx BinOp _ _ _ -> notLinear - MemOp _ -> notLinear + MemOp _ -> notLinear VectorOp _ -> unreachable where notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op @@ -291,10 +240,9 @@ transposeMiscOp op _ = case op of BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented - ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" - ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - notLinear = error $ "Can't transpose a non-linear operation: " ++ show op + ShowAny _ -> notLinear + ShowScalar _ -> notLinear + where notLinear = error $ "Can't transpose a non-linear operation: " ++ show op transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o () transposeAtom atom ct = case atom of @@ -308,16 +256,9 @@ transposeAtom atom ct = case atom of return () LinRef ref -> emitCTToRef ref ct LinTrivial -> return () - StuckProject _ _ -> error "not implemented" - StuckTabApp _ _ -> error "not implemented" - -- let (idxs, v) = asNaryProj i' x' - -- lookupSubstM (atomVarName v) >>= \case - -- RenameNonlin _ -> error "an error, probably" - -- LinRef ref -> do - -- ref' <- applyProjectionsRef (toList idxs) ref - -- emitCTToRef ref' ct - -- LinTrivial -> return () - RepValAtom _ -> error "not implemented" + StuckProject _ _ -> error "not linear" + StuckTabApp _ _ -> error "not linear" + RepValAtom _ -> error "not linear" where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o () @@ -333,8 +274,7 @@ transposeHof hof ct = case hof of (ctBody, ctState) <- fromPair ct (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal transposeAtom s cts RunReader r (BinaryLamExpr hB refB body) -> do @@ -342,8 +282,7 @@ transposeHof hof ct = case hof of baseMonoid <- tangentBaseMonoidFor accumTy (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ct) + transposeExpr body (sink ct) return UnitVal transposeAtom r ct' RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do @@ -351,8 +290,7 @@ transposeHof hof ct = case hof of (ctBody, ctEff) <- fromPair ct void $ emitRunReader noHint ctEff \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeExpr body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal _ -> notImplemented diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index d85d88247..83ba3ffbe 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -64,6 +64,7 @@ data LetAnn = | InlineLet -- Binding explicitly tagged "do not inline" | NoInlineLet + | LinearLet -- Bound expression is pure, and the binding's occurrences are summarized by -- the UsageInfo | OccInfoPure UsageInfo From 1b2d252b45592476d32d1e91d69b5def8150d834 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 9 Nov 2023 21:41:14 -0500 Subject: [PATCH 13/41] Add some missing linearity annotations. We really need to build the linearity checker. --- src/lib/Builder.hs | 19 +++++++++++-------- src/lib/Inference.hs | 2 +- src/lib/Linearize.hs | 18 +++++++++--------- src/lib/Simplify.hs | 2 +- src/lib/Transpose.hs | 18 ++++++------------ 5 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 0bb8e8fec..caebe262f 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -765,23 +765,23 @@ mkTypedHof hof = do effTy <- effTyOfHof hof return $ TypedHof effTy hof -buildForAnn - :: (Emits n, ScopableBuilder r m) +mkFor + :: (ScopableBuilder r m) => NameHint -> ForAnn -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -buildForAnn hint ann (IxType iTy ixDict) body = do + -> m n (Expr r n) +mkFor hint ann (IxType iTy ixDict) body = do lam <- withFreshBinder hint iTy \b -> do let v = binderVar b body' <- buildBlock $ body $ sink v return $ LamExpr (UnaryNest b) body' - emitHof $ For ann (IxType iTy ixDict) lam + liftM toExpr $ mkTypedHof $ For ann (IxType iTy ixDict) lam buildFor :: (Emits n, ScopableBuilder r m) => NameHint -> Direction -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) -> m n (Atom r n) -buildFor hint dir ty body = buildForAnn hint dir ty body +buildFor hint ann ty body = mkFor hint ann ty body >>= emit buildMap :: (Emits n, ScopableBuilder SimpIR m) => SAtom n @@ -853,6 +853,10 @@ emitLin e = case toExpr e of expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr {-# INLINE emitLin #-} +emitHofLin :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHofLin hof = mkTypedHof hof >>= emitLin +{-# INLINE emitHofLin #-} + zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) @@ -1100,9 +1104,8 @@ mkApplyMethod d i xs = do mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n) mkInstanceDict instanceName args = do instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName - sourceName <- getSourceName <$> lookupClassDef className PairE (ListE params) _ <- instantiate instanceDef args - let ty = toType $ DictType sourceName className params + ty <- toType <$> dictType className params return $ toDict $ InstanceDict ty instanceName args mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index b9253952a..48a672c88 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -2013,7 +2013,7 @@ generalizeDict ty dict = do result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict case result of Failure e -> error $ "Failed to generalize " ++ pprint dict - ++ " to " ++ pprint ty ++ " because " ++ pprint e + ++ " to " ++ show ty ++ " because " ++ pprint e Success ans -> return ans generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 870548364..98bbb7d39 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -270,7 +270,7 @@ applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom applyLinLam (LamExpr bs body) = do TangentArgs args <- liftSubstReaderT $ getTangentArgs extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do - substM body >>= emit + substM body >>= emitLin -- === actual linearization passs === @@ -299,7 +299,7 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts - emit =<< applySubst substFrag tangentBody + emitLin =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun @@ -358,7 +358,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do WithTangent pRest tfRest <- linearizeDecls rest cont return $ WithTangent pRest do t <- tf - vt <- emitDecl (getNameHint b) ann (Atom t) + vt <- emitDecl (getNameHint b) LinearLet (Atom t) extendTangentArgs vt $ tfRest @@ -410,7 +410,7 @@ linearizeExpr expr = case expr of (primal, residualss) <- fromPair result resultTangentType <- tangentType resultTy' return $ WithTangent primal do - buildCase (sink residualss) (sink resultTangentType) \i residuals -> do + emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \i residuals -> do ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i residuals' <- unpackTelescope bs residuals withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do @@ -613,13 +613,13 @@ linearizeHof hof = case hof of TrivialRecon linLam' -> return $ WithTangent primalsAux do Abs ib'' linLam'' <- sinkM (Abs ib' linLam') - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@>Rename (atomVarName i')) $ applyLinLam linLam'' ReconWithData reconAbs -> do primals <- buildMap primalsAux getFst return $ WithTangent primals do Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs) - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@> Rename (atomVarName i')) do residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs extendSubst (bs @@> (SubstVal <$> residuals')) $ @@ -636,7 +636,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunReader rLin' tanEffLam + emitHofLin $ RunReader rLin' tanEffLam RunState Nothing sInit lam -> do WithTangent sInit' sLin <- linearizeAtom sInit (lam', recon) <- linearizeEffectFun State lam @@ -649,7 +649,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunState Nothing sLin' tanEffLam + emitHofLin $ RunState Nothing sLin' tanEffLam RunWriter Nothing bm lam -> do -- TODO: check it's actually the 0/+ monoid (or should we just build that in?) bm' <- renameM bm @@ -663,7 +663,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunWriter Nothing bm'' tanEffLam + emitHofLin $ RunWriter Nothing bm'' tanEffLam RunIO body -> do (body', recon) <- linearizeExprDefunc body primalAux <- emitHof $ RunIO body' diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 9997df23c..129039bdb 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -1056,7 +1056,7 @@ exceptToMaybeExpr expr = case expr of return $ JustAtom ty x' PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' - maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do + maybes <- buildFor (getNameHint b) ann ixTy \i -> do extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 302ca9e4f..10c87d377 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -80,16 +80,12 @@ data TransposeSubstVal c n where type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a --- TODO: it might make sense to replace substNonlin/isLin --- with a single `trySubtNonlin :: e i -> Maybe (e o)`. --- But for that we need a way to traverse names, like a monadic --- version of `substE`. -substNonlin :: (SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) +substNonlin :: (PrettyE e, SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) substNonlin e = do subst <- getSubst fmapRenamingM (\v -> case subst ! v of RenameNonlin v' -> v' - _ -> error "not a nonlinear expression") e + _ -> error $ "not a nonlinear expression: " ++ pprint e) e withAccumulator :: Emits o @@ -113,7 +109,7 @@ withAccumulator ty cont = do emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) - void $ emit $ RefOp ref $ MExtend baseMonoid ct + void $ emitLin $ RefOp ref $ MExtend baseMonoid ct -- === actual pass === @@ -190,7 +186,7 @@ transposeOp op ct = case op of DAMOp _ -> error "unreachable" -- TODO: rule out statically RefOp refArg m -> do refArg' <- substNonlin refArg - let emitEff = emit . RefOp refArg' + let emitEff = emitLin . RefOp refArg' case m of MAsk -> do baseMonoid <- tangentBaseMonoidFor (getType ct) @@ -251,9 +247,7 @@ transposeAtom atom ct = case atom of PtrVar _ _ -> notTangent Var v -> do lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> - -- XXX: we seem to need this case, but it feels like it should be an error! - return () + RenameNonlin _ -> error "nonlinear" LinRef ref -> emitCTToRef ref ct LinTrivial -> return () StuckProject _ _ -> error "not linear" @@ -266,7 +260,7 @@ transposeHof hof ct = case hof of For ann ixTy' lam -> do UnaryLamExpr b body <- return lam ixTy <- substNonlin ixTy' - void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do + void $ emitLin =<< mkFor (getNameHint b) (flipDir ann) ixTy \i -> do ctElt <- tabApp (sink ct) (toAtom i) extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal From 5b7eab452b5f698c8fede5e9c669afb13c295910 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 15 Nov 2023 21:56:05 -0500 Subject: [PATCH 14/41] Rewrite the live view system to be explicitly incremental by dealing in diffs. Previously we achieved incremental updates using a sort of "rebuild the world but memoize all the things" strategy. I've come around to the idea that it's better to deal with updates explicitly. The new pattern could be called "defunctionalized state". You maintain a state but instead of being able to update it arbitrarily (e.g. using `put` in the state monad) you can only update it in a certain number of predefined ways, expressed as a "diff" data type. These diffs can be sent between servers, serialized to JSON and sent to the browser and so on. And they're expected to be monoidal so you can batch a bunch of diffs together for efficiency and even hope for some cancellation, say of pushes and pops. The other part of the pattern is to build "state servers" that serve the current state using these diffs. When you subscribe to them they send you a current state and an endless stream of updates. You can wire the servers together to achieve an incrementally updated pipeline. This worked out really well! I was even able to add a new feature: tracking the current state of a cell (pending/running/complete) in the web view. And the whole thing is much snappier because every piece of it is just working with updates instead of full states. --- dex.cabal | 4 +- src/dex.hs | 12 +- src/lib/Actor.hs | 265 ++++++++++++--- src/lib/IncState.hs | 101 ++++++ src/lib/Live/Eval.hs | 715 ++++++++++++++++----------------------- src/lib/Live/Terminal.hs | 82 ----- src/lib/Live/Web.hs | 29 +- src/lib/MonadUtil.hs | 51 +++ src/lib/PPrint.hs | 31 +- src/lib/TopLevel.hs | 1 + src/lib/Util.hs | 1 + static/index.js | 111 +++--- static/style.css | 11 + 13 files changed, 762 insertions(+), 652 deletions(-) create mode 100644 src/lib/IncState.hs delete mode 100644 src/lib/Live/Terminal.hs create mode 100644 src/lib/MonadUtil.hs diff --git a/dex.cabal b/dex.cabal index e3737f4e3..157158b66 100644 --- a/dex.cabal +++ b/dex.cabal @@ -58,6 +58,7 @@ library , Generalize , Imp , ImpToLLVM + , IncState , Inference , Inline , IRVariants @@ -72,6 +73,7 @@ library , Linearize , Logging , Lower + , MonadUtil , MTL1 , Name , Occurrence @@ -103,7 +105,6 @@ library if flag(live) exposed-modules: Actor , Live.Eval - , Live.Terminal , Live.Web , RenderHtml other-modules: Paths_dex @@ -136,6 +137,7 @@ library -- Serialization , aeson , store + , time -- Floating-point pedanticness (correcting for GHC < 9.2.2) , floating-bits if flag(live) diff --git a/src/dex.hs b/src/dex.hs index 20e90a885..5232ec9c5 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -22,7 +22,7 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Map.Strict as M -import PPrint (toJSONStr, printResult) +import PPrint (resultAsJSON, printResult) import TopLevel import Err import Name @@ -30,7 +30,7 @@ import AbstractSyntax (parseTopDeclRepl) import ConcreteSyntax (keyWordStrs, preludeImportBlock) #ifdef DEX_LIVE import RenderHtml -import Live.Terminal (runTerminal) +-- import Live.Terminal (runTerminal) import Live.Web (runWeb) #endif import Core @@ -84,14 +84,10 @@ runMode evalMode opts = case evalMode of _ -> liftIO $ putStrLn $ pprint result ClearCache -> clearCache #ifdef DEX_LIVE - -- These are broken if the prelude produces any arrays because the blockId - -- counter restarts at zero. TODO: make prelude an implicit import block WebMode fname -> do env <- loadCache runWeb fname opts env - WatchMode fname -> do - env <- loadCache - runTerminal fname opts env + WatchMode _ -> error "not implemented" #endif printIncrementalSource :: DocFmt -> SourceBlock -> IO () @@ -106,7 +102,7 @@ printIncrementalSource fmt sb = case fmt of printIncrementalResult :: DocFmt -> Result -> IO () printIncrementalResult fmt result = case fmt of ResultOnly -> case pprint result of [] -> return (); msg -> putStrLn msg - JSONDoc -> case toJSONStr result of "{}" -> return (); s -> putStrLn s + JSONDoc -> case resultAsJSON result of "{}" -> return (); s -> putStrLn s TextDoc -> do isatty <- queryTerminal stdOutput putStr $ printResult isatty result diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index 3fb452c06..18a835bc4 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -1,70 +1,229 @@ --- Copyright 2022 Google LLC +-- Copyright 2023 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Actor (PChan, sendPChan, sendOnly, subChan, - Actor, runActor, spawn, - LogServerMsg (..), logServer) where +{-# LANGUAGE UndecidableInstances #-} -import Control.Concurrent (Chan, forkIO, newChan, readChan, ThreadId, writeChan) -import Control.Monad.State.Strict +module Actor ( + ActorM, Actor (..), launchActor, send, selfMailbox, messageLoop, + sliceMailbox, SubscribeMsg (..), IncServer, IncServerT, FileWatcher, + StateServer, flushDiffs, handleSubscribeMsg, subscribe, subscribeIO, sendSync, + runIncServerT, launchFileWatcher + ) where -import Util (onFst, onSnd) +import Control.Concurrent +import Control.Monad +import Control.Monad.State.Strict hiding (get) +import Control.Monad.Reader +import qualified Data.ByteString as BS +import Data.IORef +import Data.Text.Encoding qualified as T +import Data.Text (Text) +import System.Directory (getModificationTime) +import GHC.Generics --- Micro-actors. +import IncState +import MonadUtil --- In this model, an "actor" is just an IO computation (presumably --- running on its own Haskell thread) that receives messages on a --- Control.Concurrent.Chan channel. The idea is that the actor thread --- only receives information (or synchronization) from other threads --- through messages sent on that one channel, and no other thread --- reads messages from that channel. +-- === Actor implementation === --- We start the actor with a two-way view of its input channel so it --- can subscribe itself to message streams by passing (a send-only --- view of) it to another actor. -type Actor a = Chan a -> IO () +newtype ActorM msg a = ActorM { runActorM :: ReaderT (Chan msg) IO a } + deriving (Functor, Applicative, Monad, MonadIO) --- We also define a send-only channel type, to help ourselves not --- accidentally read from channels we aren't supposed to. -newtype PChan a = PChan { sendPChan :: a -> IO () } +newtype Mailbox a = Mailbox { sendToMailbox :: a -> IO () } -sendOnly :: Chan a -> PChan a -sendOnly chan = PChan $ \ !x -> writeChan chan x +class (Show msg, MonadIO m) => Actor msg m | m -> msg where + selfChan :: m (Chan msg) -subChan :: (a -> b) -> PChan b -> PChan a -subChan f chan = PChan (sendPChan chan . f) +instance Show msg => Actor msg (ActorM msg) where + selfChan = ActorM ask --- Synchronously execute an actor. -runActor :: Actor a -> IO () -runActor actor = newChan >>= actor +instance Actor msg m => Actor msg (ReaderT r m) where selfChan = lift $ selfChan +instance Actor msg m => Actor msg (StateT s m) where selfChan = lift $ selfChan --- Asynchronously launch an actor. Immediately returns permission to --- kill that actor and to send it messages. -spawn :: Actor a -> IO (ThreadId, PChan a) -spawn actor = do +send :: MonadIO m => Mailbox msg -> msg -> m () +send chan msg = liftIO $ sendToMailbox chan msg + +selfMailbox :: Actor msg m => (a -> msg) -> m (Mailbox a) +selfMailbox asSelfMessage = do + chan <- selfChan + return $ Mailbox \msg -> writeChan chan (asSelfMessage msg) + +launchActor :: MonadIO m => ActorM msg () -> m (Mailbox msg) +launchActor m = liftIO do chan <- newChan - tid <- forkIO $ actor chan - return (tid, sendOnly chan) - --- A log server. Combines inputs monoidally and pushes incremental --- updates to subscribers. - -data LogServerMsg a = Subscribe (PChan a) - | Publish a - -logServer :: Monoid a => Actor (LogServerMsg a) -logServer self = flip evalStateT (mempty, []) $ forever $ do - msg <- liftIO $ readChan self - case msg of - Subscribe chan -> do - curVal <- gets fst - liftIO $ chan `sendPChan` curVal - modify $ onSnd (chan:) - Publish x -> do - modify $ onFst (<> x) - subscribers <- gets snd - mapM_ (liftIO . (`sendPChan` x)) subscribers + void $ forkIO $ runReaderT (runActorM m) chan + return $ Mailbox \msg -> writeChan chan msg + +messageLoop :: Actor msg m => (msg -> m ()) -> m () +messageLoop handleMessage = do + forever do + msg <- liftIO . readChan =<< selfChan + handleMessage msg + +sliceMailbox :: (b -> a) -> Mailbox a -> Mailbox b +sliceMailbox f (Mailbox sendMsg) = Mailbox $ sendMsg . f + +-- === Promises === + +newtype Promise a = Promise (MVar a) +newtype PromiseSetter a = PromiseSetter (MVar a) + +newPromise :: MonadIO m => m (Promise a, PromiseSetter a) +newPromise = do + v <- liftIO $ newEmptyMVar + return (Promise v, PromiseSetter v) + +waitForPromise :: MonadIO m => Promise a -> m a +waitForPromise (Promise v) = liftIO $ readMVar v + +setPromise :: MonadIO m => PromiseSetter a -> a -> m () +setPromise (PromiseSetter v) x = liftIO $ putMVar v x + +-- Message that expects a synchronous reponse +data SyncMsg msg response = SyncMsg msg (PromiseSetter response) + +sendSync :: MonadIO m => Mailbox (SyncMsg msg response) -> msg -> m response +sendSync mailbox msg = do + (result, resultSetter) <- newPromise + send mailbox (SyncMsg msg resultSetter) + waitForPromise result + + +-- === Diff server === + +data IncServerState s d = IncServerState + { subscribers :: [Mailbox d] + , bufferedUpdates :: d + , curIncState :: s } + deriving (Show, Generic) + +class (Monoid d, MonadIO m) => IncServer s d m | m -> s, m -> d where + getIncServerStateRef :: m (IORef (IncServerState s d)) + +data SubscribeMsg s d = Subscribe (SyncMsg (Mailbox d) s) deriving (Show) + +getIncServerState :: IncServer s d m => m (IncServerState s d) +getIncServerState = readRef =<< getIncServerStateRef + +updateIncServerState :: IncServer s d m => (IncServerState s d -> IncServerState s d) -> m () +updateIncServerState f = do + ref <- getIncServerStateRef + prev <- readRef ref + writeRef ref $ f prev + +handleSubscribeMsg :: IncServer s d m => SubscribeMsg s d -> m () +handleSubscribeMsg (Subscribe (SyncMsg newSub response)) = do + flushDiffs + updateIncServerState \s -> s { subscribers = newSub : subscribers s } + curState <- curIncState <$> getIncServerState + setPromise response curState + +flushDiffs :: IncServer s d m => m () +flushDiffs = do + d <- bufferedUpdates <$> getIncServerState + updateIncServerState \s -> s { bufferedUpdates = mempty } + subs <- subscribers <$> getIncServerState + -- TODO: consider testing for emptiness here + forM_ subs \sub -> send sub d + +type StateServer s d = Mailbox (SubscribeMsg s d) + +subscribe :: Actor msg m => (d -> msg) -> StateServer s d -> m s +subscribe inject server = do + updateChannel <- selfMailbox inject + sendSync (sliceMailbox Subscribe server) updateChannel + +subscribeIO :: StateServer s d -> IO (s, Chan d) +subscribeIO server = do + chan <- newChan + let mailbox = Mailbox (writeChan chan) + s <- sendSync (sliceMailbox Subscribe server) mailbox + return (s, chan) + +newtype IncServerT s d m a = IncServerT { runIncServerT' :: ReaderT (Ref (IncServerState s d)) m a } + deriving (Functor, Applicative, Monad, MonadIO, Actor msg, FreshNames name, MonadTrans) + +instance (MonadIO m, IncState s d) => IncServer s d (IncServerT s d m) where + getIncServerStateRef = IncServerT ask + +instance (MonadIO m, IncState s d) => DefuncState d (IncServerT s d m) where + update d = updateIncServerState \s -> s + { bufferedUpdates = bufferedUpdates s <> d + , curIncState = curIncState s `applyDiff` d} + +instance (MonadIO m, IncState s d) => LabelReader (SingletonLabel s) (IncServerT s d m) where + getl It = curIncState <$> getIncServerState + +runIncServerT :: (MonadIO m, IncState s d) => s -> IncServerT s d m a -> m a +runIncServerT s cont = do + ref <- newRef $ IncServerState [] mempty s + runReaderT (runIncServerT' cont) ref + +-- === Refs === +-- Just a wrapper around IORef lifted to `MonadIO` + +type Ref = IORef + +newRef :: MonadIO m => a -> m (Ref a) +newRef = liftIO . newIORef + +readRef :: MonadIO m => Ref a -> m a +readRef = liftIO . readIORef + +writeRef :: MonadIO m => Ref a -> a -> m () +writeRef ref val = liftIO $ writeIORef ref val + +-- === Clock === + +-- Provides a periodic clock signal. The time interval is in microseconds. +launchClock :: MonadIO m => Int -> Mailbox () -> m () +launchClock intervalMicroseconds mailbox = + liftIO $ void $ forkIO $ forever do + threadDelay intervalMicroseconds + send mailbox () + +-- === File watcher === + +type SourceFileContents = Text +type FileWatcher = StateServer SourceFileContents (Overwrite SourceFileContents) + +readFileContents :: MonadIO m => FilePath -> m Text +readFileContents path = liftIO $ T.decodeUtf8 <$> BS.readFile path + +data FileWatcherMsg = + ClockSignal_FW () + | Subscribe_FW (SubscribeMsg Text (Overwrite Text)) + deriving (Show) + +launchFileWatcher :: MonadIO m => FilePath -> m FileWatcher +launchFileWatcher path = sliceMailbox Subscribe_FW <$> launchActor (fileWatcherImpl path) + +fileWatcherImpl :: FilePath -> ActorM FileWatcherMsg () +fileWatcherImpl path = do + initContents <- readFileContents path + t0 <- liftIO $ getModificationTime path + launchClock 100000 =<< selfMailbox ClockSignal_FW + modTimeRef <- newRef t0 + runIncServerT initContents $ messageLoop \case + Subscribe_FW msg -> handleSubscribeMsg msg + ClockSignal_FW () -> do + tOld <- readRef modTimeRef + tNew <- liftIO $ getModificationTime path + when (tNew /= tOld) do + newContents <- readFileContents path + update $ OverwriteWith newContents + flushDiffs + writeRef modTimeRef tNew + +-- === instances === + +instance Show msg => Show (SyncMsg msg response) where + show (SyncMsg msg _) = show msg + +instance Show (Mailbox a) where + show _ = "mailbox" +deriving instance Actor msg m => Actor msg (FreshNameT m) diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs new file mode 100644 index 000000000..19d0a0884 --- /dev/null +++ b/src/lib/IncState.hs @@ -0,0 +1,101 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} + +module IncState ( + IncState (..), MapEltUpdate (..), MapUpdate (..), + Overwrite (..), TailUpdate (..)) where + +import qualified Data.Map.Strict as M +import GHC.Generics + +-- === IncState === + +class Monoid d => IncState s d where + applyDiff :: s -> d -> s + +-- === Diff utils === + +data MapEltUpdate v = + Create v + | Update v + | Delete + deriving (Functor, Show, Generic) + +data MapUpdate k v = MapUpdate { mapUpdates :: M.Map k (MapEltUpdate v) } + deriving (Functor, Show, Generic) + +instance Ord k => Monoid (MapUpdate k v) where + mempty = MapUpdate mempty + +instance Ord k => Semigroup (MapUpdate k v) where + MapUpdate m1 <> MapUpdate m2 = MapUpdate $ + M.mapMaybe id (M.intersectionWith combineElts m1 m2) + <> M.difference m1 m2 + <> M.difference m2 m1 + where combineElts e1 e2 = case e1 of + Create _ -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Update v -> Just $ Create v + Delete -> Nothing + Update _ -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Update v -> Just $ Update v + Delete -> Just $ Delete + Delete -> case e2 of + Create v -> Just $ Update v + Update _ -> error "shouldn't be updating a node that doesn't exist" + Delete -> error "shouldn't be deleting a node that doesn't exist" + +instance Ord k => IncState (M.Map k v) (MapUpdate k v) where + applyDiff m (MapUpdate updates) = + M.mapMaybe id (M.intersectionWith applyEltUpdate m updates) + <> M.difference m updates + <> M.mapMaybe applyEltCreation (M.difference updates m) + where applyEltUpdate _ = \case + Create _ -> error "key already exists" + Update v -> Just v + Delete -> Nothing + applyEltCreation = \case + Create v -> Just v + Update _ -> error "key doesn't exist yet" + Delete -> error "key doesn't exist yet" + +data TailUpdate a = TailUpdate + { numDropped :: Int + , newTail :: [a] } + deriving (Show, Generic) + +instance Semigroup (TailUpdate a) where + TailUpdate n1 xs1 <> TailUpdate n2 xs2 = + let xs1Rem = length xs1 - n2 in + if xs1Rem >= 0 + then TailUpdate n1 (take xs1Rem xs1 <> xs2) -- n2 clobbered by xs1 + else TailUpdate (n1 - xs1Rem) xs2 -- xs1 clobbered by n2 + +instance Monoid (TailUpdate a) where + mempty = TailUpdate 0 [] + +instance IncState [a] (TailUpdate a) where + applyDiff xs (TailUpdate numDrop ys) = take (length xs - numDrop) xs <> ys + +-- Trivial diff that works for any type - just replace the old value with a completely new one. +data Overwrite a = NoChange | OverwriteWith a deriving (Show) + +instance Semigroup (Overwrite a) where + l <> r = case r of + OverwriteWith r' -> OverwriteWith r' + NoChange -> l + +instance Monoid (Overwrite a) where + mempty = NoChange + +instance IncState a (Overwrite a) where + applyDiff s = \case + NoChange -> s + OverwriteWith s' -> s' + diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 779d1a5ff..4912a8367 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -1,458 +1,311 @@ --- Copyright 2019 Google LLC +-- Copyright 2023 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Live.Eval (RFragment (..), SetVal(..), watchAndEvalFile) where +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wno-orphans #-} -import Control.Concurrent (forkIO, killThread, readChan, threadDelay, ThreadId) -import Control.Monad.Reader +module Live.Eval ( + watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, dagAsUpdate) where + +import Control.Concurrent +import Control.Monad import Control.Monad.State.Strict -import Data.ByteString qualified as BS +import qualified Data.Map.Strict as M +import Data.Aeson (ToJSON, ToJSONKey, toJSON, Value) +import Data.Functor ((<&>)) +import Data.Maybe (fromJust) import Data.Text (Text) -import Data.Text.Encoding qualified as T -import Data.Map.Strict qualified as M - -import Data.Aeson (ToJSON, toJSON, (.=)) -import Data.Aeson qualified as A -import Data.Text.Prettyprint.Doc -import System.Directory (getModificationTime) +import GHC.Generics -import ConcreteSyntax import Actor -import RenderHtml (ToMarkup, pprintHtml) -import TopLevel +import IncState import Types.Misc import Types.Source -import Util (onFst, onSnd) - -type NodeId = Int -data WithId a = WithId { getNodeId :: NodeId - , withoutId :: a } - deriving Show +import TopLevel +import ConcreteSyntax +import RenderHtml (ToMarkup, pprintHtml) +import MonadUtil -data RFragment = RFragment (SetVal [NodeId]) - (M.Map NodeId SourceBlock) - (M.Map NodeId Result) +-- === Top-level interface === --- Start watching and evaluating the given file. Returns a channel on --- which one can subscribe to updates to the evaluation state. --- --- The overall system looks like this: --- - `forkWatchFile` creates an actor that watches the file for --- changes and sends `FileChanged` messages to the driver. --- - `runDriver` creates the main driver actor, which manages --- the evaluation state and produces rendering fragments. --- - `logServer` creates an actor that accumulates rendering fragments --- from the driver and broadcasts them to any subscribed clients. --- --- `FileChanged` messages from the watch file actor may invalidate the --- current state. The driver delegates the actual evaluation to a --- sub-thread so it can remain responsive. --- -- `watchAndEvalFile` returns the channel by which a client may -- subscribe by sending a write-only view of its input channel. watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx - -> IO (PChan (PChan RFragment)) + -> IO (Evaluator SourceBlock Result) watchAndEvalFile fname opts env = do - (_, resultsChan) <- spawn logServer - let cfg = (opts, subChan Publish resultsChan) - (_, driverChan) <- spawn $ runDriver cfg env - forkWatchFile fname $ subChan FileChanged driverChan - return $ subChan Subscribe resultsChan - --- === executing blocks concurrently === - -type SourceContents = Text - -type DriverCfg = (EvalConfig, PChan RFragment) - --- The evaluation-in-progress state is --- - The (identified) current top-level environment --- - If a worker is currently running, its ThreadId and the --- SourceBlock it it working on (necessarily in the current --- top-level environment) --- - The list of blocks that remain to be evaluated (if any) after --- the current worker completes. If nonempty, there should be --- a current worker. --- This is consistent at entry and exit from handling each message, --- but may be briefly inconsistent while a message is being handled. -type WorkerSpec = Maybe (ThreadId, WithId SourceBlock) -data SourceEvalState = SourceEvalState - (WithId TopStateEx) WorkerSpec [WithId SourceBlock] - -initialEvalState :: TopStateEx -> SourceEvalState -initialEvalState env = (SourceEvalState (WithId 0 env) Nothing []) - -newtype DriverM a = DriverM - { drive :: (ReaderT DriverCfg - (ReaderT (PChan DriverEvent) - (StateT (SourceEvalState, CacheState) IO)) a) - } - deriving (Functor, Applicative, Monad, MonadIO) - -type EvalCache = M.Map (SourceBlock, WithId TopStateEx) (NodeId, WithId TopStateEx) -data CacheState = CacheState - { nextBlockId :: NodeId - , nextStateId :: NodeId - , evalCache :: EvalCache } - -emptyCache :: CacheState -emptyCache = CacheState 0 1 mempty - -class (Monad m, MonadIO m) => Driver m where - askOptions :: m EvalConfig - askResultsOutput :: m (PChan RFragment) - askSelf :: m (PChan DriverEvent) - getTopState :: m (WithId TopStateEx) - putTopState :: WithId TopStateEx -> m () - -- Resets the evaluation state to initial, from the given TopStateEx. - -- Returns the old top state and the old worker spec, for reuse - refresh :: TopStateEx -> m (WithId TopStateEx, WorkerSpec) - -- Get the work chunk we are waiting for, if any - getWorkingBlock :: m (Maybe (WithId SourceBlock)) - -- Run the action if there is no worker, otherwise don't - whenNoWorker :: m () -> m () - putWorker :: WorkerSpec -> m () - -- If a block is pending, remove it from the queue and run the - -- action on it, otherwise don't. - popPending :: (WithId SourceBlock -> m ()) -> m () - putPending :: [WithId SourceBlock] -> m () - newBlockId :: m Int - newStateId :: m Int - lookupCache :: SourceBlock - -> WithId TopStateEx - -> m (Maybe (NodeId, WithId TopStateEx)) - insertCache :: SourceBlock - -> WithId TopStateEx - -> (NodeId, WithId TopStateEx) - -> m () - --- The externally visible behavior of the main driver loop: --- - When the source file changes, send the new set of visible node IDs --- (`updateResultList`) to the `PChan RFragment` --- - When a new source block is discovered, assign an ID to it and send --- the association of that block with that ID (`makeNewBlockId`) --- - When a source block is successfully evaluated, associate the result --- with its ID and send that (inside `evalBlock`) - --- Internally, we implement this behavior with a driver thread that --- forks a worker thread. Why two threads? So the driver can notice --- if a source block in progress has disappeared from the file and --- kill the worker when that happens. - --- The worker communicates with the driver by sending a "work --- complete" message. Note that a worker due to be killed may send a --- "work complete" message before the driver actually kills it. If a --- "file changed" message arrived in the interim, the TopState the --- worker delivers remains valid to enter into the cache, but should --- not change the driver's then-current TopState. - --- For this reason, the WorkComplete message contains the ids of the --- TopStateEx and SourceBlock that the woker evaluated. - -data DriverEvent = FileChanged SourceContents - | WorkComplete (WithId TopStateEx) (WithId SourceBlock) (Result, TopStateEx) - -runDriver :: DriverCfg -> TopStateEx -> Actor DriverEvent -runDriver cfg env self = do - liftM fst - $ flip runStateT (initialEvalState env, emptyCache) - $ flip runReaderT (sendOnly self) - $ flip runReaderT cfg - $ drive $ forever $ do - msg <- liftIO $ readChan self - case msg of - (FileChanged source) -> evalSource env source - (WorkComplete block topState payload) -> processWork block topState payload - --- Start evaluation of the (updated) source file in the given (fresh) --- evaluation state. The evaluation state carried in the monad is --- still the state as of the end of the previous message. -evalSource :: Driver m => TopStateEx -> SourceContents -> m () -evalSource env source = do - -- Save the old state from the monad, because we need to kill or - -- reuse the worker from it. - (oldTopState, oldWorker) <- refresh env - let UModule _ _ blocks = parseUModule Main source - (evaluated, remaining) <- tryEvalBlocksCached blocks - (reused, remaining') <- tryReuseWorker oldTopState oldWorker remaining - remaining'' <- mapM makeNewBlockId remaining' - updateResultList $ map getNodeId $ evaluated ++ reused ++ remaining'' - putPending $ reused ++ remaining'' - maybeLaunchWorker - --- See which blocks already have completed values and reuse those. -tryEvalBlocksCached :: Driver m - => [SourceBlock] - -> m ([WithId SourceBlock], [SourceBlock]) -tryEvalBlocksCached [] = return ([], []) -tryEvalBlocksCached blocks@(block:rest) = do - env <- getTopState - res <- lookupCache block env - case res of - Nothing -> return ([], blocks) - Just (blockId, env') -> do - let block' = WithId blockId block - putTopState env' - (evaluated, remaining) <- tryEvalBlocksCached rest - return (block':evaluated, remaining) - --- See whether the formerly active worker (if any) is still doing --- something useful given the list of blocks we are waiting to finish; --- if so reuse it, and if not kill it. -tryReuseWorker :: Driver m - => WithId TopStateEx - -> WorkerSpec - -> [SourceBlock] - -> m ([WithId SourceBlock], [SourceBlock]) -tryReuseWorker _ w [] = - liftIO (forM_ w (killThread . fst)) >> return ([], []) -tryReuseWorker _ Nothing blocks = - return ([], blocks) -tryReuseWorker oldEnv w@(Just (_, oldNext)) (next:rest) = do - curEnv <- getTopState - if (curEnv == oldEnv) && (withoutId oldNext == next) then do - -- Reuse the worker - putWorker w - return ([oldNext], rest) - else - liftIO (forM_ w (killThread . fst)) >> return ([], next:rest) - -processWork :: Driver m - => WithId TopStateEx - -> WithId SourceBlock - -> (Result, TopStateEx) - -> m () -processWork oldState block answer = do - -- The computed result is true regardless of whether this is the - -- worker we are waiting for or not, and therefore safe to cache - -- outside the `when` clause. There is a narrow benefit here: if a - -- worker completes normally while we're processing a FileChanged - -- message, it can send a sound WorkComplete message before we - -- actually kill it. We record that result in case the user edits - -- back to a state where it can be shown. - newState <- recordTruth oldState block answer - curState <- getTopState - waitingFor <- getWorkingBlock - when (oldState == curState - && (fmap withoutId waitingFor == Just (withoutId block))) $ do - -- We only update our working state if this message is, in fact, - -- from the worker we are currently waiting for. - rotateWorkingState newState - --- Record what the worker computed in our cache of truths, and return --- the updated environment. This is sound regardless of whether we --- are waiting for this evaluation or not. -recordTruth :: Driver m - => WithId TopStateEx - -> WithId SourceBlock - -> (Result, TopStateEx) - -> m (WithId TopStateEx) -recordTruth oldState (WithId blockId block) (result, s) = do - resultsChan <- askResultsOutput - liftIO $ resultsChan `sendPChan` oneResult blockId result - newState <- makeNewStateId s - insertCache block oldState (blockId, newState) - return newState - --- Update our current evaluation state assuming the work we were --- waiting for was just completed with the given new evaluation --- environment. -rotateWorkingState :: Driver m => WithId TopStateEx -> m () -rotateWorkingState newState = do - putTopState newState - putWorker Nothing -- Worker finished - maybeLaunchWorker - --- === DriverM utils === - --- If we have work to do but no worker doing it, launch such a worker. -maybeLaunchWorker :: (Driver m) => m () -maybeLaunchWorker = do - whenNoWorker $ popPending \next -> do - curState <- getTopState - opts <- askOptions - self <- askSelf - tid <- liftIO $ forkWorker opts curState next self - putWorker $ Just (tid, next) - -forkWorker :: EvalConfig -> WithId TopStateEx -> WithId SourceBlock - -> PChan DriverEvent -> IO ThreadId -forkWorker opts curState block chan = forkIO $ do - result <- evalSourceBlockIO opts (withoutId curState) (withoutId block) - chan `sendPChan` (WorkComplete curState block result) - -makeNewBlockId :: Driver m => SourceBlock -> m (WithId SourceBlock) -makeNewBlockId block = do - newId <- newBlockId - resultsChan <- askResultsOutput - liftIO $ resultsChan `sendPChan` oneSourceBlock newId block - return $ WithId newId block - -makeNewStateId :: Driver m => TopStateEx -> m (WithId TopStateEx) -makeNewStateId env = do - newId <- newStateId - return $ WithId newId env - --- === utils for sending results === - -updateResultList :: Driver m => [NodeId] -> m () -updateResultList ids = do - resultChan <- askResultsOutput - liftIO $ resultChan `sendPChan` RFragment (Set ids) mempty mempty - -oneResult :: NodeId -> Result -> RFragment -oneResult k r = RFragment mempty mempty (M.singleton k r) - -oneSourceBlock :: NodeId -> SourceBlock -> RFragment -oneSourceBlock k b = RFragment mempty (M.singleton k b) mempty - --- === watching files === - --- A non-Actor source. Sends file contents to channel whenever file --- is modified. -forkWatchFile :: FilePath -> PChan Text -> IO () -forkWatchFile fname chan = onmod fname $ sendFileContents fname chan - -sendFileContents :: String -> PChan Text -> IO () -sendFileContents fname chan = do - putStrLn $ fname ++ " updated" - s <- T.decodeUtf8 <$> BS.readFile fname - sendPChan chan s - -onmod :: FilePath -> IO () -> IO () -onmod fname action = do - action - t <- getModificationTime fname - void $ forkIO $ loop t - where - loop t = do - t' <- getModificationTime fname - threadDelay 100000 - unless (t == t') action - loop t' - --- === instances === - -instance Driver DriverM where - askOptions = DriverM $ asks fst - askResultsOutput = DriverM $ asks snd - askSelf = DriverM $ lift $ ask - getTopState = DriverM $ do - (SourceEvalState s _ _) <- gets fst - return s - - putTopState s = DriverM $ modify $ onFst \(SourceEvalState _ w blocks) - -> (SourceEvalState s w blocks) - - refresh env = DriverM $ do - (SourceEvalState oldState oldWorker _) <- gets fst - modify $ onFst $ const $ initialEvalState env - return (oldState, oldWorker) - - getWorkingBlock = DriverM $ do - (SourceEvalState _ w _) <- gets fst - return $ (fmap snd) w - - whenNoWorker (DriverM action) = DriverM $ do - (SourceEvalState _ w _) <- gets fst - case w of - (Just _) -> return () - Nothing -> action + watcher <- launchFileWatcher fname + parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source + launchDagEvaluator parser env (evalSourceBlockIO opts) - putWorker w = DriverM $ modify $ onFst \(SourceEvalState s _ blocks) - -> (SourceEvalState s w blocks) +type ResultsServer = Evaluator SourceBlock Result +type ResultsUpdate = EvalStatusUpdate SourceBlock Result - popPending action = do - (SourceEvalState _ _ curPending) <- DriverM $ gets fst - case curPending of - [] -> return () - (next:rest) -> do - DriverM $ modify $ onFst \(SourceEvalState s w _) - -> (SourceEvalState s w rest) - action next +-- === DAG diff state === - putPending blocks = DriverM $ modify $ onFst \(SourceEvalState s w _) - -> (SourceEvalState s w blocks) +-- We intend to make this an arbitrary Dag at some point but for now we just +-- assume that dependence is just given by the top-to-bottom ordering of blocks +-- within the file. - lookupCache block env = DriverM $ do - cache <- gets (evalCache . snd) - return $ M.lookup (block, env) cache - - newBlockId = DriverM $ do - newId <- gets $ nextBlockId . snd - modify $ onSnd \cache -> cache {nextBlockId = newId + 1 } - return newId - - newStateId = DriverM $ do - newId <- gets $ nextStateId . snd - modify $ onSnd \cache -> cache {nextStateId = newId + 1 } - return newId - - insertCache block env val = DriverM $ modify $ onSnd \cache -> - cache { evalCache = M.insert (block, env) val $ evalCache cache } - -instance Semigroup RFragment where - (RFragment x y z) <> (RFragment x' y' z') = RFragment (x<>x') (y<>y') (z<>z') - -instance Monoid RFragment where - mempty = RFragment mempty mempty mempty - -instance Eq (WithId a) where - (==) (WithId x _) (WithId y _) = x == y - -instance Ord (WithId a) where - compare (WithId x _) (WithId y _) = compare x y - -instance ToJSON a => ToJSON (SetVal a) where - toJSON (Set x) = A.object ["val" .= toJSON x] - toJSON NotSet = A.Null - -instance (ToJSON k, ToJSON v) => ToJSON (MonMap k v) where - toJSON (MonMap m) = toJSON (M.toList m) - -instance ToJSON RFragment where - toJSON (RFragment ids blocks results) = toJSON (ids, contents) - where contents = MonMap (M.map toHtmlFragment blocks) - <> MonMap (M.map toHtmlFragment results) - -type TreeAddress = [Int] -type HtmlFragment = [(TreeAddress, String)] - -toHtmlFragment :: ToMarkup a => a -> HtmlFragment -toHtmlFragment x = [([], pprintHtml x)] - -instance Pretty SourceEvalState where - pretty (SourceEvalState env worker pending) = - "In env ID" <+> pretty (getNodeId env) <> line - <> "waiting for" <+> pretty (show worker) <+> "to evaluate" <> line - <> pretty (map prettify pending) where - prettify (WithId blockId block) = (blockId, block) - -instance Pretty DriverEvent where - pretty (FileChanged contents) = "New file contents" <> line <> pretty contents - pretty (WorkComplete env (WithId blockId block) (result, _)) = - "Finished evaluating" <+> pretty (blockId, block) - <+> "in env with ID" <+> pretty (getNodeId env) - <+> "got" <+> pretty result - --- === some handy monoids === - -data SetVal a = Set a | NotSet +type NodeId = Int -instance Semigroup (SetVal a) where - x <> NotSet = x - _ <> Set x = Set x +data NodeList a = NodeList + { orderedNodes :: [NodeId] + , nodeMap :: M.Map NodeId a } + deriving (Show, Generic) + +data NodeListUpdate a = NodeListUpdate + { orderedNodesUpdate :: TailUpdate NodeId + , nodeMapUpdate :: MapUpdate NodeId a } + deriving (Show, Functor, Generic) + +instance Semigroup (NodeListUpdate a) where + NodeListUpdate x1 y1 <> NodeListUpdate x2 y2 = NodeListUpdate (x1<>x2) (y1<>y2) + +instance Monoid (NodeListUpdate a) where + mempty = NodeListUpdate mempty mempty + +instance IncState (NodeList a) (NodeListUpdate a) where + applyDiff (NodeList m xs) (NodeListUpdate dm dxs) = + NodeList (applyDiff m dm) (applyDiff xs dxs) + +type Dag = NodeList +type DagUpdate = NodeListUpdate + +dagAsUpdate :: Dag a -> DagUpdate a +dagAsUpdate (NodeList xs m)= NodeListUpdate (TailUpdate 0 xs) (MapUpdate $ fmap Create m) + +emptyNodeList :: NodeList a +emptyNodeList = NodeList [] mempty + +buildNodeList :: FreshNames NodeId m => [a] -> m (NodeList a) +buildNodeList vals = do + nodeList <- forM vals \val -> do + nodeId <- freshName + return (nodeId, val) + return $ NodeList (fst <$> nodeList) (M.fromList nodeList) + +commonPrefixLength :: Eq a => [a] -> [a] -> Int +commonPrefixLength (x:xs) (y:ys) | x == y = 1 + commonPrefixLength xs ys +commonPrefixLength _ _ = 0 + +nodeListVals :: NodeList a -> [a] +nodeListVals nodes = orderedNodes nodes <&> \k -> fromJust $ M.lookup k (nodeMap nodes) + +computeNodeListUpdate :: (Eq a, FreshNames NodeId m) => NodeList a -> [a] -> m (NodeListUpdate a) +computeNodeListUpdate nodes newVals = do + let prefixLength = commonPrefixLength (nodeListVals nodes) newVals + let oldTail = drop prefixLength $ orderedNodes nodes + NodeList newTail nodesCreated <- buildNodeList $ drop prefixLength newVals + let nodeUpdates = fmap Create nodesCreated <> M.fromList (fmap (,Delete) oldTail) + return $ NodeListUpdate (TailUpdate (length oldTail) newTail) (MapUpdate nodeUpdates) + +-- === Cell parser === + +-- This coarsely parses the full file into blocks and forms a DAG (for now a +-- trivial one assuming all top-to-bottom dependencies) of the results. + +type CellParser a = StateServer (Dag a) (DagUpdate a) + +data CellParserMsg a = + Subscribe_CP (SubscribeMsg (Dag a) (DagUpdate a)) + | Update_CP (Overwrite Text) + deriving (Show) + +launchCellParser :: (Eq a, MonadIO m) => FileWatcher -> (Text -> [a]) -> m (CellParser a) +launchCellParser fileWatcher parseCells = + sliceMailbox Subscribe_CP <$> launchActor (cellParserImpl fileWatcher parseCells) + +cellParserImpl :: Eq a => FileWatcher -> (Text -> [a]) -> ActorM (CellParserMsg a) () +cellParserImpl fileWatcher parseCells = runFreshNameT do + initContents <- subscribe Update_CP fileWatcher + initNodeList <- buildNodeList $ parseCells initContents + runIncServerT initNodeList $ messageLoop \case + Subscribe_CP msg -> handleSubscribeMsg msg + Update_CP NoChange -> return () + Update_CP (OverwriteWith newContents) -> do + let newCells = parseCells newContents + curNodeList <- getl It + update =<< computeNodeListUpdate curNodeList newCells + flushDiffs + +-- === Dag evaluator === + +-- This is where we track the state of evaluation and decide what we needs to be +-- run and what needs to be killed. + +type Evaluator i o = StateServer (EvalStatus i o) (EvalStatusUpdate i o) +newtype EvaluatorM s i o a = + EvaluatorM { runEvaluatorM' :: + IncServerT (EvalStatus i o) (EvalStatusUpdate i o) + (StateT (EvaluatorState s i o) + (ActorM (EvaluatorMsg s i o))) a } + deriving (Functor, Applicative, Monad, MonadIO, + Actor (EvaluatorMsg s i o), + IncServer (EvalStatus i o) (EvalStatusUpdate i o)) + +instance DefuncState (EvaluatorMUpdate s i o) (EvaluatorM s i o) where + update = \case + UpdateDagEU dag -> EvaluatorM $ update dag + UpdateCurJob status -> EvaluatorM $ lift $ modify \s -> s { curRunningJob = status } + UpdateEnvs envs -> EvaluatorM $ lift $ modify \s -> s { prevEnvs = envs} + AppendEnv env -> do + envs <- getl PrevEnvs + update $ UpdateEnvs $ envs ++ [env] + UpdateJobStatus nodeId status -> do + NodeState i _ <- fromJust <$> getl (NodeInfo nodeId) + let newState = NodeState i status + update $ UpdateDagEU $ NodeListUpdate mempty $ MapUpdate $ M.singleton nodeId (Update newState) + +instance LabelReader (EvaluatorMLabel s i o) (EvaluatorM s i o) where + getl l = case l of + NodeListEM -> EvaluatorM $ orderedNodes <$> getl It + NodeInfo nodeId -> EvaluatorM $ M.lookup nodeId <$> nodeMap <$> getl It + PrevEnvs -> EvaluatorM $ lift $ prevEnvs <$> get + CurRunningJob -> EvaluatorM $ lift $ curRunningJob <$> get + EvalFun -> EvaluatorM $ lift $ evalFun <$> get + +data EvaluatorMUpdate s i o = + UpdateDagEU (NodeListUpdate (NodeState i o)) + | UpdateJobStatus NodeId (NodeEvalStatus o) + | UpdateCurJob CurJobStatus + | UpdateEnvs [s] + | AppendEnv s + +data EvaluatorMLabel s i o a where + NodeListEM :: EvaluatorMLabel s i o [NodeId] + NodeInfo :: NodeId -> EvaluatorMLabel s i o (Maybe (NodeState i o)) + PrevEnvs :: EvaluatorMLabel s i o [s] + CurRunningJob :: EvaluatorMLabel s i o (CurJobStatus) + EvalFun :: EvaluatorMLabel s i o (EvalFun s i o) + +-- The envs after each cell evaluated so far +type EvalFun s i o = s -> i -> IO (o, s) +type CurJobStatus = Maybe (ThreadId, NodeId, CellIndex) + +data EvaluatorState s i o = EvaluatorState + { prevEnvs :: [s] + , evalFun :: EvalFun s i o + , curRunningJob :: CurJobStatus } + +data NodeEvalStatus o = + Waiting + | Running + | Complete o + deriving (Show, Generic) + +data NodeState i o = NodeState i (NodeEvalStatus o) deriving (Show, Generic) + +type Show3 s i o = (Show s, Show i, Show o) + +type EvalStatus i o = NodeList (NodeState i o) +type EvalStatusUpdate i o = NodeListUpdate (NodeState i o) + +type CellIndex = Int -- index in the list of cells, not the NodeId + +data EvaluatorMsg s i o = + SourceUpdate (DagUpdate i) + | JobComplete ThreadId s o + | Subscribe_E (SubscribeMsg (EvalStatus i o) (EvalStatusUpdate i o)) + deriving (Show) + +initEvaluatorState :: s -> EvalFun s i o -> EvaluatorState s i o +initEvaluatorState s evalCell = EvaluatorState [s] evalCell Nothing + +launchDagEvaluator :: (Show3 s i o, MonadIO m) => CellParser i -> s -> EvalFun s i o -> m (Evaluator i o) +launchDagEvaluator cellParser env evalCell = do + mailbox <- launchActor do + let s = initEvaluatorState env evalCell + void $ flip runStateT s $ runIncServerT emptyNodeList $ runEvaluatorM' $ + dagEvaluatorImpl cellParser + return $ sliceMailbox Subscribe_E mailbox + +dagEvaluatorImpl :: (Show3 s i o) => CellParser i -> EvaluatorM s i o () +dagEvaluatorImpl cellParser = do + initDag <- subscribe SourceUpdate cellParser + processDagUpdate (dagAsUpdate initDag) >> flushDiffs + launchNextJob + messageLoop \case + Subscribe_E msg -> handleSubscribeMsg msg + SourceUpdate dagUpdate -> do + processDagUpdate dagUpdate + flushDiffs + JobComplete threadId env result -> do + processJobComplete threadId env result + flushDiffs + +processJobComplete :: (Show3 s i o) => ThreadId -> s -> o -> EvaluatorM s i o () +processJobComplete threadId newEnv result = do + getl CurRunningJob >>= \case + Just (expectedThreadId, nodeId, _) -> do + when (threadId == expectedThreadId) do -- otherwise it's a zombie + update $ UpdateJobStatus nodeId (Complete result) + update $ UpdateCurJob Nothing + update $ AppendEnv newEnv + launchNextJob + Nothing -> return () -- this job is a zombie + +nextJobIndex :: EvaluatorM s i o Int +nextJobIndex = do + envs <- getl PrevEnvs + return $ length envs - 1 + +launchNextJob :: (Show3 s i o) => EvaluatorM s i o () +launchNextJob = do + jobIndex <- nextJobIndex + nodeList <- getl NodeListEM + when (jobIndex < length nodeList) do -- otherwise we're all done + curEnv <- (!! jobIndex) <$> getl PrevEnvs + let nodeId = nodeList !! jobIndex + launchJob jobIndex nodeId curEnv + +launchJob :: (Show3 s i o) => CellIndex -> NodeId -> s -> EvaluatorM s i o () +launchJob jobIndex nodeId env = do + jobAction <- getl EvalFun + NodeState source _ <- fromJust <$> getl (NodeInfo nodeId) + resultMailbox <- selfMailbox id + threadId <- liftIO $ forkIO do + threadId <- myThreadId + (result, finalEnv) <- jobAction env source + send resultMailbox $ JobComplete threadId finalEnv result + update $ UpdateJobStatus nodeId Running + update $ UpdateCurJob (Just (threadId, nodeId, jobIndex)) + +computeNumValidCells :: DagUpdate i -> EvaluatorM s i o Int +computeNumValidCells dagUpdate = do + let nDropped = numDropped $ orderedNodesUpdate dagUpdate + nTotal <- length <$> getl NodeListEM + return $ nTotal - nDropped + +processDagUpdate :: (Show3 s i o) => DagUpdate i -> EvaluatorM s i o () +processDagUpdate dagUpdate = do + nValid <- computeNumValidCells dagUpdate + envs <- getl PrevEnvs + update $ UpdateEnvs $ take (nValid + 1) envs + update $ UpdateDagEU $ fmap (\i -> NodeState i Waiting) dagUpdate + getl CurRunningJob >>= \case + Nothing -> launchNextJob + Just (threadId, _, jobIndex) + | (jobIndex >= nValid) -> do + -- Current job is no longer valid. Kill it and restart. + liftIO $ killThread threadId + update $ UpdateCurJob Nothing + launchNextJob + | otherwise -> return () -- Current job is fine. Let it continue. -instance Monoid (SetVal a) where - mempty = NotSet +-- === instances === -newtype MonMap k v = MonMap (M.Map k v) deriving (Show, Eq) +instance ToJSON a => ToJSON (NodeListUpdate a) +instance (ToJSON a, ToJSONKey k) => ToJSON (MapUpdate k a) +instance ToJSON a => ToJSON (TailUpdate a) +instance ToJSON a => ToJSON (MapEltUpdate a) +instance ToJSON o => ToJSON (NodeEvalStatus o) +instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) -instance (Ord k, Semigroup v) => Semigroup (MonMap k v) where - MonMap m <> MonMap m' = MonMap $ M.unionWith (<>) m m' +instance ToJSON SourceBlock where toJSON = toJSONViaHtml +instance ToJSON Result where toJSON = toJSONViaHtml -instance (Ord k, Semigroup v) => Monoid (MonMap k v) where - mempty = MonMap mempty +toJSONViaHtml :: ToMarkup a => a -> Value +toJSONViaHtml x = toJSON $ pprintHtml x diff --git a/src/lib/Live/Terminal.hs b/src/lib/Live/Terminal.hs deleted file mode 100644 index c995ea64f..000000000 --- a/src/lib/Live/Terminal.hs +++ /dev/null @@ -1,82 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Live.Terminal (runTerminal) where - -import Control.Concurrent (Chan, readChan, forkIO) -import Control.Monad.State.Strict -import Data.Foldable (fold) -import qualified Data.Map.Strict as M - -import System.Console.ANSI (clearScreen, setCursorPosition) -import System.IO (BufferMode (..), hSetBuffering, stdin) - -import Actor -import Cat -import Live.Eval -import PPrint (printLitBlock) -import TopLevel - -runTerminal :: FilePath -> EvalConfig -> TopStateEx -> IO () -runTerminal fname opts env = do - resultsChan <- watchAndEvalFile fname opts env - displayResultsTerm resultsChan - -type DisplayPos = Int -data KeyboardCommand = ScrollUp | ScrollDown | ResetDisplay - -type TermDisplayM = StateT DisplayPos (CatT RFragment IO) - -displayResultsTerm :: PChan (PChan RFragment) -> IO () -displayResultsTerm resultsSubscribe = - runActor \self -> do - resultsSubscribe `sendPChan` subChan Left (sendOnly self) - void $ forkIO $ monitorKeyboard $ subChan Right (sendOnly self) - evalCatT $ flip evalStateT 0 $ forever $ termDisplayLoop self - -termDisplayLoop :: (Chan (Either RFragment KeyboardCommand)) -> TermDisplayM () -termDisplayLoop self = do - req <- liftIO $ readChan self - case req of - Left result -> extend result - Right command -> case command of - ScrollUp -> modify (+ 4) - ScrollDown -> modify (\p -> max 0 (p - 4)) - ResetDisplay -> put 0 - results <- look - pos <- get - case renderResults results of - Nothing -> return () - Just s -> liftIO $ do - let cropped = cropTrailingLines pos s - setCursorPosition 0 0 - clearScreen -- TODO: clean line-by-line instead - putStr cropped - -cropTrailingLines :: Int -> String -> String -cropTrailingLines n s = unlines $ reverse $ drop n $ reverse $ lines s - --- TODO: show incremental results -renderResults :: RFragment -> Maybe String -renderResults (RFragment NotSet _ _) = Nothing -renderResults (RFragment (Set ids) blocks results) = - liftM fold $ forM ids $ \i -> do - b <- M.lookup i blocks - r <- M.lookup i results - return $ printLitBlock True b r - --- A non-Actor source. Sends keyboard command signals as they occur. -monitorKeyboard :: PChan KeyboardCommand -> IO () -monitorKeyboard chan = do - hSetBuffering stdin NoBuffering - forever $ do - c <- getChar - case c of - 'k' -> chan `sendPChan` ScrollUp - 'j' -> chan `sendPChan` ScrollDown - 'q' -> chan `sendPChan` ResetDisplay - _ -> return () - diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index ad7715599..b4e4060d1 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -14,12 +14,12 @@ import Network.Wai (Application, StreamingBody, pathInfo, import Network.Wai.Handler.Warp (run) import Network.HTTP.Types (status200, status404) import Data.Aeson (ToJSON, encode) -import Data.Binary.Builder (fromByteString, Builder) +import Data.Binary.Builder (fromByteString) import Data.ByteString.Lazy (toStrict) +import qualified Data.ByteString as BS import Paths_dex (getDataFileName) -import Actor import Live.Eval import TopLevel @@ -29,7 +29,7 @@ runWeb fname opts env = do putStrLn "Streaming output to http://localhost:8000/" run 8000 $ serveResults resultsChan -serveResults :: ToJSON a => PChan (PChan a) -> Application +serveResults :: ResultsServer -> Application serveResults resultsSubscribe request respond = do print (pathInfo request) case pathInfo request of @@ -47,13 +47,18 @@ serveResults resultsSubscribe request respond = do fname <- getDataFileName dataFname respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing -resultStream :: ToJSON a => PChan (PChan a) -> StreamingBody -resultStream resultsSubscribe write flush = runActor \self -> do - write (makeBuilder ("start"::String)) >> flush - resultsSubscribe `sendPChan` (sendOnly self) - forever $ do msg <- readChan self - write (makeBuilder msg) >> flush +resultStream :: ResultsServer -> StreamingBody +resultStream resultsServer write flush = do + write (fromByteString $ encodeResults ("start"::String)) >> flush + (initResult, resultsChan) <- subscribeIO resultsServer + sendUpdate $ dagAsUpdate initResult + forever $ readChan resultsChan >>= sendUpdate + where + sendUpdate :: ResultsUpdate -> IO () + sendUpdate update = do + let s = encodeResults update + write (fromByteString s) >> flush -makeBuilder :: ToJSON a => a -> Builder -makeBuilder = fromByteString . toStrict . wrap . encode - where wrap s = "data:" <> s <> "\n\n" + encodeResults :: ToJSON a => a -> BS.ByteString + encodeResults = toStrict . wrap . encode + where wrap s = "data:" <> s <> "\n\n" diff --git a/src/lib/MonadUtil.hs b/src/lib/MonadUtil.hs new file mode 100644 index 000000000..6d75e2377 --- /dev/null +++ b/src/lib/MonadUtil.hs @@ -0,0 +1,51 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} + +module MonadUtil ( + DefuncState (..), LabelReader (..), SingletonLabel (..), FreshNames (..), + runFreshNameT, FreshNameT (..)) where + +import Control.Monad.Reader +import Control.Monad.State.Strict + +-- === Defunctionalized state === +-- Interface for state whose allowable updates are specified by a data type. +-- Useful for `IncState`, for specifying read-only env components, or +-- generally for specifying certain constraints on updates. + +class DefuncState d m | m -> d where + update :: d -> m () + +class LabelReader (l :: * -> *) m | m -> l where + getl :: l a -> m a + +data SingletonLabel a b where + It :: SingletonLabel a a + +-- === Fresh name monad === + +-- Used for ad-hoc names with no nested binders that don't need to be treated +-- carefully using the whole "foil" name system. + +class Monad m => FreshNames a m | m -> a where + freshName :: m a + +newtype FreshNameT m a = FreshNameT { runFreshNameT' :: StateT Int m a } + deriving (Functor, Applicative, Monad, MonadIO) + +instance MonadIO m => FreshNames Int (FreshNameT m) where + freshName = FreshNameT do + fresh <- get + put (fresh + 1) + return fresh + +instance FreshNames a m => FreshNames a (ReaderT r m) where + freshName = lift freshName + +runFreshNameT :: MonadIO m => FreshNameT m a -> m a +runFreshNameT cont = evalStateT (runFreshNameT' cont) 0 diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index af1c48f1e..d5e8473cd 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -9,11 +9,12 @@ {-# OPTIONS_GHC -Wno-orphans #-} module PPrint ( - pprint, pprintCanonicalized, pprintList, asStr , atPrec, toJSONStr, + pprint, pprintCanonicalized, pprintList, asStr , atPrec, resultAsJSON, PrettyPrec(..), PrecedenceLevel (..), prettyBlock, printLitBlock, printResult, prettyFromPrettyPrec) where import Data.Aeson hiding (Result, Null, Value, Success) +import Data.Aeson.Encoding (encodingToLazyByteString, value) import GHC.Exts (Constraint) import GHC.Float import Data.Foldable (toList, fold) @@ -1015,21 +1016,19 @@ addColor True c s = setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] ++ s ++ setSGRCode [Reset] -toJSONStr :: ToJSON a => a -> String -toJSONStr = B.unpack . encode - -instance ToJSON Result where - toJSON (Result outs err) = object (outMaps <> errMaps) - where - errMaps = case err of - Failure e -> ["error" .= String (fromString $ pprint e)] - Success () -> [] - outMaps = flip foldMap outs $ \case - BenchResult name compileTime runTime _ -> - [ "bench_name" .= toJSON name - , "compile_time" .= toJSON compileTime - , "run_time" .= toJSON runTime ] - out -> ["result" .= String (fromString $ pprint out)] +resultAsJSON :: Result -> String +resultAsJSON (Result outs err) = + B.unpack $ encodingToLazyByteString $ value $ object (outMaps <> errMaps) + where + errMaps = case err of + Failure e -> ["error" .= String (fromString $ pprint e)] + Success () -> [] + outMaps = flip foldMap outs $ \case + BenchResult name compileTime runTime _ -> + [ "bench_name" .= toJSON name + , "compile_time" .= toJSON compileTime + , "run_time" .= toJSON runTime ] + out -> ["result" .= String (fromString $ pprint out)] -- === Concrete syntax rendering === diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 2f4970c3d..fb7e566a9 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -136,6 +136,7 @@ newtype TopperM (n::S) a = TopperM -- Hides the `n` parameter as an existential data TopStateEx where TopStateEx :: Distinct n => Env n -> RuntimeEnv -> TopStateEx +instance Show TopStateEx where show _ = "TopStateEx" -- Hides the `n` parameter as an existential data TopSerializedStateEx where diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 853c384e5..4c257f95d 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -12,6 +12,7 @@ import Prelude import qualified Data.Set as Set import qualified Data.Map.Strict as M import Control.Applicative +import Control.Monad.Reader import Control.Monad.State.Strict import System.CPUTime import GHC.Base (getTag) diff --git a/static/index.js b/static/index.js index 4b6862b09..4bec4bb01 100644 --- a/static/index.js +++ b/static/index.js @@ -16,25 +16,6 @@ var katexOptions = { trust: true }; -var cells = {}; - -function append_contents(key, contents) { - if (key in cells) { - var cur_cells = cells[key]; - } else { - var cell = document.createElement("div"); - cell.className = "cell"; - cells[key] = [cell]; - var cur_cells = [cell]; - } - for (var i = 0; i < contents.length; i++) { - for (var j = 0; j < cur_cells.length; j++) { - var node = lookup_address(cur_cells[j], contents[i][0]) - node.innerHTML += contents[i][1]; - } - } -} - function lookup_address(cell, address) { var node = cell for (i = 0; i < address.length; i++) { @@ -99,7 +80,7 @@ function renderLaTeX() { * Rendering the Table of Contents / Navigation Bar * 2 key functions * - `updateNavigation()` which inserts/updates the navigation bar - * - and it's helper `extractStructure()` which extracts the structure of the page + * - and its helper `extractStructure()` which extracts the structure of the page * and adds ids to heading elements. */ function updateNavigation() { @@ -185,6 +166,10 @@ var RENDER_MODE = Object.freeze({ DYNAMIC: "dynamic", }) +// mapping from server-provided NodeID to HTML id +var cells = {}; +var body = document.getElementById("main-output"); + /** * Renders the webpage. * @param {RENDER_MODE} renderMode The render mode, either static or dynamic. @@ -199,41 +184,69 @@ function render(renderMode) { // For dynamic pages (via `dex web`), listen to update events. var source = new EventSource("/getnext"); source.onmessage = function(event) { - var body = document.getElementById("main-output"); var msg = JSON.parse(event.data); if (msg == "start") { body.innerHTML = ""; cells = {} return + } else { + process_update(msg); + renderLaTeX(); + renderHovertips(); + updateNavigation(); } - var order = msg[0]; - var contents = msg[1]; - for (var i = 0; i < contents.length; i++) { - append_contents(contents[i][0], contents[i][1]); - } - if (order != null) { - var new_cells = {}; - body.innerHTML = ""; - for (var i = 0; i < order.val.length; i++) { - var key = order.val[i] - var cur_cells = cells[key] - if (cur_cells.length == 0) { - var cur_cell = new_cells[key][0].cloneNode(true) - } else { - var cur_cell = cur_cells.pop() - if (key in new_cells) { - new_cells[key].push(cur_cell); - } else { - new_cells[key] = [cur_cell]; - } - } - body.appendChild(cur_cell); - } - Object.assign(cells, new_cells); - } - renderLaTeX(); - renderHovertips(); - updateNavigation(); }; } } + +function set_cell_contents(cell, contents) { + var source_text = contents[0]; + cell.innerHTML = source_text + var results = contents[1]; + tag = results["tag"] + if (tag == "Waiting") { + cell.className = "waiting-cell"; + } else if (tag == "Running") { + cell.className = "running-cell"; + } else if (tag == "Complete") { + cell.className = "complete-cell"; + cell.innerHTML += results["contents"] + } else { + console.error(tag); + } +} + +function process_update(msg) { + var cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; + var num_dropped = msg["orderedNodesUpdate"]["numDropped"]; + var new_tail = msg["orderedNodesUpdate"]["newTail"]; + + // drop_dead_cells + for (i = 0; i < num_dropped; i++) { + body.lastElementChild.remove(); + } + + Object.keys(cell_updates).forEach(function (node_id) { + var update = cell_updates[node_id]; + var tag = update["tag"] + var contents = update["contents"] + if (tag == "Create") { + var cell = document.createElement("div"); + cells[node_id] = cell; + set_cell_contents(cell, contents) + } else if (tag == "Update") { + var cell = cells[node_id]; + set_cell_contents(cell, contents); + } else if (tag == "Delete") { + delete cells[node_id] + } else { + console.error(tag); + } + }); + + // append_new_cells + new_tail.forEach(function (node_id) { + body.appendChild(cells[node_id]); + }); + +} diff --git a/static/style.css b/static/style.css index cab311add..a6a36fe52 100644 --- a/static/style.css +++ b/static/style.css @@ -103,3 +103,14 @@ code { .iso-sugar { color: #25BBA7; } + +.waiting-cell { + background-color: #DDDDFF; +} + +.running-cell { + background-color: #DDFFDD; +} + +.complete-cell { +} From add49a9c8d03f26de1919bb7d2493ba8dd4df2b3 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 15 Nov 2023 22:48:36 -0500 Subject: [PATCH 15/41] Improve live view performance * Only update hovertips/latex for new and updated cells * Skip the nav bar in live view. We can probably make this one incremental too if we really want to. --- src/lib/Live/Web.hs | 3 ++- static/index.js | 17 ++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index b4e4060d1..820b5d4ae 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -44,7 +44,8 @@ serveResults resultsSubscribe request respond = do [("Content-Type", "text/plain")] "404 - Not Found" where respondWith dataFname ctype = do - fname <- getDataFileName dataFname + fname <- return dataFname -- lets us skip rebuilding during development + -- fname <- getDataFileName dataFname respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing resultStream :: ResultsServer -> StreamingBody diff --git a/static/index.js b/static/index.js index 4bec4bb01..65c3a1cc7 100644 --- a/static/index.js +++ b/static/index.js @@ -24,8 +24,8 @@ function lookup_address(cell, address) { return node } -function renderHovertips() { - var spans = document.querySelectorAll(".code-span"); +function renderHovertips(root) { + var spans = root.querySelectorAll(".code-span"); Array.from(spans).map((span) => attachHovertip(span)); } @@ -63,14 +63,14 @@ function removeHighlighting(event, node) { }) } -function renderLaTeX() { +function renderLaTeX(root) { // Render LaTeX equations in prose blocks via KaTeX, if available. // Skip rendering if KaTeX is unavailable. if (typeof renderMathInElement == 'undefined') { return; } // Render LaTeX equations in prose blocks via KaTeX. - var proseBlocks = document.querySelectorAll(".prose-block"); + var proseBlocks = root.querySelectorAll(".prose-block"); Array.from(proseBlocks).map((proseBlock) => renderMathInElement(proseBlock, katexOptions) ); @@ -177,8 +177,8 @@ var body = document.getElementById("main-output"); function render(renderMode) { if (renderMode == RENDER_MODE.STATIC) { // For static pages, simply call rendering functions once. - renderLaTeX(); - renderHovertips(); + renderLaTeX(document); + renderHovertips(document); updateNavigation(); } else { // For dynamic pages (via `dex web`), listen to update events. @@ -191,9 +191,6 @@ function render(renderMode) { return } else { process_update(msg); - renderLaTeX(); - renderHovertips(); - updateNavigation(); } }; } @@ -214,6 +211,8 @@ function set_cell_contents(cell, contents) { } else { console.error(tag); } + renderLaTeX(cell); + renderHovertips(cell); } function process_update(msg) { From 5f8d27f7a64704474a1e7dad44fcaf25169f5318 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 16 Nov 2023 10:14:16 -0500 Subject: [PATCH 16/41] Default to Nat/Int when inferring literals --- src/lib/Inference.hs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 48a672c88..e0b9624a1 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -434,12 +434,8 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case e case reqTy of TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy xs _ -> throw TypeErr $ "Unexpected table constructor. Expected: " ++ pprint reqTy - UNatLit x -> do - let litVal = Con $ Lit $ Word64Lit $ fromIntegral x - applyFromLiteralMethod reqTy "from_unsigned_integer" litVal - UIntLit x -> do - let litVal = Con $ Lit $ Int64Lit $ fromIntegral x - applyFromLiteralMethod reqTy "from_integer" litVal + UNatLit x -> fromNatLit x reqTy + UIntLit x -> fromIntLit x reqTy UPrim UTuple xs -> case reqTy of TyKind -> toAtom . ProdType <$> mapM checkUType xs TyCon (ProdType reqTys) -> do @@ -554,8 +550,8 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of _ -> return $ toAtom v x' -> return x' liftM (SigmaAtom Nothing) $ matchPrimApp prim xs' - UNatLit _ -> throw TypeErr $ "Can't infer type of literal. Try an explicit annotation" - UIntLit _ -> throw TypeErr $ "Can't infer type of literal. Try an explicit annotation" + UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit l NatTy + UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit l (BaseTy $ Scalar Int32Type) UFloatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Float32Lit $ realToFrac x UHole -> throw TypeErr "Can't infer value of hole" @@ -566,6 +562,16 @@ expectEq reqTy actualTy = alphaEq reqTy actualTy >>= \case "\nActual: " ++ pprint actualTy {-# INLINE expectEq #-} +fromIntLit :: Emits o => Int -> CType o -> InfererM i o (CAtom o) +fromIntLit x ty = do + let litVal = Con $ Lit $ Int64Lit $ fromIntegral x + applyFromLiteralMethod ty "from_integer" litVal + +fromNatLit :: Emits o => Word64 -> CType o -> InfererM i o (CAtom o) +fromNatLit x ty = do + let litVal = Con $ Lit $ Word64Lit $ fromIntegral x + applyFromLiteralMethod ty "from_unsigned_integer" litVal + matchReq :: Ext o o' => RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') matchReq (Check reqTy) x = do reqTy' <- sinkM reqTy From 0b11ea9aa41170e485d20322f6ad5f2eba745bcc Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 16 Nov 2023 20:58:14 -0500 Subject: [PATCH 17/41] Fix some unimplemented cases --- src/lib/Builder.hs | 14 ++++++++++++++ src/lib/CheapReduction.hs | 3 ++- src/lib/Optimize.hs | 1 + src/lib/Vectorize.hs | 2 ++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index caebe262f..55a7e71ce 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -1506,6 +1506,20 @@ visitLamNoEmits visitLamNoEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$> visitExprNoEmits body +visitDeclsNoEmits + :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) + => Nest (Decl r) i i' + -> (forall o'. DExt o o' => Nest (Decl r) o o' -> m i' o' a) + -> m i o a +visitDeclsNoEmits Empty cont = getDistinct >>= \Distinct -> cont Empty +visitDeclsNoEmits (Nest (Let b (DeclBinding ann expr)) decls) cont = do + expr' <- visitExprNoEmits expr + withFreshBinder (getNameHint b) (getType expr') \(b':>_) -> do + let decl' = Let b' $ DeclBinding ann expr' + extendRenamer (b@>binderName b') do + visitDeclsNoEmits decls \decls' -> + cont $ Nest decl' decls' + -- === Emitting expression visitor === class Visitor m r i o => ExprVisitorEmits m r i o | m -> i, m -> o where diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c70815282..b9bc0ce01 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -35,6 +35,7 @@ import Types.Core import Types.Imp import Types.Primitives import Util +import GHC.Stack -- Carry out the reductions we are willing to carry out during type -- inference. The goal is to support type aliases like `Int = Int32` @@ -341,7 +342,7 @@ class NonAtomRenamer m i o => Visitor m r i o | m -> i, m -> o where visitPi :: PiType r i -> m (PiType r o) class VisitGeneric (e:: E) (r::IR) | e -> r where - visitGeneric :: Visitor m r i o => e i -> m (e o) + visitGeneric :: HasCallStack => Visitor m r i o => e i -> m (e o) type Visitor2 (m::MonadKind2) r = forall i o . Visitor (m i o) r i o diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index b1714fd62..425291cd4 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -170,6 +170,7 @@ licmExpr = \case block <- mkBlock =<< applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' + Block _ (Abs decls result) -> visitDeclsEmits decls $ licmExpr result expr -> visitGeneric expr >>= emit seqLICM :: RNest SDecl n1 n2 -- hoisted decls diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 189fe03dc..eafefb538 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -619,6 +619,8 @@ instance ExprVisitorNoEmits (CalcWidthM i o) SimpIR i o where let ty = getType expr' modify (\(LiftE x) -> LiftE $ min (typeByteWidth ty) x) return expr' + Block _ (Abs decls result) -> mkBlock =<< visitDeclsNoEmits decls \decls' -> do + Abs decls' <$> visitExprNoEmits result _ -> fallback where fallback = visitGeneric expr From 9b1d22ec3a393513cc93a793e2f1bcbfa9413dc8 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 16 Nov 2023 23:03:00 -0500 Subject: [PATCH 18/41] Add line numbers to live view. It's a bit janky but it's enough to help me update tests and examples. --- src/lib/Live/Eval.hs | 3 ++- static/index.js | 20 +++++++++++++------- static/style.css | 23 ++++++++++++++++++++--- 3 files changed, 35 insertions(+), 11 deletions(-) diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 4912a8367..a8c449a8c 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -304,7 +304,8 @@ instance ToJSON a => ToJSON (MapEltUpdate a) instance ToJSON o => ToJSON (NodeEvalStatus o) instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) -instance ToJSON SourceBlock where toJSON = toJSONViaHtml +instance ToJSON SourceBlock where + toJSON b = toJSON (sbLine b, pprintHtml b) instance ToJSON Result where toJSON = toJSONViaHtml toJSONViaHtml :: ToMarkup a => a -> Value diff --git a/static/index.js b/static/index.js index 65c3a1cc7..1d2c36a5a 100644 --- a/static/index.js +++ b/static/index.js @@ -197,16 +197,23 @@ function render(renderMode) { } function set_cell_contents(cell, contents) { - var source_text = contents[0]; - cell.innerHTML = source_text + var line_num = contents[0][0]; + var source_text = contents[0][1]; + var line_num_div = document.createElement("div"); + + line_num_div.innerHTML = line_num.toString(); + line_num_div.className = "line-num"; + cell.innerHTML = "" + cell.appendChild(line_num_div); + cell.innerHTML += source_text var results = contents[1]; tag = results["tag"] if (tag == "Waiting") { - cell.className = "waiting-cell"; + cell.className = "cell waiting-cell"; } else if (tag == "Running") { - cell.className = "running-cell"; + cell.className = "cell running-cell"; } else if (tag == "Complete") { - cell.className = "complete-cell"; + cell.className = "cell complete-cell"; cell.innerHTML += results["contents"] } else { console.error(tag); @@ -222,8 +229,7 @@ function process_update(msg) { // drop_dead_cells for (i = 0; i < num_dropped; i++) { - body.lastElementChild.remove(); - } + body.lastElementChild.remove();} Object.keys(cell_updates).forEach(function (node_id) { var update = cell_updates[node_id]; diff --git a/static/style.css b/static/style.css index a6a36fe52..0383c9578 100644 --- a/static/style.css +++ b/static/style.css @@ -11,7 +11,6 @@ body { display: flex; justify-content: space-between; overflow-x: hidden; - --main-width: 50rem; --nav-width: 20rem; } @@ -47,7 +46,12 @@ nav ol { margin: auto; } + +.code-block { +} + .code-block, .err-block, .result-block { + margin: 0em 0em 0em 4em; padding: 0em 0em 0em 2em; display: block; font-family: monospace; @@ -104,12 +108,25 @@ code { color: #25BBA7; } +.cell { +} + +.line-num { + display: block; + font-family: monospace; + width: 3em; + color: #808080; + float: left; + text-align: right; +} + + .waiting-cell { - background-color: #DDDDFF; + border-left: 6px solid #AAAAFF; } .running-cell { - background-color: #DDFFDD; + border-left: 6px solid #AAFFAA; } .complete-cell { From 83b0cd0107f89ef854e1536edd2a2325fa973b60 Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 17 Nov 2023 23:29:46 -0500 Subject: [PATCH 19/41] Add lexeme IDs during parsing rather than computing them from source spans later. --- dex.cabal | 2 +- src/lib/ConcreteSyntax.hs | 15 +-- src/lib/Lexing.hs | 69 ++++++++--- src/lib/RenderHtml.hs | 149 ++++++++---------------- src/lib/SourceInfo.hs | 213 +++------------------------------- src/lib/TopLevel.hs | 2 +- src/lib/TraverseSourceInfo.hs | 127 -------------------- src/lib/Types/Source.hs | 49 ++++++-- 8 files changed, 163 insertions(+), 463 deletions(-) delete mode 100644 src/lib/TraverseSourceInfo.hs diff --git a/dex.cabal b/dex.cabal index 157158b66..f7bb1bbdf 100644 --- a/dex.cabal +++ b/dex.cabal @@ -91,7 +91,6 @@ library , SourceRename , TopLevel , Transpose - , TraverseSourceInfo , Types.Core , Types.Imp , Types.Misc @@ -143,6 +142,7 @@ library if flag(live) build-depends: binary , blaze-html + , blaze-markup , cmark , http-types , wai diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 2bae7ea7a..2e5aebcf0 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -61,7 +61,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" $ Misc $ ImportModule Prelude +preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -108,11 +108,12 @@ sourceBlock :: Parser SourceBlock sourceBlock = do offset <- getOffset pos <- getSourcePos - (src, (level, b)) <- withSource $ withRecovery recover $ do + (src, (sm, (level, b))) <- withSource $ withSourceMaps $ withRecovery recover do level <- logLevel <|> logTime <|> logBench <|> return LogNothing b <- sourceBlock' return (level, b) - return $ SourceBlock (unPos (sourceLine pos)) offset level src b + let sm' = sm { lexemeInfo = lexemeInfo sm <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} + return $ SourceBlock (unPos (sourceLine pos)) offset level src sm' b recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') recover e = do @@ -154,7 +155,7 @@ consumeTillBreak = void $ manyTill anySingle $ eof <|> void (try (eol >> eol)) logLevel :: Parser LogLevel logLevel = do - void $ try $ lexeme $ char '%' >> string "passes" + void $ try $ lexeme MiscLexeme $ char '%' >> string "passes" passes <- many passName eol case passes of @@ -163,13 +164,13 @@ logLevel = do logTime :: Parser LogLevel logTime = do - void $ try $ lexeme $ char '%' >> string "time" + void $ try $ lexeme MiscLexeme $ char '%' >> string "time" eol return PrintEvalTime logBench :: Parser LogLevel logBench = do - void $ try $ lexeme $ char '%' >> string "bench" + void $ try $ lexeme MiscLexeme $ char '%' >> string "bench" benchName <- strLit eol return $ PrintBench benchName @@ -391,7 +392,7 @@ immediateLParen = label "'(' (without preceding whitespace)" do nextChar >>= \case '(' -> precededByWhitespace >>= \case True -> empty - False -> charLexeme '(' + False -> lParen _ -> empty immediateParens :: Parser a -> Parser a diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 0f0fc3ddb..5854b743f 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -15,6 +15,7 @@ import Data.Text (Text) import Data.Text qualified as T import Data.Void import Data.Word +import qualified Data.Map.Strict as M import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) @@ -25,15 +26,18 @@ import Text.Megaparsec.Debug import Err import SourceInfo import Types.Primitives +import Types.Source +import Util (toSnocList) data ParseCtx = ParseCtx { curIndent :: Int -- used Reader-style (i.e. ask/local) , canBreak :: Bool -- used Reader-style (i.e. ask/local) , prevWhitespace :: Bool -- tracks whether we just consumed whitespace - } + , sourceIdCounter :: Int + , curSourceMap :: SourceMaps } -- append to, writer-style initParseCtx :: ParseCtx -initParseCtx = ParseCtx 0 False False +initParseCtx = ParseCtx 0 False False 0 mempty type Parser = StateT ParseCtx (Parsec Void Text) @@ -64,7 +68,7 @@ nextChar = do {-# INLINE nextChar #-} anyCaseName :: Lexer SourceName -anyCaseName = label "name" $ lexeme $ +anyCaseName = label "name" $ lexeme LowerName $ -- TODO: distinguish lowercase/uppercase checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> (T.unpack <$> takeWhileP Nothing (\c -> isAlphaNum c || c == '\'' || c == '_')) @@ -121,7 +125,7 @@ keyWordToken = \case PassKW -> "pass" keyWord :: KeyWord -> Lexer () -keyWord kw = lexeme $ try $ string (fromString $ keyWordToken kw) +keyWord kw = lexeme Keyword $ try $ string (fromString $ keyWordToken kw) >> notFollowedBy nameTailChar keyWordSet :: HS.HashSet String @@ -131,19 +135,19 @@ keyWordStrs :: [String] keyWordStrs = map keyWordToken [DefKW .. PassKW] primName :: Lexer String -primName = lexeme $ try $ char '%' >> some alphaNumChar +primName = lexeme MiscLexeme $ try $ char '%' >> some alphaNumChar charLit :: Lexer Char -charLit = lexeme $ char '\'' >> L.charLiteral <* char '\'' +charLit = lexeme MiscLexeme $ char '\'' >> L.charLiteral <* char '\'' strLit :: Lexer String -strLit = lexeme $ char '"' >> manyTill L.charLiteral (char '"') +strLit = lexeme StringLiteralLexeme $ char '"' >> manyTill L.charLiteral (char '"') natLit :: Lexer Word64 -natLit = lexeme $ try $ L.decimal <* notFollowedBy (char '.') +natLit = lexeme LiteralLexeme $ try $ L.decimal <* notFollowedBy (char '.') doubleLit :: Lexer Double -doubleLit = lexeme $ +doubleLit = lexeme LiteralLexeme $ try L.float <|> try (fromIntegral <$> (L.decimal :: Parser Int) <* char '.') <|> try do @@ -161,22 +165,22 @@ knownSymStrs = HS.fromList -- string must be in `knownSymStrs` sym :: Text -> Lexer () -sym s = lexeme $ try $ string s >> notFollowedBy symChar +sym s = lexeme Symbol $ try $ string s >> notFollowedBy symChar anySym :: Lexer String -anySym = lexeme $ try $ do +anySym = lexeme Symbol $ try $ do s <- some symChar failIf (s `HS.member` knownSymStrs) "" return s symName :: Lexer SourceName -symName = label "symbol name" $ lexeme $ try $ do +symName = label "symbol name" $ lexeme Symbol $ try $ do s <- between (char '(') (char ')') $ some symChar return $ "(" <> s <> ")" backquoteName :: Lexer SourceName backquoteName = label "backquoted name" $ - lexeme $ try $ between (char '`') (char '`') anyCaseName + lexeme Symbol $ try $ between (char '`') (char '`') anyCaseName -- brackets and punctuation -- (can't treat as sym because e.g. `((` is two separate lexemes) @@ -192,7 +196,7 @@ semicolon = charLexeme ';' underscore = charLexeme '_' charLexeme :: Char -> Parser () -charLexeme c = void $ lexeme $ char c +charLexeme c = void $ lexeme Symbol $ char c nameTailChar :: Parser Char nameTailChar = alphaNumChar <|> char '\'' <|> char '_' @@ -243,10 +247,10 @@ recordNonWhitespace = modify \ctx -> ctx { prevWhitespace = False } {-# INLINE recordNonWhitespace #-} nameString :: Parser String -nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar +nameString = lexeme LowerName . try $ (:) <$> lowerChar <*> many alphaNumChar thisNameString :: Text -> Parser () -thisNameString s = lexeme $ try $ string s >> notFollowedBy alphaNumChar +thisNameString s = lexeme MiscLexeme $ try $ string s >> notFollowedBy alphaNumChar bracketed :: Parser () -> Parser () -> Parser a -> Parser a bracketed left right p = between left right $ mayBreak $ sc >> p @@ -310,10 +314,37 @@ failIf :: Bool -> String -> Parser () failIf True s = fail s failIf False _ = return () -lexeme :: Parser a -> Parser a -lexeme p = L.lexeme sc (p <* recordNonWhitespace) +newSourceId :: Parser SourceId +newSourceId = do + c <- gets sourceIdCounter + modify \ctx -> ctx { sourceIdCounter = c + 1 } + return $ SourceId c + +withSourceMaps :: Parser a -> Parser (SourceMaps, a) +withSourceMaps cont = do + smPrev <- gets curSourceMap + modify \ctx -> ctx { curSourceMap = mempty } + result <- cont + sm <- gets curSourceMap + modify \ctx -> ctx { curSourceMap = smPrev } + return (sm, result) + +emitSourceMaps :: SourceMaps -> Parser () +emitSourceMaps m = modify \ctx -> ctx { curSourceMap = curSourceMap ctx <> m } + +lexeme :: LexemeType -> Parser a -> Parser a +lexeme lexemeType p = do + start <- getOffset + ans <- p + end <- getOffset + recordNonWhitespace + sc + name <- newSourceId + emitSourceMaps $ mempty + { lexemeList = toSnocList [name] + , lexemeInfo = M.singleton name (lexemeType, (start, end)) } + return ans {-# INLINE lexeme #-} symbol :: Text -> Parser () symbol s = void $ L.symbol sc s - diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index f87d2accb..3f192eb63 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -7,30 +7,26 @@ {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} -module RenderHtml (pprintHtml, progHtml, ToMarkup, treeToHtml) where +module RenderHtml (pprintHtml, progHtml, ToMarkup) where +import Text.Blaze.Internal (MarkupM) import Text.Blaze.Html5 as H hiding (map) import Text.Blaze.Html5.Attributes as At import Text.Blaze.Html.Renderer.String -import Data.List qualified as L +import qualified Data.Map.Strict as M +import Control.Monad.State.Strict +import Data.Maybe (fromJust) import Data.Text qualified as T import Data.Text.IO qualified as T import CMark (commonmarkToHtml) import System.IO.Unsafe -import Control.Monad -import Text.Megaparsec hiding (chunk) -import Text.Megaparsec.Char as C import Err -import Lexing (Parser, symChar, keyWordStrs, symbol, parseit, withSource) import Paths_dex (getDataFileName) import PPrint () -import SourceInfo -import TraverseSourceInfo import Types.Misc import Types.Source -import Util cssSource :: T.Text cssSource = unsafePerformIO $ @@ -79,9 +75,7 @@ instance ToMarkup Output where instance ToMarkup SourceBlock where toMarkup block = case sbContents block of (Misc (ProseBlock s)) -> cdiv "prose-block" $ mdToHtml s - TopDecl decl -> renderSpans decl block - Command _ g -> renderSpans g block - _ -> cdiv "code-block" $ highlightSyntax (sbText block) + _ -> renderSpans (sbSourceMaps block) (sbText block) mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s @@ -89,94 +83,43 @@ mdToHtml s = preEscapedText $ commonmarkToHtml [] s cdiv :: String -> Html -> Html cdiv c inner = H.div inner ! class_ (stringValue c) --- === syntax highlighting === - -spanDelimitedCode :: SourceBlock -> [SrcPosCtx] -> Html -spanDelimitedCode block ctxs = - let (Just tree) = srcCtxsToTree block ctxs in - spanDelimitedCode' block tree - -spanDelimitedCode' :: SourceBlock -> SpanTree -> Html -spanDelimitedCode' block tree = treeToHtml (sbText block) tree - -treeToHtml :: T.Text -> SpanTree -> Html -treeToHtml source' tree = - let tree' = fillTreeAndAddTrivialLeaves (T.unpack source') tree in - treeToHtml' source' tree' - -treeToHtml' :: T.Text -> SpanTree -> Html -treeToHtml' source' tree = case tree of - Span (_, _, _) children -> - let body' = foldMap (treeToHtml' source') children in - H.span body' ! spanClass - LeafSpan (l, r, _) -> - let spanText = sliceText l r source' in - H.span (highlightSyntax spanText) ! spanLeaf - Trivia (l, r) -> - let spanText = sliceText l r source' in - highlightSyntax spanText - where - spanClass :: Attribute - spanClass = At.class_ "code-span" - - spanLeaf :: Attribute - spanLeaf = At.class_ "code-span-leaf" - -srcCtxsToSpanInfos :: SourceBlock -> [SrcPosCtx] -> [SpanPayload] -srcCtxsToSpanInfos block ctxs = - let blockOffset = sbOffset block in - let ctxs' = L.sort ctxs in - (0, maxBound, 0) : mapMaybe (convert' blockOffset) ctxs' - where convert' :: Int -> SrcPosCtx -> Maybe SpanPayload - convert' offset (SrcPosCtx (Just (l, r)) (Just spanId)) = Just (l - offset, r - offset, spanId + 1) - convert' _ _ = Nothing - -srcCtxsToTree :: SourceBlock -> [SrcPosCtx] -> Maybe SpanTree -srcCtxsToTree block ctxs = makeEmptySpanTree (srcCtxsToSpanInfos block ctxs) - -renderSpans :: HasSourceInfo a => a -> SourceBlock -> Html -renderSpans x block = - let x' = addSpanIds x in - let ctxs = gatherSourceInfo x' in - toHtml $ cdiv "code-block" $ spanDelimitedCode block ctxs - -highlightSyntax :: T.Text -> Html -highlightSyntax s = foldMap (uncurry syntaxSpan) classified - where classified = ignoreExcept $ parseit s (many (withSource classify) <* eof) - -syntaxSpan :: T.Text -> StrClass -> Html -syntaxSpan s NormalStr = toHtml s -syntaxSpan s c = H.span (toHtml s) ! class_ (stringValue className) - where - className = case c of - CommentStr -> "comment" - KeywordStr -> "keyword" - CommandStr -> "command" - SymbolStr -> "symbol" - TypeNameStr -> "type-name" - IsoSugarStr -> "iso-sugar" - WhitespaceStr -> "whitespace" - NormalStr -> error "Should have been matched already" - -data StrClass = NormalStr - | CommentStr | KeywordStr | CommandStr | SymbolStr | TypeNameStr - | IsoSugarStr | WhitespaceStr - -classify :: Parser StrClass -classify = - (try (char ':' >> lowerWord) >> return CommandStr) - <|> (symbol "-- " >> manyTill anySingle (void eol <|> eof) >> return CommentStr) - <|> (do s <- lowerWord - return $ if s `elem` keyWordStrs then KeywordStr else NormalStr) - <|> (upperWord >> return TypeNameStr) - <|> try (char '#' >> (char '?' <|> char '&' <|> char '|' <|> pure ' ') - >> lowerWord >> return IsoSugarStr) - <|> (some symChar >> return SymbolStr) - <|> (some space1 >> return WhitespaceStr) - <|> (anySingle >> return NormalStr) - -lowerWord :: Parser String -lowerWord = (:) <$> lowerChar <*> many alphaNumChar - -upperWord :: Parser String -upperWord = (:) <$> upperChar <*> many alphaNumChar +renderSpans :: SourceMaps -> T.Text -> Markup +renderSpans sm sourceText = cdiv "code-block" do + runTextWalkerT sourceText do + forM_ (lexemeList sm) \sourceId -> do + let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo sm) + takeTo l >>= emitSpan "" + takeTo r >>= emitSpan (lexemeClass lexemeTy) + takeRest >>= emitSpan "" + +emitSpan :: String -> T.Text -> TextWalker () +emitSpan className t = lift $ H.span (toHtml t) ! class_ (stringValue className) + +lexemeClass :: LexemeType -> String +lexemeClass = \case + Keyword -> "keyword" + Symbol -> "symbol" + TypeName -> "type-name" + LowerName -> "" + UpperName -> "" + LiteralLexeme -> "literal" + StringLiteralLexeme -> "" + MiscLexeme -> "" + +type TextWalker a = StateT (Int, T.Text) MarkupM a + +runTextWalkerT :: T.Text -> TextWalker a -> MarkupM a +runTextWalkerT t cont = evalStateT cont (0, t) + +-- index is the *absolute* index, from the very beginning +takeTo :: Int -> TextWalker T.Text +takeTo startPos = do + (curPos, curText) <- get + let (prefix, remText) = T.splitAt (startPos- curPos) curText + put (startPos, remText) + return prefix + +takeRest :: TextWalker T.Text +takeRest = do + endPos <- gets $ T.length . snd + takeTo endPos diff --git a/src/lib/SourceInfo.hs b/src/lib/SourceInfo.hs index 6079fe5ff..b768af81f 100644 --- a/src/lib/SourceInfo.hs +++ b/src/lib/SourceInfo.hs @@ -7,42 +7,21 @@ {-# OPTIONS_GHC -Wno-incomplete-patterns #-} module SourceInfo ( - SrcPos, SpanId, SrcPosCtx (..), emptySrcPosCtx, fromPos, - pattern EmptySrcPosCtx, - sliceText, SpanTree (..), SpanTreeM (..), SpanPayload, SpanPos, - evalSpanTree, makeSpanTree, makeEmptySpanTree, makeSpanTreeRec, - fixSpanPayloads, - fillTreeAndAddTrivialLeaves - ) where + SrcPos, SourceId (..), SrcPosCtx (..), emptySrcPosCtx, fromPos, + pattern EmptySrcPosCtx) where -import Data.Data import Data.Hashable -import Data.Char (isSpace) -import Data.List (findIndex) -import Data.Maybe (listToMaybe, maybeToList) import Data.Store (Store (..)) -import qualified Data.Text as T import GHC.Generics (Generic (..)) -import Control.Applicative -import Control.Monad.State.Strict -- === Core API === -type SrcPos = (Int, Int) -type SpanId = Int +newtype SourceId = SourceId Int deriving (Show, Eq, Ord, Generic) -data SrcPosCtx = SrcPosCtx (Maybe SrcPos) (Maybe SpanId) - deriving (Show, Eq, Generic, Data) -instance Hashable SrcPosCtx -instance Store SrcPosCtx +type SrcPos = (Int, Int) -instance Ord SrcPosCtx where - compare (SrcPosCtx pos spanId) (SrcPosCtx pos' spanId') = - case (pos, pos') of - (Just (l, r), Just (l', r')) -> compare (l, r', spanId) (l', r, spanId') - (Just _, _) -> GT - (_, Just _) -> LT - (_, _) -> compare spanId spanId' +data SrcPosCtx = SrcPosCtx (Maybe SrcPos) (Maybe SourceId) + deriving (Show, Eq, Generic) emptySrcPosCtx :: SrcPosCtx emptySrcPosCtx = SrcPosCtx Nothing Nothing @@ -53,174 +32,16 @@ pattern EmptySrcPosCtx = SrcPosCtx Nothing Nothing fromPos :: SrcPos -> SrcPosCtx fromPos pos = SrcPosCtx (Just pos) Nothing --- === Span utilities === - -type SpanPayload = (Int, Int, SpanId) -type SpanPos = (Int, Int) - -data SpanTree = - Span SpanPayload [SpanTree] | - LeafSpan SpanPayload | - Trivia SpanPos - deriving (Show, Eq) - -newtype SpanTreeM a = SpanTreeM - { runSpanTree' :: StateT [SpanPayload] Maybe a } - deriving (Functor, Applicative, Monad, MonadState [SpanPayload], Alternative) - -evalSpanTree :: SpanTreeM a -> [SpanPayload] -> Maybe a -evalSpanTree m spans = evalStateT (runSpanTree' m) spans - -getNextSpanPayload :: SpanTreeM (Maybe SpanPayload) -getNextSpanPayload = SpanTreeM $ do - infos <- get - case infos of - [] -> return Nothing - x:xs -> put xs >> return (Just x) - -data SpanContained = Contained | NotContained | PartialOverlap - deriving (Show, Eq) - --- | @contained x y@ returns whether @y@ is contained in @x@. -spanContained :: SpanPayload -> SpanPayload -> SpanContained -spanContained (lpos, rpos, _) (lpos', rpos', _) = - case (lpos <= lpos', rpos >= rpos') of - (True, True) -> Contained - (False, False) -> NotContained - (_, _) -> if rpos <= lpos' - then NotContained - else PartialOverlap - --- | @makeSpanTreeRec x@ returns a @[SpanTree]@ with the children of @x@. -getSpanChildren :: SpanPayload -> SpanTreeM (Maybe [SpanTree]) -getSpanChildren root = do - getNextSpanPayload >>= \case - Just child -> do - case spanContained root child of - -- If `child` is contained in `root`, then we add it as a child. - Contained -> do - childTree <- makeSpanTreeRec child - remainingChildren <- getSpanChildren root - return $ Just (maybeToList childTree ++ concat (maybeToList remainingChildren)) - NotContained -> do infos <- get; put (child : infos); return $ Just [] - PartialOverlap -> do infos <- get; put (child : infos); return $ Just [] - Nothing -> return $ Just [] - --- | @makeSpanTreeRec x@ returns a @SpanTree@ with the @x@ as the root. -makeSpanTreeRec :: SpanPayload -> SpanTreeM (Maybe SpanTree) -makeSpanTreeRec root = do - children <- getSpanChildren root - case children of - Nothing -> return Nothing - Just [] -> return $ Just (LeafSpan root) - Just xs -> return $ Just (Span root xs) - -makeEmptySpanTree :: [SpanPayload] -> Maybe SpanTree -makeEmptySpanTree [] = Nothing -makeEmptySpanTree (root:children) = join $ evalSpanTree (makeSpanTreeRec root) children - -makeSpanTree :: (Show a, IsTrivia a) => [a] -> [SpanPayload] -> Maybe SpanTree -makeSpanTree xs infos = case makeEmptySpanTree infos of - Nothing -> Nothing - Just posTree -> Just (fillTreeAndAddTrivialLeaves xs posTree) - -slice :: Int -> Int -> [a] -> [a] -slice left right xs = take (right - left) (drop left xs) - -sliceText :: Int -> Int -> T.Text -> T.Text -sliceText left right xs = T.take (right - left) (T.drop left xs) - -getSpanPos :: SpanTree -> SpanPos -getSpanPos tree = case tree of - Span (l, r, _) _ -> (l, r) - LeafSpan (l, r, _) -> (l, r) - Trivia pos -> pos - -fillTrivia :: SpanPayload -> [SpanTree] -> [SpanTree] -fillTrivia (l, r, _) offsets = - let (before, after) = case offsets of - [] -> ([], []) - _ -> - let (headL, _) = getSpanPos (head offsets) in - let (_, tailR) = getSpanPos (last offsets) in - let before' = [Trivia (l, headL) | l /= headL] in - let after' = [Trivia (tailR, r) | r /= tailR] in - (before', after') in - let offsets' = before ++ offsets ++ after in - let pairs = zip offsets' (drop 1 offsets') in - let unzipped = pairs >>= getOffsetAndTrivia in - maybeToList (listToMaybe offsets') ++ unzipped - where getOffsetAndTrivia :: (SpanTree, SpanTree) -> [SpanTree] - getOffsetAndTrivia (t, t') = - let (_, r') = endpoints t in - let (l', _) = endpoints t' in - let diff = l' - r' in - if diff == 0 then - [t'] - else - [Trivia (r', l'), t'] - -fixSpanPayloads :: [SpanPayload] -> [SpanPayload] -fixSpanPayloads spans = - let pairs = zip spans (drop 1 spans) in - let unzipped = pairs >>= mergeSpans in - unzipped ++ [last spans] - where mergeSpans :: (SpanPayload, SpanPayload) -> [SpanPayload] - mergeSpans (s, s') = case spanContained s s' of - Contained -> [s] - NotContained -> [s] - -- Note: currently, overlapping spans are simply dropped. - -- Consider replacing with approach that preserves partial span info. - PartialOverlap -> [] - -rebalanceTrivia :: Show a => (a -> Bool) -> [a] -> [SpanTree] -> [SpanTree] -rebalanceTrivia trivia xs trees = - let whitespaceSeparated = trees >>= createTrivia in - whitespaceSeparated - where - createTrivia :: SpanTree -> [SpanTree] - createTrivia t = case t of - Span _ _ -> [t] - LeafSpan _ -> blah - Trivia _ -> blah - where blah :: [SpanTree] - blah = - let (l, r) = endpoints t in - let s' = slice l r xs in - let firstNonTrivia = findIndex (not . trivia) s' in - let lastNonTrivia = fmap (length s' -) (findIndex (not . trivia) (reverse s')) in - case (firstNonTrivia, lastNonTrivia) of - (Just l', Nothing) | l' > 0 -> [Trivia (l, l + l'), shiftTree (l + l', r) t] - (Nothing, Just r') | r' < length s' -> [shiftTree (l, l + r') t, Trivia (l + r', r)] - (Just l', Just r') | l' > 0 || r' < length s' -> - [Trivia (l, l + l'), shiftTree (l + l', l + r') t, Trivia (l + r', r)] - (_, _) -> [t] - - -- - shiftTree :: SpanPos -> SpanTree -> SpanTree - shiftTree (l', r') t = case t of - Span (_, _, i) children -> Span (l', r', i) children - LeafSpan (_, _, i) -> LeafSpan (l', r', i) - Trivia _ -> Trivia (l', r') - -endpoints :: SpanTree -> (Int, Int) -endpoints (Span (l, r, _) _) = (l, r) -endpoints (LeafSpan (l, r, _)) = (l, r) -endpoints (Trivia (l, r)) = (l, r) - -class IsTrivia a where - isTrivia :: a -> Bool +instance Ord SrcPosCtx where + compare (SrcPosCtx pos spanId) (SrcPosCtx pos' spanId') = + case (pos, pos') of + (Just (l, r), Just (l', r')) -> compare (l, r', spanId) (l', r, spanId') + (Just _, _) -> GT + (_, Just _) -> LT + (_, _) -> compare spanId spanId' -instance IsTrivia Char where - isTrivia = isSpace +instance Hashable SourceId +instance Hashable SrcPosCtx --- | Fills a @SpanTree@ with @Trivia@ in span gaps. -fillTreeAndAddTrivialLeaves :: Show a => IsTrivia a => [a] -> SpanTree -> SpanTree -fillTreeAndAddTrivialLeaves xs tree = case tree of - Span info children -> - let children' = fillTrivia info children in - let children'' = rebalanceTrivia isTrivia xs children' in - let filled = map (fillTreeAndAddTrivialLeaves xs) children'' in - Span info filled - LeafSpan _ -> tree - Trivia _ -> tree +instance Store SourceId +instance Store SrcPosCtx diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index fb7e566a9..c56423212 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -206,7 +206,7 @@ catchLogsAndErrs m = do evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n Result evalSourceBlockRepl block = do case block of - SourceBlock _ _ _ _ (Misc (ImportModule name)) -> do + SourceBlock _ _ _ _ _ (Misc (ImportModule name)) -> do -- TODO: clear source map and synth candidates before calling this ensureModuleLoaded name _ -> return () diff --git a/src/lib/TraverseSourceInfo.hs b/src/lib/TraverseSourceInfo.hs deleted file mode 100644 index 1265fbf51..000000000 --- a/src/lib/TraverseSourceInfo.hs +++ /dev/null @@ -1,127 +0,0 @@ --- Copyright 2022 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module TraverseSourceInfo (HasSourceInfo, gatherSourceInfo, addSpanIds) where - -import qualified Data.ByteString as BS -import Control.Monad.State -import Control.Monad.Writer -import GHC.Generics -import GHC.Int -import GHC.Word - -import Occurrence qualified as Occ -import SourceInfo -import Types.OpNames qualified as P -import Types.Primitives -import Types.Source - -class HasSourceInfo a where - traverseSourceInfo :: Applicative m => (SrcPosCtx -> m SrcPosCtx) -> a -> m a - - default traverseSourceInfo :: (Applicative m, Generic a, HasSourceInfo (Rep a Any)) => (SrcPosCtx -> m SrcPosCtx) -> a -> m a - traverseSourceInfo f x = to <$> traverseSourceInfo f (from x :: Rep a Any) - -tc :: HasSourceInfo a => Applicative m => (SrcPosCtx -> m SrcPosCtx) -> a -> m a -tc = traverseSourceInfo - -instance HasSourceInfo (V1 p) where - traverseSourceInfo _ x = pure x - -instance HasSourceInfo (U1 p) where - traverseSourceInfo _ x = pure x - -instance (HasSourceInfo c) => HasSourceInfo (K1 i c p) where - traverseSourceInfo f (K1 x) = K1 <$> traverseSourceInfo f x - -instance HasSourceInfo (f p) => HasSourceInfo (M1 i c f p) where - traverseSourceInfo f (M1 x) = M1 <$> traverseSourceInfo f x - -instance (HasSourceInfo (a p), HasSourceInfo (b p)) => HasSourceInfo ((a :+: b) p) where - traverseSourceInfo f (L1 x) = L1 <$> traverseSourceInfo f x - traverseSourceInfo f (R1 x) = R1 <$> traverseSourceInfo f x - -instance (HasSourceInfo (a p), HasSourceInfo (b p)) => HasSourceInfo ((a :*: b) p) where - traverseSourceInfo f (a :*: b) = (:*:) <$> traverseSourceInfo f a <*> traverseSourceInfo f b - -instance HasSourceInfo P.TC -instance HasSourceInfo P.Con -instance HasSourceInfo P.MemOp -instance HasSourceInfo P.VectorOp -instance HasSourceInfo P.MiscOp -instance HasSourceInfo PrimName -instance HasSourceInfo UnOp -instance HasSourceInfo BinOp -instance HasSourceInfo CmpOp -instance HasSourceInfo BaseType -instance HasSourceInfo ScalarBaseType -instance HasSourceInfo Device - -instance (HasSourceInfo a, HasSourceInfo b) => HasSourceInfo (a, b) -instance (HasSourceInfo a, HasSourceInfo b, HasSourceInfo c) => HasSourceInfo (a, b, c) -instance (HasSourceInfo a, HasSourceInfo b) => HasSourceInfo (Either a b) -instance HasSourceInfo a => HasSourceInfo [a] -instance HasSourceInfo a => HasSourceInfo (Maybe a) - -instance HasSourceInfo Occ.Count -instance HasSourceInfo Occ.UsageInfo -instance HasSourceInfo LetAnn -instance HasSourceInfo UResumePolicy -instance HasSourceInfo CInstanceDef -instance HasSourceInfo CTopDecl' - -instance HasSourceInfo AppExplicitness -instance HasSourceInfo CDef -instance HasSourceInfo CSDecl' -instance HasSourceInfo CSBlock -instance HasSourceInfo ForKind -instance HasSourceInfo Group' - -instance HasSourceInfo Bin' - -instance HasSourceInfo a => HasSourceInfo (WithSrc a) where - traverseSourceInfo f (WithSrc pos x) = WithSrc <$> f pos <*> tc f x - -instance HasSourceInfo () where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Char where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int32 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int64 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word8 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word16 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word32 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word64 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Float where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Double where - traverseSourceInfo _ x = pure x -instance HasSourceInfo BS.ByteString where - traverseSourceInfo _ x = pure x - --- The real base case. -instance HasSourceInfo SrcPosCtx where - traverseSourceInfo f x = f x - -gatherSourceInfo :: (HasSourceInfo a) => a -> [SrcPosCtx] -gatherSourceInfo x = execWriter (tc (\(ctx :: SrcPosCtx) -> tell [ctx] >> return ctx) x) - -addSpanIds :: (HasSourceInfo a) => a -> a -addSpanIds x = evalState (tc f x) 0 - where f (SrcPosCtx maybeSrcPos _) = do - currentId <- get - put (currentId + 1) - return (SrcPosCtx maybeSrcPos (Just currentId)) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index fdb2a3eba..9212a4072 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -20,7 +20,6 @@ module Types.Source where -import Data.Data import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M @@ -37,7 +36,7 @@ import Name import qualified Types.OpNames as P import IRVariants import SourceInfo -import Util (File (..)) +import Util (File (..), SnocList) import Types.Primitives @@ -66,9 +65,40 @@ newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr pattern SISourceName :: (n ~ VoidS) => SourceName -> SourceOrInternalName c n pattern SISourceName n = SourceOrInternalName (SourceName EmptySrcPosCtx n) -pattern SIInternalName :: SourceName -> Name c n -> Maybe SrcPos -> Maybe SpanId -> SourceOrInternalName c n +pattern SIInternalName :: SourceName -> Name c n -> Maybe SrcPos -> Maybe SourceId -> SourceOrInternalName c n pattern SIInternalName n a srcPos spanId = SourceOrInternalName (InternalName (SrcPosCtx srcPos spanId) n a) +-- === Source Info === + +-- This is just for syntax highlighting. It won't be needed if we have +-- a separate lexing pass where we have a complete lossless data type for +-- lexemes. +data LexemeType = + Keyword + | Symbol + | TypeName + | LowerName + | UpperName + | LiteralLexeme + | StringLiteralLexeme + | MiscLexeme + deriving (Show, Generic) + +type Span = (Int, Int) +data SourceMaps = SourceMaps + { lexemeList :: SnocList SourceId + , lexemeInfo :: M.Map SourceId (LexemeType, Span) + , astParent :: M.Map SourceId SourceId + , astChildren :: M.Map SourceId [SourceId]} + deriving (Show, Generic) + +instance Semigroup SourceMaps where + SourceMaps a b c d <> SourceMaps a' b' c' d' = + SourceMaps (a <> a') (b <> b') (c <> c') (d <> d') + +instance Monoid SourceMaps where + mempty = SourceMaps mempty mempty mempty mempty + -- === Concrete syntax === -- The grouping-level syntax of the source language @@ -393,7 +423,7 @@ data WithSrcE (a::E) (n::S) = WithSrcE SrcPosCtx (a n) deriving (Show, Generic) data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcPosCtx (binder n l) - deriving (Show, Data, Generic) + deriving (Show, Generic) class HasSrcPos a where srcPos :: a -> SrcPosCtx @@ -443,11 +473,12 @@ data UModule = UModule -- === top-level blocks === data SourceBlock = SourceBlock - { sbLine :: Int - , sbOffset :: Int - , sbLogLevel :: LogLevel - , sbText :: Text - , sbContents :: SourceBlock' } + { sbLine :: Int + , sbOffset :: Int + , sbLogLevel :: LogLevel + , sbText :: Text + , sbSourceMaps :: SourceMaps + , sbContents :: SourceBlock' } deriving (Show, Generic) type ReachedEOF = Bool From 14ada502f9fa9dcd95813d0c09f0718259e63025 Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 20 Nov 2023 14:37:39 -0500 Subject: [PATCH 20/41] Replace source spans with IDs. These are better because we can hang more information off them. We can give an ID to e.g. the binder of a `Ix n` constraint that doesn't actually have a source span and then choose how to highlight later on depending on the situation. And we can use the IDs to model the graph of relationships between source components. --- dex.cabal | 1 - src/lib/AbstractSyntax.hs | 505 ++++++++++++++++++-------------------- src/lib/Builder.hs | 6 +- src/lib/CheapReduction.hs | 6 +- src/lib/CheckType.hs | 12 +- src/lib/ConcreteSyntax.hs | 489 ++++++++++++++++++------------------ src/lib/Core.hs | 5 - src/lib/Err.hs | 212 ++-------------- src/lib/Export.hs | 4 +- src/lib/Inference.hs | 56 ++--- src/lib/Lexing.hs | 121 +++++---- src/lib/Live/Web.hs | 2 +- src/lib/MTL1.hs | 52 +--- src/lib/Name.hs | 12 +- src/lib/PPrint.hs | 81 +++--- src/lib/QueryType.hs | 4 +- src/lib/QueryTypePure.hs | 2 +- src/lib/SourceInfo.hs | 47 ---- src/lib/SourceRename.hs | 58 +++-- src/lib/Subst.hs | 1 - src/lib/TopLevel.hs | 33 +-- src/lib/Types/Core.hs | 21 +- src/lib/Types/Source.hs | 322 ++++++++++++------------ src/lib/Vectorize.hs | 41 ++-- 24 files changed, 891 insertions(+), 1202 deletions(-) delete mode 100644 src/lib/SourceInfo.hs diff --git a/dex.cabal b/dex.cabal index f7bb1bbdf..f8eb49f41 100644 --- a/dex.cabal +++ b/dex.cabal @@ -87,7 +87,6 @@ library , Serialize , Simplify , Subst - , SourceInfo , SourceRename , TopLevel , Transpose diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index 367959987..bc1026a81 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -53,7 +53,6 @@ import Data.Functor import Data.Either import Data.Maybe (catMaybes) import Data.Set qualified as S -import Data.String (fromString) import Data.Text (Text) import ConcreteSyntax @@ -61,29 +60,28 @@ import Err import Name import PPrint () import Types.Primitives -import SourceInfo import Types.Source import qualified Types.OpNames as P import Util -- === Converting concrete syntax to abstract syntax === -parseExpr :: Fallible m => Group -> m (UExpr VoidS) +parseExpr :: Fallible m => GroupW -> m (UExpr VoidS) parseExpr e = liftSyntaxM $ expr e -parseDecl :: Fallible m => CTopDecl -> m (UTopDecl VoidS VoidS) +parseDecl :: Fallible m => CTopDeclW -> m (UTopDecl VoidS VoidS) parseDecl d = liftSyntaxM $ topDecl d parseBlock :: Fallible m => CSBlock -> m (UBlock VoidS) parseBlock b = liftSyntaxM $ block b liftSyntaxM :: Fallible m => SyntaxM a -> m a -liftSyntaxM cont = liftExcept $ runFallibleM cont +liftSyntaxM cont = liftExcept cont parseTopDeclRepl :: Text -> Maybe SourceBlock parseTopDeclRepl s = case sbContents b of UnParseable True _ -> Nothing - _ -> case runFallibleM (checkSourceBlockParses $ sbContents b) of + _ -> case checkSourceBlockParses $ sbContents b of Success _ -> Just b Failure _ -> Nothing where b = mustParseSourceBlock s @@ -91,7 +89,7 @@ parseTopDeclRepl s = case sbContents b of checkSourceBlockParses :: SourceBlock' -> SyntaxM () checkSourceBlockParses = \case - TopDecl (WithSrc _ (CSDecl ann (CExpr e)))-> do + TopDecl (WithSrcs _ _ (CSDecl ann (CExpr e)))-> do when (ann /= PlainLet) $ fail "Cannot annotate expressions" void $ expr e TopDecl d -> void $ topDecl d @@ -103,67 +101,68 @@ checkSourceBlockParses = \case -- === Converting concrete syntax to abstract syntax === -type SyntaxM = FallibleM +type SyntaxM = Except -topDecl :: CTopDecl -> SyntaxM (UTopDecl VoidS VoidS) -topDecl = dropSrc topDecl' where - topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc emptySrcPosCtx d) - topDecl' (CData name tyConParams givens constructors) = do - tyConParams' <- aExplicitParams tyConParams +topDecl :: CTopDeclW -> SyntaxM (UTopDecl VoidS VoidS) +topDecl (WithSrcs sid sids topDecl') = case topDecl' of + CSDecl ann d -> ULocalDecl <$> decl ann (WithSrcs sid sids d) + CData name tyConParams givens constructors -> do + tyConParams' <- fromMaybeM tyConParams Empty aExplicitParams givens' <- aOptGivens givens constructors' <- forM constructors \(v, ps) -> do - ps' <- toNest <$> mapM (tyOptBinder Explicit) ps + ps' <- fromMaybeM ps Empty \(WithSrcs _ _ ps') -> + toNest <$> mapM (tyOptBinder Explicit) ps' return (v, ps') return $ UDataDefDecl - (UDataDef name (givens' >>> tyConParams') $ - map (\(name', cons) -> (name', UDataDefTrail cons)) constructors') - (fromString name) - (toNest $ map (fromString . fst) constructors') - topDecl' (CStruct name params givens fields defs) = do - params' <- aExplicitParams params + (UDataDef (withoutSrc name) (givens' >>> tyConParams') $ + map (\(name', cons) -> (withoutSrc name', UDataDefTrail cons)) constructors') + (fromSourceNameW name) + (toNest $ map (fromSourceNameW . fst) constructors') + CStruct name params givens fields defs -> do + params' <- fromMaybeM params Empty aExplicitParams givens' <- aOptGivens givens fields' <- forM fields \(v, ty) -> (v,) <$> expr ty methods <- forM defs \(ann, d) -> do - (methodName, lam) <- aDef d - return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam) - return $ UStructDecl (fromString name) (UStructDef name (givens' >>> params') fields' methods) - topDecl' (CInterface name params methods) = do + (WithSrc _ methodName, lam) <- aDef d + return (ann, methodName, Abs (WithSrcB sid (UBindSource "self")) lam) + return $ UStructDecl (fromSourceNameW name) (UStructDef (withoutSrc name) (givens' >>> params') fields' methods) + CInterface name params methods -> do params' <- aExplicitParams params (methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do ty' <- expr ty - return (fromString methodName, ty') - return $ UInterface params' methodTys (fromString name) (toNest methodNames) - topDecl' (CInstanceDecl def) = aInstanceDef def + return (fromSourceNameW methodName, ty') + return $ UInterface params' methodTys (fromSourceNameW name) (toNest methodNames) + CInstanceDecl def -> aInstanceDef def -decl :: LetAnn -> CSDecl -> SyntaxM (UDecl VoidS VoidS) -decl ann = propagateSrcB \case +decl :: LetAnn -> CSDeclW -> SyntaxM (UDecl VoidS VoidS) +decl ann (WithSrcs sid _ d) = WithSrcB sid <$> case d of CLet binder rhs -> do (p, ty) <- patOptAnn binder ULet ann p ty <$> asExpr <$> block rhs CBind _ _ -> throw SyntaxErr "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope." CDefDecl def -> do (name, lam) <- aDef def - return $ ULet ann (fromString name) Nothing (ns $ ULam lam) + return $ ULet ann (fromSourceNameW name) Nothing (WithSrcE sid (ULam lam)) CExpr g -> UExprDecl <$> expr g CPass -> return UPass aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS) -aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do - let clName' = fromString clName +aInstanceDef (CInstanceDef (WithSrc clNameId clName) args givens methods instNameAndParams) = do + let clName' = SourceName clNameId clName args' <- mapM expr args givens' <- aOptGivens givens methods' <- catMaybes <$> mapM aMethod methods case instNameAndParams of Nothing -> return $ UInstance clName' givens' args' methods' NothingB ImplicitApp - Just (instName, optParams) -> do - let instName' = JustB $ fromString instName + Just (WithSrc sid instName, optParams) -> do + let instName' = JustB $ WithSrcB sid $ UBindSource instName case optParams of Just params -> do params' <- aExplicitParams params return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp -aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS) +aDef :: CDef -> SyntaxM (SourceNameW, ULamExpr VoidS) aDef (CDef name params optRhs optGivens body) = do explicitParams <- explicitBindersOptAnn params let rhsDefault = (ExplicitApp, Nothing, Nothing) @@ -176,8 +175,8 @@ aDef (CDef name params optRhs optGivens body) = do body' <- block body return (name, ULamExpr allParams expl effs resultTy body') -stripParens :: Group -> Group -stripParens (WithSrc _ (CParens [g])) = stripParens g +stripParens :: GroupW -> GroupW +stripParens (WithSrcs _ _ (CParens [g])) = stripParens g stripParens g = g -- === combinators for different sorts of binder lists === @@ -186,140 +185,150 @@ aOptGivens :: Maybe GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) aOptGivens optGivens = fromMaybeM optGivens Empty aGivens binderList - :: [Group] -> (Group -> SyntaxM (Nest UAnnBinder VoidS VoidS)) + :: [GroupW] -> (GroupW -> SyntaxM (Nest UAnnBinder VoidS VoidS)) -> SyntaxM (Nest UAnnBinder VoidS VoidS) binderList gs cont = concatNests <$> forM gs \case - WithSrc _ (CGivens gs') -> aGivens gs' + WithSrcs _ _ (CGivens gs') -> aGivens gs' g -> cont g withTrailingConstraints - :: Group -> (Group -> SyntaxM (UAnnBinder VoidS VoidS)) + :: GroupW -> (GroupW -> SyntaxM (UAnnBinder VoidS VoidS)) -> SyntaxM (Nest UAnnBinder VoidS VoidS) withTrailingConstraints g cont = case g of - Binary Pipe lhs c -> do - Nest (UAnnBinder expl b ann cs) bs <- withTrailingConstraints lhs cont - (ctx, s) <- case b of - UBindSource ctx s -> return (ctx, s) - UIgnore -> throw SyntaxErr "Can't constrain anonymous binders" - UBind _ _ _ -> error "Shouldn't have internal names until renaming pass" + WithSrcs _ _ (CBin Pipe lhs c) -> do + Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs <- withTrailingConstraints lhs cont + s <- case b of + UBindSource s -> return s + UIgnore -> throw SyntaxErr "Can't constrain anonymous binders" + UBind _ _ -> error "Shouldn't have internal names until renaming pass" c' <- expr c - let v = WithSrcE ctx $ UVar (SourceName ctx s) - return $ UnaryNest (UAnnBinder expl b ann (cs ++ [c'])) + return $ UnaryNest (UAnnBinder expl (WithSrcB sid b) ann (cs ++ [c'])) >>> bs - >>> UnaryNest (asConstraintBinder v c') + >>> UnaryNest (asConstraintBinder (mkUVar sid s) c') _ -> UnaryNest <$> cont g where asConstraintBinder :: UExpr VoidS -> UConstraint VoidS -> UAnnBinder VoidS VoidS asConstraintBinder v c = do - let t = ns $ UApp c [v] [] - UAnnBinder (Inferred Nothing (Synth Full)) UIgnore (UAnn t) [] + let sid = srcPos c + let t = WithSrcE sid (UApp c [v] []) + UAnnBinder (Inferred Nothing (Synth Full)) (WithSrcB sid UIgnore) (UAnn t) [] + +mkUVar :: SrcId -> SourceName -> UExpr VoidS +mkUVar sid v = WithSrcE sid $ UVar $ SourceName sid v aGivens :: GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) -aGivens (implicits, optConstraints) = do +aGivens ((WithSrcs _ _ implicits), optConstraints) = do implicits' <- concatNests <$> forM implicits \b -> withTrailingConstraints b implicitArgBinder - constraints <- fromMaybeM optConstraints Empty (\gs -> toNest <$> mapM synthBinder gs) + constraints <- fromMaybeM optConstraints Empty (\(WithSrcs _ _ gs) -> toNest <$> mapM synthBinder gs) return $ implicits' >>> constraints -synthBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS) +synthBinder :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) synthBinder g = tyOptBinder (Inferred Nothing (Synth Full)) g concatNests :: [Nest b VoidS VoidS] -> Nest b VoidS VoidS concatNests [] = Empty concatNests (b:bs) = b >>> concatNests bs -implicitArgBinder :: Group -> SyntaxM (UAnnBinder VoidS VoidS) +implicitArgBinder :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) implicitArgBinder g = do UAnnBinder _ b ann cs <- binderOptTy (Inferred Nothing Unify) g s <- case b of - UBindSource _ s -> return $ Just s - _ -> return Nothing + WithSrcB _ (UBindSource s) -> return $ Just s + _ -> return Nothing return $ UAnnBinder (Inferred s Unify) b ann cs aExplicitParams :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) -aExplicitParams bs = binderList bs \b -> withTrailingConstraints b \b' -> +aExplicitParams (WithSrcs _ _ bs) = binderList bs \b -> withTrailingConstraints b \b' -> binderOptTy Explicit b' -aPiBinders :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aPiBinders :: [GroupW] -> SyntaxM (Nest UAnnBinder VoidS VoidS) aPiBinders bs = binderList bs \b -> UnaryNest <$> tyOptBinder Explicit b explicitBindersOptAnn :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) -explicitBindersOptAnn bs = binderList bs \b -> withTrailingConstraints b \b' -> - binderOptTy Explicit b' +explicitBindersOptAnn (WithSrcs _ _ bs) = + binderList bs \b -> withTrailingConstraints b \b' -> binderOptTy Explicit b' -- === -- Binder pattern with an optional type annotation -patOptAnn :: Group -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) -patOptAnn (Binary Colon lhs typeAnn) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) -patOptAnn (WithSrc _ (CParens [g])) = patOptAnn g +patOptAnn :: GroupW -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) +patOptAnn (WithSrcs _ _ (CBin Colon lhs typeAnn)) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) +patOptAnn (WithSrcs _ _ (CParens [g])) = patOptAnn g patOptAnn g = (,Nothing) <$> pat g -uBinder :: Group -> SyntaxM (UBinder c VoidS VoidS) -uBinder (WithSrc src b) = addSrcContext src $ case b of - CIdentifier name -> return $ fromString name - CHole -> return UIgnore +uBinder :: GroupW -> SyntaxM (UBinder c VoidS VoidS) +uBinder (WithSrcs sid _ b) = case b of + CLeaf (CIdentifier name) -> return $ fromSourceNameW $ WithSrc sid name + CLeaf CHole -> return $ WithSrcB sid UIgnore _ -> throw SyntaxErr "Binder must be an identifier or `_`" -- Type annotation with an optional binder pattern -tyOptPat :: Group -> SyntaxM (UAnnBinder VoidS VoidS) -tyOptPat = \case +tyOptPat :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +tyOptPat grpTop@(WithSrcs sid _ grp) = case grp of -- Named type - Binary Colon lhs typeAnn -> UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] + CBin Colon lhs typeAnn -> + UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] -- Binder in grouping parens. - WithSrc _ (CParens [g]) -> tyOptPat g + CParens [g] -> tyOptPat g -- Anonymous type - g -> UAnnBinder Explicit UIgnore <$> (UAnn <$> expr g) <*> pure [] + _ -> UAnnBinder Explicit (WithSrcB sid UIgnore) <$> (UAnn <$> expr grpTop) <*> pure [] -- Pattern of a case binder. This treats bare names specially, in -- that they become (nullary) constructors to match rather than names -- to bind. -casePat :: Group -> SyntaxM (UPat VoidS VoidS) +casePat :: GroupW -> SyntaxM (UPat VoidS VoidS) casePat = \case - (WithSrc src (CIdentifier name)) -> return $ WithSrcB src $ UPatCon (fromString name) Empty + WithSrcs src _ (CLeaf (CIdentifier name)) -> + return $ WithSrcB src $ UPatCon (fromSourceNameW (WithSrc src name)) Empty g -> pat g -pat :: Group -> SyntaxM (UPat VoidS VoidS) -pat = propagateSrcB pat' where - pat' (CBin (WithSrc _ DepComma) lhs rhs) = do +pat :: GroupW -> SyntaxM (UPat VoidS VoidS) +pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of + CBin DepComma lhs rhs -> do lhs' <- pat lhs rhs' <- pat rhs return $ UPatDepPair $ PairB lhs' rhs' - pat' (CBrackets gs) = UPatTable . toNest <$> (mapM pat gs) + CBrackets gs -> UPatTable . toNest <$> (mapM pat gs) -- TODO: use Python-style trailing comma (like `(x,y,)`) for singleton tuples - pat' (CParens [g]) = dropSrcB <$> casePat g - pat' (CParens gs) = UPatProd . toNest <$> mapM pat gs - pat' CHole = return $ UPatBinder UIgnore - pat' (CIdentifier name) = return $ UPatBinder $ fromString name - pat' (CBin (WithSrc _ JuxtaposeWithSpace) lhs rhs) = do + CParens gs -> case gs of + [g] -> do + WithSrcB _ g' <- casePat g + return g' + _ -> UPatProd . toNest <$> mapM pat gs + CLeaf CHole -> return $ UPatBinder (WithSrcB sid UIgnore) + CLeaf (CIdentifier name) -> return $ UPatBinder $ fromSourceNameW $ WithSrc sid name + CJuxtapose True lhs rhs -> do case lhs of - WithSrc _ (CBin (WithSrc _ JuxtaposeWithSpace) _ _) -> + WithSrcs _ _ (CJuxtapose True _ _) -> throw SyntaxErr "Only unary constructors can form patterns without parens" _ -> return () name <- identifier "pattern constructor name" lhs arg <- pat rhs - return $ UPatCon (fromString name) (UnaryNest arg) - pat' (CBin (WithSrc _ JuxtaposeNoSpace) lhs rhs) = do + return $ UPatCon (fromSourceNameW name) (UnaryNest arg) + CJuxtapose False lhs rhs -> do name <- identifier "pattern constructor name" lhs case rhs of - WithSrc _ (CParens gs) -> UPatCon (fromString name) . toNest <$> mapM pat gs + WithSrcs _ _ (CParens gs) -> do + gs' <- mapM pat gs + return $ UPatCon (fromSourceNameW name) (toNest gs') _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - pat' _ = throw SyntaxErr "Illegal pattern" + _ -> throw SyntaxErr "Illegal pattern" -tyOptBinder :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) -tyOptBinder expl = \case - Binary Pipe _ _ -> throw SyntaxErr "Unexpected constraint" - Binary Colon name ty -> do +tyOptBinder :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +tyOptBinder expl (WithSrcs sid sids grp) = case grp of + CBin Pipe _ _ -> throw SyntaxErr "Unexpected constraint" + CBin Colon name ty -> do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] g -> do - ty <- expr g - return $ UAnnBinder expl UIgnore (UAnn ty) [] + ty <- expr (WithSrcs sid sids g) + return $ UAnnBinder expl (WithSrcB sid UIgnore) (UAnn ty) [] -binderOptTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) +binderOptTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) binderOptTy expl = \case - Binary Colon name ty -> do + WithSrcs _ _ (CBin Colon name ty) -> do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] @@ -327,55 +336,56 @@ binderOptTy expl = \case b <- uBinder g return $ UAnnBinder expl b UNoAnn [] -binderReqTy :: Explicitness -> Group -> SyntaxM (UAnnBinder VoidS VoidS) -binderReqTy expl (Binary Colon name ty) = do +binderReqTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +binderReqTy expl (WithSrcs _ _ (CBin Colon name ty)) = do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] binderReqTy _ _ = throw SyntaxErr $ "Expected an annotated binder" -argList :: [Group] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) +argList :: [GroupW] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) argList gs = partitionEithers <$> mapM singleArg gs -singleArg :: Group -> SyntaxM (Either (UExpr VoidS) (UNamedArg VoidS)) +singleArg :: GroupW -> SyntaxM (Either (UExpr VoidS) (UNamedArg VoidS)) singleArg = \case - WithSrc src (CBin (WithSrc _ CSEqual) lhs rhs) -> addSrcContext src $ Right <$> - ((,) <$> identifier "named argument" lhs <*> expr rhs) + WithSrcs _ _ (CBin CSEqual lhs rhs) -> Right <$> + ((,) <$> withoutSrc <$> identifier "named argument" lhs <*> expr rhs) g -> Left <$> expr g -identifier :: String -> Group -> SyntaxM SourceName -identifier ctx = dropSrc identifier' where - identifier' (CIdentifier name) = return name - identifier' _ = throw SyntaxErr $ "Expected " ++ ctx ++ " to be an identifier" +identifier :: String -> GroupW -> SyntaxM SourceNameW +identifier ctx (WithSrcs sid _ g) = case g of + CLeaf (CIdentifier name) -> return $ WithSrc sid name + _ -> throw SyntaxErr $ "Expected " ++ ctx ++ " to be an identifier" -aEffects :: ([Group], Maybe Group) -> SyntaxM (UEffectRow VoidS) -aEffects (effs, optEffTail) = do +aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS) +aEffects (WithSrcs _ _ (effs, optEffTail)) = do lhs <- mapM effect effs rhs <- forM optEffTail \effTail -> - fromString <$> identifier "effect row remainder variable" effTail + fromSourceNameW <$> identifier "effect row remainder variable" effTail return $ UEffectRow (S.fromList lhs) rhs -effect :: Group -> SyntaxM (UEffect VoidS) -effect (WithSrc _ (CParens [g])) = effect g -effect (Binary JuxtaposeWithSpace (Identifier "Read") (Identifier h)) = - return $ URWSEffect Reader $ fromString h -effect (Binary JuxtaposeWithSpace (Identifier "Accum") (Identifier h)) = - return $ URWSEffect Writer $ fromString h -effect (Binary JuxtaposeWithSpace (Identifier "State") (Identifier h)) = - return $ URWSEffect State $ fromString h -effect (Identifier "Except") = return UExceptionEffect -effect (Identifier "IO") = return UIOEffect -effect _ = throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." - -aMethod :: CSDecl -> SyntaxM (Maybe (UMethodDef VoidS)) -aMethod (WithSrc _ CPass) = return Nothing -aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of +effect :: GroupW -> SyntaxM (UEffect VoidS) +effect (WithSrcs _ _ grp) = case grp of + CParens [g] -> effect g + CJuxtapose True (Identifier "Read" ) (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect Reader $ fromSourceNameW (WithSrc sid h) + CJuxtapose True (Identifier "Accum") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect Writer $ fromSourceNameW (WithSrc sid h) + CJuxtapose True (Identifier "State") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect State $ fromSourceNameW (WithSrc sid h) + CLeaf (CIdentifier "Except") -> return UExceptionEffect + CLeaf (CIdentifier "IO" ) -> return UIOEffect + _ -> throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." + +aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS)) +aMethod (WithSrcs _ _ CPass) = return Nothing +aMethod (WithSrcs src _ d) = Just . WithSrcE src <$> case d of CDefDecl def -> do - (name, lam) <- aDef def - return $ UMethodDef (fromString name) lam - CLet (WithSrc _ (CIdentifier name)) rhs -> do + (WithSrc sid name, lam) <- aDef def + return $ UMethodDef (SourceName sid name) lam + CLet (WithSrcs sid _ (CLeaf (CIdentifier name))) rhs -> do rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs - return $ UMethodDef (fromString name) rhs' + return $ UMethodDef (fromSourceNameW (WithSrc sid name)) rhs' _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." asExpr :: UBlock VoidS -> UExpr VoidS @@ -384,22 +394,22 @@ asExpr (WithSrcE src b) = case b of _ -> WithSrcE src $ UDo $ WithSrcE src b block :: CSBlock -> SyntaxM (UBlock VoidS) -block (ExprBlock g) = WithSrcE emptySrcPosCtx . UBlock Empty <$> expr g -block (IndentedBlock decls) = do +block (ExprBlock g) = WithSrcE (srcPos g) . UBlock Empty <$> expr g +block (IndentedBlock sid decls) = do (decls', result) <- blockDecls decls - return $ WithSrcE emptySrcPosCtx $ UBlock decls' result + return $ WithSrcE sid $ UBlock decls' result -blockDecls :: [CSDecl] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) +blockDecls :: [CSDeclW] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) blockDecls [] = error "shouldn't have empty list of decls" -blockDecls [WithSrc src d] = addSrcContext src case d of +blockDecls [WithSrcs _ _ d] = case d of CExpr g -> (Empty,) <$> expr g _ -> throw SyntaxErr "Block must end in expression" -blockDecls (WithSrc pos (CBind b rhs):ds) = do +blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs - body <- block $ IndentedBlock ds + body <- block $ IndentedBlock sid ds -- Not really the right SrcId let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing Nothing body - return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam)) + return (Empty, WithSrcE sid $ extendAppRight rhs' (WithSrcE sid lam)) blockDecls (d:ds) = do d' <- decl PlainLet d (ds', e) <- blockDecls ds @@ -407,86 +417,76 @@ blockDecls (d:ds) = do -- === Concrete to abstract syntax of expressions === -expr :: Group -> SyntaxM (UExpr VoidS) -expr = propagateSrcE expr' where - expr' CEmpty = return UHole - -- Binders (e.g., in pi types) should not hit this case - expr' (CIdentifier name) = return $ fromString name - expr' (CPrim prim xs) = UPrim prim <$> mapM expr xs - expr' (CNat word) = return $ UNatLit word - expr' (CInt int) = return $ UIntLit int - expr' (CString str) = return $ explicitApp (fromString "to_list") - [ns $ UTabCon $ map (ns . charExpr) str] - expr' (CChar char) = return $ charExpr char - expr' (CFloat num) = return $ UFloatLit num - expr' CHole = return UHole - expr' (CParens [g]) = dropSrcE <$> expr g - expr' (CParens gs) = UPrim UTuple <$> mapM expr gs +expr :: GroupW -> SyntaxM (UExpr VoidS) +expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of + CLeaf x -> leaf sid x + CPrim prim xs -> UPrim prim <$> mapM expr xs + CParens [g] -> do + WithSrcE _ result <- expr g + return result + CParens gs -> UPrim UTuple <$> mapM expr gs -- Table constructors here. Other uses of square brackets -- should be detected upstream, before calling expr. - expr' (CBrackets gs) = UTabCon <$> mapM expr gs - expr' (CGivens _) = throw SyntaxErr $ "Unexpected `given` clause" - expr' (CArrow lhs effs rhs) = do + CBrackets gs -> UTabCon <$> mapM expr gs + CGivens _ -> throw SyntaxErr $ "Unexpected `given` clause" + CArrow lhs effs rhs -> do case lhs of - WithSrc _ (CParens gs) -> do + WithSrcs _ _ (CParens gs) -> do bs <- aPiBinders gs effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy _ -> throw SyntaxErr "Argument types should be in parentheses" - expr' (CDo b) = UDo <$> block b - -- Binders (e.g., in pi types) should not hit this case - expr' (CBin (WithSrc opSrc op) lhs rhs) = - case op of - JuxtaposeNoSpace -> do - f <- expr lhs - case rhs of - WithSrc _ (CParens args) -> do - (posArgs, namedArgs) <- argList args - return $ UApp f posArgs namedArgs - WithSrc _ (CBrackets args) -> do - args' <- mapM expr args - return $ UTabApp f args' - _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - JuxtaposeWithSpace -> extendAppRight <$> expr lhs <*> expr rhs - Dollar -> extendAppRight <$> expr lhs <*> expr rhs - Pipe -> extendAppLeft <$> expr lhs <*> expr rhs - Dot -> do - lhs' <- expr lhs - WithSrc src rhs' <- return rhs - name <- addSrcContext src $ case rhs' of - CIdentifier name -> return $ FieldName name - CNat i -> return $ FieldNum $ fromIntegral i - _ -> throw SyntaxErr "Field must be a name or an integer" - return $ UFieldAccess lhs' (WithSrc src name) - DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs - EvalBinOp s -> evalOp s - DepAmpersand -> do - lhs' <- tyOptPat lhs - UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs - DepComma -> UDepPair <$> (expr lhs) <*> expr rhs - CSEqual -> throw SyntaxErr "Equal sign must be used as a separator for labels or binders, not a standalone operator" - Colon -> throw SyntaxErr "Colon separates binders from their type annotations, is not a standalone operator.\nIf you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" - ImplicitArrow -> case lhs of - WithSrc _ (CParens gs) -> do - bs <- aPiBinders gs - resultTy <- expr rhs - return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy - _ -> throw SyntaxErr "Argument types should be in parentheses" - FatArrow -> do - lhs' <- tyOptPat lhs - UTabPi . (UTabPiExpr lhs') <$> expr rhs - where - evalOp s = do - let f = WithSrcE opSrc (fromString s) - lhs' <- expr lhs - rhs' <- expr rhs - return $ explicitApp f [lhs', rhs'] - expr' (CPrefix name g) = + CDo b -> UDo <$> block b + CJuxtapose hasSpace lhs rhs -> case hasSpace of + True -> extendAppRight <$> expr lhs <*> expr rhs + False -> do + f <- expr lhs + case rhs of + WithSrcs _ _ (CParens args) -> do + (posArgs, namedArgs) <- argList args + return $ UApp f posArgs namedArgs + WithSrcs _ _ (CBrackets args) -> do + args' <- mapM expr args + return $ UTabApp f args' + _ -> error "unexpected postfix group (should be ruled out at grouping stage)" + CBin op lhs rhs -> case op of + Dollar -> extendAppRight <$> expr lhs <*> expr rhs + Pipe -> extendAppLeft <$> expr lhs <*> expr rhs + Dot -> do + lhs' <- expr lhs + WithSrcs src _ rhs' <- return rhs + name <- case rhs' of + CLeaf (CIdentifier name) -> return $ FieldName name + CLeaf (CNat i ) -> return $ FieldNum $ fromIntegral i + _ -> throw SyntaxErr "Field must be a name or an integer" + return $ UFieldAccess lhs' (WithSrc src name) + DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs + EvalBinOp s -> evalOp s + DepAmpersand -> do + lhs' <- tyOptPat lhs + UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs + DepComma -> UDepPair <$> (expr lhs) <*> expr rhs + CSEqual -> throw SyntaxErr "Equal sign must be used as a separator for labels or binders, not a standalone operator" + Colon -> throw SyntaxErr "Colon separates binders from their type annotations, is not a standalone operator.\nIf you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" + ImplicitArrow -> case lhs of + WithSrcs _ _ (CParens gs) -> do + bs <- aPiBinders gs + resultTy <- expr rhs + return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy + _ -> throw SyntaxErr "Argument types should be in parentheses" + FatArrow -> do + lhs' <- tyOptPat lhs + UTabPi . (UTabPiExpr lhs') <$> expr rhs + where + evalOp s = do + let f = WithSrcE (srcPos s) (fromSourceNameW s) + lhs' <- expr lhs + rhs' <- expr rhs + return $ explicitApp f [lhs', rhs'] + CPrefix (WithSrc _ name) g -> do case name of - ".." -> range "RangeTo" <$> expr g - "..<" -> range "RangeToExc" <$> expr g - "+" -> (dropSrcE <$> expr g) <&> \case + "+" -> (withoutSrc <$> expr g) <&> \case UNatLit i -> UIntLit (fromIntegral i) UIntLit i -> UIntLit i UFloatLit i -> UFloatLit i @@ -495,68 +495,72 @@ expr = propagateSrcE expr' where WithSrcE _ (UNatLit i) -> UIntLit (-(fromIntegral i)) WithSrcE _ (UIntLit i) -> UIntLit (-i) WithSrcE _ (UFloatLit i) -> UFloatLit (-i) - e -> unaryApp "neg" e + e -> unaryApp (mkUVar sid "neg") e _ -> throw SyntaxErr $ "Prefix (" ++ name ++ ") not legal as a bare expression" - where - range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [lim] - expr' (CPostfix name g) = - case name of - ".." -> range "RangeFrom" <$> expr g - "<.." -> range "RangeFromExc" <$> expr g - _ -> throw SyntaxErr $ "Postfix (" ++ name ++ ") not legal as a bare expression" - where - range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [lim] - expr' (CLambda params body) = do - params' <- explicitBindersOptAnn $ map stripParens params + CLambda params body -> do + params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body return $ ULam $ ULamExpr params' ExplicitApp Nothing Nothing body' - expr' (CFor kind indices body) = do + CFor kind indices body -> do let (dir, trailingUnit) = case kind of KFor -> (Fwd, False) KFor_ -> (Fwd, True) KRof -> (Rev, False) KRof_ -> (Rev, True) -- TODO: Can we fetch the source position from the error context, to feed into `buildFor`? - e <- buildFor (0, 0) dir <$> mapM (binderOptTy Explicit) indices <*> block body + e <- buildFor sid dir <$> mapM (binderOptTy Explicit) indices <*> block body if trailingUnit - then return $ UDo $ ns $ UBlock (UnaryNest (nsB $ UExprDecl e)) (ns unitExpr) - else return $ dropSrcE e - expr' (CCase scrut alts) = UCase <$> (expr scrut) <*> mapM alternative alts + then return $ UDo $ WithSrcE sid $ UBlock (UnaryNest (WithSrcB sid $ UExprDecl e)) (unitExpr sid) + else return $ withoutSrc e + CCase scrut alts -> UCase <$> (expr scrut) <*> mapM alternative alts where alternative (match, body) = UAlt <$> casePat match <*> block body - expr' (CIf p c a) = do + CIf p c a -> do p' <- expr p c' <- block c a' <- case a of - Nothing -> return $ ns $ UBlock Empty $ ns unitExpr + Nothing -> return $ WithSrcE sid $ UBlock Empty $ unitExpr sid (Just alternative) -> block alternative return $ UCase p' - [ UAlt (nsB $ UPatCon "True" Empty) c' - , UAlt (nsB $ UPatCon "False" Empty) a'] - expr' (CWith lhs rhs) = do + [ UAlt (WithSrcB sid $ UPatCon (SourceName sid "True") Empty) c' + , UAlt (WithSrcB sid $ UPatCon (SourceName sid "False") Empty) a'] + CWith lhs rhs -> do ty <- expr lhs case rhs of - [b] -> do + WithSrcs _ _ [b] -> do b' <- binderReqTy Explicit b return $ UDepPairTy $ UDepPairType ImplicitDepPair b' ty _ -> error "n-ary dependent pairs not implemented" +leaf :: SrcId -> CLeaf -> SyntaxM (UExpr' VoidS) +leaf sid = \case + -- Binders (e.g., in pi types) should not hit this case + CIdentifier name -> return $ fromSourceNameW $ WithSrc sid name + CNat word -> return $ UNatLit word + CInt int -> return $ UIntLit int + CString str -> do + xs <- return $ map (WithSrcE sid . charExpr) str + let toListVar = mkUVar sid "to_list" + return $ explicitApp toListVar [WithSrcE sid (UTabCon xs)] + CChar char -> return $ charExpr char + CFloat num -> return $ UFloatLit num + CHole -> return UHole + charExpr :: Char -> (UExpr' VoidS) charExpr c = ULit $ Word8Lit $ fromIntegral $ fromEnum c -unitExpr :: UExpr' VoidS -unitExpr = UPrim (UCon $ P.ProdCon) [] +unitExpr :: SrcId -> UExpr VoidS +unitExpr sid = WithSrcE sid $ UPrim (UCon $ P.ProdCon) [] -- === Builders === -- TODO Does this generalize? Swap list for Nest? -buildFor :: SrcPos -> Direction -> [UAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS -buildFor pos dir binders body = case binders of +-- TODO: these SrcIds aren't really correct +buildFor :: SrcId -> Direction -> [UAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS +buildFor sid dir binders body = case binders of [] -> error "should have nonempty list of binder" - [b] -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b body - b:bs -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b $ - ns $ UBlock Empty $ buildFor pos dir bs body + [b] -> WithSrcE sid $ UFor dir $ UForExpr b body + b:bs -> WithSrcE sid $ UFor dir $ UForExpr b $ + WithSrcE sid $ UBlock Empty $ buildFor sid dir bs body -- === Helpers === @@ -574,26 +578,5 @@ unaryApp f x = UApp f [x] [] explicitApp :: UExpr n -> [UExpr n] -> UExpr' n explicitApp f xs = UApp f xs [] -ns :: (a n) -> WithSrcE a n -ns = WithSrcE emptySrcPosCtx - -nsB :: (b n l) -> WithSrcB b n l -nsB = WithSrcB emptySrcPosCtx - toNest :: [a VoidS VoidS] -> Nest a VoidS VoidS toNest = foldr Nest Empty - -dropSrc :: (t -> SyntaxM a) -> WithSrc t -> SyntaxM a -dropSrc act (WithSrc src x) = addSrcContext src $ act x - -propagateSrcE :: (t -> SyntaxM (e n)) -> WithSrc t -> SyntaxM (WithSrcE e n) -propagateSrcE act (WithSrc src x) = addSrcContext src (WithSrcE src <$> act x) - -dropSrcE :: WithSrcE e n -> e n -dropSrcE (WithSrcE _ x) = x - -propagateSrcB :: (t -> SyntaxM (binder n l)) -> WithSrc t -> SyntaxM (WithSrcB binder n l) -propagateSrcB act (WithSrc src x) = addSrcContext src (WithSrcB src <$> act x) - -dropSrcB :: WithSrcB binder n l -> binder n l -dropSrcB (WithSrcB _ x) = x diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 55a7e71ce..a12b5c8b3 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -144,7 +144,7 @@ liftTopBuilderAndEmit cont = do newtype DoubleBuilderT (r::IR) (topEmissions::B) (m::MonadKind) (n::S) (a:: *) = DoubleBuilderT { runDoubleBuilderT' :: DoubleInplaceT Env topEmissions (BuilderEmissions r) m n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, MonadIO, Catchable, MonadReader r') + , MonadIO, Catchable, MonadReader r') deriving instance (ExtOutMap Env frag, HoistableB frag, OutFrag frag, Fallible m, IRRep r) => ScopeReader (DoubleBuilderT r frag m) @@ -349,7 +349,7 @@ getCache = withEnv $ envCache . topEnv newtype TopBuilderT (m::MonadKind) (n::S) (a:: *) = TopBuilderT { runTopBuilderT' :: InplaceT Env TopEnvFrag m n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, ScopeReader, MonadTrans1, MonadReader r + , ScopeReader, MonadTrans1, MonadReader r , MonadWriter w, MonadState s, MonadIO, Catchable) type TopBuilderM = TopBuilderT HardFailM @@ -424,7 +424,7 @@ type BuilderEmissions r = RNest (Decl r) newtype BuilderT (r::IR) (m::MonadKind) (n::S) (a:: *) = BuilderT { runBuilderT' :: InplaceT Env (BuilderEmissions r) m n a } deriving ( Functor, Applicative, Monad, MonadTrans1, MonadFail, Fallible - , Catchable, CtxReader, ScopeReader, Alternative + , Catchable, ScopeReader, Alternative , MonadWriter w, MonadReader r') type BuilderM (r::IR) = BuilderT r HardFailM diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index b9bc0ce01..6fdd7280f 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -85,11 +85,11 @@ reduceTabApp f x = liftM fromJust $ liftReducerM $ reduceTabAppM f x -- === internal === -type ReducerM = SubstReaderT AtomSubstVal (EnvReaderT FallibleM) +type ReducerM = SubstReaderT AtomSubstVal (EnvReaderT Except) liftReducerM :: EnvReader m => ReducerM n n a -> m n (Maybe a) liftReducerM cont = do - liftM (ignoreExcept . runFallibleM) $ liftEnvReaderT $ runSubstReaderT idSubst do + liftM ignoreExcept $ liftEnvReaderT $ runSubstReaderT idSubst do (Just <$> cont) <|> return Nothing reduceWithDeclsM :: IRRep r => Nest (Decl r) i i' -> ReducerM i' o a -> ReducerM i o a @@ -644,7 +644,7 @@ substMStuck stuck = do substStuck :: (IRRep r, Distinct o) => (Env o, Subst AtomSubstVal i o) -> Stuck r i -> Atom r o substStuck (env, subst) stuck = - ignoreExcept $ runFallibleM $ runEnvReaderT env $ runSubstReaderT subst $ reduceStuck stuck + ignoreExcept $ runEnvReaderT env $ runSubstReaderT subst $ reduceStuck stuck reduceStuck :: (IRRep r, Distinct o) => Stuck r i -> ReducerM i o (Atom r o) reduceStuck = \case diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 31db509cf..f808e7153 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -39,16 +39,12 @@ checkTypeIs e ty = liftTyperM (void $ e |: ty) >>= liftExcept -- === the type checking/querying monad === newtype TyperM (r::IR) (i::S) (o::S) (a :: *) = - TyperM { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) FallibleEnvReaderM) i o a } + TyperM { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) (EnvReaderT Except)) i o a } deriving ( Functor, Applicative, Monad , SubstReader Name , MonadFail , Fallible , ScopeReader - , EnvReader, EnvExtender) + , EnvReader, EnvExtender, Catchable) liftTyperM :: EnvReader m => TyperM r n n a -> m n (Except a) -liftTyperM cont = - liftM runFallibleM $ liftEnvReaderT $ - flip evalStateT1 mempty $ - runSubstReaderT idSubst $ - runTyperT' cont +liftTyperM cont = liftEnvReaderT $ flip evalStateT1 mempty $ runSubstReaderT idSubst $ runTyperT' cont {-# INLINE liftTyperM #-} -- I can't make up my mind whether a `Seq` loop should be allowed to @@ -201,7 +197,7 @@ checkBinderType ty b cont = do cont b' instance IRRep r => CheckableWithEffects r (Expr r) where - checkWithEffects allowedEffs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of + checkWithEffects allowedEffs expr = case expr of App effTy f xs -> do effTy' <- checkEffTy allowedEffs effTy f' <- checkE f diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 2e5aebcf0..b37f3a747 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -9,7 +9,7 @@ module ConcreteSyntax ( keyWordStrs, showPrimName, parseUModule, parseUModuleDeps, finishUModuleParse, preludeImportBlock, mustParseSourceBlock, - pattern Binary, pattern Prefix, pattern Postfix, pattern Identifier) where + pattern Identifier) where import Control.Monad.Combinators.Expr qualified as Expr import Control.Monad.Reader @@ -28,8 +28,6 @@ import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) import Lexing -import Name -import SourceInfo import Types.Core import Types.Source import Types.Primitives @@ -72,8 +70,8 @@ mustParseSourceBlock s = mustParseit s sourceBlock -- === helpers for target ADT === -interp_operator :: String -> Bin' -interp_operator = \case +interpOperator :: WithSrc String -> Bin +interpOperator (WithSrc sid s) = case s of "&>" -> DepAmpersand "." -> Dot ",>" -> DepComma @@ -84,23 +82,10 @@ interp_operator = \case "->>" -> ImplicitArrow "=>" -> FatArrow "=" -> CSEqual - name -> EvalBinOp $ "(" <> name <> ")" + name -> EvalBinOp $ WithSrc sid $ "(" <> name <> ")" -pattern Binary :: Bin' -> Group -> Group -> Group -pattern Binary op lhs rhs <- (WithSrc _ (CBin (WithSrc _ op) lhs rhs)) where - Binary op lhs rhs = joinSrc lhs rhs $ CBin (WithSrc emptySrcPosCtx op) lhs rhs - -pattern Prefix :: SourceName -> Group -> Group -pattern Prefix op g <- (WithSrc _ (CPrefix op g)) where - Prefix op g = WithSrc emptySrcPosCtx $ CPrefix op g - -pattern Postfix :: SourceName -> Group -> Group -pattern Postfix op g <- (WithSrc _ (CPostfix op g)) where - Postfix op g = WithSrc emptySrcPosCtx $ CPostfix op g - -pattern Identifier :: SourceName -> Group -pattern Identifier name <- (WithSrc _ (CIdentifier name)) where - Identifier name = WithSrc emptySrcPosCtx $ CIdentifier name +pattern Identifier :: SourceName -> GroupW +pattern Identifier name <- (WithSrcs _ _ (CLeaf (CIdentifier name))) -- === Parser (top-level structure) === @@ -127,7 +112,7 @@ recover e = do importModule :: Parser SourceBlock' importModule = Misc . ImportModule . OrdinaryModule <$> do keyWord ImportKW - s <- anyCaseName + WithSrc _ s <- anyCaseName eol return s @@ -171,7 +156,7 @@ logTime = do logBench :: Parser LogLevel logBench = do void $ try $ lexeme MiscLexeme $ char '%' >> string "bench" - benchName <- strLit + WithSrc _ benchName <- strLit eol return $ PrintBench benchName @@ -190,10 +175,10 @@ sourceBlock' = <|> hidden (some eol >> return (Misc EmptyLines)) <|> hidden (sc >> eol >> return (Misc CommentLine)) -topDecl :: Parser CTopDecl -topDecl = withSrc $ topDecl' <* eolf +topDecl :: Parser CTopDeclW +topDecl = withSrcs topDecl' <* eolf -topDecl' :: Parser CTopDecl' +topDecl' :: Parser CTopDecl topDecl' = dataDef <|> structDef @@ -202,7 +187,8 @@ topDecl' = <|> (CInstanceDecl <$> instanceDef False) proseBlock :: Parser SourceBlock' -proseBlock = label "prose block" $ char '\'' >> fmap (Misc . ProseBlock . fst) (withSource consumeTillBreak) +proseBlock = label "prose block" $ + char '\'' >> fmap (Misc . ProseBlock . fst) (withSource consumeTillBreak) topLevelCommand :: Parser SourceBlock' topLevelCommand = @@ -214,14 +200,15 @@ topLevelCommand = "top-level command" envQuery :: Parser EnvQuery -envQuery = string ":debug" >> sc >> ( - (DumpSubst <$ (string "env" >> sc)) - <|> (InternalNameInfo <$> (string "iname" >> sc >> rawName)) - <|> (SourceNameInfo <$> (string "sname" >> sc >> anyName))) - <* eol - where - rawName :: Parser RawName - rawName = undefined -- RawName <$> (fromString <$> anyName) <*> intLit +envQuery = error "not implemented" +-- string ":debug" >> sc >> ( +-- (DumpSubst <$ (string "env" >> sc)) +-- <|> (InternalNameInfo <$> (string "iname" >> sc >> rawName)) +-- <|> (SourceNameInfo <$> (string "sname" >> sc >> anyName))) +-- <* eol +-- where +-- rawName :: Parser RawName +-- rawName = RawName <$> (fromString <$> anyName) <*> intLit explicitCommand :: Parser SourceBlock' explicitCommand = do @@ -237,13 +224,13 @@ explicitCommand = do b <- cBlock <* eolf e <- case b of ExprBlock e -> return e - IndentedBlock decls -> return $ WithSrc emptySrcPosCtx $ CDo $ IndentedBlock decls + IndentedBlock sid decls -> withSrcs $ return $ CDo $ IndentedBlock sid decls return $ case (e, cmd) of - (WithSrc _ (CIdentifier v), GetType) -> Misc $ GetNameType v + (WithSrcs sid _ (CLeaf (CIdentifier v)), GetType) -> Misc $ GetNameType (WithSrc sid v) _ -> Command cmd e -type CDefBody = ([(SourceName, Group)], [(LetAnn, CDef)]) -structDef :: Parser CTopDecl' +type CDefBody = ([(SourceNameW, GroupW)], [(LetAnn, CDef)]) +structDef :: Parser CTopDecl structDef = do keyWord StructKW tyName <- anyName @@ -267,18 +254,18 @@ funDefLetWithAnn = do def <- funDefLet return (ann, def) -dataDef :: Parser CTopDecl' +dataDef :: Parser CTopDecl dataDef = do keyWord DataKW tyName <- anyName (params, givens) <- typeParams dataCons <- onePerLine do dataConName <- anyName - dataConArgs <- optExplicitParams + dataConArgs <- optional explicitParams return (dataConName, dataConArgs) return $ CData tyName params givens dataCons -interfaceDef :: Parser CTopDecl' +interfaceDef :: Parser CTopDecl interfaceDef = do keyWord InterfaceKW className <- anyName @@ -291,7 +278,7 @@ interfaceDef = do return (methodName, ty) return $ CInterface className params methodDecls -nameAndType :: Parser (SourceName, Group) +nameAndType :: Parser (SourceNameW, GroupW) nameAndType = do n <- anyName sym ":" @@ -299,14 +286,14 @@ nameAndType = do return (n, arg) topLetOrExpr :: Parser SourceBlock' -topLetOrExpr = withSrc topLet >>= \case - WithSrc _ (CSDecl ann (CExpr e)) -> do +topLetOrExpr = topLet >>= \case + WithSrcs _ _ (CSDecl ann (CExpr e)) -> do when (ann /= PlainLet) $ fail "Cannot annotate expressions" return $ Command (EvalExpr (Printed Nothing)) e d -> return $ TopDecl d -topLet :: Parser CTopDecl' -topLet = do +topLet :: Parser CTopDeclW +topLet = withSrcs do lAnn <- topLetAnn <|> return PlainLet decl <- cDecl return $ CSDecl lAnn decl @@ -330,15 +317,16 @@ cBlock :: Parser CSBlock cBlock = indentedBlock <|> ExprBlock <$> cGroup indentedBlock :: Parser CSBlock -indentedBlock = withIndent $ - IndentedBlock <$> (withSrc cDecl `sepBy1` (semicolon <|> try nextLine)) +indentedBlock = withIndent do + WithSrcs sid _ decls <- withSrcs $ withSrcs cDecl `sepBy1` (void semicolon <|> try nextLine) + return $ IndentedBlock sid decls -cDecl :: Parser CSDecl' +cDecl :: Parser CSDecl cDecl = (CDefDecl <$> funDefLet) <|> simpleLet <|> (keyWord PassKW >> return CPass) -simpleLet :: Parser CSDecl' +simpleLet :: Parser CSDecl simpleLet = do lhs <- cGroupNoEqual next <- nextChar @@ -352,14 +340,14 @@ instanceDef isNamed = do optNameAndArgs <- case isNamed of False -> keyWord InstanceKW $> Nothing True -> keyWord NamedInstanceKW >> do - name <- fromString <$> anyName + name <- anyName args <- (sym ":" >> return Nothing) - <|> ((Just <$> parens (commaSep cParenGroup)) <* sym "->") + <|> ((Just <$> parenList cParenGroup) <* sym "->") return $ Just (name, args) className <- anyName args <- argList givens <- optional givenClause - methods <- withIndent $ withSrc cDecl `sepBy1` try nextLine + methods <- withIndent $ (withSrcs cDecl) `sepBy1` try nextLine return $ CInstanceDef className args givens methods optNameAndArgs funDefLet :: Parser CDef @@ -384,8 +372,10 @@ explicitness = (sym "->" $> ExplicitApp) <|> (sym "->>" $> ImplicitApp) -- Intended for occurrences, like `foo(x, y, z)` (cf. defParamsList). -argList :: Parser [Group] -argList = immediateParens (commaSep cParenGroup) +argList :: Parser [GroupW] +argList = do + WithSrcs _ _ gs <- withSrcs $ bracketedGroup immediateLParen rParen cParenGroup + return gs immediateLParen :: Parser () immediateLParen = label "'(' (without preceding whitespace)" do @@ -395,25 +385,24 @@ immediateLParen = label "'(' (without preceding whitespace)" do False -> lParen _ -> empty -immediateParens :: Parser a -> Parser a -immediateParens p = bracketed immediateLParen rParen p - -- Putting `sym =` inside the cases gives better errors. -typeParams :: Parser (ExplicitParams, Maybe GivenClause) +typeParams :: Parser (Maybe ExplicitParams, Maybe GivenClause) typeParams = (explicitParamsAndGivens <* sym "=") - <|> (return ([], Nothing) <* sym "=") - -explicitParamsAndGivens :: Parser (ExplicitParams, Maybe GivenClause) -explicitParamsAndGivens = (,) <$> explicitParams <*> optional givenClause + <|> (return (Nothing, Nothing) <* sym "=") -optExplicitParams :: Parser ExplicitParams -optExplicitParams = label "optional parameter list" $ - explicitParams <|> return [] +explicitParamsAndGivens :: Parser (Maybe ExplicitParams, Maybe GivenClause) +explicitParamsAndGivens = (,) <$> (Just <$> explicitParams) <*> optional givenClause explicitParams :: Parser ExplicitParams explicitParams = label "parameter list in parentheses (without preceding whitespace)" $ - immediateParens $ commaSep cParenGroup + withSrcs $ bracketedGroup immediateLParen rParen cParenGroup + +parenList :: Parser GroupW -> Parser BracketedGroup +parenList p = withSrcs $ bracketedGroup lParen rParen p + +bracketedGroup :: Parser () -> Parser () -> Parser GroupW -> Parser [GroupW] +bracketedGroup l r p = bracketed l r $ commaSep p noGap :: Parser () noGap = precededByWhitespace >>= \case @@ -421,66 +410,63 @@ noGap = precededByWhitespace >>= \case False -> return () givenClause :: Parser GivenClause -givenClause = keyWord GivenKW >> do - (,) <$> parens (commaSep cGroup) - <*> optional (parens (commaSep cGroup)) +givenClause = do + keyWord GivenKW + (,) <$> parenList cGroup <*> optional (parenList cGroup) withClause :: Parser WithClause -withClause = keyWord WithKW >> parens (commaSep cGroup) - -arrowOptEffs :: Parser (Maybe CEffs) -arrowOptEffs = sym "->" >> optional cEffs +withClause = keyWord WithKW >> parenList cGroup cEffs :: Parser CEffs -cEffs = braces do +cEffs = withSrcs $ braces do effs <- commaSep cGroupNoPipe effTail <- optional $ sym "|" >> cGroup return (effs, effTail) commaSep :: Parser a -> Parser [a] -commaSep p = p `sepBy` sym "," +commaSep p = sepBy p (sym ",") -cParenGroup :: Parser Group -cParenGroup = withSrc (CGivens <$> givenClause) <|> cGroup +cParenGroup :: Parser GroupW +cParenGroup = (withSrcs (CGivens <$> givenClause)) <|> cGroup -cGroup :: Parser Group +cGroup :: Parser GroupW cGroup = makeExprParser leafGroup ops -cGroupNoJuxt :: Parser Group +cGroupNoJuxt :: Parser GroupW cGroupNoJuxt = makeExprParser leafGroup $ withoutOp "space" $ withoutOp "." ops -cGroupNoEqual :: Parser Group +cGroupNoEqual :: Parser GroupW cGroupNoEqual = makeExprParser leafGroup $ withoutOp "=" ops -cGroupNoPipe :: Parser Group +cGroupNoPipe :: Parser GroupW cGroupNoPipe = makeExprParser leafGroup $ withoutOp "|" ops -cGroupNoArrow :: Parser Group +cGroupNoArrow :: Parser GroupW cGroupNoArrow = makeExprParser leafGroup $ withoutOp "->" ops -cNullaryLam :: Parser Group' +cNullaryLam :: Parser Group cNullaryLam = do - sym "\\." + void $ sym "\\." body <- cBlock return $ CLambda [] body -cLam :: Parser Group' +cLam :: Parser Group cLam = do - sym "\\" + void $ sym "\\" bs <- many cGroupNoJuxt - mayNotBreak $ sym "." + void $ mayNotBreak $ sym "." body <- cBlock return $ CLambda bs body -cFor :: Parser Group' +cFor :: Parser Group cFor = do kw <- forKW indices <- many cGroupNoJuxt - mayNotBreak $ sym "." + void $ mayNotBreak $ sym "." body <- cBlock return $ CFor kw indices body where forKW = keyWord ForKW $> KFor @@ -488,58 +474,26 @@ cFor = do <|> keyWord RofKW $> KRof <|> keyWord Rof_KW $> KRof_ -cDo :: Parser Group' -cDo = keyWord DoKW >> CDo <$> cBlock +cDo :: Parser Group +cDo = CDo <$> cBlock -cCase :: Parser Group' +cCase :: Parser Group cCase = do keyWord CaseKW scrut <- cGroup keyWord OfKW - alts <- onePerLine $ (,) <$> cGroupNoArrow <*> (sym "->" *> cBlock) + alts <- onePerLine cAlt return $ CCase scrut alts --- We support the following syntaxes for `if`: --- - 1-armed then-newline --- if predicate --- then consequent --- if predicate --- then --- consequent1 --- consequent2 --- - 2-armed then-newline else-newline --- if predicate --- then consequent --- else alternate --- and the three other versions where the consequent or the --- alternate are themselves blocks --- - 1-armed then-inline --- if predicate then consequent --- if predicate then --- consequent1 --- consequent2 --- - 2-armed then-inline else-inline --- if predicate then consequent else alternate --- if predicate then consequent else --- alternate1 --- alternate2 --- - Notably, an imagined 2-armed then-inline else-newline --- if predicate then --- consequent1 --- consequent2 --- else alternate --- is not an option, because the indentation lines up badly. To wit, --- one would want the `else` to be indented relative to the `if`, --- but outdented relative to the consequent block, and if the `then` is --- inline there is no indentation level that does that. --- - Last candiate is --- if predicate --- then consequent else alternate --- if predicate --- then consequent else --- alternate1 --- alternate2 -cIf :: Parser Group' +cAlt :: Parser CaseAlt +cAlt = do + pat <- cGroupNoArrow + sym "->" + body <- cBlock + return (pat, body) + +-- see [Note if-syntax] +cIf :: Parser Group cIf = mayNotBreak do keyWord IfKW predicate <- cGroup @@ -548,14 +502,14 @@ cIf = mayNotBreak do thenSameLine :: Parser (CSBlock, Maybe CSBlock) thenSameLine = do - keyWord ThenKW + void $ keyWord ThenKW cBlock >>= \case - IndentedBlock blk -> do + IndentedBlock sid blk -> do let msg = ("No `else` may follow same-line `then` and indented consequent" ++ "; indent and align both `then` and `else`, or write the " ++ "whole `if` on one line.") mayBreak $ noElse msg - return (IndentedBlock blk, Nothing) + return (IndentedBlock sid blk, Nothing) ExprBlock ex -> do alt <- optional $ (keyWord ElseKW >> cBlock) @@ -565,17 +519,17 @@ thenSameLine = do thenNewLine :: Parser (CSBlock, Maybe CSBlock) thenNewLine = withIndent $ do - keyWord ThenKW + void $ keyWord ThenKW cBlock >>= \case - IndentedBlock blk -> do + IndentedBlock sid blk -> do alt <- do -- With `mayNotBreak`, this just forbids inline else noElse ("Same-line `else` may not follow indented consequent;" ++ " put the `else` on the next line.") optional $ do - try $ nextLine >> keyWord ElseKW + void $ try $ nextLine >> keyWord ElseKW cBlock - return (IndentedBlock blk, alt) + return (IndentedBlock sid blk, alt) ExprBlock ex -> do void $ optional $ try nextLine alt <- optional $ keyWord ElseKW >> cBlock @@ -583,59 +537,69 @@ thenNewLine = withIndent $ do noElse :: String -> Parser () noElse msg = (optional $ try $ sc >> keyWord ElseKW) >>= \case - Just () -> fail msg + Just _ -> fail msg Nothing -> return () -leafGroup :: Parser Group -leafGroup = do - leaf <- leafGroup' - postOps <- many postfixGroup - return $ foldl (\accum (op, opLhs) -> joinSrc accum opLhs $ CBin (WithSrc emptySrcPosCtx op) accum opLhs) leaf postOps +leafGroup :: Parser GroupW +leafGroup = leafGroup' >>= appendPostfixGroups where - - leafGroup' :: Parser Group - leafGroup' = withSrc do + leafGroup' :: Parser GroupW + leafGroup' = do next <- nextChar case next of - '_' -> underscore $> CHole - '(' -> (CIdentifier <$> symName) + '_' -> withSrcs $ CLeaf <$> (underscore >> pure CHole) + '(' -> toCLeaf CIdentifier <$> symName <|> cParens '[' -> cBrackets - '\"' -> CString <$> strLit - '\'' -> CChar <$> charLit + '\"' -> toCLeaf CString <$> strLit + '\'' -> toCLeaf CChar <$> charLit '%' -> do - name <- primName + WithSrc sid name <- primName case strToPrimName name of - Just prim -> CPrim prim <$> argList + Just prim -> WithSrcs sid [] <$> CPrim prim <$> argList Nothing -> fail $ "Unrecognized primitive: " ++ name - _ | isDigit next -> ( CNat <$> natLit - <|> CFloat <$> doubleLit) - '\\' -> cNullaryLam <|> cLam + _ | isDigit next -> ( toCLeaf CNat <$> natLit + <|> toCLeaf CFloat <$> doubleLit) + '\\' -> withSrcs (cNullaryLam <|> cLam) -- For exprs include for, rof, for_, rof_ - 'f' -> cFor <|> cIdentifier - 'd' -> cDo <|> cIdentifier - 'r' -> cFor <|> cIdentifier - 'c' -> cCase <|> cIdentifier - 'i' -> cIf <|> cIdentifier + 'f' -> (withSrcs cFor) <|> cIdentifier + 'd' -> (withSrcs cDo) <|> cIdentifier + 'r' -> (withSrcs cFor) <|> cIdentifier + 'c' -> (withSrcs cCase) <|> cIdentifier + 'i' -> (withSrcs cIf) <|> cIdentifier _ -> cIdentifier - postfixGroup :: Parser (Bin', Group) - postfixGroup = noGap >> - ((JuxtaposeNoSpace,) <$> withSrc cParens) - <|> ((JuxtaposeNoSpace,) <$> withSrc cBrackets) - <|> ((Dot,) <$> (try $ char '.' >> withSrc cFieldName)) + appendPostfixGroups :: GroupW -> Parser GroupW + appendPostfixGroups g = + (noGap >> appendPostfixGroup g >>= appendPostfixGroups) + <|> return g -cFieldName :: Parser Group' -cFieldName = cIdentifier <|> (CNat <$> natLit) + appendPostfixGroup :: GroupW -> Parser GroupW + appendPostfixGroup g = withSrcs $ + (CJuxtapose False g <$> cParens) + <|> (CJuxtapose False g <$> cBrackets) + <|> appendFieldAccess g -cIdentifier :: Parser Group' -cIdentifier = CIdentifier <$> anyName + appendFieldAccess :: GroupW -> Parser Group + appendFieldAccess g = try do + void $ char '.' + field <- cFieldName + return $ CBin Dot g field -cParens :: Parser Group' -cParens = CParens <$> parens (commaSep cParenGroup) +cFieldName :: Parser GroupW +cFieldName = cIdentifier <|> (toCLeaf CNat <$> natLit) -cBrackets :: Parser Group' -cBrackets = CBrackets <$> brackets (commaSep cGroup) +cIdentifier :: Parser GroupW +cIdentifier = toCLeaf CIdentifier <$> anyName + +toCLeaf :: (a -> CLeaf) -> WithSrc a -> GroupW +toCLeaf toLeaf (WithSrc sid leaf) = WithSrcs sid [] $ CLeaf $ toLeaf leaf + +cParens :: Parser GroupW +cParens = withSrcs $ CParens <$> bracketedGroup lParen rParen cParenGroup + +cBrackets :: Parser GroupW +cBrackets = withSrcs $ CBrackets <$> bracketedGroup lBracket rBracket cGroup -- A `PrecTable` is enough information to (i) remove or replace -- operators for special contexts, and (ii) build the input structure @@ -649,7 +613,7 @@ makeExprParser p tbl = Expr.makeExprParser p tbl' where withoutOp :: SourceName -> PrecTable a -> PrecTable a withoutOp op tbl = map (filter ((/= op) . fst)) tbl -ops :: PrecTable Group +ops :: PrecTable GroupW ops = [ [symOpL "!"] , [juxtaposition] @@ -670,7 +634,6 @@ ops = , [symOpN "==", symOpN "!="] , [symOpL "&&"] , [symOpL "||"] - , [unOpPre "..", unOpPre "..<", unOpPost "..", unOpPost "<.."] , [symOpR "=>"] , [arrow, symOpR "->>"] , [symOpL ">>>"] @@ -689,87 +652,79 @@ ops = , [symOpL "="] ] where other = ("other", anySymOp) - backquote = ("backquote", Expr.InfixL $ opWithSrc $ backquoteName >>= (return . binApp . EvalBinOp)) - juxtaposition = ("space", Expr.InfixL $ opWithSrc $ sc $> (binApp JuxtaposeWithSpace)) + backquote = ("backquote", Expr.InfixL backquoteOp) + juxtaposition = ("space", Expr.InfixL $ sc >> addSrcIdToBinOp (return $ CJuxtapose True)) + withClausePostfix = ("with", Expr.Postfix withClausePostfixOp) arrow = ("->", Expr.InfixR arrowOp) -opWithSrc :: Parser (SrcPos -> a -> a -> a) - -> Parser (a -> a -> a) -opWithSrc p = do - (f, pos) <- withPos p - return $ f pos -{-# INLINE opWithSrc #-} - -anySymOp :: Expr.Operator Parser Group -anySymOp = Expr.InfixL $ opWithSrc $ do +addSrcIdToBinOp :: Parser (GroupW -> GroupW -> Group) -> Parser (GroupW -> GroupW -> GroupW) +addSrcIdToBinOp op = do + f <- op + sid <- freshSrcId + return \x y -> WithSrcs sid [] $ f x y +{-# INLINE addSrcIdToBinOp #-} + +addSrcIdToUnOp :: Parser (GroupW -> Group) -> Parser (GroupW -> GroupW) +addSrcIdToUnOp op = do + f <- op + sid <- freshSrcId + return \x -> WithSrcs sid [] $ f x +{-# INLINE addSrcIdToUnOp #-} + +backquoteOp :: Parser (GroupW -> GroupW -> GroupW) +backquoteOp = binApp do + fname <- backquoteName + return $ EvalBinOp fname + +anySymOp :: Expr.Operator Parser GroupW +anySymOp = Expr.InfixL $ binApp do s <- label "infix operator" (mayBreak anySym) - return $ binApp $ interp_operator s + return $ interpOperator s -infixSym :: SourceName -> Parser () -infixSym s = mayBreak $ sym $ T.pack s +infixSym :: SourceName -> Parser SrcId +infixSym s = mayBreak $ symWithId $ T.pack s -symOpN :: SourceName -> (SourceName, Expr.Operator Parser Group) +symOpN :: SourceName -> (SourceName, Expr.Operator Parser GroupW) symOpN s = (s, Expr.InfixN $ symOp s) -symOpL :: SourceName -> (SourceName, Expr.Operator Parser Group) +symOpL :: SourceName -> (SourceName, Expr.Operator Parser GroupW) symOpL s = (s, Expr.InfixL $ symOp s) -symOpR :: SourceName -> (SourceName, Expr.Operator Parser Group) +symOpR :: SourceName -> (SourceName, Expr.Operator Parser GroupW) symOpR s = (s, Expr.InfixR $ symOp s) -symOp :: SourceName -> Parser (Group -> Group -> Group) -symOp s = opWithSrc $ do - label "infix operator" (infixSym s) - return $ binApp $ interp_operator s +symOp :: SourceName -> Parser (GroupW -> GroupW -> GroupW) +symOp s = binApp do + sid <- label "infix operator" (infixSym s) + return $ interpOperator (WithSrc sid s) -arrowOp :: Parser (Group -> Group -> Group) -arrowOp = do - WithSrc src effs <- withSrc arrowOptEffs - return \lhs rhs -> WithSrc src $ CArrow lhs effs rhs +arrowOp :: Parser (GroupW -> GroupW -> GroupW) +arrowOp = addSrcIdToBinOp do + sym "->" + optEffs <- optional cEffs + return \lhs rhs -> CArrow lhs optEffs rhs -unOpPre :: SourceName -> (SourceName, Expr.Operator Parser Group) -unOpPre s = (s, Expr.Prefix $ unOp CPrefix s) +unOpPre :: SourceName -> (SourceName, Expr.Operator Parser GroupW) +unOpPre s = (s, Expr.Prefix $ prefixOp s) -unOpPost :: SourceName -> (SourceName, Expr.Operator Parser Group) -unOpPost s = (s, Expr.Postfix $ unOp CPostfix s) +prefixOp :: SourceName -> Parser (GroupW -> GroupW) +prefixOp s = addSrcIdToUnOp do + symId <- symWithId (fromString s) + return $ CPrefix (WithSrc symId s) -unOp :: (SourceName -> Group -> Group') -> SourceName -> Parser (Group -> Group) -unOp f s = do - ((), pos) <- withPos $ sym $ fromString s - return \g@(WithSrc grpPos _) -> WithSrc (joinPos (fromPos pos) grpPos) $ f s g - -binApp :: Bin' -> SrcPos -> Group -> Group -> Group -binApp f pos x y = joinSrc3 f' x y $ CBin f' x y - where f' = WithSrc (fromPos pos) f - -withClausePostfix :: (SourceName, Expr.Operator Parser Group) -withClausePostfix = ("with", op) - where - op = Expr.Postfix do - rhs <- withClause - return \lhs -> WithSrc emptySrcPosCtx $ CWith lhs rhs -- TODO: source info +binApp :: Parser Bin -> Parser (GroupW -> GroupW -> GroupW) +binApp f = addSrcIdToBinOp $ CBin <$> f -withSrc :: Parser a -> Parser (WithSrc a) -withSrc p = do - (x, pos) <- withPos p - return $ WithSrc (fromPos pos) x +withClausePostfixOp :: Parser (GroupW -> GroupW) +withClausePostfixOp = addSrcIdToUnOp do + rhs <- withClause + return \lhs -> CWith lhs rhs -joinSrc :: WithSrc a1 -> WithSrc a2 -> a3 -> WithSrc a3 -joinSrc (WithSrc p1 _) (WithSrc p2 _) x = WithSrc (joinPos p1 p2) x - -joinSrc3 :: WithSrc a1 -> WithSrc a2 -> WithSrc a3 -> a4 -> WithSrc a4 -joinSrc3 (WithSrc p1 _) (WithSrc p2 _) (WithSrc p3 _) x = - WithSrc (concatPos [p1, p2, p3]) x - -joinPos :: SrcPosCtx -> SrcPosCtx -> SrcPosCtx -joinPos (SrcPosCtx Nothing _) c@(SrcPosCtx _ _) = c -joinPos c@(SrcPosCtx _ _) (SrcPosCtx Nothing _) = c -joinPos (SrcPosCtx (Just (l, h)) spanId) (SrcPosCtx (Just (l', h')) _) = - SrcPosCtx (Just (min l l', max h h')) spanId - -concatPos :: [SrcPosCtx] -> SrcPosCtx -concatPos [] = error "concatPos: unexpected empty [SrcPosCtx]" -concatPos (pos:rest) = foldl joinPos pos rest +withSrcs :: Parser a -> Parser (WithSrcs a) +withSrcs p = do + sid <- freshSrcId + (sids, result) <- collectAtomicLexemeIds p + return $ WithSrcs sid sids result -- === primitive constructors and operators === @@ -865,3 +820,47 @@ primNames = M.fromList unary op = UUnOp op ptrTy ty = PtrType (CPU, ty) miscOp op = UMiscOp op + +-- === notes === + +-- note [if-syntax] +-- We support the following syntaxes for `if`: +-- - 1-armed then-newline +-- if predicate +-- then consequent +-- if predicate +-- then +-- consequent1 +-- consequent2 +-- - 2-armed then-newline else-newline +-- if predicate +-- then consequent +-- else alternate +-- and the three other versions where the consequent or the +-- alternate are themselves blocks +-- - 1-armed then-inline +-- if predicate then consequent +-- if predicate then +-- consequent1 +-- consequent2 +-- - 2-armed then-inline else-inline +-- if predicate then consequent else alternate +-- if predicate then consequent else +-- alternate1 +-- alternate2 +-- - Notably, an imagined 2-armed then-inline else-newline +-- if predicate then +-- consequent1 +-- consequent2 +-- else alternate +-- is not an option, because the indentation lines up badly. To wit, +-- one would want the `else` to be indented relative to the `if`, +-- but outdented relative to the consequent block, and if the `then` is +-- inline there is no indentation level that does that. +-- - Last candiate is +-- if predicate +-- then consequent else alternate +-- if predicate +-- then consequent else +-- alternate1 +-- alternate2 diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 8f5ee3acb..2c60f846e 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -82,7 +82,6 @@ newtype EnvReaderT (m::MonadKind) (n::S) (a:: *) = , MonadWriter w, Fallible, Alternative) type EnvReaderM = EnvReaderT Identity -type FallibleEnvReaderM = EnvReaderT FallibleM runEnvReaderM :: Distinct n => Env n -> EnvReaderM n a -> a runEnvReaderM bindings m = runIdentity $ runEnvReaderT bindings m @@ -132,10 +131,6 @@ instance MonadIO m => MonadIO (EnvReaderT m n) where deriving instance (Monad m, MonadState s m) => MonadState s (EnvReaderT m o) -instance (Monad m, CtxReader m) => CtxReader (EnvReaderT m o) where - getErrCtx = EnvReaderT $ lift getErrCtx - {-# INLINE getErrCtx #-} - instance (Monad m, Catchable m) => Catchable (EnvReaderT m o) where catchErr (EnvReaderT (ReaderT m)) f = EnvReaderT $ ReaderT \env -> m env `catchErr` \err -> runReaderT (runEnvReaderT' $ f err) env diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 34e1a374c..47e279ee2 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -5,11 +5,8 @@ -- https://developers.google.com/open-source/licenses/bsd module Err (Err (..), ErrType (..), Except (..), - ErrCtx (..), SrcTextCtx, Fallible (..), Catchable (..), catchErrExcept, - FallibleM (..), HardFailM (..), CtxReader (..), - runFallibleM, runHardFail, throw, - addContext, addSrcContext, addSrcTextContext, + HardFailM (..), runHardFail, throw, catchIOExcept, liftExcept, liftExceptAlt, assertEq, ignoreExcept, pprint, docAsStr, getCurrentCallStack, printCurrentCallStack @@ -24,19 +21,16 @@ import Control.Monad.State.Strict import Control.Monad.Reader import Data.Coerce import Data.Foldable (fold) -import Data.Text (Text) import Data.Text qualified as T import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc -import GHC.Generics (Generic (..)) import GHC.Stack import System.Environment import System.IO.Unsafe -import SourceInfo -- === core API === -data Err = Err ErrType ErrCtx String deriving (Show, Eq) +data Err = Err ErrType String deriving (Show, Eq) data ErrType = NoErr | ParseErr @@ -62,17 +56,8 @@ data ErrType = NoErr | SearchFailure -- used as the identity for `Alternative` instances and for MonadFail deriving (Show, Eq) -type SrcTextCtx = Maybe (Int, Text) -- Int is the offset in the source file -data ErrCtx = ErrCtx - { srcTextCtx :: SrcTextCtx - , srcPosCtx :: SrcPosCtx - , messageCtx :: [String] - , stackCtx :: Maybe [String] } - deriving (Show, Eq, Generic) - class MonadFail m => Fallible m where throwErr :: Err -> m a - addErrCtx :: ErrCtx -> m a -> m a class Fallible m => Catchable m where catchErr :: m a -> (Err -> m a) -> m a @@ -80,55 +65,14 @@ class Fallible m => Catchable m where catchErrExcept :: Catchable m => m a -> m (Except a) catchErrExcept m = catchErr (Success <$> m) (\e -> return $ Failure e) --- We have this in its own class because IO and `Except` can't implement it --- (but FallibleM can) -class Fallible m => CtxReader m where - getErrCtx :: m ErrCtx - -newtype FallibleM a = - FallibleM { fromFallibleM :: ReaderT ErrCtx Except a } - deriving (Functor, Applicative, Monad) - -instance Fallible FallibleM where - -- TODO: we end up adding the context multiple times when we do throw/catch. - -- We should fix it. - throwErr (Err errTy ctx s) = FallibleM $ ReaderT \ambientCtx -> - throwErr $ Err errTy (ambientCtx <> ctx) s - {-# INLINE throwErr #-} - addErrCtx ctx (FallibleM m) = FallibleM $ local (<> ctx) m - {-# INLINE addErrCtx #-} - -instance Catchable FallibleM where - FallibleM m `catchErr` handler = FallibleM $ ReaderT \ctx -> - case runReaderT m ctx of - Failure errs -> runReaderT (fromFallibleM $ handler errs) ctx - Success ans -> return ans - -instance CtxReader FallibleM where - getErrCtx = FallibleM ask - {-# INLINE getErrCtx #-} - -instance Alternative FallibleM where - empty = throw SearchFailure "" - {-# INLINE empty #-} - m1 <|> m2 = do - catchSearchFailure m1 >>= \case - Nothing -> m2 - Just x -> return x - {-# INLINE (<|>) #-} - catchSearchFailure :: Catchable m => m a -> m (Maybe a) catchSearchFailure m = (Just <$> m) `catchErr` \case - Err SearchFailure _ _ -> return Nothing + Err SearchFailure _ -> return Nothing err -> throwErr err instance Fallible IO where throwErr errs = throwIO errs {-# INLINE throwErr #-} - addErrCtx ctx m = do - result <- catchIOExcept m - liftExcept $ addErrCtx ctx result - {-# INLINE addErrCtx #-} instance Catchable IO where catchErr cont handler = @@ -136,10 +80,6 @@ instance Catchable IO where Success result -> return result Failure errs -> handler errs -runFallibleM :: FallibleM a -> Except a -runFallibleM m = runReaderT (fromFallibleM m) mempty -{-# INLINE runFallibleM #-} - -- === Except type === -- Except is isomorphic to `Either Err` but having a distinct type makes it @@ -167,6 +107,20 @@ instance Monad Except where Success x >>= f = f x {-# INLINE (>>=) #-} +instance Alternative Except where + empty = throw SearchFailure "" + {-# INLINE empty #-} + m1 <|> m2 = do + catchSearchFailure m1 >>= \case + Nothing -> m2 + Just x -> return x + {-# INLINE (<|>) #-} + +instance Catchable Except where + Success ans `catchErr` _ = Success ans + Failure errs `catchErr` f = f errs + {-# INLINE catchErr #-} + -- === HardFail === -- Implements Fallible by crashing. Used in type querying when we want to avoid @@ -205,24 +159,13 @@ instance MonadFail HardFailM where instance Fallible HardFailM where throwErr errs = error $ pprint errs {-# INLINE throwErr #-} - addErrCtx _ cont = cont - {-# INLINE addErrCtx #-} -- === convenience layer === throw :: Fallible m => ErrType -> String -> m a -throw errTy s = throwErr $ addCompilerStackCtx $ Err errTy mempty s +throw errTy s = throwErr $ Err errTy s {-# INLINE throw #-} -addCompilerStackCtx :: Err -> Err -addCompilerStackCtx (Err ty ctx msg) = Err ty ctx{stackCtx = compilerStack} msg - where -#ifdef DEX_DEBUG - compilerStack = getCurrentCallStack () -#else - compilerStack = stackCtx ctx -#endif - getCurrentCallStack :: () -> Maybe [String] getCurrentCallStack () = #ifdef DEX_DEBUG @@ -238,27 +181,15 @@ printCurrentCallStack :: Maybe [String] -> String printCurrentCallStack Nothing = "" printCurrentCallStack (Just frames) = fold frames -addContext :: Fallible m => String -> m a -> m a -addContext s m = addErrCtx (mempty {messageCtx = [s]}) m -{-# INLINE addContext #-} - -addSrcContext :: Fallible m => SrcPosCtx -> m a -> m a -addSrcContext ctx m = addErrCtx (mempty {srcPosCtx = ctx}) m -{-# INLINE addSrcContext #-} - -addSrcTextContext :: Fallible m => Int -> Text -> m a -> m a -addSrcTextContext offset text m = - addErrCtx (mempty {srcTextCtx = Just (offset, text)}) m - catchIOExcept :: MonadIO m => IO a -> m (Except a) catchIOExcept m = liftIO $ (liftM Success m) `catches` [ Handler \(e::Err) -> return $ Failure e - , Handler \(e::IOError) -> return $ Failure $ Err DataIOErr mempty $ show e + , Handler \(e::IOError) -> return $ Failure $ Err DataIOErr $ show e -- Propagate asynchronous exceptions like ThreadKilled; they are -- part of normal operation (of the live evaluation modes), not -- compiler bugs. , Handler \(e::AsyncException) -> liftIO $ throwIO e - , Handler \(e::SomeException) -> return $ Failure $ Err CompilerErr mempty $ show e + , Handler \(e::SomeException) -> return $ Failure $ Err CompilerErr $ show e ] liftExcept :: Fallible m => Except a -> m a @@ -287,20 +218,14 @@ assertEq x y s = if x == y then return () instance (Monoid w, Fallible m) => Fallible (WriterT w m) where throwErr errs = lift $ throwErr errs {-# INLINE throwErr #-} - addErrCtx ctx (WriterT m) = WriterT $ addErrCtx ctx m - {-# INLINE addErrCtx #-} instance Fallible [] where throwErr _ = [] {-# INLINE throwErr #-} - addErrCtx _ m = m - {-# INLINE addErrCtx #-} instance Fallible Maybe where throwErr _ = Nothing {-# INLINE throwErr #-} - addErrCtx _ m = m - {-# INLINE addErrCtx #-} -- === small pretty-printing utils === -- These are here instead of in PPrint.hs for import cycle reasons @@ -318,46 +243,18 @@ layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions -- === instances === -instance MonadFail FallibleM where - fail s = throw SearchFailure s - {-# INLINE fail #-} - instance Fallible Except where throwErr errs = Failure errs {-# INLINE throwErr #-} - addErrCtx _ (Success ans) = Success ans - addErrCtx ctx (Failure (Err errTy ctx' s)) = - Failure $ Err errTy (ctx <> ctx') s - {-# INLINE addErrCtx #-} - instance MonadFail Except where - fail s = Failure $ Err SearchFailure mempty s + fail s = Failure $ Err SearchFailure s {-# INLINE fail #-} instance Exception Err instance Pretty Err where - pretty (Err e ctx s) = pretty e <> pretty s <> prettyCtx - -- TODO: figure out a more uniform way to newlines - where prettyCtx = case ctx of - ErrCtx _ (SrcPosCtx Nothing _) [] Nothing -> mempty - _ -> hardline <> pretty ctx - -instance Pretty ErrCtx where - pretty (ErrCtx maybeTextCtx maybePosCtx messages stack) = - -- The order of messages is outer-scope-to-inner-scope, but we want to print - -- them starting the other way around (Not for a good reason. It's just what - -- we've always done.) - prettyLines (reverse messages) <> highlightedSource <> prettyStack - where - highlightedSource = case (maybeTextCtx, maybePosCtx) of - (Just (offset, text), SrcPosCtx (Just (start, stop)) _) -> - hardline <> pretty (highlightRegion (start - offset, stop - offset) text) - _ -> mempty - prettyStack = case stack of - Nothing -> mempty - Just s -> hardline <> "Compiler stack trace:" <> nest 2 (hardline <> prettyLines s) + pretty (Err e s) = pretty e <> pretty s instance Pretty a => Pretty (Except a) where pretty (Success x) = "Success:" <+> pretty x @@ -397,80 +294,15 @@ instance Pretty ErrType where instance Fallible m => Fallible (ReaderT r m) where throwErr errs = lift $ throwErr errs {-# INLINE throwErr #-} - addErrCtx ctx (ReaderT f) = ReaderT \r -> addErrCtx ctx $ f r - {-# INLINE addErrCtx #-} instance Catchable m => Catchable (ReaderT r m) where ReaderT f `catchErr` handler = ReaderT \r -> f r `catchErr` \e -> runReaderT (handler e) r -instance CtxReader m => CtxReader (ReaderT r m) where - getErrCtx = lift getErrCtx - {-# INLINE getErrCtx #-} - instance Fallible m => Fallible (StateT s m) where throwErr errs = lift $ throwErr errs {-# INLINE throwErr #-} - addErrCtx ctx (StateT f) = StateT \s -> addErrCtx ctx $ f s - {-# INLINE addErrCtx #-} instance Catchable m => Catchable (StateT s m) where StateT f `catchErr` handler = StateT \s -> f s `catchErr` \e -> runStateT (handler e) s - -instance CtxReader m => CtxReader (StateT s m) where - getErrCtx = lift getErrCtx - {-# INLINE getErrCtx #-} - -instance Semigroup ErrCtx where - ErrCtx text (SrcPosCtx p spanId) ctxStrs stk <> ErrCtx text' (SrcPosCtx p' spanId') ctxStrs' stk' = - ErrCtx (leftmostJust text text') - (SrcPosCtx (rightmostJust p p') (rightmostJust spanId spanId')) - (ctxStrs <> ctxStrs') - (leftmostJust stk stk') -- We usually extend errors form the right - -instance Monoid ErrCtx where - mempty = ErrCtx Nothing emptySrcPosCtx [] Nothing - --- === misc util stuff === - -leftmostJust :: Maybe a -> Maybe a -> Maybe a -leftmostJust (Just x) _ = Just x -leftmostJust Nothing y = y - -rightmostJust :: Maybe a -> Maybe a -> Maybe a -rightmostJust = flip leftmostJust - -prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann -prettyLines xs = foldMap (\d -> pretty d <> hardline) xs - -highlightRegion :: (Int, Int) -> Text -> Text -highlightRegion pos@(low, high) s - | low > high || high > T.length s = - error $ "Bad region: \n" ++ show pos ++ "\n" ++ T.unpack s - | otherwise = - -- TODO: flag to control line numbers - -- (disabling for now because it makes quine tests tricky) - -- "Line " ++ show (1 + lineNum) ++ "\n" - allLines !! lineNum <> "\n" <> - T.replicate start " " <> T.replicate (stop - start) "^" <> "\n" - where - allLines = T.lines s - (lineNum, start, stop) = getPosTriple pos allLines - -getPosTriple :: (Int, Int) -> [Text] -> (Int, Int, Int) -getPosTriple (start, stop) lines_ = (lineNum, start - offset, stop') - where - lineLengths = map ((+1) . T.length) lines_ - lineOffsets = cumsum lineLengths - lineNum = maxLT lineOffsets start - offset = lineOffsets !! lineNum - stop' = min (stop - offset) (lineLengths !! lineNum) - -cumsum :: [Int] -> [Int] -cumsum xs = scanl (+) 0 xs - -maxLT :: Ord a => [a] -> a -> Int -maxLT [] _ = 0 -maxLT (x:xs) n = if n < x then -1 - else 1 + maxLT xs n diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 813874bba..7983f52cb 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -87,7 +87,7 @@ instance FromName (Rename r) where fromName = JustRefer newtype ExportSigM (r::IR) (i::S) (o::S) (a:: *) = ExportSigM { - runExportSigM :: SubstReaderT (Rename r) (EnvReaderT FallibleM) i o a } + runExportSigM :: SubstReaderT (Rename r) (EnvReaderT Except) i o a } deriving ( Functor, Applicative, Monad, ScopeReader, EnvExtender, Fallible , EnvReader, SubstReader (Rename r), MonadFail) @@ -95,7 +95,7 @@ liftExportSigM :: (EnvReader m, Fallible1 m) => ExportSigM r n n a -> m n a liftExportSigM cont = do Distinct <- getDistinct env <- unsafeGetEnv - liftExcept $ runFallibleM $ runEnvReaderT env $ runSubstReaderT idSubst $ + liftExcept $ runEnvReaderT env $ runSubstReaderT idSubst $ runExportSigM cont corePiToExportSig :: CallingConvention diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index e0b9624a1..2cd8f1056 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -36,7 +36,6 @@ import Err import IRVariants import MTL1 import Name -import SourceInfo import Subst import QueryType import Types.Core @@ -109,7 +108,7 @@ inferTopUDecl (UInstance className bs params methods maybeName expl) result = do instanceAtomName <- emitTopLet (getNameHint instanceName') PlainLet $ Atom lam applyRename (instanceName' @> atomVarName instanceAtomName) result _ -> error "impossible" -inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case decl of +inferTopUDecl (ULocalDecl (WithSrcB _ decl)) result = case decl of UPass -> return $ UDeclResultDone result UExprDecl _ -> error "Shouldn't have this at the top level (should have become a command instead)" ULet letAnn p tyAnn rhs -> case p of @@ -154,10 +153,10 @@ data InfState (n::S) = InfState , infEffects :: EffectRow CoreIR n } newtype InfererM (i::S) (o::S) (a:: *) = InfererM - { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR FallibleM)) i o a } + { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR Except)) i o a } deriving (Functor, Applicative, Monad, MonadFail, Alternative, Builder CoreIR, EnvExtender, ScopableBuilder CoreIR, - ScopeReader, EnvReader, Fallible, Catchable, CtxReader, SubstReader Name) + ScopeReader, EnvReader, Fallible, Catchable, SubstReader Name) type InfererCPSB b i o a = (forall o'. DExt o o' => b o o' -> InfererM i o' a) -> InfererM i o a type InfererCPSB2 b i i' o a = (forall o'. DExt o o' => b o o' -> InfererM i' o' a) -> InfererM i o a @@ -165,7 +164,7 @@ type InfererCPSB2 b i i' o a = (forall o'. DExt o o' => b o o' -> InfererM i' o' liftInfererM :: (EnvReader m, Fallible1 m) => InfererM n n a -> m n a liftInfererM cont = do ansM <- liftBuilderT $ runReaderT1 emptyInfState $ runSubstReaderT idSubst $ runInfererM' cont - liftExcept $ runFallibleM ansM + liftExcept ansM where emptyInfState :: InfState n emptyInfState = InfState (Givens HM.empty) Pure @@ -276,9 +275,7 @@ withFreshUnificationVar -> (forall o'. (Emits o', DExt o o') => CAtomVar o' -> SolverM i o' (e o')) -> SolverM i o (e o) withFreshUnificationVar desc k cont = do - -- TODO: we shouldn't need the context stuff on `InfVarBound` anymore - ctx <- srcPosCtx <$> getErrCtx - withInferenceVar "_unif_" (InfVarBound k (ctx, desc)) \v -> do + withInferenceVar "_unif_" (InfVarBound k) \v -> do ans <- toAtomVar v >>= cont soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case Just soln -> return soln @@ -377,7 +374,7 @@ topDown :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) topDown ty uexpr = topDownPartial (typeAsPartialType ty) uexpr topDownPartial :: Emits o => PartialType o -> UExpr i -> InfererM i o (CAtom o) -topDownPartial partialTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos $ +topDownPartial partialTy exprWithSrc@(WithSrcE _ expr) = case partialTy of PartialType partialPiTy -> case expr of ULam lam -> toAtom <$> Lam <$> checkULamPartial partialPiTy lam @@ -404,7 +401,7 @@ etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do -- Doesn't introduce implicit pi binders or dependent pairs topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) -topDownExplicit reqTy exprWithSrc@(WithSrcE pos expr) = addSrcContext pos case expr of +topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of ULam lamExpr -> case reqTy of TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam lamExpr piTy _ -> throw TypeErr $ "Unexpected lambda. Expected: " ++ pprint reqTy @@ -464,13 +461,13 @@ bottomUp expr = bottomUpExplicit expr >>= instantiateSigma Infer -- Doesn't instantiate implicit args bottomUpExplicit :: Emits o => UExpr i -> InfererM i o (SigmaAtom o) -bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of +bottomUpExplicit (WithSrcE _ expr) = case expr of UVar ~(InternalName _ sn v) -> do v' <- renameM v ty <- getUVarType v' return $ SigmaUVar sn ty v' ULit l -> return $ SigmaAtom Nothing $ Con $ Lit l - UFieldAccess x (WithSrc pos' field) -> addSrcContext pos' do + UFieldAccess x (WithSrc _ field) -> do x' <- bottomUp x ty <- return $ getType x' fields <- getFieldDefs ty @@ -494,7 +491,7 @@ bottomUpExplicit (WithSrcE pos expr) = addSrcContext pos case expr of SigmaAtom Nothing <$> checkOrInferApp f' posArgs namedArgs Infer UTabApp tab args -> do tab' <- bottomUp tab - SigmaAtom Nothing <$> inferTabApp (srcPos tab) tab' args + SigmaAtom Nothing <$> inferTabApp tab' args UPi (UPiExpr bs appExpl effs ty) -> do -- TODO: check explicitness constraints withUBinders bs \(ZipB expls bs') -> do @@ -608,8 +605,8 @@ withBlockDecls => UBlock i -> (forall i' o'. (Emits o', DExt o o') => UExpr i' -> InfererM i' o' (e o')) -> InfererM i o (e o) -withBlockDecls (WithSrcE src (UBlock declsTop result)) contTop = - addSrcContext src $ go declsTop $ contTop result where +withBlockDecls (WithSrcE _ (UBlock declsTop result)) contTop = + go declsTop $ contTop result where go :: (Emits o, Zonkable e) => Nest UDecl i i' -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) @@ -623,7 +620,7 @@ withUDecl => UDecl i i' -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) -> InfererM i o (e o) -withUDecl (WithSrcB src d) cont = addSrcContext src case d of +withUDecl (WithSrcB _ d) cont = case d of UPass -> withDistinct cont UExprDecl e -> withDistinct $ bottomUp e >> cont ULet letAnn p ann rhs -> do @@ -1095,8 +1092,8 @@ pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body))) -- === n-ary applications === -inferTabApp :: Emits o => SrcPosCtx -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) -inferTabApp tabCtx tab args = addSrcContext tabCtx do +inferTabApp :: Emits o => CAtom o -> [UExpr i] -> InfererM i o (CAtom o) +inferTabApp tab args = do tabTy <- return $ getType tab args' <- inferNaryTabAppArgs tabTy args naryTabApp tab args' @@ -1116,8 +1113,7 @@ inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of _ -> throw TypeErr $ "Expected a table type but got: " ++ pprint tabTy checkSigmaDependent :: UExpr i -> PartialType o -> InfererM i o (CAtom o) -checkSigmaDependent e@(WithSrcE ctx _) ty = addSrcContext ctx $ - withReducibleEmissions depFunErrMsg $ topDownPartial (sink ty) e +checkSigmaDependent e ty = withReducibleEmissions depFunErrMsg $ topDownPartial (sink ty) e where depFunErrMsg = "Dependent functions can only be applied to fully evaluated expressions. " ++ @@ -1200,7 +1196,7 @@ inferStructDef (UStructDef tyConName paramBs fields _) = do withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do let (fieldNames, fieldTys) = unzip fields tys <- mapM checkUType fieldTys - let dataConDefs = StructFields $ zip fieldNames tys + let dataConDefs = StructFields $ zip (withoutSrc <$> fieldNames) tys return $ TyConDef tyConName roleExpls paramBs' dataConDefs inferDotMethod @@ -1470,7 +1466,7 @@ superclassDictTys (Nest b bs) = do (binderType b:) <$> superclassDictTys bs' checkMethodDef :: ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) -checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do +checkMethodDef className methodTys (WithSrcE _ m) = do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do @@ -1483,7 +1479,7 @@ checkUEffRow (UEffectRow effs t) = do effs' <- liftM eSetFromList $ mapM checkUEff $ toList effs t' <- case t of Nothing -> return NoTail - Just (~(SIInternalName _ v _ _)) -> do + Just (SourceOrInternalName ~(InternalName _ _ v)) -> do v' <- toAtomVar =<< renameM v expectEq EffKind (getType v') return $ EffectRowTail v' @@ -1491,7 +1487,7 @@ checkUEffRow (UEffectRow effs t) = do checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) checkUEff eff = case eff of - URWSEffect rws (~(SIInternalName _ region _ _)) -> do + URWSEffect rws (SourceOrInternalName ~(InternalName _ _ region)) -> do region' <- renameM region >>= toAtomVar expectEq (TyCon HeapType) (getType region') return $ RWSEffect rws (toAtom region') @@ -1519,7 +1515,7 @@ checkCasePat => UPat i i' -> CType o -> (forall o'. (Emits o', Ext o o') => InfererM i' o' (CAtom o')) -> InfererM i o (Alt CoreIR o) -checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of +checkCasePat (WithSrcB _ pat) scrutineeTy cont = case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon tyConDef <- lookupTyCon dataDefName @@ -1563,7 +1559,7 @@ bindLetPat => UPat i i' -> CAtomVar o -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) -> InfererM i o (e o) -bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of +bindLetPat (WithSrcB _ pat) v cont = case pat of UPatBinder b -> getDistinct >>= \Distinct -> extendSubst (b @> atomVarName v) cont UPatProd ps -> do let n = nestLength ps @@ -1613,7 +1609,7 @@ checkUType t = do return t' checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) -checkUParam k uty@(WithSrcE pos _) = addSrcContext pos $ +checkUParam k uty = withReducibleEmissions msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty where msg = "Can't reduce type expression: " ++ pprint uty @@ -1963,7 +1959,7 @@ extendSolution (AtomVar v _) t = isUnificationName :: EnvReader m => CAtomName n -> m n Bool isUnificationName v = lookupEnv v >>= \case - AtomNameBinding (SolverBound (InfVarBound _ _)) -> return True + AtomNameBinding (SolverBound (InfVarBound _)) -> return True _ -> return False {-# INLINE isUnificationName #-} @@ -2126,7 +2122,7 @@ typeAsSynthType = \case TyCon (DictTy dictTy) -> return $ SynthDictType dictTy TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> return $ SynthPiType (expls, Abs bs d) - ty -> Failure $ Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty + ty -> Failure $ Err TypeErr $ "Can't synthesize terms of type: " ++ pprint ty {-# SCC typeAsSynthType #-} getSuperclassClosure :: EnvReader m => Givens n -> [SynthAtom n] -> m n (Givens n) @@ -2250,7 +2246,7 @@ emptyMixedArgs :: MixedArgs (CAtom n) emptyMixedArgs = ([], []) typeErrAsSearchFailure :: InfererM i n a -> InfererM i n a -typeErrAsSearchFailure cont = cont `catchErr` \err@(Err errTy _ _) -> do +typeErrAsSearchFailure cont = cont `catchErr` \err@(Err errTy _) -> do case errTy of TypeErr -> empty _ -> throwErr err diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 5854b743f..37613d05a 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -24,7 +24,6 @@ import qualified Text.Megaparsec.Char.Lexer as L import Text.Megaparsec.Debug import Err -import SourceInfo import Types.Primitives import Types.Source import Util (toSnocList) @@ -34,10 +33,11 @@ data ParseCtx = ParseCtx , canBreak :: Bool -- used Reader-style (i.e. ask/local) , prevWhitespace :: Bool -- tracks whether we just consumed whitespace , sourceIdCounter :: Int + , curAtomicLexemes :: [SrcId] , curSourceMap :: SourceMaps } -- append to, writer-style initParseCtx :: ParseCtx -initParseCtx = ParseCtx 0 False False 0 mempty +initParseCtx = ParseCtx 0 False False 0 mempty mempty type Parser = StateT ParseCtx (Parsec Void Text) @@ -67,12 +67,15 @@ nextChar = do return $ T.head i {-# INLINE nextChar #-} -anyCaseName :: Lexer SourceName -anyCaseName = label "name" $ lexeme LowerName $ -- TODO: distinguish lowercase/uppercase +anyCaseName :: Lexer (WithSrc SourceName) +anyCaseName = label "name" $ lexeme LowerName anyCaseName' -- TODO: distinguish lowercase/uppercase + +anyCaseName' :: Lexer SourceName +anyCaseName' = checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> (T.unpack <$> takeWhileP Nothing (\c -> isAlphaNum c || c == '\'' || c == '_')) -anyName :: Lexer SourceName +anyName :: Lexer (WithSrc SourceName) anyName = anyCaseName <|> symName checkNotKeyword :: Parser String -> Parser String @@ -125,8 +128,11 @@ keyWordToken = \case PassKW -> "pass" keyWord :: KeyWord -> Lexer () -keyWord kw = lexeme Keyword $ try $ string (fromString $ keyWordToken kw) - >> notFollowedBy nameTailChar +keyWord kw = atomicLexeme Keyword $ try $ + string (fromString $ keyWordToken kw) >> notFollowedBy nameTailChar + where + nameTailChar :: Parser Char + nameTailChar = alphaNumChar <|> char '\'' <|> char '_' keyWordSet :: HS.HashSet String keyWordSet = HS.fromList keyWordStrs @@ -134,19 +140,19 @@ keyWordSet = HS.fromList keyWordStrs keyWordStrs :: [String] keyWordStrs = map keyWordToken [DefKW .. PassKW] -primName :: Lexer String +primName :: Lexer (WithSrc String) primName = lexeme MiscLexeme $ try $ char '%' >> some alphaNumChar -charLit :: Lexer Char +charLit :: Lexer (WithSrc Char) charLit = lexeme MiscLexeme $ char '\'' >> L.charLiteral <* char '\'' -strLit :: Lexer String +strLit :: Lexer (WithSrc String) strLit = lexeme StringLiteralLexeme $ char '"' >> manyTill L.charLiteral (char '"') -natLit :: Lexer Word64 +natLit :: Lexer (WithSrc Word64) natLit = lexeme LiteralLexeme $ try $ L.decimal <* notFollowedBy (char '.') -doubleLit :: Lexer Double +doubleLit :: Lexer (WithSrc Double) doubleLit = lexeme LiteralLexeme $ try L.float <|> try (fromIntegral <$> (L.decimal :: Parser Int) <* char '.') @@ -163,24 +169,30 @@ knownSymStrs = HS.fromList , "->", "->>", "=>", "?->", "?=>", "--o", "--", "<<<", ">>>" , "..", "<..", "..<", "..<", "<..<", "?", "#", "##", "#?", "#&", "#|", "@"] --- string must be in `knownSymStrs` sym :: Text -> Lexer () -sym s = lexeme Symbol $ try $ string s >> notFollowedBy symChar +sym s = atomicLexeme Symbol $ sym' s + +symWithId :: Text -> Lexer SrcId +symWithId s = liftM srcPos $ lexeme Symbol $ sym' s -anySym :: Lexer String +-- string must be in `knownSymStrs` +sym' :: Text -> Lexer () +sym' s = void $ try $ string s >> notFollowedBy symChar + +anySym :: Lexer (WithSrc String) anySym = lexeme Symbol $ try $ do s <- some symChar failIf (s `HS.member` knownSymStrs) "" return s -symName :: Lexer SourceName +symName :: Lexer (WithSrc SourceName) symName = label "symbol name" $ lexeme Symbol $ try $ do s <- between (char '(') (char ')') $ some symChar return $ "(" <> s <> ")" -backquoteName :: Lexer SourceName +backquoteName :: Lexer (WithSrc SourceName) backquoteName = label "backquoted name" $ - lexeme Symbol $ try $ between (char '`') (char '`') anyCaseName + lexeme Symbol $ try $ between (char '`') (char '`') anyCaseName' -- brackets and punctuation -- (can't treat as sym because e.g. `((` is two separate lexemes) @@ -196,10 +208,7 @@ semicolon = charLexeme ';' underscore = charLexeme '_' charLexeme :: Char -> Parser () -charLexeme c = void $ lexeme Symbol $ char c - -nameTailChar :: Parser Char -nameTailChar = alphaNumChar <|> char '\'' <|> char '_' +charLexeme c = atomicLexeme Symbol $ void $ char c symChar :: Parser Char symChar = token (\c -> if HS.member c symChars then Just c else Nothing) mempty @@ -247,35 +256,23 @@ recordNonWhitespace = modify \ctx -> ctx { prevWhitespace = False } {-# INLINE recordNonWhitespace #-} nameString :: Parser String -nameString = lexeme LowerName . try $ (:) <$> lowerChar <*> many alphaNumChar +nameString = lexemeIgnoreSrcId LowerName . try $ (:) <$> lowerChar <*> many alphaNumChar thisNameString :: Text -> Parser () -thisNameString s = lexeme MiscLexeme $ try $ string s >> notFollowedBy alphaNumChar +thisNameString s = lexemeIgnoreSrcId MiscLexeme $ try $ string s >> notFollowedBy alphaNumChar bracketed :: Parser () -> Parser () -> Parser a -> Parser a -bracketed left right p = between left right $ mayBreak $ sc >> p +bracketed left right p = do + left + ans <- mayBreak $ sc >> p + right + return ans {-# INLINE bracketed #-} -parens :: Parser a -> Parser a -parens p = bracketed lParen rParen p -{-# INLINE parens #-} - -brackets :: Parser a -> Parser a -brackets p = bracketed lBracket rBracket p -{-# INLINE brackets #-} - braces :: Parser a -> Parser a braces p = bracketed lBrace rBrace p {-# INLINE braces #-} -withPos :: Parser a -> Parser (a, SrcPos) -withPos p = do - n <- getOffset - x <- p - n' <- getOffset - return $ (x, (n, n')) -{-# INLINE withPos #-} - nextLine :: Parser () nextLine = do eol @@ -286,7 +283,9 @@ nextLine = do withSource :: Parser a -> Parser (Text, a) withSource p = do s <- getInput - (x, (start, end)) <- withPos p + start <- getOffset + x <- p + end <- getOffset return (T.take (end - start) s, x) {-# INLINE withSource #-} @@ -314,11 +313,11 @@ failIf :: Bool -> String -> Parser () failIf True s = fail s failIf False _ = return () -newSourceId :: Parser SourceId -newSourceId = do +freshSrcId :: Parser SrcId +freshSrcId = do c <- gets sourceIdCounter modify \ctx -> ctx { sourceIdCounter = c + 1 } - return $ SourceId c + return $ SrcId c withSourceMaps :: Parser a -> Parser (SourceMaps, a) withSourceMaps cont = do @@ -332,19 +331,37 @@ withSourceMaps cont = do emitSourceMaps :: SourceMaps -> Parser () emitSourceMaps m = modify \ctx -> ctx { curSourceMap = curSourceMap ctx <> m } -lexeme :: LexemeType -> Parser a -> Parser a +lexemeIgnoreSrcId :: LexemeType -> Parser a -> Parser a +lexemeIgnoreSrcId lexemeType p = withoutSrc <$> lexeme lexemeType p + +symbol :: Text -> Parser () +symbol s = void $ L.symbol sc s + +lexeme :: LexemeType -> Parser a -> Parser (WithSrc a) lexeme lexemeType p = do start <- getOffset ans <- p end <- getOffset recordNonWhitespace sc - name <- newSourceId + sid <- freshSrcId emitSourceMaps $ mempty - { lexemeList = toSnocList [name] - , lexemeInfo = M.singleton name (lexemeType, (start, end)) } - return ans + { lexemeList = toSnocList [sid] + , lexemeInfo = M.singleton sid (lexemeType, (start, end)) } + return $ WithSrc sid ans {-# INLINE lexeme #-} -symbol :: Text -> Parser () -symbol s = void $ L.symbol sc s +atomicLexeme :: LexemeType -> Parser () -> Parser () +atomicLexeme lexemeType p = do + WithSrc sid () <- lexeme lexemeType p + modify \ctx -> ctx { curAtomicLexemes = curAtomicLexemes ctx ++ [sid] } +{-# INLINE atomicLexeme #-} + +collectAtomicLexemeIds :: Parser a -> Parser ([SrcId], a) +collectAtomicLexemeIds p = do + prevAtomicLexemes <- gets curAtomicLexemes + modify \ctx -> ctx { curAtomicLexemes = [] } + ans <- p + localLexemes <- gets curAtomicLexemes + modify \ctx -> ctx { curAtomicLexemes = prevAtomicLexemes } + return (localLexemes, ans) diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index 820b5d4ae..ae7273028 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -18,7 +18,7 @@ import Data.Binary.Builder (fromByteString) import Data.ByteString.Lazy (toStrict) import qualified Data.ByteString as BS -import Paths_dex (getDataFileName) +-- import Paths_dex (getDataFileName) import Live.Eval import TopLevel diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index bb14ca55c..2011fa64d 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -12,7 +12,6 @@ import Control.Monad.Reader import Control.Monad.Writer.Class import Control.Monad.State.Strict import Control.Monad.Trans.Maybe -import qualified Control.Monad.Trans.Except as MTE import Control.Applicative import Data.Foldable (toList) @@ -137,17 +136,10 @@ instance (SinkableE r, EnvExtender m) => EnvExtender (ReaderT1 r m) where instance (Monad1 m, Fallible (m n)) => Fallible (ReaderT1 r m n) where throwErr = lift11 . throwErr - addErrCtx ctx (ReaderT1 m) = ReaderT1 $ addErrCtx ctx m - {-# INLINE addErrCtx #-} instance (Monad1 m, Catchable (m n)) => Catchable (ReaderT1 s m n) where catchErr (ReaderT1 m) f = ReaderT1 $ catchErr m (runReaderT1' . f) -instance (Monad1 m, CtxReader (m n)) => CtxReader (ReaderT1 s m n) where - getErrCtx = lift11 getErrCtx - {-# INLINE getErrCtx #-} - - -------------------- StateT1 -------------------- newtype StateT1 (s :: E) (m :: MonadKind1) (n :: S) (a :: *) = @@ -194,16 +186,10 @@ instance (SinkableE s, ScopeReader m) => ScopeReader (StateT1 s m) where instance (Monad1 m, Fallible (m n)) => Fallible (StateT1 s m n) where throwErr = lift11 . throwErr - addErrCtx ctx (WrapStateT1 m) = WrapStateT1 $ addErrCtx ctx m - {-# INLINE addErrCtx #-} instance (Monad1 m, Catchable (m n)) => Catchable (StateT1 s m n) where catchErr (WrapStateT1 m) f = WrapStateT1 $ catchErr m (runStateT1' . f) -instance (Monad1 m, CtxReader (m n)) => CtxReader (StateT1 s m n) where - getErrCtx = lift11 getErrCtx - {-# INLINE getErrCtx #-} - instance (Monad1 m, Alternative1 m) => Alternative ((StateT1 s m) n) where empty = lift11 empty {-# INLINE empty #-} @@ -259,7 +245,6 @@ runScopedT1 m s = fst <$> runStateT1 (runScopedT1' m) s deriving instance (Monad1 m, Fallible1 m) => Fallible (ScopedT1 s m n) deriving instance (Monad1 m, Catchable1 m) => Catchable (ScopedT1 s m n) -deriving instance (Monad1 m, CtxReader1 m) => CtxReader (ScopedT1 s m n) instance (SinkableE s, EnvExtender m) => EnvExtender (ScopedT1 s m) where refreshAbs ab cont = ScopedT1 \s -> do @@ -286,8 +271,6 @@ instance Monad (m n) => MonadFail (MaybeT1 m n) where instance Monad (m n) => Fallible (MaybeT1 m n) where throwErr _ = empty - addErrCtx _ cont = cont - {-# INLINE addErrCtx #-} instance EnvReader m => EnvReader (MaybeT1 m) where unsafeGetEnv = lift11 unsafeGetEnv @@ -303,39 +286,6 @@ instance EnvExtender m => EnvExtender (MaybeT1 m) where refreshAbs ab cont = MaybeT1 $ MaybeT $ refreshAbs ab \b e -> runMaybeT $ runMaybeT1' $ cont b e --------------------- FallibleT1 -------------------- - -newtype FallibleT1 (m::MonadKind1) (n::S) a = - FallibleT1 { fromFallibleT :: ReaderT ErrCtx (MTE.ExceptT Err (m n)) a } - deriving (Functor, Applicative, Monad) - -runFallibleT1 :: Monad1 m => FallibleT1 m n a -> m n (Except a) -runFallibleT1 m = - MTE.runExceptT (runReaderT (fromFallibleT m) mempty) >>= \case - Right ans -> return $ Success ans - Left errs -> return $ Failure errs -{-# INLINE runFallibleT1 #-} - -instance Monad1 m => MonadFail (FallibleT1 m n) where - fail s = throw SearchFailure s - {-# INLINE fail #-} - -instance Monad1 m => Fallible (FallibleT1 m n) where - throwErr (Err errTy ctx s) = FallibleT1 $ ReaderT \ambientCtx -> - MTE.throwE $ Err errTy (ambientCtx <> ctx) s - addErrCtx ctx (FallibleT1 m) = FallibleT1 $ local (<> ctx) m - {-# INLINE addErrCtx #-} - -instance ScopeReader m => ScopeReader (FallibleT1 m) where - unsafeGetScope = FallibleT1 $ lift $ lift unsafeGetScope - {-# INLINE unsafeGetScope #-} - getDistinct = FallibleT1 $ lift $ lift $ getDistinct - {-# INLINE getDistinct #-} - -instance EnvReader m => EnvReader (FallibleT1 m) where - unsafeGetEnv = FallibleT1 $ lift $ lift unsafeGetEnv - {-# INLINE unsafeGetEnv #-} - -------------------- StreamWriter -------------------- class Monad m => StreamWriter w m | m -> w where @@ -389,7 +339,7 @@ class MonoidE d => DiffStateE (s::E) (d::E) where newtype DiffStateT1 (s::E) (d::E) (m::MonadKind1) (n::S) (a:: *) = DiffStateT1' { runDiffStateT1'' :: StateT (s n, d n) (m n) a } deriving ( Functor, Applicative, Monad, MonadFail, MonadIO - , Fallible, Catchable, CtxReader) + , Fallible, Catchable) pattern DiffStateT1 :: ((s n, d n) -> m n (a, (s n, d n))) -> DiffStateT1 s d m n a pattern DiffStateT1 cont = DiffStateT1' (StateT cont) diff --git a/src/lib/Name.hs b/src/lib/Name.hs index bf12eec1f..dc36f6c38 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -881,9 +881,6 @@ type MonadIO2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadIO (m n l) type Catchable1 (m :: MonadKind1) = forall (n::S) . Catchable (m n ) type Catchable2 (m :: MonadKind2) = forall (n::S) (l::S) . Catchable (m n l) -type CtxReader1 (m :: MonadKind1) = forall (n::S) . CtxReader (m n ) -type CtxReader2 (m :: MonadKind2) = forall (n::S) (l::S) . CtxReader (m n l) - type MonadFail1 (m :: MonadKind1) = forall (n::S) . MonadFail (m n ) type MonadFail2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadFail (m n l) @@ -1562,13 +1559,6 @@ instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, Fallible m) => Fallible (InplaceT bindings decls m n) where throwErr errs = UnsafeMakeInplaceT \_ _ -> throwErr errs - addErrCtx ctx cont = UnsafeMakeInplaceT \env decls -> - addErrCtx ctx $ unsafeRunInplaceT cont env decls - {-# INLINE addErrCtx #-} - -instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, CtxReader m) - => CtxReader (InplaceT bindings decls m n) where - getErrCtx = lift1 getErrCtx instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m , Alternative m) @@ -1637,7 +1627,7 @@ newtype DoubleInplaceT (bindings::E) (d1::B) (d2::B) (m::MonadKind) (n::S) (a :: { unsafeRunDoubleInplaceT :: StateT (Scope UnsafeS, d1 UnsafeS UnsafeS) (InplaceT bindings d2 m n) a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, MonadWriter w, MonadReader r, MonadIO, Catchable) + , MonadWriter w, MonadReader r, MonadIO, Catchable) liftDoubleInplaceT :: (Monad m, ExtOutMap bindings d2, OutFrag d2) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index d5e8473cd..1a189ae1e 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -379,9 +379,9 @@ instance Pretty IxMethod where pretty method = p $ show method instance Pretty (SolverBinding n) where - pretty (InfVarBound ty _) = "Inference variable of type:" <+> p ty - pretty (SkolemBound ty ) = "Skolem variable of type:" <+> p ty - pretty (DictBound ty ) = "Dictionary variable of type:" <+> p ty + pretty (InfVarBound ty) = "Inference variable of type:" <+> p ty + pretty (SkolemBound ty) = "Skolem variable of type:" <+> p ty + pretty (DictBound ty) = "Dictionary variable of type:" <+> p ty instance Pretty (Binding c n) where pretty b = case b of @@ -510,24 +510,24 @@ instance Pretty Result where where maybeErr = case r of Failure err -> p err Success () -> mempty -instance Pretty (UBinder c n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UBinder c n l) where +instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UBinder' c n l) where prettyPrec b = atPrec ArgPrec case b of - UBindSource _ v -> p v - UIgnore -> "_" - UBind _ v _ -> p v + UBindSource v -> p v + UIgnore -> "_" + UBind v _ -> p v -instance PrettyE e => Pretty (WithSrcE e n) where - pretty (WithSrcE _ x) = p x +instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = p x +instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x -instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where - prettyPrec (WithSrcE _ x) = prettyPrec x +instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = p x +instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x -instance PrettyB b => Pretty (WithSrcB b n l) where - pretty (WithSrcB _ x) = p x +instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = p x +instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x -instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where - prettyPrec (WithSrcB _ x) = prettyPrec x +instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = p x +instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x instance PrettyE e => Pretty (SourceNameOr e n) where pretty (SourceName _ v) = p v @@ -1037,9 +1037,6 @@ instance Pretty SourceBlock' where pretty d = fromString $ show d instance Pretty CTopDecl where - pretty (WithSrc _ d) = p d - -instance Pretty CTopDecl' where pretty (CSDecl ann decl) = annDoc <> p decl where annDoc = case ann of PlainLet -> mempty @@ -1047,9 +1044,6 @@ instance Pretty CTopDecl' where pretty d = fromString $ show d instance Pretty CSDecl where - pretty (WithSrc _ d) = p d - -instance Pretty CSDecl' where pretty = undefined -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk @@ -1070,38 +1064,27 @@ instance Pretty AppExplicitness where pretty ImplicitApp = "->>" instance Pretty CSBlock where - pretty (IndentedBlock decls) = nest 2 $ prettyLines decls + pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls pretty (ExprBlock g) = pArg g +instance Pretty Group where pretty = prettyFromPrettyPrec instance PrettyPrec Group where - prettyPrec (WithSrc _ g) = prettyPrec g - -instance Pretty Group where - pretty = prettyFromPrettyPrec - -instance PrettyPrec Group' where - prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n - prettyPrec (CPrim prim args) = prettyOpDefault prim args - prettyPrec (CParens blk) = - atPrec ArgPrec $ "(" <> p blk <> ")" - prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g - prettyPrec (CBin (WithSrc _ JuxtaposeWithSpace) lhs rhs) = - atPrec AppPrec $ pApp lhs <+> pArg rhs - prettyPrec (CBin op lhs rhs) = - atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs - prettyPrec (CLambda args body) = - atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body - prettyPrec (CCase scrut alts) = - atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts - prettyPrec g = atPrec ArgPrec $ fromString $ show g + prettyPrec = undefined + -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n + -- prettyPrec (CPrim prim args) = prettyOpDefault prim args + -- prettyPrec (CParens blk) = + -- atPrec ArgPrec $ "(" <> p blk <> ")" + -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g + -- prettyPrec (CBin op lhs rhs) = + -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs + -- prettyPrec (CLambda args body) = + -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body + -- prettyPrec (CCase scrut alts) = + -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts + -- prettyPrec g = atPrec ArgPrec $ fromString $ show g instance Pretty Bin where - pretty (WithSrc _ b) = p b - -instance Pretty Bin' where - pretty (EvalBinOp name) = fromString name - pretty JuxtaposeWithSpace = " " - pretty JuxtaposeNoSpace = "" + pretty (EvalBinOp name) = fromString (withoutSrc name) pretty DepAmpersand = "&>" pretty Dot = "." pretty DepComma = ",>" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 5210014cf..d0fcbd93d 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -347,11 +347,11 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where isData :: EnvReader m => Type CoreIR n -> m n Bool isData ty = do result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty - case runFallibleM result of + case result of Success () -> return True Failure _ -> return False -checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () +checkDataLike :: Type CoreIR i -> SubstReaderT Name (EnvReaderT Except) i o () checkDataLike ty = case ty of StuckTy _ _ -> notData TyCon con -> case con of diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index f21a066eb..45a080f79 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -31,7 +31,7 @@ instance IRRep r => HasType r (AtomBinding r) where getType = \case LetBound (DeclBinding _ e) -> getType e MiscBound ty -> ty - SolverBound (InfVarBound ty _) -> ty + SolverBound (InfVarBound ty) -> ty SolverBound (SkolemBound ty) -> ty SolverBound (DictBound ty) -> ty NoinlineFun ty _ -> ty diff --git a/src/lib/SourceInfo.hs b/src/lib/SourceInfo.hs deleted file mode 100644 index b768af81f..000000000 --- a/src/lib/SourceInfo.hs +++ /dev/null @@ -1,47 +0,0 @@ --- Copyright 2021 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module SourceInfo ( - SrcPos, SourceId (..), SrcPosCtx (..), emptySrcPosCtx, fromPos, - pattern EmptySrcPosCtx) where - -import Data.Hashable -import Data.Store (Store (..)) -import GHC.Generics (Generic (..)) - --- === Core API === - -newtype SourceId = SourceId Int deriving (Show, Eq, Ord, Generic) - -type SrcPos = (Int, Int) - -data SrcPosCtx = SrcPosCtx (Maybe SrcPos) (Maybe SourceId) - deriving (Show, Eq, Generic) - -emptySrcPosCtx :: SrcPosCtx -emptySrcPosCtx = SrcPosCtx Nothing Nothing - -pattern EmptySrcPosCtx :: SrcPosCtx -pattern EmptySrcPosCtx = SrcPosCtx Nothing Nothing - -fromPos :: SrcPos -> SrcPosCtx -fromPos pos = SrcPosCtx (Just pos) Nothing - -instance Ord SrcPosCtx where - compare (SrcPosCtx pos spanId) (SrcPosCtx pos' spanId') = - case (pos, pos') of - (Just (l, r), Just (l', r')) -> compare (l, r', spanId) (l', r, spanId') - (Just _, _) -> GT - (_, Just _) -> LT - (_, _) -> compare spanId spanId' - -instance Hashable SourceId -instance Hashable SrcPosCtx - -instance Store SourceId -instance Store SrcPosCtx diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index f9c4abcd4..bf876df27 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -60,7 +60,7 @@ data RenamerSubst n = RenamerSubst { renamerSourceMap :: SourceMap n , renamerMayShadow :: Bool } newtype RenamerM (n::S) (a:: *) = - RenamerM { runRenamerM :: OutReaderT RenamerSubst (ScopeReaderT FallibleM) n a } + RenamerM { runRenamerM :: OutReaderT RenamerSubst (ScopeReaderT Except) n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible , ScopeReader, ScopeExtender) @@ -68,7 +68,7 @@ liftRenamer :: (EnvReader m, Fallible1 m, SinkableE e) => RenamerM n (e n) -> m liftRenamer cont = do sm <- withEnv $ envSourceMap . moduleEnv Distinct <- getDistinct - (liftExcept =<<) $ liftM runFallibleM $ liftScopeReaderT $ + (liftExcept =<<) $ liftScopeReaderT $ runOutReaderT (RenamerSubst sm False) $ runRenamerM $ cont class ( Monad1 m, ScopeReader m @@ -154,8 +154,8 @@ instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrI instance (SourceRenamableE e, SourceRenamableB b) => SourceRenamableE (Abs b e) where sourceRenameE (Abs b e) = sourceRenameB b \b' -> Abs b' <$> sourceRenameE e -instance SourceRenamableB (UBinder (AtomNameC CoreIR)) where - sourceRenameB b cont = sourceRenameUBinder UAtomVar b cont +instance SourceRenamableB (UBinder' (AtomNameC CoreIR)) where + sourceRenameB b cont = sourceRenameUBinder' UAtomVar b cont instance SourceRenamableE UAnn where sourceRenameE UNoAnn = return UNoAnn @@ -218,12 +218,10 @@ instance SourceRenamableE UEffect where sourceRenameE UIOEffect = return UIOEffect instance SourceRenamableE a => SourceRenamableE (WithSrcE a) where - sourceRenameE (WithSrcE pos e) = addSrcContext pos $ - WithSrcE pos <$> sourceRenameE e + sourceRenameE (WithSrcE pos e) = WithSrcE pos <$> sourceRenameE e instance SourceRenamableB a => SourceRenamableB (WithSrcB a) where - sourceRenameB (WithSrcB pos b) cont = addSrcContext pos $ - sourceRenameB b \b' -> cont $ WithSrcB pos b' + sourceRenameB (WithSrcB pos b) cont = sourceRenameB b \b' -> cont $ WithSrcB pos b' instance SourceRenamableB UTopDecl where sourceRenameB decl cont = case decl of @@ -302,13 +300,13 @@ sourceRenameUBinderNest asUVar (Nest b bs) cont = sourceRenameUBinderNest asUVar bs \bs' -> cont $ Nest b' bs' -sourceRenameUBinder :: (Color c, Distinct o, Renamer m) +sourceRenameUBinder' :: (Color c, Distinct o, Renamer m) => (forall l. Name c l -> UVar l) - -> UBinder c i i' - -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) + -> UBinder' c i i' + -> (forall o'. DExt o o' => UBinder' c o o' -> m o' a) -> m o a -sourceRenameUBinder asUVar ubinder cont = case ubinder of - UBindSource pos b -> do +sourceRenameUBinder' asUVar ubinder cont = case ubinder of + UBindSource b -> do SourceMap sm <- askSourceMap mayShadow <- askMayShadow let shadows = M.member b sm @@ -317,9 +315,17 @@ sourceRenameUBinder asUVar ubinder cont = case ubinder of withFreshM (getNameHint b) \freshName -> do Distinct <- getDistinct extendSourceMap b (asUVar $ binderName freshName) $ - cont $ UBind pos b freshName - UBind _ _ _ -> error "Shouldn't be source-renaming internal names" - UIgnore -> cont UIgnore + cont $ UBind b freshName + UBind _ _ -> error "Shouldn't be source-renaming internal names" + UIgnore -> cont $ UIgnore + +sourceRenameUBinder :: (Color c, Distinct o, Renamer m) + => (forall l. Name c l -> UVar l) + -> UBinder c i i' + -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) + -> m o a +sourceRenameUBinder asUVar (WithSrcB sid ubinder) cont = + sourceRenameUBinder' asUVar ubinder \ubinder' -> cont (WithSrcB sid ubinder') instance SourceRenamableE UDataDef where sourceRenameE (UDataDef tyConName paramBs dataCons) = do @@ -384,14 +390,14 @@ class SourceRenamablePat (pat::B) where -> (forall o'. DExt o o' => SiblingSet -> pat o o' -> m o' a) -> m o a -instance SourceRenamablePat (UBinder (AtomNameC CoreIR)) where +instance SourceRenamablePat (UBinder' (AtomNameC CoreIR)) where sourceRenamePat sibs ubinder cont = do newSibs <- case ubinder of - UBindSource _ b -> do + UBindSource b -> do when (S.member b sibs) $ throw RepeatedPatVarErr $ pprint b return $ S.singleton b UIgnore -> return mempty - UBind _ _ _ -> error "Shouldn't be source-renaming internal names" + UBind _ _ -> error "Shouldn't be source-renaming internal names" sourceRenameB ubinder \ubinder' -> cont (sibs <> newSibs) ubinder' @@ -431,7 +437,7 @@ instance (SourceRenamablePat p1, SourceRenamablePat p2) cont sibs' $ RightB p' instance SourceRenamablePat p => SourceRenamablePat (WithSrcB p) where - sourceRenamePat sibs (WithSrcB pos pat) cont = addSrcContext pos do + sourceRenamePat sibs (WithSrcB pos pat) cont = do sourceRenamePat sibs pat \sibs' pat' -> cont sibs' $ WithSrcB pos pat' @@ -460,11 +466,11 @@ class HasSourceNames (b::B) where instance HasSourceNames UTopDecl where sourceNames decl = case decl of ULocalDecl d -> sourceNames d - UDataDefDecl _ ~(UBindSource _ tyConName) dataConNames -> do + UDataDefDecl _ ~(WithSrcB _ (UBindSource tyConName)) dataConNames -> do S.singleton tyConName <> sourceNames dataConNames - UStructDecl ~(UBindSource _ tyConName) _ -> do + UStructDecl ~(WithSrcB _ (UBindSource tyConName)) _ -> do S.singleton tyConName - UInterface _ _ ~(UBindSource _ className) methodNames -> do + UInterface _ _ ~(WithSrcB _ (UBindSource className)) methodNames -> do S.singleton className <> sourceNames methodNames UInstance _ _ _ _ instanceName _ -> sourceNames instanceName @@ -499,11 +505,11 @@ instance HasSourceNames b => HasSourceNames (Nest b)where sourceNames (Nest b rest) = sourceNames b <> sourceNames rest -instance HasSourceNames (UBinder c) where +instance HasSourceNames (UBinder' c) where sourceNames b = case b of - UBindSource _ name -> S.singleton name + UBindSource name -> S.singleton name UIgnore -> mempty - UBind {} -> error "Shouldn't be source-renaming internal names" + UBind _ _ -> error "Shouldn't be source-renaming internal names" -- === misc instance === diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 0983908ad..06265b78a 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -334,7 +334,6 @@ deriving instance (Monad1 m, MonadFail1 m) => MonadFail (SubstReaderT v m i deriving instance (Monad1 m, Alternative1 m) => Alternative (SubstReaderT v m i o) deriving instance Fallible1 m => Fallible (SubstReaderT v m i o) deriving instance Catchable1 m => Catchable (SubstReaderT v m i o) -deriving instance CtxReader1 m => CtxReader (SubstReaderT v m i o) type ScopedSubstReader (v::V) = SubstReaderT v (ScopeReaderT Identity) :: MonadKind2 diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index c56423212..fe8e2fe9e 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -59,7 +59,6 @@ import Inline import Logging import Lower import MTL1 -import SourceInfo import Subst import Name import OccAnalysis @@ -237,12 +236,12 @@ evalSourceBlock mname block = do case resultErrs result of Failure _ -> case sbContents block of TopDecl decl -> do - case runFallibleM (parseDecl decl) of + case parseDecl decl of Success decl' -> emitSourceMap $ uDeclErrSourceMap mname decl' Failure _ -> return () _ -> return () _ -> return () - return $ filterLogs block $ addResultCtx block result + return $ filterLogs block result evalSourceBlock' :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n () @@ -267,7 +266,7 @@ evalSourceBlock' mname block = case sbContents block of s <- getDexString stringVal logTop $ TextOut s RenderHtml -> do - stringVal <- evalUExpr $ addTypeAnn expr (referTo "String") + stringVal <- evalUExpr $ addTypeAnn expr (referTo $ WithSrc (srcPos expr) "String") s <- getDexString stringVal logTop $ HtmlOut s ExportFun _ -> error "not implemented" @@ -280,23 +279,21 @@ evalSourceBlock' mname block = case sbContents block of GetType -> do -- TODO: don't actually evaluate it val <- evalUExpr expr logTop $ TextOut $ pprintCanonicalized $ getType val - DeclareForeign fname dexName cTy -> do - let b = fromString dexName :: UBinder (AtomNameC CoreIR) VoidS VoidS + DeclareForeign fname (WithSrc _ dexName) cTy -> do ty <- evalUType =<< parseExpr cTy asFFIFunType ty >>= \case Nothing -> throw TypeErr "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available - let hint = getNameHint b - fTop <- emitBinding hint $ TopFunBinding $ FFITopFun fname impFunTy + let hint = fromString dexName + fTop <- emitBinding hint $ TopFunBinding $ FFITopFun (withoutSrc fname) impFunTy vCore <- emitBinding hint $ AtomNameBinding $ FFIFunBound naryPiTy fTop - UBindSource _ sourceName <- return b emitSourceMap $ SourceMap $ - M.singleton sourceName [ModuleVar mname (Just $ UAtomVar vCore)] + M.singleton dexName [ModuleVar mname (Just $ UAtomVar vCore)] DeclareCustomLinearization fname zeros g -> do expr <- parseExpr g - lookupSourceMap fname >>= \case + lookupSourceMap (withoutSrc fname) >>= \case Nothing -> throw UnboundVarErr $ pprint fname Just (UAtomVar fname') -> do lookupCustomRules fname' >>= \case @@ -328,7 +325,7 @@ evalSourceBlock' mname block = case sbContents block of UnParseable _ s -> throw ParseErr s Misc m -> case m of GetNameType v -> do - ty <- sourceNameType v + ty <- sourceNameType (withoutSrc v) logTop $ TextOut $ pprintCanonicalized ty ImportModule moduleName -> importModule moduleName QueryEnv query -> void $ runEnvQuery query $> UnitE @@ -337,11 +334,11 @@ evalSourceBlock' mname block = case sbContents block of EmptyLines -> return () where addTypeAnn :: UExpr n -> UExpr n -> UExpr n - addTypeAnn e = WithSrcE emptySrcPosCtx . UTypeAnn e + addTypeAnn e = WithSrcE (srcPos e) . UTypeAnn e addShowAny :: UExpr n -> UExpr n - addShowAny e = WithSrcE emptySrcPosCtx $ UApp (referTo "show_any") [e] [] - referTo :: SourceName -> UExpr n - referTo = WithSrcE emptySrcPosCtx . UVar . SourceName emptySrcPosCtx + addShowAny e = WithSrcE (srcPos e) $ UApp (referTo $ WithSrc (srcPos e) "show_any") [e] [] + referTo :: SourceNameW -> UExpr n + referTo (WithSrc sid name) = WithSrcE sid $ UVar $ SourceName sid name runEnvQuery :: Topper m => EnvQuery -> m n () runEnvQuery query = do @@ -738,10 +735,6 @@ checkPass name cont = do #endif return result -addResultCtx :: SourceBlock -> Result -> Result -addResultCtx block (Result outs errs) = - Result outs (addSrcTextContext (sbOffset block) (sbText block) errs) - logTop :: TopLogger m => Output -> m () logTop x = logIO [x] diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 2ff148a21..0475f6ac8 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -38,7 +38,6 @@ import Foreign.Ptr import Name import Util (FileHash, SnocList (..), Tree (..)) import IRVariants -import SourceInfo import qualified Types.OpNames as P import Types.Primitives @@ -809,15 +808,11 @@ data InfVarDesc = deriving (Show, Generic, Eq, Ord) data SolverBinding (n::S) = - InfVarBound (CType n) InfVarCtx + InfVarBound (CType n) | SkolemBound (CType n) | DictBound (CType n) deriving (Show, Generic) --- Context for why we created an inference variable. --- This helps us give better "ambiguous variable" errors. -type InfVarCtx = (SrcPosCtx, InfVarDesc) - newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) deriving (OutFrag) @@ -2314,19 +2309,19 @@ instance AlphaHashableE LinearizationSpec instance GenericE SolverBinding where type RepE SolverBinding = EitherE3 - (PairE CType (LiftE InfVarCtx)) + CType CType CType fromE = \case - InfVarBound ty ctx -> Case0 (PairE ty (LiftE ctx)) - SkolemBound ty -> Case1 ty - DictBound ty -> Case2 ty + InfVarBound ty -> Case0 ty + SkolemBound ty -> Case1 ty + DictBound ty -> Case2 ty {-# INLINE fromE #-} toE = \case - Case0 (PairE ty (LiftE ct)) -> InfVarBound ty ct - Case1 ty -> SkolemBound ty - Case2 ty -> DictBound ty + Case0 ty -> InfVarBound ty + Case1 ty -> SkolemBound ty + Case2 ty -> DictBound ty _ -> error "impossible" {-# INLINE toE #-} diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 9212a4072..e14da16d9 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -24,7 +24,6 @@ import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M import qualified Data.Set as S -import Data.String (IsString, fromString) import Data.Text (Text) import Data.Text.Prettyprint.Doc (Pretty (..), hardline, (<+>)) import Data.Word @@ -35,26 +34,16 @@ import Data.Store (Store (..)) import Name import qualified Types.OpNames as P import IRVariants -import SourceInfo import Util (File (..), SnocList) import Types.Primitives -data SourceName' = SourceName' SrcPosCtx SourceName - deriving (Show, Eq, Ord, Generic) - -fromName :: SourceName -> SourceName' -fromName = SourceName' emptySrcPosCtx - -instance HasNameHint SourceName' where - getNameHint (SourceName' _ name) = getNameHint name - data SourceNameOr (a::E) (n::S) where -- Only appears before renaming pass - SourceName :: SrcPosCtx -> SourceName -> SourceNameOr a n + SourceName :: SrcId -> SourceName -> SourceNameOr a n -- Only appears after renaming pass -- We maintain the source name for user-facing error messages. - InternalName :: SrcPosCtx -> SourceName -> a n -> SourceNameOr a n + InternalName :: SrcId -> SourceName -> a n -> SourceNameOr a n deriving instance Eq (a n) => Eq (SourceNameOr a n) deriving instance Ord (a n) => Ord (SourceNameOr a n) deriving instance Show (a n) => Show (SourceNameOr a n) @@ -62,14 +51,10 @@ deriving instance Show (a n) => Show (SourceNameOr a n) newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr (Name c) n) deriving (Eq, Ord, Show, Generic) -pattern SISourceName :: (n ~ VoidS) => SourceName -> SourceOrInternalName c n -pattern SISourceName n = SourceOrInternalName (SourceName EmptySrcPosCtx n) - -pattern SIInternalName :: SourceName -> Name c n -> Maybe SrcPos -> Maybe SourceId -> SourceOrInternalName c n -pattern SIInternalName n a srcPos spanId = SourceOrInternalName (InternalName (SrcPosCtx srcPos spanId) n a) - -- === Source Info === +newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) + -- This is just for syntax highlighting. It won't be needed if we have -- a separate lexing pass where we have a complete lossless data type for -- lexemes. @@ -86,10 +71,10 @@ data LexemeType = type Span = (Int, Int) data SourceMaps = SourceMaps - { lexemeList :: SnocList SourceId - , lexemeInfo :: M.Map SourceId (LexemeType, Span) - , astParent :: M.Map SourceId SourceId - , astChildren :: M.Map SourceId [SourceId]} + { lexemeList :: SnocList SrcId + , lexemeInfo :: M.Map SrcId (LexemeType, Span) + , astParent :: M.Map SrcId SrcId + , astChildren :: M.Map SrcId [SrcId]} deriving (Show, Generic) instance Semigroup SourceMaps where @@ -102,93 +87,100 @@ instance Monoid SourceMaps where -- === Concrete syntax === -- The grouping-level syntax of the source language +-- aliases for the "with source ID versions" + +type GroupW = WithSrcs Group +type CTopDeclW = WithSrcs CTopDecl +type CSDeclW = WithSrcs CSDecl +type SourceNameW = WithSrc SourceName + +type BracketedGroup = WithSrcs [GroupW] -- optional arrow, effects, result type -type ExplicitParams = [Group] -type GivenClause = ([Group], Maybe [Group]) -- implicits, classes -type WithClause = [Group] -- no classes because we don't want to carry class dicts at runtime +type ExplicitParams = BracketedGroup +type GivenClause = (BracketedGroup, Maybe BracketedGroup) -- implicits, classes +type WithClause = BracketedGroup -- no classes because we don't want to carry class dicts at runtime -type CTopDecl = WithSrc CTopDecl' -data CTopDecl' - = CSDecl LetAnn CSDecl' +data CTopDecl + = CSDecl LetAnn CSDecl | CData - SourceName -- Type constructor name - ExplicitParams + SourceNameW -- Type constructor name + (Maybe ExplicitParams) (Maybe GivenClause) - [(SourceName, ExplicitParams)] -- Constructor names and argument sets + [(SourceNameW, Maybe ExplicitParams)] -- Constructor names and argument sets | CStruct - SourceName -- Type constructor name - ExplicitParams + SourceNameW -- Type constructor name + (Maybe ExplicitParams) (Maybe GivenClause) - [(SourceName, Group)] -- Field names and types + [(SourceNameW, GroupW)] -- Field names and types [(LetAnn, CDef)] | CInterface - SourceName -- Interface name + SourceNameW -- Interface name ExplicitParams - [(SourceName, Group)] -- Method declarations + [(SourceNameW, GroupW)] -- Method declarations -- header, givens (may be empty), methods, optional name. The header should contain -- the prerequisites, class name, and class arguments. | CInstanceDecl CInstanceDef deriving (Show, Generic) -type CSDecl = WithSrc CSDecl' -data CSDecl' - = CLet Group CSBlock +data CSDecl + = CLet GroupW CSBlock | CDefDecl CDef - | CExpr Group - | CBind Group CSBlock -- Arrow binder <- + | CExpr GroupW + | CBind GroupW CSBlock -- Arrow binder <- | CPass deriving (Show, Generic) -type CEffs = ([Group], Maybe Group) +type CEffs = WithSrcs ([GroupW], Maybe GroupW) data CDef = CDef - SourceName - (ExplicitParams) + SourceNameW + ExplicitParams (Maybe CDefRhs) (Maybe GivenClause) CSBlock deriving (Show, Generic) -type CDefRhs = (AppExplicitness, Maybe CEffs, Group) +type CDefRhs = (AppExplicitness, Maybe CEffs, GroupW) data CInstanceDef = CInstanceDef - SourceName -- interface name - [Group] -- args at which we're instantiating the interface + SourceNameW -- interface name + [GroupW] -- args at which we're instantiating the interface (Maybe GivenClause) - [CSDecl] -- Method definitions - (Maybe (SourceName, Maybe [Group])) -- Optional name of instance, with explicit parameters + [CSDeclW] -- Method definitions + (Maybe (SourceNameW, Maybe BracketedGroup)) -- Optional name of instance, with explicit parameters deriving (Show, Generic) -type Group = WithSrc Group' -data Group' - = CEmpty - | CIdentifier SourceName - | CPrim PrimName [Group] +data Group + = CLeaf CLeaf + | CPrim PrimName [GroupW] + | CParens [GroupW] + | CBrackets [GroupW] + | CBin Bin GroupW GroupW + | CJuxtapose Bool GroupW GroupW -- Bool means "there's a space between the groups" + | CPrefix SourceNameW GroupW -- covers unary - and unary + among others + | CGivens GivenClause + | CLambda [GroupW] CSBlock + | CFor ForKind [GroupW] CSBlock -- also for_, rof, rof_ + | CCase GroupW [CaseAlt] -- scrutinee, alternatives + | CIf GroupW CSBlock (Maybe CSBlock) + | CDo CSBlock + | CArrow GroupW (Maybe CEffs) GroupW + | CWith GroupW WithClause + deriving (Show, Generic) + +data CLeaf + = CIdentifier SourceName | CNat Word64 | CInt Int | CString String | CChar Char | CFloat Double | CHole - | CParens [Group] - | CBrackets [Group] - | CBin Bin Group Group - | CPrefix SourceName Group -- covers unary - and unary + among others - | CPostfix SourceName Group - | CLambda [Group] CSBlock - | CFor ForKind [Group] CSBlock -- also for_, rof, rof_ - | CCase Group [(Group, CSBlock)] -- scrutinee, alternatives - | CIf Group CSBlock (Maybe CSBlock) - | CDo CSBlock - | CGivens GivenClause - | CArrow Group (Maybe CEffs) Group - | CWith Group WithClause deriving (Show, Generic) -type Bin = WithSrc Bin' -data Bin' - = JuxtaposeWithSpace - | JuxtaposeNoSpace - | EvalBinOp String +type CaseAlt = (GroupW, CSBlock) -- scrutinee, lexeme Id, body + +data Bin + = EvalBinOp SourceNameW | DepAmpersand | Dot | DepComma @@ -199,7 +191,7 @@ data Bin' | FatArrow -- => | Pipe | CSEqual - deriving (Eq, Ord, Show, Generic) + deriving (Show, Generic) data LabelPrefix = PlainLabel deriving (Show, Generic) @@ -213,8 +205,8 @@ data ForKind -- `CSBlock` instead of `CBlock` because the latter is an alias for `Block CoreIR`. data CSBlock = - IndentedBlock [CSDecl] -- last decl should be a CExpr - | ExprBlock Group + IndentedBlock SrcId [CSDeclW] -- last decl should be a CExpr + | ExprBlock GroupW deriving (Show, Generic) -- === Untyped IR === @@ -244,15 +236,16 @@ data UVar (n::S) = deriving (Eq, Ord, Show, Generic) type UAtomBinder = UBinder (AtomNameC CoreIR) -data UBinder (c::C) (n::S) (l::S) where +type UBinder c = WithSrcB (UBinder' c) +data UBinder' (c::C) (n::S) (l::S) where -- Only appears before renaming pass - UBindSource :: SrcPosCtx -> SourceName -> UBinder c n n + UBindSource :: SourceName -> UBinder' c n n -- May appear before or after renaming pass - UIgnore :: UBinder c n n + UIgnore :: UBinder' c n n -- The following binders only appear after the renaming pass. -- We maintain the source name for user-facing error messages -- and named arguments. - UBind :: SrcPosCtx -> SourceName -> NameBinder c n l -> UBinder c n l + UBind :: SourceName -> NameBinder c n l -> UBinder' c n l type UBlock = WithSrcE UBlock' data UBlock' (n::S) where @@ -326,7 +319,7 @@ data UStructDef (n::S) where UStructDef :: SourceName -- source name for pretty printing -> Nest UAnnBinder n l - -> [(SourceName, UType l)] -- named payloads + -> [(SourceNameW, UType l)] -- named payloads -> [(LetAnn, SourceName, Abs UAtomBinder ULamExpr l)] -- named methods (initial binder is for `self`) -> UStructDef n @@ -381,6 +374,7 @@ data UMethodDef' (n::S) = UMethodDef (SourceNameOr (Name MethodNameC) n) (ULamEx data UAnn (n::S) = UAnn (UType n) | UNoAnn deriving Show +-- TODO: SrcId data UAnnBinder (n::S) (l::S) = UAnnBinder Explicitness (UAtomBinder n l) (UAnn n) [UConstraint n] deriving (Show, Generic) @@ -397,48 +391,84 @@ data UPat' (n::S) (l::S) = | UPatTable (Nest UPat n l) deriving (Show, Generic) -pattern UPatIgnore :: UPat' (n::S) n -pattern UPatIgnore = UPatBinder UIgnore - -- === source names for error messages === class HasSourceName a where getSourceName :: a -> SourceName +instance HasSourceName (b n l) => HasSourceName (WithSrcB b n l) where + getSourceName (WithSrcB _ b) = getSourceName b + instance HasSourceName (UAnnBinder n l) where getSourceName (UAnnBinder _ b _ _) = getSourceName b -instance HasSourceName (UBinder c n l) where +instance HasSourceName (UBinder' c n l) where getSourceName = \case - UBindSource _ sn -> sn - UIgnore -> "_" - UBind _ sn _ -> sn + UBindSource sn -> sn + UIgnore -> "_" + UBind sn _ -> sn -- === Source context helpers === -data WithSrc a = WithSrc SrcPosCtx a +-- First SrcId is for the group itself. The rest are for keywords, symbols, etc. +data WithSrcs a = WithSrcs SrcId [SrcId] a + deriving (Show, Functor, Generic) + +data WithSrc a = WithSrc SrcId a deriving (Show, Functor, Generic) -data WithSrcE (a::E) (n::S) = WithSrcE SrcPosCtx (a n) +data WithSrcE (a::E) (n::S) = WithSrcE SrcId (a n) deriving (Show, Generic) -data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcPosCtx (binder n l) +data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcId (binder n l) deriving (Show, Generic) -class HasSrcPos a where - srcPos :: a -> SrcPosCtx +class HasSrcPos withSrc a | withSrc -> a where + srcPos :: withSrc -> SrcId + withoutSrc :: withSrc -> a + +instance HasSrcPos (WithSrc (a:: *)) a where + srcPos (WithSrc pos _) = pos + withoutSrc (WithSrc _ x) = x + +instance HasSrcPos (WithSrcs (a:: *)) a where + srcPos (WithSrcs pos _ _) = pos + withoutSrc (WithSrcs _ _ x) = x -instance HasSrcPos (WithSrcE (a::E) (n::S)) where +instance HasSrcPos (WithSrcE (e::E) (n::S)) (e n) where srcPos (WithSrcE pos _) = pos + withoutSrc (WithSrcE _ x) = x -instance HasSrcPos (WithSrcB (b::B) (n::S) (n::S)) where +instance HasSrcPos (WithSrcB (b::B) (n::S) (l::S)) (b n l) where srcPos (WithSrcB pos _) = pos + withoutSrc (WithSrcB _ x) = x -instance HasSrcPos (UBinder c n l) where - srcPos = \case - UBindSource ctx _ -> ctx - UIgnore -> SrcPosCtx Nothing Nothing - UBind ctx _ _ -> ctx +class FromSourceNameW a where + fromSourceNameW :: SourceNameW -> a + +instance FromSourceNameW (SourceNameOr a VoidS) where + fromSourceNameW (WithSrc sid x) = SourceName sid x + +instance FromSourceNameW (SourceOrInternalName c VoidS) where + fromSourceNameW x = SourceOrInternalName $ fromSourceNameW x + +instance FromSourceNameW (UBinder' s VoidS VoidS) where + fromSourceNameW x = UBindSource $ withoutSrc x + +instance FromSourceNameW (UPat' VoidS VoidS) where + fromSourceNameW = UPatBinder . fromSourceNameW + +instance FromSourceNameW (UAnnBinder VoidS VoidS) where + fromSourceNameW s = UAnnBinder Explicit (fromSourceNameW s) UNoAnn [] + +instance FromSourceNameW (UExpr' VoidS) where + fromSourceNameW = UVar . fromSourceNameW + +instance FromSourceNameW (a n) => FromSourceNameW (WithSrcE a n) where + fromSourceNameW x = WithSrcE (srcPos x) $ fromSourceNameW x + +instance FromSourceNameW (b n l) => FromSourceNameW (WithSrcB b n l) where + fromSourceNameW x = WithSrcB (srcPos x) $ fromSourceNameW x -- === SourceMap === @@ -487,16 +517,16 @@ data SymbolicZeros = SymbolicZeros | InstantiateZeros deriving (Generic, Eq, Show) data SourceBlock' - = TopDecl CTopDecl - | Command CmdName Group - | DeclareForeign SourceName SourceName Group - | DeclareCustomLinearization SourceName SymbolicZeros Group + = TopDecl CTopDeclW + | Command CmdName GroupW + | DeclareForeign SourceNameW SourceNameW GroupW + | DeclareCustomLinearization SourceNameW SymbolicZeros GroupW | Misc SourceBlockMisc | UnParseable ReachedEOF String deriving (Show, Generic) data SourceBlockMisc - = GetNameType SourceName + = GetNameType SourceNameW | ImportModule ModuleSourceName | QueryEnv EnvQuery | ProseBlock Text @@ -669,32 +699,48 @@ instance HasNameHint ModuleSourceName where getNameHint Prelude = getNameHint @String "prelude" getNameHint Main = getNameHint @String "main" -instance HasNameHint (UBinder c n l) where +instance HasNameHint (UBinder' c n l) where getNameHint b = case b of - UBindSource _ v -> getNameHint v - UIgnore -> noHint - UBind _ v _ -> getNameHint v + UBindSource v -> getNameHint v + UIgnore -> noHint + UBind v _ -> getNameHint v -instance Color c => BindsNames (UBinder c) where - toScopeFrag (UBindSource _ _) = emptyOutFrag +instance Color c => BindsNames (UBinder' c) where + toScopeFrag (UBindSource _) = emptyOutFrag toScopeFrag (UIgnore) = emptyOutFrag - toScopeFrag (UBind _ _ b) = toScopeFrag b + toScopeFrag (UBind _ b) = toScopeFrag b -instance Color c => ProvesExt (UBinder c) where -instance Color c => BindsAtMostOneName (UBinder c) c where +instance Color c => ProvesExt (UBinder' c) where +instance Color c => BindsAtMostOneName (UBinder' c) c where b @> x = case b of - UBindSource _ _ -> emptyInFrag - UIgnore -> emptyInFrag - UBind _ _ b' -> b' @> x + UBindSource _ -> emptyInFrag + UIgnore -> emptyInFrag + UBind _ b' -> b' @> x -instance Color c => SinkableB (UBinder c) where +instance Color c => SinkableB (UBinder' c) where sinkingProofB _ _ _ = todoSinkableProof -instance Color c => RenameB (UBinder c) where +instance Color c => RenameB (UBinder' c) where renameB env ub cont = case ub of - UBindSource pos sn -> cont env $ UBindSource pos sn + UBindSource sn -> cont env $ UBindSource sn UIgnore -> cont env UIgnore - UBind ctx sn b -> renameB env b \env' b' -> cont env' $ UBind ctx sn b' + UBind sn b -> renameB env b \env' b' -> cont env' $ UBind sn b' + +instance SinkableB b => SinkableB (WithSrcB b) where + sinkingProofB _ _ _ = todoSinkableProof + +instance RenameB b => RenameB (WithSrcB b) where + renameB env (WithSrcB sid b) cont = + renameB env b \env' b' -> cont env' (WithSrcB sid b') + +instance ProvesExt b => ProvesExt (WithSrcB b) where + toExtEvidence (WithSrcB _ b) = toExtEvidence b + +instance BindsNames b => BindsNames (WithSrcB b) where + toScopeFrag (WithSrcB _ b) = toScopeFrag b + +instance BindsAtMostOneName b r => BindsAtMostOneName (WithSrcB b) r where + WithSrcB _ b @> x = b @> x instance ProvesExt UAnnBinder where instance BindsNames UAnnBinder where @@ -704,7 +750,7 @@ instance BindsAtMostOneName UAnnBinder (AtomNameC CoreIR) where UAnnBinder _ b _ _ @> x = b @> x instance GenericE (WithSrcE e) where - type RepE (WithSrcE e) = PairE (LiftE SrcPosCtx) e + type RepE (WithSrcE e) = PairE (LiftE SrcId) e fromE (WithSrcE ctx x) = PairE (LiftE ctx) x toE (PairE (LiftE ctx) x) = WithSrcE ctx x @@ -738,37 +784,7 @@ instance Store (SourceMap n) instance Hashable ModuleSourceName -instance Store SourceName' -instance Hashable SourceName' - -instance IsString SourceName' where - fromString = SourceName' emptySrcPosCtx - -instance IsString (SourceNameOr a VoidS) where - fromString = SourceName emptySrcPosCtx - -instance IsString (SourceOrInternalName c VoidS) where - fromString = SISourceName - -instance IsString (UBinder s VoidS VoidS) where - fromString = UBindSource emptySrcPosCtx - -instance IsString (UPat' VoidS VoidS) where - fromString = UPatBinder . fromString - -instance IsString (UAnnBinder VoidS VoidS) where - fromString s = UAnnBinder Explicit (fromString s) UNoAnn [] - -instance IsString (UExpr' VoidS) where - fromString = UVar . fromString - -instance IsString (a n) => IsString (WithSrcE a n) where - fromString = WithSrcE emptySrcPosCtx . fromString - -instance IsString (b n l) => IsString (WithSrcB b n l) where - fromString = WithSrcB emptySrcPosCtx . fromString - -deriving instance Show (UBinder s n l) +deriving instance Show (UBinder' s n l) deriving instance Show (UDataDefTrail n) deriving instance Show (ULamExpr n) deriving instance Show (UPiExpr n) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index eafefb538..d6fec397e 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -85,7 +85,7 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM SubstReaderT Name (ReaderT1 CommuteMap (ReaderT1 (LiftE Word32) - (StateT1 (LiftE [Err]) (BuilderT SimpIR FallibleM)))) i o a } + (StateT1 (LiftE [Err]) (BuilderT SimpIR Except)))) i o a } deriving ( Functor, Applicative, Monad, MonadFail, MonadReader (CommuteMap o) , MonadState (LiftE [Err] o), Fallible, ScopeReader, EnvReader , EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable @@ -108,23 +108,15 @@ liftTopVectorizeM vectorByteWidth action = do flip runStateT1 mempty $ runReaderT1 (LiftE vectorByteWidth) $ runReaderT1 mempty $ runSubstReaderT idSubst $ runTopVectorizeM action - case runFallibleM fallible of + case fallible of -- The failure case should not occur: vectorization errors should have been -- caught inside `vectorizeLoopsDecls` (and should have been added to the -- `Err` state of the `StateT` instance that is run with `runStateT` above). Failure errs -> error $ pprint errs Success (a, (LiftE errs)) -> return $ (a, errs) -addVectErrCtx :: Fallible m => String -> String -> m a -> m a -addVectErrCtx name payload m = - let ctx = mempty { messageCtx = ["In `" ++ name ++ "`:\n" ++ payload] } - in addErrCtx ctx m - throwVectErr :: Fallible m => String -> m a -throwVectErr msg = throwErr (Err MiscErr mempty msg) - -prependCtxToErr :: ErrCtx -> Err -> Err -prependCtxToErr ctx (Err ty ctx' msg) = Err ty (ctx <> ctx') msg +throwVectErr msg = throw MiscErr msg askVectorByteWidth :: TopVectorizeM i o Word32 askVectorByteWidth = TopVectorizeM $ liftSubstReaderT $ lift11 (fromLiftE <$> ask) @@ -181,10 +173,7 @@ vectorizeLoopsExpr expr = do emit =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') else renameM expr >>= emit) `catchErr` \err -> do - let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr - ctx = mempty { messageCtx = [msg] } - err' = prependCtxToErr ctx err - modify (\(LiftE errs) -> LiftE (err':errs)) + modify (\(LiftE errs) -> LiftE (err:errs)) recurSeq expr _ -> recurSeq expr PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do @@ -333,7 +322,7 @@ vectorizeSeq _ _ _ = error "expected a unary lambda expression" newtype VectorizeM i o a = VectorizeM { runVectorizeM :: - SubstReaderT VSubstValC (BuilderT SimpIR (ReaderT Word32 FallibleM)) i o a } + SubstReaderT VSubstValC (BuilderT SimpIR (ReaderT Word32 Except)) i o a } deriving ( Functor, Applicative, Monad, Fallible, MonadFail , SubstReader VSubstValC , Builder SimpIR, EnvReader, EnvExtender , ScopeReader, ScopableBuilder SimpIR) @@ -343,8 +332,7 @@ liftVectorizeM :: (SubstReader Name m, EnvReader (m i), Fallible (m i o)) liftVectorizeM loopWidth action = do subst <- getSubst act <- liftBuilderT $ runSubstReaderT (newSubst $ vSubst subst) $ runVectorizeM action - let fallible = flip runReaderT loopWidth act - case runFallibleM fallible of + case flip runReaderT loopWidth act of Success a -> return a Failure errs -> throwErr errs -- re-raise inside ambient monad where @@ -374,7 +362,7 @@ vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of _ -> error "Zip error" vectorizeExpr :: Emits o => SExpr i -> VectorizeM i o (VAtom o) -vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do +vectorizeExpr expr = do case expr of Block _ block -> vectorizeBlock block TabApp _ tbl ix -> do @@ -523,7 +511,7 @@ vectorizeType t = do fmapNamesM (uniformSubst subst) t vectorizeAtom :: SAtom i -> VectorizeM i o (VAtom o) -vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do +vectorizeAtom atom = do case atom of Stuck _ e -> vectorizeStuck e Con con -> case con of @@ -557,13 +545,12 @@ uniformSubst subst n = case subst ! n of _ -> error "Can't vectorize atom" getVectorType :: SType o -> VectorizeM i o (SType o) -getVectorType ty = addVectErrCtx "getVectorType" ("Type:\n" ++ pprint ty) do - case ty of - BaseTy (Scalar sbt) -> do - els <- getLoopWidth - return $ BaseTy $ Vector [els] sbt - -- TODO: Should we support tables? - _ -> throwVectErr $ "Can't make a vector of " ++ pprint ty +getVectorType ty = case ty of + BaseTy (Scalar sbt) -> do + els <- getLoopWidth + return $ BaseTy $ Vector [els] sbt + -- TODO: Should we support tables? + _ -> throwVectErr $ "Can't make a vector of " ++ pprint ty ensureVarying :: Emits o => VAtom o -> VectorizeM i o (SAtom o) ensureVarying (VVal s val) = case s of From 28d3538b5d8c508303481152c6f14e29dc40c9db Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 22 Nov 2023 21:44:16 -0500 Subject: [PATCH 21/41] Traverse concrete AST to get information about relationships between bits of source text --- dex.cabal | 1 + src/lib/AbstractSyntax.hs | 2 +- src/lib/ConcreteSyntax.hs | 38 ++++++------ src/lib/ImpToLLVM.hs | 12 ++-- src/lib/Inference.hs | 8 +-- src/lib/Lexing.hs | 28 ++++----- src/lib/PPrint.hs | 4 +- src/lib/QueryType.hs | 2 +- src/lib/RenderHtml.hs | 10 ++-- src/lib/RuntimePrint.hs | 16 ++--- src/lib/SourceIdTraversal.hs | 113 +++++++++++++++++++++++++++++++++++ src/lib/SourceRename.hs | 2 +- src/lib/TopLevel.hs | 12 ++-- src/lib/Types/Primitives.hs | 16 ++++- src/lib/Types/Source.hs | 26 +++++--- 15 files changed, 214 insertions(+), 76 deletions(-) create mode 100644 src/lib/SourceIdTraversal.hs diff --git a/dex.cabal b/dex.cabal index f8eb49f41..6ce74e218 100644 --- a/dex.cabal +++ b/dex.cabal @@ -88,6 +88,7 @@ library , Simplify , Subst , SourceRename + , SourceIdTraversal , TopLevel , Transpose , Types.Core diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index bc1026a81..8d8ec9818 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -496,7 +496,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of WithSrcE _ (UIntLit i) -> UIntLit (-i) WithSrcE _ (UFloatLit i) -> UFloatLit (-i) e -> unaryApp (mkUVar sid "neg") e - _ -> throw SyntaxErr $ "Prefix (" ++ name ++ ") not legal as a bare expression" + _ -> throw SyntaxErr $ "Prefix (" ++ pprint name ++ ") not legal as a bare expression" CLambda params body -> do params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index b37f3a747..2aa2c0f6e 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -31,6 +31,7 @@ import Lexing import Types.Core import Types.Source import Types.Primitives +import SourceIdTraversal (getASTInfo) import qualified Types.OpNames as P import Util @@ -59,7 +60,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty (Misc $ ImportModule Prelude) +preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty mempty (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -82,7 +83,7 @@ interpOperator (WithSrc sid s) = case s of "->>" -> ImplicitArrow "=>" -> FatArrow "=" -> CSEqual - name -> EvalBinOp $ WithSrc sid $ "(" <> name <> ")" + name -> EvalBinOp $ WithSrc sid $ fromString $ "(" <> name <> ")" pattern Identifier :: SourceName -> GroupW pattern Identifier name <- (WithSrcs _ _ (CLeaf (CIdentifier name))) @@ -93,12 +94,13 @@ sourceBlock :: Parser SourceBlock sourceBlock = do offset <- getOffset pos <- getSourcePos - (src, (sm, (level, b))) <- withSource $ withSourceMaps $ withRecovery recover do + (src, (lexInfo, (level, b))) <- withSource $ withLexemeInfo $ withRecovery recover do level <- logLevel <|> logTime <|> logBench <|> return LogNothing b <- sourceBlock' return (level, b) - let sm' = sm { lexemeInfo = lexemeInfo sm <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} - return $ SourceBlock (unPos (sourceLine pos)) offset level src sm' b + let lexInfo' = lexInfo { lexemeInfo = lexemeInfo lexInfo <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} + let astInfo = getASTInfo b + return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' astInfo b recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') recover e = do @@ -124,7 +126,7 @@ declareForeign = do void $ label "type annotation" $ sym ":" ty <- cGroup eol - return $ DeclareForeign foreignName b ty + return $ DeclareForeign (fmap fromString foreignName) b ty declareCustomLinearization :: Parser SourceBlock' declareCustomLinearization = do @@ -681,19 +683,19 @@ anySymOp = Expr.InfixL $ binApp do s <- label "infix operator" (mayBreak anySym) return $ interpOperator s -infixSym :: SourceName -> Parser SrcId +infixSym :: String -> Parser SrcId infixSym s = mayBreak $ symWithId $ T.pack s -symOpN :: SourceName -> (SourceName, Expr.Operator Parser GroupW) -symOpN s = (s, Expr.InfixN $ symOp s) +symOpN :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpN s = (fromString s, Expr.InfixN $ symOp s) -symOpL :: SourceName -> (SourceName, Expr.Operator Parser GroupW) -symOpL s = (s, Expr.InfixL $ symOp s) +symOpL :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpL s = (fromString s, Expr.InfixL $ symOp s) -symOpR :: SourceName -> (SourceName, Expr.Operator Parser GroupW) -symOpR s = (s, Expr.InfixR $ symOp s) +symOpR :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpR s = (fromString s, Expr.InfixR $ symOp s) -symOp :: SourceName -> Parser (GroupW -> GroupW -> GroupW) +symOp :: String -> Parser (GroupW -> GroupW -> GroupW) symOp s = binApp do sid <- label "infix operator" (infixSym s) return $ interpOperator (WithSrc sid s) @@ -704,13 +706,13 @@ arrowOp = addSrcIdToBinOp do optEffs <- optional cEffs return \lhs rhs -> CArrow lhs optEffs rhs -unOpPre :: SourceName -> (SourceName, Expr.Operator Parser GroupW) -unOpPre s = (s, Expr.Prefix $ prefixOp s) +unOpPre :: String -> (SourceName, Expr.Operator Parser GroupW) +unOpPre s = (fromString s, Expr.Prefix $ prefixOp s) -prefixOp :: SourceName -> Parser (GroupW -> GroupW) +prefixOp :: String -> Parser (GroupW -> GroupW) prefixOp s = addSrcIdToUnOp do symId <- symWithId (fromString s) - return $ CPrefix (WithSrc symId s) + return $ CPrefix (WithSrc symId $ fromString s) binApp :: Parser Bin -> Parser (GroupW -> GroupW -> GroupW) binApp f = addSrcIdToBinOp $ CBin <$> f diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index c5691cd2f..396aa980e 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -311,7 +311,7 @@ compileInstr instr = case instr of compileIf p' (compileVoidBlock cons) (compileVoidBlock alt) IQueryParallelism f s -> do let IFunType cc _ _ = snd f - let kernelFuncName = topLevelFunName $ fst f + let kernelFuncName = topLevelFunName $ fromString $ fst f n <- (`asIntWidth` i64) =<< compileExpr s case cc of MCThreadLaunch -> do @@ -339,7 +339,7 @@ compileInstr instr = case instr of ILaunch (fname, IFunType cc _ _) size args -> [] <$ do size' <- (`asIntWidth` i64) =<< compileExpr size args' <- mapM compileExpr args - let kernelFuncName = topLevelFunName fname + let kernelFuncName = topLevelFunName (fromString fname) case cc of MCThreadLaunch -> do kernelParams <- packArgs args' @@ -508,11 +508,11 @@ compileInstr instr = case instr of let resultTys = map scalarTy impResultTys case cc of FFICC -> do - ans <- emitExternCall (makeFunSpec fname ty) args' + ans <- emitExternCall (makeFunSpec (fromString fname) ty) args' return [ans] FFIMultiResultCC -> do resultPtr <- makeMultiResultAlloc resultTys - emitVoidExternCall (makeFunSpec fname ty) (resultPtr : args') + emitVoidExternCall (makeFunSpec (fromString fname) ty) (resultPtr : args') loadMultiResultAlloc resultTys resultPtr _ -> error $ "Unsupported calling convention: " ++ show cc DebugPrint fmtStr x -> [] <$ do @@ -539,11 +539,11 @@ compileInstr instr = case instr of -- TODO: use a careful naming discipline rather than strings -- (this is only used on the CUDA path which is currently broken anyway) topLevelFunName :: SourceName -> L.Name -topLevelFunName name = fromString name +topLevelFunName name = fromString $ pprint name makeFunSpec :: SourceName -> IFunType -> ExternFunSpec makeFunSpec name impFunTy = - ExternFunSpec (L.Name (fromString name)) retTy [] [] argTys + ExternFunSpec (L.Name (fromString $ pprint name)) retTy [] [] argTys where (retTy, argTys) = impFunTyToLLVMTy impFunTy impFunTyToLLVMTy :: IFunType -> LLVMFunType diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 2cd8f1056..bd1964d15 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -638,7 +638,7 @@ applyFromLiteralMethod :: Emits n => CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) applyFromLiteralMethod resultTy methodName litVal = lookupSourceMap methodName >>= \case - Nothing -> error $ "prelude function not found: " ++ methodName + Nothing -> error $ "prelude function not found: " ++ pprint methodName Just ~(UMethodVar methodName') -> do MethodBinding className _ <- lookupEnv methodName' dictTy <- toType <$> dictType className [toAtom resultTy] @@ -955,14 +955,14 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs (arg:argsRest, namedArgs) <- return args if isHole arg then do - let desc = (fSourceName, "_") + let desc = (pprint fSourceName, "_") withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) (argsRest, namedArgs) else do arg' <- checkOrInferExplicitArg isDependent arg argTy withDistinct $ cont arg' (argsRest, namedArgs) Inferred argName infMech -> do - let desc = (fSourceName, fromMaybe "_" argName) + let desc = (pprint $ fSourceName, fromMaybe "_" (fmap pprint argName)) case lookupNamedArg args argName of Just arg -> do arg' <- checkOrInferExplicitArg isDependent arg argTy @@ -1471,7 +1471,7 @@ checkMethodDef className methodTys (WithSrcE _ m) = do MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do ClassBinding classDef <- lookupEnv className - throw TypeErr $ pprint sourceName ++ " is not a method of " ++ getSourceName classDef + throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint (getSourceName classDef) (i,) <$> toAtom <$> Lam <$> checkULam rhs (methodTys !! i) checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 37613d05a..d40c4c8a4 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -32,12 +32,12 @@ data ParseCtx = ParseCtx { curIndent :: Int -- used Reader-style (i.e. ask/local) , canBreak :: Bool -- used Reader-style (i.e. ask/local) , prevWhitespace :: Bool -- tracks whether we just consumed whitespace - , sourceIdCounter :: Int + , sourceIdCounter :: Int -- starts at 1 (0 is reserved for the root) , curAtomicLexemes :: [SrcId] - , curSourceMap :: SourceMaps } -- append to, writer-style + , curLexemeInfo :: LexemeInfo } -- append to, writer-style initParseCtx :: ParseCtx -initParseCtx = ParseCtx 0 False False 0 mempty mempty +initParseCtx = ParseCtx 0 False False 1 mempty mempty type Parser = StateT ParseCtx (Parsec Void Text) @@ -72,7 +72,7 @@ anyCaseName = label "name" $ lexeme LowerName anyCaseName' -- TODO: distinguish anyCaseName' :: Lexer SourceName anyCaseName' = - checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> + liftM MkSourceName $ checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> (T.unpack <$> takeWhileP Nothing (\c -> isAlphaNum c || c == '\'' || c == '_')) anyName :: Lexer (WithSrc SourceName) @@ -188,7 +188,7 @@ anySym = lexeme Symbol $ try $ do symName :: Lexer (WithSrc SourceName) symName = label "symbol name" $ lexeme Symbol $ try $ do s <- between (char '(') (char ')') $ some symChar - return $ "(" <> s <> ")" + return $ MkSourceName $ "(" <> s <> ")" backquoteName :: Lexer (WithSrc SourceName) backquoteName = label "backquoted name" $ @@ -319,17 +319,17 @@ freshSrcId = do modify \ctx -> ctx { sourceIdCounter = c + 1 } return $ SrcId c -withSourceMaps :: Parser a -> Parser (SourceMaps, a) -withSourceMaps cont = do - smPrev <- gets curSourceMap - modify \ctx -> ctx { curSourceMap = mempty } +withLexemeInfo :: Parser a -> Parser (LexemeInfo, a) +withLexemeInfo cont = do + smPrev <- gets curLexemeInfo + modify \ctx -> ctx { curLexemeInfo = mempty } result <- cont - sm <- gets curSourceMap - modify \ctx -> ctx { curSourceMap = smPrev } + sm <- gets curLexemeInfo + modify \ctx -> ctx { curLexemeInfo = smPrev } return (sm, result) -emitSourceMaps :: SourceMaps -> Parser () -emitSourceMaps m = modify \ctx -> ctx { curSourceMap = curSourceMap ctx <> m } +emitLexemeInfo :: LexemeInfo -> Parser () +emitLexemeInfo m = modify \ctx -> ctx { curLexemeInfo = curLexemeInfo ctx <> m } lexemeIgnoreSrcId :: LexemeType -> Parser a -> Parser a lexemeIgnoreSrcId lexemeType p = withoutSrc <$> lexeme lexemeType p @@ -345,7 +345,7 @@ lexeme lexemeType p = do recordNonWhitespace sc sid <- freshSrcId - emitSourceMaps $ mempty + emitLexemeInfo $ mempty { lexemeList = toSnocList [sid] , lexemeInfo = M.singleton sid (lexemeType, (start, end)) } return $ WithSrc sid ans diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 1a189ae1e..bac0e3bbe 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -835,7 +835,7 @@ instance PrettyPrec (NewtypeTyCon n) where UserADTType name _ (TyConParams infs params) -> case (infs, params) of ([], []) -> atPrec ArgPrec $ p name ([Explicit, Explicit], [l, r]) - | Just sym <- fromInfix (fromString name) -> + | Just sym <- fromInfix (fromString $ pprint name) -> atPrec ArgPrec $ align $ group $ parens $ flatAlt " " "" <> pApp l <> line <> p sym <+> pApp r _ -> atPrec LowestPrec $ pAppArg (p name) $ ignoreSynthParams (TyConParams infs params) @@ -1084,7 +1084,7 @@ instance PrettyPrec Group where -- prettyPrec g = atPrec ArgPrec $ fromString $ show g instance Pretty Bin where - pretty (EvalBinOp name) = fromString (withoutSrc name) + pretty (EvalBinOp name) = pretty (withoutSrc name) pretty DepAmpersand = "&>" pretty Dot = "." pretty DepComma = ",>" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index d0fcbd93d..79214c57a 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -151,7 +151,7 @@ getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do ClassDef _ _ methodNames _ _ _ _ _ <- lookupClassDef className case elemIndex methodSourceName methodNames of - Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className + Nothing -> error $ pprint methodSourceName ++ " is not a method of " ++ pprint className Just i -> return i {-# INLINE getMethodIndex #-} diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 3f192eb63..fb9e7fadb 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -75,7 +75,7 @@ instance ToMarkup Output where instance ToMarkup SourceBlock where toMarkup block = case sbContents block of (Misc (ProseBlock s)) -> cdiv "prose-block" $ mdToHtml s - _ -> renderSpans (sbSourceMaps block) (sbText block) + _ -> renderSpans (sbLexemeInfo block) (sbASTInfo block) (sbText block) mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s @@ -83,11 +83,11 @@ mdToHtml s = preEscapedText $ commonmarkToHtml [] s cdiv :: String -> Html -> Html cdiv c inner = H.div inner ! class_ (stringValue c) -renderSpans :: SourceMaps -> T.Text -> Markup -renderSpans sm sourceText = cdiv "code-block" do +renderSpans :: LexemeInfo -> ASTInfo -> T.Text -> Markup +renderSpans lexInfo astInfo sourceText = cdiv "code-block" do runTextWalkerT sourceText do - forM_ (lexemeList sm) \sourceId -> do - let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo sm) + forM_ (lexemeList lexInfo) \sourceId -> do + let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo lexInfo) takeTo l >>= emitSpan "" takeTo r >>= emitSpan (lexemeClass lexemeTy) takeRest >>= emitSpan "" diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index af42c0c2f..15fa1c86b 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -109,7 +109,7 @@ showAnyTyCon tyCon atom = case tyCon of showDataCon (sink $ cons !! i) arg return UnitVal StructFields fields -> do - emitLit tySourceName + emitLit $ pprint tySourceName parens do sepBy ", " $ (enumerate fields) <&> \(i, _) -> rec =<< projectStruct i atom @@ -117,9 +117,9 @@ showAnyTyCon tyCon atom = case tyCon of showDataCon :: Emits n' => DataConDef n' -> CAtom n' -> Print n' showDataCon (DataConDef sn _ _ projss) arg = do case projss of - [] -> emitLit sn + [] -> emitLit $ pprint sn _ -> parens do - emitLit (sn ++ " ") + emitLit (pprint sn ++ " ") sepBy " " $ projss <&> \projs -> -- we use `init` to strip off the `UnwrapCompoundNewtype` since -- we're already under the case alternative @@ -204,16 +204,16 @@ stringLitAsCharTab s = do finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) finTabTyCore n eltTy = return $ IxType (FinTy n) (DictCon $ IxFin n) ==> eltTy -getPreludeFunction :: EnvReader m => String -> m n (CAtom n) +getPreludeFunction :: EnvReader m => SourceName -> m n (CAtom n) getPreludeFunction sourceName = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of UAtomVar v -> toAtom <$> toAtomVar v _ -> notfound Nothing -> notfound - where notfound = error $ "Function not defined: " ++ sourceName + where notfound = error $ "Function not defined: " ++ pprint sourceName -applyPreludeFunction :: (Emits n, CBuilder m) => String -> [CAtom n] -> m n (CAtom n) +applyPreludeFunction :: (Emits n, CBuilder m) => SourceName -> [CAtom n] -> m n (CAtom n) applyPreludeFunction name args = do f <- getPreludeFunction name naryApp f args @@ -221,14 +221,14 @@ applyPreludeFunction name args = do strType :: forall n m. EnvReader m => m n (CType n) strType = constructPreludeType "List" $ TyConParams [Explicit] [toAtom (CharRepTy :: CType n)] -constructPreludeType :: EnvReader m => String -> TyConParams n -> m n (CType n) +constructPreludeType :: EnvReader m => SourceName -> TyConParams n -> m n (CType n) constructPreludeType sourceName params = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of UTyConVar v -> return $ toType $ UserADTType sourceName v params _ -> notfound Nothing -> notfound - where notfound = error $ "Type constructor not defined: " ++ sourceName + where notfound = error $ "Type constructor not defined: " ++ pprint sourceName forEachTabElt :: (Emits n, ScopableBuilder CoreIR m) diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs new file mode 100644 index 000000000..1fb33a50f --- /dev/null +++ b/src/lib/SourceIdTraversal.hs @@ -0,0 +1,113 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module SourceIdTraversal (getASTInfo) where + +import qualified Data.Map.Strict as M +import Control.Monad.Reader +import Control.Monad.Writer.Strict + +import Types.Source +import Types.Primitives + +getASTInfo :: SourceBlock' -> ASTInfo +getASTInfo b = runTreeM (SrcId 0) $ visit b + +type TreeM = ReaderT SrcId (Writer ASTInfo) + +runTreeM :: SrcId -> TreeM () -> ASTInfo +runTreeM root cont = snd $ runWriter $ runReaderT cont root + +enterNode :: SrcId -> TreeM a -> TreeM a +enterNode sid cont = do + emitNode sid + local (const sid) cont + +emitNode :: SrcId -> TreeM () +emitNode child = do + parent <- ask + tell $ ASTInfo (M.singleton child parent) (M.singleton parent [child]) + +class IsTree a where + visit :: a -> TreeM () + +instance IsTree SourceBlock' where + visit = \case + TopDecl decl -> visit decl + Command _ g -> visit g + DeclareForeign v1 v2 g -> visit v1 >> visit v2 >> visit g + DeclareCustomLinearization v _ g -> visit v >> visit g + Misc _ -> return () + UnParseable _ _ -> return () + +instance IsTree Group where + visit = \case + CLeaf _ -> return () + CPrim _ xs -> mapM_ visit xs + CParens xs -> mapM_ visit xs + CBrackets xs -> mapM_ visit xs + CBin _ l r -> visit l >> visit r + CJuxtapose _ l r -> visit l >> visit r + CPrefix l r -> visit l >> visit r + CGivens (x,y) -> visit x >> visit y + CLambda args body -> visit args >> visit body + CFor _ args body -> visit args >> visit body + CCase scrut alts -> visit scrut >> visit alts + CIf scrut ifTrue ifFalse -> visit scrut >> visit ifTrue >> visit ifFalse + CDo body -> visit body + CArrow l effs r -> visit l >> visit effs >> visit r + CWith b body -> visit b >> visit body + +instance IsTree CSBlock where + visit = \case + IndentedBlock sid decls -> enterNode sid $ visit decls + ExprBlock body -> visit body + +instance IsTree CSDecl where + visit = \case + CLet v rhs -> visit v >> visit rhs + CDefDecl def -> visit def + CExpr g -> visit g + CBind v body -> visit v >> visit body + CPass -> return () + +instance IsTree CTopDecl where + visit = \case + CSDecl _ decl -> visit decl + CData v params givens cons -> visit v >> visit params >> visit givens >> visit cons + CStruct v params givens fields methods -> visit v >> visit params >> visit givens >> visit fields >> visit methods + CInterface v params methods -> visit v >> visit params >> visit methods + CInstanceDecl def -> visit def + +instance IsTree CDef where + visit (CDef v params rhs givens body) = + visit v >> visit params >> visit rhs >> visit givens >> visit body + +instance IsTree CInstanceDef where + visit (CInstanceDef v args givens methods name) = + visit v >> visit args >> visit givens >> visit methods >> visit name + +instance IsTree a => IsTree (WithSrc a) where + visit (WithSrc sid x) = enterNode sid $ visit x + +instance IsTree a => IsTree (WithSrcs a) where + visit (WithSrcs sid sids x) = enterNode sid $ mapM_ emitNode sids >> visit x + +instance IsTree a => IsTree [a] where + visit xs = mapM_ visit xs + +instance IsTree a => IsTree (Maybe a) where + visit xs = mapM_ visit xs + +instance (IsTree a, IsTree b) => IsTree (a, b) where + visit (x, y) = visit x >> visit y + +instance (IsTree a, IsTree b, IsTree c) => IsTree (a, b, c) where + visit (x, y, z) = visit x >> visit y >> visit z + +instance IsTree AppExplicitness where visit _ = return () +instance IsTree SourceName where visit _ = return () +instance IsTree LetAnn where visit _ = return () diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index bf876df27..ee2b6f9f4 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -111,7 +111,7 @@ lookupSourceName v = do LocalVar v' : _ -> return v' [ModuleVar _ maybeV] -> case maybeV of Just v' -> return v' - Nothing -> throw VarDefErr v + Nothing -> throw VarDefErr $ pprint v vs -> throw AmbiguousVarErr $ ambiguousVarErrMsg v vs ambiguousVarErrMsg :: SourceName -> [SourceNameDef n] -> String diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index fe8e2fe9e..89021559b 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -204,8 +204,8 @@ catchLogsAndErrs m = do -- know ahead of time which modules will be needed. evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n Result evalSourceBlockRepl block = do - case block of - SourceBlock _ _ _ _ _ (Misc (ImportModule name)) -> do + case sbContents block of + Misc (ImportModule name) -> do -- TODO: clear source map and synth candidates before calling this ensureModuleLoaded name _ -> return () @@ -286,8 +286,8 @@ evalSourceBlock' mname block = case sbContents block of "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available - let hint = fromString dexName - fTop <- emitBinding hint $ TopFunBinding $ FFITopFun (withoutSrc fname) impFunTy + let hint = fromString $ pprint dexName + fTop <- emitBinding hint $ TopFunBinding $ FFITopFun (pprint $ withoutSrc fname) impFunTy vCore <- emitBinding hint $ AtomNameBinding $ FFIFunBound naryPiTy fTop emitSourceMap $ SourceMap $ M.singleton dexName [ModuleVar mname (Just $ UAtomVar vCore)] @@ -751,7 +751,7 @@ loadModuleSource :: (MonadIO m, Fallible m) => EvalConfig -> ModuleSourceName -> m File loadModuleSource config moduleName = do fullPath <- case moduleName of - OrdinaryModule moduleName' -> findFullPath $ moduleName' ++ ".dx" + OrdinaryModule moduleName' -> findFullPath $ pprint moduleName' ++ ".dx" Prelude -> case preludeFile config of Nothing -> findFullPath "prelude.dx" Just path -> return path @@ -766,7 +766,7 @@ loadModuleSource config moduleName = do Nothing -> throw ModuleImportErr $ unlines [ "Couldn't find a source file for module " ++ (case moduleName of - OrdinaryModule n -> n; Prelude -> "prelude"; Main -> error "") + OrdinaryModule n -> pprint n; Prelude -> "prelude"; Main -> error "") , "Hint: Consider extending --lib-path?" ] diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index 83ba3ffbe..8096f7e6e 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -24,9 +24,11 @@ module Types.Primitives ( import qualified Data.ByteString as BS import Data.Int +import Data.String (IsString (..)) import Data.Word import Data.Hashable import Data.Store (Store (..)) +import Data.Text.Prettyprint.Doc (Pretty (..)) import qualified Data.Store.Internal as SI import Foreign.Ptr @@ -34,8 +36,9 @@ import GHC.Generics (Generic (..)) import Occurrence import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) +import Name -type SourceName = String +newtype SourceName = MkSourceName String deriving (Show, Eq, Ord, Generic) newtype AlwaysEqual a = AlwaysEqual a deriving (Show, Generic, Functor, Foldable, Traversable, Hashable, Store) @@ -181,6 +184,16 @@ emptyLit = \case -- === Typeclass instances === +instance HasNameHint SourceName where + getNameHint (MkSourceName v) = getNameHint v + +instance Pretty SourceName where + pretty (MkSourceName v) = pretty v + +instance IsString SourceName where + fromString v = MkSourceName v + +instance Store SourceName instance Store RequiredMethodAccess instance Store LetAnn instance Store RWS @@ -194,6 +207,7 @@ instance Store AppExplicitness instance Store DepPairExplicitness instance Store InferenceMechanism +instance Hashable SourceName instance Hashable RWS instance Hashable Direction instance Hashable BaseType diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index e14da16d9..961445ea4 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -53,6 +53,7 @@ newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr -- === Source Info === +-- XXX: 0 is reserved for the root newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) -- This is just for syntax highlighting. It won't be needed if we have @@ -70,19 +71,25 @@ data LexemeType = deriving (Show, Generic) type Span = (Int, Int) -data SourceMaps = SourceMaps +data LexemeInfo = LexemeInfo { lexemeList :: SnocList SrcId - , lexemeInfo :: M.Map SrcId (LexemeType, Span) - , astParent :: M.Map SrcId SrcId + , lexemeInfo :: M.Map SrcId (LexemeType, Span) } + deriving (Show, Generic) + +data ASTInfo = ASTInfo + { astParent :: M.Map SrcId SrcId , astChildren :: M.Map SrcId [SrcId]} deriving (Show, Generic) -instance Semigroup SourceMaps where - SourceMaps a b c d <> SourceMaps a' b' c' d' = - SourceMaps (a <> a') (b <> b') (c <> c') (d <> d') +instance Semigroup LexemeInfo where + LexemeInfo a b <> LexemeInfo a' b' = LexemeInfo (a <> a') (b <> b') +instance Monoid LexemeInfo where + mempty = LexemeInfo mempty mempty -instance Monoid SourceMaps where - mempty = SourceMaps mempty mempty mempty mempty +instance Semigroup ASTInfo where + ASTInfo a b <> ASTInfo a' b' = ASTInfo (a <> a') (M.unionWith (<>) b b') +instance Monoid ASTInfo where + mempty = ASTInfo mempty mempty -- === Concrete syntax === -- The grouping-level syntax of the source language @@ -507,7 +514,8 @@ data SourceBlock = SourceBlock , sbOffset :: Int , sbLogLevel :: LogLevel , sbText :: Text - , sbSourceMaps :: SourceMaps + , sbLexemeInfo :: LexemeInfo + , sbASTInfo :: ASTInfo , sbContents :: SourceBlock' } deriving (Show, Generic) From 3fe24e97264bcf07fcf9f612c2c82b5d83a1f8aa Mon Sep 17 00:00:00 2001 From: Dougal Date: Sat, 25 Nov 2023 22:14:23 -0500 Subject: [PATCH 22/41] Fix bugs in lexeme tracking --- src/dex.hs | 2 +- src/lib/ConcreteSyntax.hs | 6 ++-- src/lib/IncState.hs | 5 +++- src/lib/Lexing.hs | 27 ++++++++++++------ src/lib/Live/Eval.hs | 15 ++++++---- src/lib/Live/Web.hs | 2 +- src/lib/RenderHtml.hs | 59 +++++++++++++++++++++++---------------- src/lib/Types/Source.hs | 2 ++ 8 files changed, 74 insertions(+), 44 deletions(-) diff --git a/src/dex.hs b/src/dex.hs index 5232ec9c5..2874649e2 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -116,7 +116,7 @@ printFinal fmt prog = case fmt of TextDoc -> return () JSONDoc -> return () #ifdef DEX_LIVE - HTMLDoc -> putStr $ progHtml prog + HTMLDoc -> undefined -- putStr $ progHtml prog #endif readSourceBlock :: (MonadIO (m n), EnvReader m) => String -> m n SourceBlock diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 2aa2c0f6e..5762083ed 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -197,12 +197,12 @@ topLevelCommand = importModule <|> declareForeign <|> declareCustomLinearization - <|> (Misc . QueryEnv <$> envQuery) + -- <|> (Misc . QueryEnv <$> envQuery) <|> explicitCommand "top-level command" -envQuery :: Parser EnvQuery -envQuery = error "not implemented" +_envQuery :: Parser EnvQuery +_envQuery = error "not implemented" -- string ":debug" >> sc >> ( -- (DumpSubst <$ (string "env" >> sc)) -- <|> (InternalNameInfo <$> (string "iname" >> sc >> rawName)) diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs index 19d0a0884..3c8f90d77 100644 --- a/src/lib/IncState.hs +++ b/src/lib/IncState.hs @@ -8,7 +8,7 @@ module IncState ( IncState (..), MapEltUpdate (..), MapUpdate (..), - Overwrite (..), TailUpdate (..)) where + Overwrite (..), TailUpdate (..), mapUpdateMapWithKey) where import qualified Data.Map.Strict as M import GHC.Generics @@ -29,6 +29,9 @@ data MapEltUpdate v = data MapUpdate k v = MapUpdate { mapUpdates :: M.Map k (MapEltUpdate v) } deriving (Functor, Show, Generic) +mapUpdateMapWithKey :: MapUpdate k v -> (k -> v -> v') -> MapUpdate k v' +mapUpdateMapWithKey (MapUpdate m) f = MapUpdate $ M.mapWithKey (\k v -> fmap (f k) v) m + instance Ord k => Monoid (MapUpdate k v) where mempty = MapUpdate mempty diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index d40c4c8a4..18c85f55d 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -235,12 +235,21 @@ space = gets canBreak >>= \case True -> space1 False -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String)) +setCanBreakLocally :: Bool -> Parser a -> Parser a +setCanBreakLocally brLocal p = do + brPrev <- gets canBreak + modify \ctx -> ctx {canBreak = brLocal} + ans <- p + modify \ctx -> ctx {canBreak = brPrev} + return ans +{-# INLINE setCanBreakLocally #-} + mayBreak :: Parser a -> Parser a -mayBreak p = pLocal (\ctx -> ctx { canBreak = True }) p +mayBreak p = setCanBreakLocally True p {-# INLINE mayBreak #-} mayNotBreak :: Parser a -> Parser a -mayNotBreak p = pLocal (\ctx -> ctx { canBreak = False }) p +mayNotBreak p = setCanBreakLocally False p {-# INLINE mayNotBreak #-} precededByWhitespace :: Parser Bool @@ -294,14 +303,16 @@ withIndent p = do nextLine indent <- T.length <$> takeWhileP (Just "space") (==' ') when (indent <= 0) empty - pLocal (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ mayNotBreak p + locallyExtendCurIndent indent $ mayNotBreak p {-# INLINE withIndent #-} -pLocal :: (ParseCtx -> ParseCtx) -> Parser a -> Parser a -pLocal f p = do - s <- get - put (f s) >> p <* put s -{-# INLINE pLocal #-} +locallyExtendCurIndent :: Int -> Parser a -> Parser a +locallyExtendCurIndent n p = do + indentPrev <- gets curIndent + modify \ctx -> ctx { curIndent = indentPrev + n } + ans <- p + modify \ctx -> ctx { curIndent = indentPrev } + return ans eol :: Parser () eol = void MC.eol diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index a8c449a8c..f5248302f 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -8,7 +8,7 @@ {-# OPTIONS_GHC -Wno-orphans #-} module Live.Eval ( - watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, dagAsUpdate) where + watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, dagAsUpdate, addSourceBlockIds) where import Control.Concurrent import Control.Monad @@ -33,13 +33,16 @@ import MonadUtil -- `watchAndEvalFile` returns the channel by which a client may -- subscribe by sending a write-only view of its input channel. -watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx - -> IO (Evaluator SourceBlock Result) +watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO ResultsServer watchAndEvalFile fname opts env = do watcher <- launchFileWatcher fname parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source launchDagEvaluator parser env (evalSourceBlockIO opts) +addSourceBlockIds :: NodeListUpdate (NodeState SourceBlock o) -> NodeListUpdate (NodeState SourceBlockWithId o) +addSourceBlockIds (NodeListUpdate listUpdate mapUpdate) = NodeListUpdate listUpdate mapUpdate' + where mapUpdate' = mapUpdateMapWithKey mapUpdate \k (NodeState i o) -> NodeState (SourceBlockWithId k i) o + type ResultsServer = Evaluator SourceBlock Result type ResultsUpdate = EvalStatusUpdate SourceBlock Result @@ -297,15 +300,15 @@ processDagUpdate dagUpdate = do -- === instances === -instance ToJSON a => ToJSON (NodeListUpdate a) +instance (ToJSON i, ToJSON o) => ToJSON (NodeListUpdate (NodeState i o)) where instance (ToJSON a, ToJSONKey k) => ToJSON (MapUpdate k a) instance ToJSON a => ToJSON (TailUpdate a) instance ToJSON a => ToJSON (MapEltUpdate a) instance ToJSON o => ToJSON (NodeEvalStatus o) instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) -instance ToJSON SourceBlock where - toJSON b = toJSON (sbLine b, pprintHtml b) +instance ToJSON SourceBlockWithId where + toJSON b@(SourceBlockWithId _ b') = toJSON (sbLine b', pprintHtml b) instance ToJSON Result where toJSON = toJSONViaHtml toJSONViaHtml :: ToMarkup a => a -> Value diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index ae7273028..1ab7394ca 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -57,7 +57,7 @@ resultStream resultsServer write flush = do where sendUpdate :: ResultsUpdate -> IO () sendUpdate update = do - let s = encodeResults update + let s = encodeResults $ addSourceBlockIds update write (fromByteString s) >> flush encodeResults :: ToJSON a => a -> BS.ByteString diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index fb9e7fadb..7e4aac451 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -16,6 +16,7 @@ import Text.Blaze.Html.Renderer.String import qualified Data.Map.Strict as M import Control.Monad.State.Strict import Data.Maybe (fromJust) +import Data.String (fromString) import Data.Text qualified as T import Data.Text.IO qualified as T import CMark (commonmarkToHtml) @@ -72,10 +73,10 @@ instance ToMarkup Output where HtmlOut s -> preEscapedString s _ -> cdiv "result-block" $ toHtml $ pprint out -instance ToMarkup SourceBlock where - toMarkup block = case sbContents block of - (Misc (ProseBlock s)) -> cdiv "prose-block" $ mdToHtml s - _ -> renderSpans (sbLexemeInfo block) (sbASTInfo block) (sbText block) +instance ToMarkup SourceBlockWithId where + toMarkup (SourceBlockWithId blockId block) = case sbContents block of + Misc (ProseBlock s) -> cdiv "prose-block" $ mdToHtml s + _ -> renderSpans blockId (sbLexemeInfo block) (sbASTInfo block) (sbText block) mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s @@ -83,28 +84,37 @@ mdToHtml s = preEscapedText $ commonmarkToHtml [] s cdiv :: String -> Html -> Html cdiv c inner = H.div inner ! class_ (stringValue c) -renderSpans :: LexemeInfo -> ASTInfo -> T.Text -> Markup -renderSpans lexInfo astInfo sourceText = cdiv "code-block" do +type BlockId = Int + +renderSpans :: BlockId -> LexemeInfo -> ASTInfo -> T.Text -> Markup +renderSpans blockId lexInfo astInfo sourceText = cdiv "code-block" do runTextWalkerT sourceText do forM_ (lexemeList lexInfo) \sourceId -> do let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo lexInfo) - takeTo l >>= emitSpan "" - takeTo r >>= emitSpan (lexemeClass lexemeTy) - takeRest >>= emitSpan "" - -emitSpan :: String -> T.Text -> TextWalker () -emitSpan className t = lift $ H.span (toHtml t) ! class_ (stringValue className) - -lexemeClass :: LexemeType -> String + takeTo l >>= emitSpan Nothing (Just "comment") + takeTo r >>= emitSpan (Just (blockId, sourceId)) (lexemeClass lexemeTy) + takeRest >>= emitSpan Nothing (Just "comment") + +emitSpan :: Maybe (BlockId, SrcId) -> Maybe String -> T.Text -> TextWalker () +emitSpan maybeSrcId className t = lift do + let classAttr = case className of + Nothing -> mempty + Just c -> class_ (stringValue c) + let idAttr = case maybeSrcId of + Nothing -> mempty + Just (bid, SrcId sid) -> At.id (fromString $ "span_" ++ show bid ++ "_"++ show sid) + H.span (toHtml t) ! classAttr ! idAttr + +lexemeClass :: LexemeType -> Maybe String lexemeClass = \case - Keyword -> "keyword" - Symbol -> "symbol" - TypeName -> "type-name" - LowerName -> "" - UpperName -> "" - LiteralLexeme -> "literal" - StringLiteralLexeme -> "" - MiscLexeme -> "" + Keyword -> Just "keyword" + Symbol -> Just "symbol" + TypeName -> Just "type-name" + LowerName -> Nothing + UpperName -> Nothing + LiteralLexeme -> Just "literal" + StringLiteralLexeme -> Nothing + MiscLexeme -> Nothing type TextWalker a = StateT (Int, T.Text) MarkupM a @@ -121,5 +131,6 @@ takeTo startPos = do takeRest :: TextWalker T.Text takeRest = do - endPos <- gets $ T.length . snd - takeTo endPos + (curPos, curText) <- get + put (curPos + T.length curText, mempty) + return curText diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 961445ea4..81f1c3387 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -509,6 +509,8 @@ data UModule = UModule -- === top-level blocks === +data SourceBlockWithId = SourceBlockWithId Int SourceBlock + data SourceBlock = SourceBlock { sbLine :: Int , sbOffset :: Int From b94d35919bbb1731fa7a46375b9161692c850c9c Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 27 Nov 2023 11:08:03 -0500 Subject: [PATCH 23/41] Use explicit data structures to represent the AST on the browser side. Previously we baked the AST into the HTML tree which made it hard to change and add more information. --- src/lib/AbstractSyntax.hs | 22 ++-- src/lib/ConcreteSyntax.hs | 26 ++-- src/lib/Lexing.hs | 14 +- src/lib/Live/Eval.hs | 17 ++- src/lib/PPrint.hs | 2 +- src/lib/RenderHtml.hs | 6 +- src/lib/SourceIdTraversal.hs | 3 +- src/lib/Types/Source.hs | 5 +- static/index.js | 239 ++++++++++++++--------------------- 9 files changed, 156 insertions(+), 178 deletions(-) diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index 8d8ec9818..a21062ff1 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -195,7 +195,7 @@ withTrailingConstraints :: GroupW -> (GroupW -> SyntaxM (UAnnBinder VoidS VoidS)) -> SyntaxM (Nest UAnnBinder VoidS VoidS) withTrailingConstraints g cont = case g of - WithSrcs _ _ (CBin Pipe lhs c) -> do + WithSrcs _ _ (CBin (WithSrc _ Pipe) lhs c) -> do Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs <- withTrailingConstraints lhs cont s <- case b of UBindSource s -> return s @@ -253,7 +253,7 @@ explicitBindersOptAnn (WithSrcs _ _ bs) = -- Binder pattern with an optional type annotation patOptAnn :: GroupW -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) -patOptAnn (WithSrcs _ _ (CBin Colon lhs typeAnn)) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) +patOptAnn (WithSrcs _ _ (CBin (WithSrc _ Colon) lhs typeAnn)) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) patOptAnn (WithSrcs _ _ (CParens [g])) = patOptAnn g patOptAnn g = (,Nothing) <$> pat g @@ -267,7 +267,7 @@ uBinder (WithSrcs sid _ b) = case b of tyOptPat :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) tyOptPat grpTop@(WithSrcs sid _ grp) = case grp of -- Named type - CBin Colon lhs typeAnn -> + CBin (WithSrc _ Colon) lhs typeAnn -> UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] -- Binder in grouping parens. CParens [g] -> tyOptPat g @@ -285,7 +285,7 @@ casePat = \case pat :: GroupW -> SyntaxM (UPat VoidS VoidS) pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of - CBin DepComma lhs rhs -> do + CBin (WithSrc _ DepComma) lhs rhs -> do lhs' <- pat lhs rhs' <- pat rhs return $ UPatDepPair $ PairB lhs' rhs' @@ -317,8 +317,8 @@ pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of tyOptBinder :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) tyOptBinder expl (WithSrcs sid sids grp) = case grp of - CBin Pipe _ _ -> throw SyntaxErr "Unexpected constraint" - CBin Colon name ty -> do + CBin (WithSrc _ Pipe) _ _ -> throw SyntaxErr "Unexpected constraint" + CBin (WithSrc _ Colon) name ty -> do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] @@ -328,7 +328,7 @@ tyOptBinder expl (WithSrcs sid sids grp) = case grp of binderOptTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) binderOptTy expl = \case - WithSrcs _ _ (CBin Colon name ty) -> do + WithSrcs _ _ (CBin (WithSrc _ Colon) name ty) -> do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] @@ -337,7 +337,7 @@ binderOptTy expl = \case return $ UAnnBinder expl b UNoAnn [] binderReqTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) -binderReqTy expl (WithSrcs _ _ (CBin Colon name ty)) = do +binderReqTy expl (WithSrcs _ _ (CBin (WithSrc _ Colon) name ty)) = do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] @@ -348,7 +348,7 @@ argList gs = partitionEithers <$> mapM singleArg gs singleArg :: GroupW -> SyntaxM (Either (UExpr VoidS) (UNamedArg VoidS)) singleArg = \case - WithSrcs _ _ (CBin CSEqual lhs rhs) -> Right <$> + WithSrcs _ _ (CBin (WithSrc _ CSEqual) lhs rhs) -> Right <$> ((,) <$> withoutSrc <$> identifier "named argument" lhs <*> expr rhs) g -> Left <$> expr g @@ -450,7 +450,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of args' <- mapM expr args return $ UTabApp f args' _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - CBin op lhs rhs -> case op of + CBin (WithSrc opSid op) lhs rhs -> case op of Dollar -> extendAppRight <$> expr lhs <*> expr rhs Pipe -> extendAppLeft <$> expr lhs <*> expr rhs Dot -> do @@ -480,7 +480,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of UTabPi . (UTabPiExpr lhs') <$> expr rhs where evalOp s = do - let f = WithSrcE (srcPos s) (fromSourceNameW s) + let f = WithSrcE opSid (fromSourceNameW (WithSrc opSid s)) lhs' <- expr lhs rhs' <- expr rhs return $ explicitApp f [lhs', rhs'] diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 5762083ed..1d9a0c320 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -71,8 +71,8 @@ mustParseSourceBlock s = mustParseit s sourceBlock -- === helpers for target ADT === -interpOperator :: WithSrc String -> Bin -interpOperator (WithSrc sid s) = case s of +interpOperator :: String -> Bin +interpOperator = \case "&>" -> DepAmpersand "." -> Dot ",>" -> DepComma @@ -83,7 +83,7 @@ interpOperator (WithSrc sid s) = case s of "->>" -> ImplicitArrow "=>" -> FatArrow "=" -> CSEqual - name -> EvalBinOp $ WithSrc sid $ fromString $ "(" <> name <> ")" + name -> EvalBinOp $ fromString $ "(" <> name <> ")" pattern Identifier :: SourceName -> GroupW pattern Identifier name <- (WithSrcs _ _ (CLeaf (CIdentifier name))) @@ -477,7 +477,9 @@ cFor = do <|> keyWord Rof_KW $> KRof_ cDo :: Parser Group -cDo = CDo <$> cBlock +cDo = do + keyWord DoKW + CDo <$> cBlock cCase :: Parser Group cCase = do @@ -584,9 +586,9 @@ leafGroup = leafGroup' >>= appendPostfixGroups appendFieldAccess :: GroupW -> Parser Group appendFieldAccess g = try do - void $ char '.' + sid <- dot field <- cFieldName - return $ CBin Dot g field + return $ CBin (WithSrc sid Dot) g field cFieldName :: Parser GroupW cFieldName = cIdentifier <|> (toCLeaf CNat <$> natLit) @@ -675,13 +677,13 @@ addSrcIdToUnOp op = do backquoteOp :: Parser (GroupW -> GroupW -> GroupW) backquoteOp = binApp do - fname <- backquoteName - return $ EvalBinOp fname + WithSrc sid fname <- backquoteName + return $ WithSrc sid $ EvalBinOp fname anySymOp :: Expr.Operator Parser GroupW anySymOp = Expr.InfixL $ binApp do - s <- label "infix operator" (mayBreak anySym) - return $ interpOperator s + WithSrc sid s <- label "infix operator" (mayBreak anySym) + return $ WithSrc sid $ interpOperator s infixSym :: String -> Parser SrcId infixSym s = mayBreak $ symWithId $ T.pack s @@ -698,7 +700,7 @@ symOpR s = (fromString s, Expr.InfixR $ symOp s) symOp :: String -> Parser (GroupW -> GroupW -> GroupW) symOp s = binApp do sid <- label "infix operator" (infixSym s) - return $ interpOperator (WithSrc sid s) + return $ WithSrc sid $ interpOperator s arrowOp :: Parser (GroupW -> GroupW -> GroupW) arrowOp = addSrcIdToBinOp do @@ -714,7 +716,7 @@ prefixOp s = addSrcIdToUnOp do symId <- symWithId (fromString s) return $ CPrefix (WithSrc symId $ fromString s) -binApp :: Parser Bin -> Parser (GroupW -> GroupW -> GroupW) +binApp :: Parser BinW -> Parser (GroupW -> GroupW -> GroupW) binApp f = addSrcIdToBinOp $ CBin <$> f withClausePostfixOp :: Parser (GroupW -> GroupW) diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 18c85f55d..4d3b6dc8e 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -216,6 +216,10 @@ symChar = token (\c -> if HS.member c symChars then Just c else Nothing) mempty symChars :: HS.HashSet Char symChars = HS.fromList ".,!$^&*:-~+/=<>|?\\@#" +-- XXX: unlike other lexemes, this doesn't consume trailing whitespace +dot :: Parser SrcId +dot = srcPos <$> lexeme' (return ()) Symbol (void $ char '.') + -- === Util === sc :: Parser () @@ -349,18 +353,22 @@ symbol :: Text -> Parser () symbol s = void $ L.symbol sc s lexeme :: LexemeType -> Parser a -> Parser (WithSrc a) -lexeme lexemeType p = do +lexeme lexemeType p = lexeme' sc lexemeType p +{-# INLINE lexeme #-} + +lexeme' :: Parser () -> LexemeType -> Parser a -> Parser (WithSrc a) +lexeme' sc' lexemeType p = do start <- getOffset ans <- p end <- getOffset recordNonWhitespace - sc + sc' sid <- freshSrcId emitLexemeInfo $ mempty { lexemeList = toSnocList [sid] , lexemeInfo = M.singleton sid (lexemeType, (start, end)) } return $ WithSrc sid ans -{-# INLINE lexeme #-} +{-# INLINE lexeme' #-} atomicLexeme :: LexemeType -> Parser () -> Parser () atomicLexeme lexemeType p = do diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index f5248302f..5363a2ee2 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -28,6 +28,7 @@ import TopLevel import ConcreteSyntax import RenderHtml (ToMarkup, pprintHtml) import MonadUtil +import Util (unsnoc) -- === Top-level interface === @@ -305,10 +306,24 @@ instance (ToJSON a, ToJSONKey k) => ToJSON (MapUpdate k a) instance ToJSON a => ToJSON (TailUpdate a) instance ToJSON a => ToJSON (MapEltUpdate a) instance ToJSON o => ToJSON (NodeEvalStatus o) +instance ToJSON SrcId +deriving instance ToJSONKey SrcId +instance ToJSON ASTInfo +instance ToJSON LexemeType instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) +data SourceBlockJSONData = SourceBlockJSONData + { jdLine :: Int + , jdBlockId :: Int + , jdLexemeList :: [SrcId] + , jdASTInfo :: ASTInfo + , jdHTML :: String } deriving (Generic) + +instance ToJSON SourceBlockJSONData + instance ToJSON SourceBlockWithId where - toJSON b@(SourceBlockWithId _ b') = toJSON (sbLine b', pprintHtml b) + toJSON b@(SourceBlockWithId blockId b') = toJSON $ SourceBlockJSONData + (sbLine b') blockId (unsnoc $ lexemeList $ sbLexemeInfo b') (sbASTInfo b') (pprintHtml b) instance ToJSON Result where toJSON = toJSONViaHtml toJSONViaHtml :: ToMarkup a => a -> Value diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index bac0e3bbe..9add7af52 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -1084,7 +1084,7 @@ instance PrettyPrec Group where -- prettyPrec g = atPrec ArgPrec $ fromString $ show g instance Pretty Bin where - pretty (EvalBinOp name) = pretty (withoutSrc name) + pretty (EvalBinOp name) = pretty name pretty DepAmpersand = "&>" pretty Dot = "." pretty DepComma = ",>" diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 7e4aac451..bc4809a41 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -76,7 +76,7 @@ instance ToMarkup Output where instance ToMarkup SourceBlockWithId where toMarkup (SourceBlockWithId blockId block) = case sbContents block of Misc (ProseBlock s) -> cdiv "prose-block" $ mdToHtml s - _ -> renderSpans blockId (sbLexemeInfo block) (sbASTInfo block) (sbText block) + _ -> renderSpans blockId (sbLexemeInfo block) (sbText block) mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s @@ -86,8 +86,8 @@ cdiv c inner = H.div inner ! class_ (stringValue c) type BlockId = Int -renderSpans :: BlockId -> LexemeInfo -> ASTInfo -> T.Text -> Markup -renderSpans blockId lexInfo astInfo sourceText = cdiv "code-block" do +renderSpans :: BlockId -> LexemeInfo -> T.Text -> Markup +renderSpans blockId lexInfo sourceText = cdiv "code-block" do runTextWalkerT sourceText do forM_ (lexemeList lexInfo) \sourceId -> do let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo lexInfo) diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 1fb33a50f..565027294 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -49,7 +49,7 @@ instance IsTree Group where CPrim _ xs -> mapM_ visit xs CParens xs -> mapM_ visit xs CBrackets xs -> mapM_ visit xs - CBin _ l r -> visit l >> visit r + CBin b l r -> visit b >> visit l >> visit r CJuxtapose _ l r -> visit l >> visit r CPrefix l r -> visit l >> visit r CGivens (x,y) -> visit x >> visit y @@ -111,3 +111,4 @@ instance (IsTree a, IsTree b, IsTree c) => IsTree (a, b, c) where instance IsTree AppExplicitness where visit _ = return () instance IsTree SourceName where visit _ = return () instance IsTree LetAnn where visit _ = return () +instance IsTree Bin where visit _ = return () diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 81f1c3387..5c01c3f37 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -100,6 +100,7 @@ type GroupW = WithSrcs Group type CTopDeclW = WithSrcs CTopDecl type CSDeclW = WithSrcs CSDecl type SourceNameW = WithSrc SourceName +type BinW = WithSrc Bin type BracketedGroup = WithSrcs [GroupW] -- optional arrow, effects, result type @@ -161,7 +162,7 @@ data Group | CPrim PrimName [GroupW] | CParens [GroupW] | CBrackets [GroupW] - | CBin Bin GroupW GroupW + | CBin BinW GroupW GroupW | CJuxtapose Bool GroupW GroupW -- Bool means "there's a space between the groups" | CPrefix SourceNameW GroupW -- covers unary - and unary + among others | CGivens GivenClause @@ -187,7 +188,7 @@ data CLeaf type CaseAlt = (GroupW, CSBlock) -- scrutinee, lexeme Id, body data Bin - = EvalBinOp SourceNameW + = EvalBinOp SourceName | DepAmpersand | Dot | DepComma diff --git a/static/index.js b/static/index.js index 1d2c36a5a..7657b3024 100644 --- a/static/index.js +++ b/static/index.js @@ -16,53 +16,6 @@ var katexOptions = { trust: true }; -function lookup_address(cell, address) { - var node = cell - for (i = 0; i < address.length; i++) { - node = node.children[address[i]] - } - return node -} - -function renderHovertips(root) { - var spans = root.querySelectorAll(".code-span"); - Array.from(spans).map((span) => attachHovertip(span)); -} - -function attachHovertip(node) { - node.addEventListener("mouseover", (event) => highlightNode( event, node)); - node.addEventListener("mouseout" , (event) => removeHighlighting(event, node)); -} - -function highlightNode(event, node) { - event.stopPropagation(); - node.style.backgroundColor = "lightblue"; - node.style.outlineColor = "lightblue"; - node.style.outlineStyle = "solid"; - Array.from(node.children).map(function (child) { - if (isCodeSpanOrLeaf(child)) { - child.style.backgroundColor = "yellow"; - } - }) -} - -function isCodeSpanOrLeaf(node) { - return node.classList.contains("code-span") || node.classList.contains("code-span-leaf") - -} - -function removeHighlighting(event, node) { - event.stopPropagation(); - node.style.backgroundColor = null; - node.style.outlineColor = null; - node.style.outlineStyle = null; - Array.from(node.children).map(function (child) { - if (isCodeSpanOrLeaf(child)) { - child.style.backgroundColor = null; - } - }) -} - function renderLaTeX(root) { // Render LaTeX equations in prose blocks via KaTeX, if available. // Skip rendering if KaTeX is unavailable. @@ -76,84 +29,6 @@ function renderLaTeX(root) { ); } -/** - * Rendering the Table of Contents / Navigation Bar - * 2 key functions - * - `updateNavigation()` which inserts/updates the navigation bar - * - and its helper `extractStructure()` which extracts the structure of the page - * and adds ids to heading elements. -*/ -function updateNavigation() { - function navItemList(struct) { - var listEle = document.createElement('ol') - struct.children.forEach(childStruct=> - listEle.appendChild(navItem(childStruct)) - ); - return listEle; - } - function navItem(struct) { - var a = document.createElement('a'); - a.appendChild(document.createTextNode(struct.text)); - a.title = struct.text; - a.href = "#"+struct.id; - - var ele = document.createElement('li') - ele.appendChild(a) - ele.appendChild(navItemList(struct)); - return ele; - } - - var navbarEle = document.getElementById("navbar") - if (navbarEle === null) { // create it - navbarEle = document.createElement("div"); - navbarEle.id="navbar"; - navOuterEle = document.createElement("nav") - navOuterEle.appendChild(navbarEle); - document.body.prepend(navOuterEle); - } - - navbarEle.innerHTML = "" - var structure = extractStructure() - navbarEle.appendChild(navItemList(structure)); -} - -function extractStructure() { // Also sets ids on h1,h2,... - var headingsNodes = document.querySelectorAll("h1, h2, h3, h4, h5, h6"); - // For now we are just fulling going to regenerate the structure each time - // Might be better if we made minimal changes, but 🤷 - - // Extract the structure of the document - var structure = {children:[]} - var active = [structure.children]; - headingsNodes.forEach( - function(currentValue, currentIndex) { - currentValue.id = "s-" + currentIndex; - var currentLevel = parseInt(currentValue.nodeName[1]); - - // Insert dummy levels up for any levels that are skipped - for (var i=active.length; i < currentLevel; i++) { - var dummy = {id: "", text: "", children: []} - active.push(dummy.children); - var parentList = active[i-1] - parentList.push(dummy); - } - // delete this level and everything after - active.splice(currentLevel, active.length); - - var currentStructure = { - id: currentValue.id, - text: currentValue.textContent, - children: [], - }; - active.push(currentStructure.children); - - var parentList = active[active.length-2] - parentList.push(currentStructure); - }, - ); - return structure; -} - /** * HTML rendering mode. * Static rendering is used for static HTML pages. @@ -178,8 +53,6 @@ function render(renderMode) { if (renderMode == RENDER_MODE.STATIC) { // For static pages, simply call rendering functions once. renderLaTeX(document); - renderHovertips(document); - updateNavigation(); } else { // For dynamic pages (via `dex web`), listen to update events. var source = new EventSource("/getnext"); @@ -190,24 +63,88 @@ function render(renderMode) { cells = {} return } else { - process_update(msg); + processUpdate(msg); } }; } } -function set_cell_contents(cell, contents) { - var line_num = contents[0][0]; - var source_text = contents[0][1]; - var line_num_div = document.createElement("div"); +function selectSpan(cellCtx, srcId) { + let [cell, blockId, _] = cellCtx + return cell.querySelector("#span_".concat(blockId.toString(), "_", srcId.toString()));} + +function attachHovertip(cellCtx, srcId) { + let span = selectSpan(cellCtx, srcId); + span.addEventListener("mouseover", (event) => enterSpan(event, cellCtx, srcId)); + span.addEventListener("mouseout" , (event) => leaveSpan(event, cellCtx, srcId));} + +function getParent(cellCtx, srcId) { + let [ , , astInfo] = cellCtx; + let parent = astInfo["astParent"][srcId.toString()] + if (parent == undefined) { + console.error(srcId, astInfo); + throw new Error("Can't find parent"); + } else { + return parent; + }} + +function getChildren(cellCtx, srcId) { + let [ , , astInfo] = cellCtx; + let children = astInfo["astChildren"][srcId.toString()] + if (children == undefined) { + return []; + } else { + return children; + }} + +function traverseSpans(cellCtx, srcId, f) { + let span = selectSpan(cellCtx, srcId) + if (span !== null) f(span); + getChildren(cellCtx, srcId).map(function (childId) { + traverseSpans(cellCtx, childId, f); + })} - line_num_div.innerHTML = line_num.toString(); - line_num_div.className = "line-num"; +function enterSpan(event, cellCtx, srcId) { + event.stopPropagation(); + let parentId = getParent(cellCtx, srcId); + traverseSpans(cellCtx, parentId, function (span) { + span.style.backgroundColor = "lightblue"; + span.style.outlineColor = "lightblue"; + span.style.outlineStyle = "solid"; + }); + let siblingIds = getChildren(cellCtx, parentId); + siblingIds.map(function (siblingId) { + traverseSpans(cellCtx, siblingId, function (span) { + span.style.backgroundColor = "yellow"; + })})} + +function leaveSpan(event, cellCtx, srcId) { + event.stopPropagation(); + let parentId = getParent(cellCtx, srcId); + traverseSpans(cellCtx, parentId, function (span) { + span.style.backgroundColor = null; + span.style.outlineColor = null; + span.style.outlineStyle = null; + }); + let siblingIds = getChildren(cellCtx, parentId); + siblingIds.map(function (siblingId) { + traverseSpans(cellCtx, siblingId, function (span) { + span.style.backgroundColor = null; + })})} + +function setCellContents(cell, contents) { + let source = contents[0]; + let results = contents[1]; + let lineNum = source["jdLine"]; + let sourceText = source["jdHTML"]; + let lineNumDiv = document.createElement("div"); + lineNumDiv.innerHTML = lineNum.toString(); + lineNumDiv.className = "line-num"; cell.innerHTML = "" - cell.appendChild(line_num_div); - cell.innerHTML += source_text - var results = contents[1]; - tag = results["tag"] + cell.appendChild(lineNumDiv); + cell.innerHTML += sourceText + + tag = results["tag"] if (tag == "Waiting") { cell.className = "cell waiting-cell"; } else if (tag == "Running") { @@ -219,10 +156,9 @@ function set_cell_contents(cell, contents) { console.error(tag); } renderLaTeX(cell); - renderHovertips(cell); } -function process_update(msg) { +function processUpdate(msg) { var cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; var num_dropped = msg["orderedNodesUpdate"]["numDropped"]; var new_tail = msg["orderedNodesUpdate"]["newTail"]; @@ -238,10 +174,10 @@ function process_update(msg) { if (tag == "Create") { var cell = document.createElement("div"); cells[node_id] = cell; - set_cell_contents(cell, contents) + setCellContents(cell, contents) } else if (tag == "Update") { var cell = cells[node_id]; - set_cell_contents(cell, contents); + setCellContents(cell, contents); } else if (tag == "Delete") { delete cells[node_id] } else { @@ -251,7 +187,22 @@ function process_update(msg) { // append_new_cells new_tail.forEach(function (node_id) { - body.appendChild(cells[node_id]); - }); + cell = cells[node_id]; + body.appendChild(cell); + }) + // add hovertips + new_tail.forEach(function (node_id) { + cell = cells[node_id]; + var update = cell_updates[node_id]; + if (update["tag"] == "Create") { + var source = update["contents"][0]; + var blockId = source["jdBlockId"]; + var astInfo = source["jdASTInfo"]; + var lexemeList = source["jdLexemeList"]; + cellCtx = [cell, blockId, astInfo]; + lexemeList.map(function (lexemeId) {attachHovertip(cellCtx, lexemeId)}) + } + }); } + From 562c10c55288696eaaf132bc430f7d65e3e509f0 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 28 Nov 2023 10:40:55 -0500 Subject: [PATCH 24/41] Tweaks to parse highlighting. It's works pretty nicely now! --- src/lib/Live/Eval.hs | 72 +++++++++++++++++++++++++++++++++++- src/lib/SourceIdTraversal.hs | 2 +- src/lib/Types/Source.hs | 3 ++ static/index.js | 72 +++++++++++++++++++----------------- static/style.css | 8 ++++ 5 files changed, 120 insertions(+), 37 deletions(-) diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 5363a2ee2..286e05160 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -13,11 +13,14 @@ module Live.Eval ( import Control.Concurrent import Control.Monad import Control.Monad.State.Strict +import Control.Monad.Reader import qualified Data.Map.Strict as M import Data.Aeson (ToJSON, ToJSONKey, toJSON, Value) import Data.Functor ((<&>)) -import Data.Maybe (fromJust) +import Data.Foldable (fold) +import Data.Maybe (fromJust, fromMaybe) import Data.Text (Text) +import Prelude hiding (span) import GHC.Generics import Actor @@ -317,14 +320,79 @@ data SourceBlockJSONData = SourceBlockJSONData , jdBlockId :: Int , jdLexemeList :: [SrcId] , jdASTInfo :: ASTInfo + , jdASTLims :: M.Map SrcId (SrcId, SrcId) -- precomputed leftmost and rightmost spans associated with each node , jdHTML :: String } deriving (Generic) instance ToJSON SourceBlockJSONData instance ToJSON SourceBlockWithId where toJSON b@(SourceBlockWithId blockId b') = toJSON $ SourceBlockJSONData - (sbLine b') blockId (unsnoc $ lexemeList $ sbLexemeInfo b') (sbASTInfo b') (pprintHtml b) + { jdLine = sbLine b' + , jdBlockId = blockId + , jdLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b' + , jdASTInfo = sbASTInfo b' + , jdASTLims = computeASTLims (sbASTInfo b') (sbLexemeInfo b') + , jdHTML = pprintHtml b + } instance ToJSON Result where toJSON = toJSONViaHtml + toJSONViaHtml :: ToMarkup a => a -> Value toJSONViaHtml x = toJSON $ pprintHtml x + +-- === computing the linear lexeme limits on each SrcId === + +data OrdSrcId = OrdSrcId Int SrcId +newtype OrdSrcIdSpan = OrdSrcIdSpan (Maybe (OrdSrcId, OrdSrcId)) + +type SpanMap = M.Map SrcId OrdSrcIdSpan +type ComputeSpanM = ReaderT (ASTInfo, LexemeInfo) (State SpanMap) + +instance Eq OrdSrcId where + OrdSrcId x _ == OrdSrcId y _ = x == y + +instance Ord OrdSrcId where + compare (OrdSrcId x _) (OrdSrcId y _) = compare x y + +instance Monoid OrdSrcIdSpan where + mempty = OrdSrcIdSpan Nothing + +instance Semigroup OrdSrcIdSpan where + OrdSrcIdSpan Nothing <> s = s + s <> OrdSrcIdSpan Nothing = s + OrdSrcIdSpan (Just (l, r)) <> OrdSrcIdSpan (Just (l', r')) = + OrdSrcIdSpan $ Just (min l l', max r r') + +computeASTLims :: ASTInfo -> LexemeInfo -> M.Map SrcId (SrcId, SrcId) +computeASTLims astInfo lexemeInfo = + M.mapMaybe stripOrd $ flip execState mempty $ flip runReaderT (astInfo, lexemeInfo) $ + visitSrcId rootSrcId + where stripOrd :: OrdSrcIdSpan -> Maybe (SrcId, SrcId) + stripOrd (OrdSrcIdSpan s) = case s of + Just (OrdSrcId _ l, OrdSrcId _ r) -> Just (l, r) + Nothing -> Nothing + +insertSpan :: SrcId -> OrdSrcIdSpan -> ComputeSpanM () +insertSpan sid span = modify \m -> M.insert sid span m + +getSelfSpans :: SrcId -> ComputeSpanM [OrdSrcIdSpan] +getSelfSpans sid = do + lexemes <- asks $ lexemeInfo . snd + case M.lookup sid lexemes of + Nothing -> return [] + Just (_, (low, _)) -> do + let sidOrd = OrdSrcId low sid + return [OrdSrcIdSpan $ Just $ (sidOrd, sidOrd)] + +getChildren :: SrcId -> ComputeSpanM [SrcId] +getChildren sid = do + astInfo <- asks fst + return $ fromMaybe [] $ M.lookup sid $ astChildren astInfo + +visitSrcId :: SrcId -> ComputeSpanM OrdSrcIdSpan +visitSrcId sid = do + childSpans <- mapM visitSrcId =<< getChildren sid + selfSpans <- getSelfSpans sid + let finalSpan = fold $ selfSpans ++ childSpans + insertSpan sid finalSpan + return finalSpan diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 565027294..9897bc360 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -49,7 +49,7 @@ instance IsTree Group where CPrim _ xs -> mapM_ visit xs CParens xs -> mapM_ visit xs CBrackets xs -> mapM_ visit xs - CBin b l r -> visit b >> visit l >> visit r + CBin b l r -> visit l >> visit b >> visit r CJuxtapose _ l r -> visit l >> visit r CPrefix l r -> visit l >> visit r CGivens (x,y) -> visit x >> visit y diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 5c01c3f37..cdb0a47f4 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -56,6 +56,9 @@ newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr -- XXX: 0 is reserved for the root newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) +rootSrcId :: SrcId +rootSrcId = SrcId 0 + -- This is just for syntax highlighting. It won't be needed if we have -- a separate lexing pass where we have a complete lossless data type for -- lexemes. diff --git a/static/index.js b/static/index.js index 7657b3024..836c0a11b 100644 --- a/static/index.js +++ b/static/index.js @@ -70,16 +70,16 @@ function render(renderMode) { } function selectSpan(cellCtx, srcId) { - let [cell, blockId, _] = cellCtx + let [cell, blockId, , ] = cellCtx return cell.querySelector("#span_".concat(blockId.toString(), "_", srcId.toString()));} function attachHovertip(cellCtx, srcId) { let span = selectSpan(cellCtx, srcId); - span.addEventListener("mouseover", (event) => enterSpan(event, cellCtx, srcId)); - span.addEventListener("mouseout" , (event) => leaveSpan(event, cellCtx, srcId));} + span.addEventListener("mouseover", (event) => toggleSpan(event, cellCtx, srcId)); + span.addEventListener("mouseout" , (event) => toggleSpan(event, cellCtx, srcId));} function getParent(cellCtx, srcId) { - let [ , , astInfo] = cellCtx; + let [ , , astInfo, ] = cellCtx; let parent = astInfo["astParent"][srcId.toString()] if (parent == undefined) { console.error(srcId, astInfo); @@ -89,7 +89,7 @@ function getParent(cellCtx, srcId) { }} function getChildren(cellCtx, srcId) { - let [ , , astInfo] = cellCtx; + let [ , , astInfo, ] = cellCtx; let children = astInfo["astChildren"][srcId.toString()] if (children == undefined) { return []; @@ -97,40 +97,43 @@ function getChildren(cellCtx, srcId) { return children; }} -function traverseSpans(cellCtx, srcId, f) { - let span = selectSpan(cellCtx, srcId) - if (span !== null) f(span); - getChildren(cellCtx, srcId).map(function (childId) { - traverseSpans(cellCtx, childId, f); - })} +function isLeafGroup(span) { + return span !== null && (span.classList.contains("keyword") || span.classList.contains("symbol")) +} -function enterSpan(event, cellCtx, srcId) { - event.stopPropagation(); - let parentId = getParent(cellCtx, srcId); - traverseSpans(cellCtx, parentId, function (span) { - span.style.backgroundColor = "lightblue"; - span.style.outlineColor = "lightblue"; - span.style.outlineStyle = "solid"; - }); - let siblingIds = getChildren(cellCtx, parentId); - siblingIds.map(function (siblingId) { - traverseSpans(cellCtx, siblingId, function (span) { - span.style.backgroundColor = "yellow"; - })})} +function toggleSrcIdHighlighting(cellCtx, srcId) { + let maybeLeaf = selectSpan(cellCtx, srcId) + // XXX: this is a bit of a hack. We should probably collect information + // about node types on the Haskell side + if (isLeafGroup(maybeLeaf)) { + maybeLeaf.classList.toggle("highlighted-leaf"); + } else { + getSrcIdSpans(cellCtx, srcId).map(function (span) { + span.classList.toggle("highlighted"); + })}} + +// All HTML spans associated with the srcId (these should be contiguous) +function getSrcIdSpans(cellCtx, srcId) { + let [ , , , nodeSpans] = cellCtx; + let [leftSrcId, rightSrcId] = nodeSpans[srcId]; + return spansBetween(selectSpan(cellCtx, leftSrcId), selectSpan(cellCtx, rightSrcId));} + +function spansBetween(l, r) { + let spans = [] + while (l !== null && !(Object.is(l, r))) { + spans.push(l); + l = l.nextSibling; + } + spans.push(r) + return spans;} -function leaveSpan(event, cellCtx, srcId) { +function toggleSpan(event, cellCtx, srcId) { event.stopPropagation(); let parentId = getParent(cellCtx, srcId); - traverseSpans(cellCtx, parentId, function (span) { - span.style.backgroundColor = null; - span.style.outlineColor = null; - span.style.outlineStyle = null; - }); let siblingIds = getChildren(cellCtx, parentId); siblingIds.map(function (siblingId) { - traverseSpans(cellCtx, siblingId, function (span) { - span.style.backgroundColor = null; - })})} + toggleSrcIdHighlighting(cellCtx, siblingId) + })} function setCellContents(cell, contents) { let source = contents[0]; @@ -200,7 +203,8 @@ function processUpdate(msg) { var blockId = source["jdBlockId"]; var astInfo = source["jdASTInfo"]; var lexemeList = source["jdLexemeList"]; - cellCtx = [cell, blockId, astInfo]; + var astLims = source["jdASTLims"]; + cellCtx = [cell, blockId, astInfo, astLims]; lexemeList.map(function (lexemeId) {attachHovertip(cellCtx, lexemeId)}) } }); diff --git a/static/style.css b/static/style.css index 0383c9578..6132484d0 100644 --- a/static/style.css +++ b/static/style.css @@ -100,6 +100,14 @@ code { color: #E07000; } +.highlighted { + background-color: yellow; +} + +.highlighted-leaf { + background-color: lightblue; +} + .type-name { color: #A80000; } From 9de972dc8115ff272ebd114746c8c51d3e151c5f Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 28 Nov 2023 22:13:41 -0500 Subject: [PATCH 25/41] Move most of the highlighting logic to Haskell where we can change it more easily. --- src/lib/ConcreteSyntax.hs | 8 +-- src/lib/Live/Eval.hs | 104 +++++++++++++-------------------- src/lib/SourceIdTraversal.hs | 36 ++++++------ src/lib/Types/Source.hs | 21 ++++--- static/index.js | 109 ++++++++++++++--------------------- 5 files changed, 114 insertions(+), 164 deletions(-) diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 1d9a0c320..f0442b361 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -31,7 +31,7 @@ import Lexing import Types.Core import Types.Source import Types.Primitives -import SourceIdTraversal (getASTInfo) +import SourceIdTraversal import qualified Types.OpNames as P import Util @@ -60,7 +60,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty mempty (Misc $ ImportModule Prelude) +preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty Nothing (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -99,8 +99,8 @@ sourceBlock = do b <- sourceBlock' return (level, b) let lexInfo' = lexInfo { lexemeInfo = lexemeInfo lexInfo <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} - let astInfo = getASTInfo b - return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' astInfo b + let groupTree = getGroupTree b + return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' (Just groupTree) b recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') recover e = do diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 286e05160..fadbd02aa 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -13,12 +13,11 @@ module Live.Eval ( import Control.Concurrent import Control.Monad import Control.Monad.State.Strict -import Control.Monad.Reader +import Control.Monad.Writer.Strict import qualified Data.Map.Strict as M import Data.Aeson (ToJSON, ToJSONKey, toJSON, Value) import Data.Functor ((<&>)) -import Data.Foldable (fold) -import Data.Maybe (fromJust, fromMaybe) +import Data.Maybe (fromJust) import Data.Text (Text) import Prelude hiding (span) import GHC.Generics @@ -311,7 +310,6 @@ instance ToJSON a => ToJSON (MapEltUpdate a) instance ToJSON o => ToJSON (NodeEvalStatus o) instance ToJSON SrcId deriving instance ToJSONKey SrcId -instance ToJSON ASTInfo instance ToJSON LexemeType instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) @@ -319,8 +317,8 @@ data SourceBlockJSONData = SourceBlockJSONData { jdLine :: Int , jdBlockId :: Int , jdLexemeList :: [SrcId] - , jdASTInfo :: ASTInfo - , jdASTLims :: M.Map SrcId (SrcId, SrcId) -- precomputed leftmost and rightmost spans associated with each node + , jdFocusMap :: FocusMap + , jdHighlightMap :: HighlightMap , jdHTML :: String } deriving (Generic) instance ToJSON SourceBlockJSONData @@ -330,69 +328,45 @@ instance ToJSON SourceBlockWithId where { jdLine = sbLine b' , jdBlockId = blockId , jdLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b' - , jdASTInfo = sbASTInfo b' - , jdASTLims = computeASTLims (sbASTInfo b') (sbLexemeInfo b') + , jdFocusMap = computeFocus b' + , jdHighlightMap = computeHighlights b' , jdHTML = pprintHtml b } instance ToJSON Result where toJSON = toJSONViaHtml - toJSONViaHtml :: ToMarkup a => a -> Value toJSONViaHtml x = toJSON $ pprintHtml x --- === computing the linear lexeme limits on each SrcId === - -data OrdSrcId = OrdSrcId Int SrcId -newtype OrdSrcIdSpan = OrdSrcIdSpan (Maybe (OrdSrcId, OrdSrcId)) - -type SpanMap = M.Map SrcId OrdSrcIdSpan -type ComputeSpanM = ReaderT (ASTInfo, LexemeInfo) (State SpanMap) - -instance Eq OrdSrcId where - OrdSrcId x _ == OrdSrcId y _ = x == y - -instance Ord OrdSrcId where - compare (OrdSrcId x _) (OrdSrcId y _) = compare x y - -instance Monoid OrdSrcIdSpan where - mempty = OrdSrcIdSpan Nothing - -instance Semigroup OrdSrcIdSpan where - OrdSrcIdSpan Nothing <> s = s - s <> OrdSrcIdSpan Nothing = s - OrdSrcIdSpan (Just (l, r)) <> OrdSrcIdSpan (Just (l', r')) = - OrdSrcIdSpan $ Just (min l l', max r r') - -computeASTLims :: ASTInfo -> LexemeInfo -> M.Map SrcId (SrcId, SrcId) -computeASTLims astInfo lexemeInfo = - M.mapMaybe stripOrd $ flip execState mempty $ flip runReaderT (astInfo, lexemeInfo) $ - visitSrcId rootSrcId - where stripOrd :: OrdSrcIdSpan -> Maybe (SrcId, SrcId) - stripOrd (OrdSrcIdSpan s) = case s of - Just (OrdSrcId _ l, OrdSrcId _ r) -> Just (l, r) - Nothing -> Nothing - -insertSpan :: SrcId -> OrdSrcIdSpan -> ComputeSpanM () -insertSpan sid span = modify \m -> M.insert sid span m - -getSelfSpans :: SrcId -> ComputeSpanM [OrdSrcIdSpan] -getSelfSpans sid = do - lexemes <- asks $ lexemeInfo . snd - case M.lookup sid lexemes of - Nothing -> return [] - Just (_, (low, _)) -> do - let sidOrd = OrdSrcId low sid - return [OrdSrcIdSpan $ Just $ (sidOrd, sidOrd)] - -getChildren :: SrcId -> ComputeSpanM [SrcId] -getChildren sid = do - astInfo <- asks fst - return $ fromMaybe [] $ M.lookup sid $ astChildren astInfo - -visitSrcId :: SrcId -> ComputeSpanM OrdSrcIdSpan -visitSrcId sid = do - childSpans <- mapM visitSrcId =<< getChildren sid - selfSpans <- getSelfSpans sid - let finalSpan = fold $ selfSpans ++ childSpans - insertSpan sid finalSpan - return finalSpan +-- === highlighting on hover === +-- TODO: put this somewhere else, like RenderHtml or something + +newtype FocusMap = FocusMap (M.Map LexemeId SrcId) deriving (ToJSON, Semigroup, Monoid) +newtype HighlightMap = HighlightMap (M.Map SrcId Highlights) deriving (ToJSON, Semigroup, Monoid) +type Highlights = [(HighlightType, LexemeSpan)] +data HighlightType = HighlightGroup | HighlightLeaf deriving Generic + +instance ToJSON HighlightType + +computeFocus :: SourceBlock -> FocusMap +computeFocus sb = execWriter $ mapM go $ sbGroupTree sb where + go :: GroupTree -> Writer FocusMap () + go t = forM_ (gtChildren t) \child-> do + go child + tell $ FocusMap $ M.singleton (gtSrcId child) (gtSrcId t) + +computeHighlights :: SourceBlock -> HighlightMap +computeHighlights sb = execWriter $ mapM go $ sbGroupTree sb where + go :: GroupTree -> Writer HighlightMap () + go t = do + spans <- forM (gtChildren t) \child -> do + go child + return (getHighlightType (gtSrcId child), gtSpan child) + tell $ HighlightMap $ M.singleton (gtSrcId t) spans + + getHighlightType :: SrcId -> HighlightType + getHighlightType sid = case M.lookup sid (lexemeInfo $ sbLexemeInfo sb) of + Nothing -> HighlightGroup -- not a lexeme + Just (lexemeTy, _) -> case lexemeTy of + Symbol -> HighlightLeaf + Keyword -> HighlightLeaf + _ -> HighlightGroup diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 9897bc360..1900cafbc 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -4,32 +4,34 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module SourceIdTraversal (getASTInfo) where +module SourceIdTraversal (getGroupTree) where -import qualified Data.Map.Strict as M -import Control.Monad.Reader import Control.Monad.Writer.Strict +import Data.Functor ((<&>)) import Types.Source import Types.Primitives -getASTInfo :: SourceBlock' -> ASTInfo -getASTInfo b = runTreeM (SrcId 0) $ visit b +getGroupTree :: SourceBlock' -> GroupTree +getGroupTree b = mkGroupTree rootSrcId $ runTreeM $ visit b -type TreeM = ReaderT SrcId (Writer ASTInfo) +type TreeM = Writer [GroupTree] -runTreeM :: SrcId -> TreeM () -> ASTInfo -runTreeM root cont = snd $ runWriter $ runReaderT cont root +mkGroupTree :: SrcId -> [GroupTree] -> GroupTree +mkGroupTree sid = \case + [] -> GroupTree sid (sid,sid) [] -- no children - must be a lexeme + subtrees -> GroupTree sid (l,r) subtrees + where l = minimum $ subtrees <&> \(GroupTree _ (l',_) _) -> l' + r = maximum $ subtrees <&> \(GroupTree _ (_,r') _) -> r' -enterNode :: SrcId -> TreeM a -> TreeM a -enterNode sid cont = do - emitNode sid - local (const sid) cont +runTreeM :: TreeM () -> [GroupTree] +runTreeM cont = snd $ runWriter $ cont -emitNode :: SrcId -> TreeM () -emitNode child = do - parent <- ask - tell $ ASTInfo (M.singleton child parent) (M.singleton parent [child]) +enterNode :: SrcId -> TreeM () -> TreeM () +enterNode sid cont = tell [mkGroupTree sid (runTreeM cont)] + +emitLexeme :: SrcId -> TreeM () +emitLexeme lexemeId = tell [mkGroupTree lexemeId []] class IsTree a where visit :: a -> TreeM () @@ -94,7 +96,7 @@ instance IsTree a => IsTree (WithSrc a) where visit (WithSrc sid x) = enterNode sid $ visit x instance IsTree a => IsTree (WithSrcs a) where - visit (WithSrcs sid sids x) = enterNode sid $ mapM_ emitNode sids >> visit x + visit (WithSrcs sid sids x) = enterNode sid $ mapM_ emitLexeme sids >> visit x instance IsTree a => IsTree [a] where visit xs = mapM_ visit xs diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index cdb0a47f4..3c39d603a 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -53,7 +53,8 @@ newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr -- === Source Info === --- XXX: 0 is reserved for the root +-- XXX: 0 is reserved for the root The IDs are generated from left to right in +-- parsing order, so IDs for lexemes are guaranteed to be sorted correctly. newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) rootSrcId :: SrcId @@ -79,21 +80,19 @@ data LexemeInfo = LexemeInfo , lexemeInfo :: M.Map SrcId (LexemeType, Span) } deriving (Show, Generic) -data ASTInfo = ASTInfo - { astParent :: M.Map SrcId SrcId - , astChildren :: M.Map SrcId [SrcId]} - deriving (Show, Generic) +type LexemeId = SrcId +type LexemeSpan = (LexemeId, LexemeId) +data GroupTree = GroupTree + { gtSrcId :: SrcId + , gtSpan :: LexemeSpan + , gtChildren :: [GroupTree] } + deriving (Show, Generic) instance Semigroup LexemeInfo where LexemeInfo a b <> LexemeInfo a' b' = LexemeInfo (a <> a') (b <> b') instance Monoid LexemeInfo where mempty = LexemeInfo mempty mempty -instance Semigroup ASTInfo where - ASTInfo a b <> ASTInfo a' b' = ASTInfo (a <> a') (M.unionWith (<>) b b') -instance Monoid ASTInfo where - mempty = ASTInfo mempty mempty - -- === Concrete syntax === -- The grouping-level syntax of the source language @@ -521,7 +520,7 @@ data SourceBlock = SourceBlock , sbLogLevel :: LogLevel , sbText :: Text , sbLexemeInfo :: LexemeInfo - , sbASTInfo :: ASTInfo + , sbGroupTree :: Maybe GroupTree , sbContents :: SourceBlock' } deriving (Show, Generic) diff --git a/static/index.js b/static/index.js index 836c0a11b..725cb7a86 100644 --- a/static/index.js +++ b/static/index.js @@ -69,54 +69,38 @@ function render(renderMode) { } } -function selectSpan(cellCtx, srcId) { - let [cell, blockId, , ] = cellCtx - return cell.querySelector("#span_".concat(blockId.toString(), "_", srcId.toString()));} - function attachHovertip(cellCtx, srcId) { let span = selectSpan(cellCtx, srcId); span.addEventListener("mouseover", (event) => toggleSpan(event, cellCtx, srcId)); span.addEventListener("mouseout" , (event) => toggleSpan(event, cellCtx, srcId));} -function getParent(cellCtx, srcId) { - let [ , , astInfo, ] = cellCtx; - let parent = astInfo["astParent"][srcId.toString()] - if (parent == undefined) { - console.error(srcId, astInfo); - throw new Error("Can't find parent"); - } else { - return parent; - }} - -function getChildren(cellCtx, srcId) { - let [ , , astInfo, ] = cellCtx; - let children = astInfo["astChildren"][srcId.toString()] - if (children == undefined) { - return []; - } else { - return children; - }} - -function isLeafGroup(span) { - return span !== null && (span.classList.contains("keyword") || span.classList.contains("symbol")) -} +function selectSpan(cellCtx, srcId) { + let [cell, blockId, , ] = cellCtx + return cell.querySelector("#span_".concat(blockId.toString(), "_", srcId.toString()));} -function toggleSrcIdHighlighting(cellCtx, srcId) { - let maybeLeaf = selectSpan(cellCtx, srcId) - // XXX: this is a bit of a hack. We should probably collect information - // about node types on the Haskell side - if (isLeafGroup(maybeLeaf)) { - maybeLeaf.classList.toggle("highlighted-leaf"); +function getHighlightClass(highlightType) { + if (highlightType == "HighlightGroup") { + return "highlighted"; + } else if (highlightType == "HighlightLeaf") { + return "highlighted-leaf"; } else { - getSrcIdSpans(cellCtx, srcId).map(function (span) { - span.classList.toggle("highlighted"); - })}} + throw new Error("Unrecognized highlight type"); + } +} -// All HTML spans associated with the srcId (these should be contiguous) -function getSrcIdSpans(cellCtx, srcId) { - let [ , , , nodeSpans] = cellCtx; - let [leftSrcId, rightSrcId] = nodeSpans[srcId]; - return spansBetween(selectSpan(cellCtx, leftSrcId), selectSpan(cellCtx, rightSrcId));} +function toggleSpan(event, cellCtx, srcId) { + event.stopPropagation(); + let [ , , focusMap, highlightMap] = cellCtx; + let focus = focusMap[srcId.toString()]; + if (focus == null) return; + let highlights = highlightMap[focus.toString()]; + highlights.map(function (highlight) { + let [highlightType, [l, r]] = highlight; + let spans = spansBetween(selectSpan(cellCtx, l), selectSpan(cellCtx, r)); + let highlightClass = getHighlightClass(highlightType); + spans.map(function (span) { + span.classList.toggle(highlightClass); + })})} function spansBetween(l, r) { let spans = [] @@ -127,14 +111,6 @@ function spansBetween(l, r) { spans.push(r) return spans;} -function toggleSpan(event, cellCtx, srcId) { - event.stopPropagation(); - let parentId = getParent(cellCtx, srcId); - let siblingIds = getChildren(cellCtx, parentId); - siblingIds.map(function (siblingId) { - toggleSrcIdHighlighting(cellCtx, siblingId) - })} - function setCellContents(cell, contents) { let source = contents[0]; let results = contents[1]; @@ -162,18 +138,18 @@ function setCellContents(cell, contents) { } function processUpdate(msg) { - var cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; - var num_dropped = msg["orderedNodesUpdate"]["numDropped"]; - var new_tail = msg["orderedNodesUpdate"]["newTail"]; + let cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; + let num_dropped = msg["orderedNodesUpdate"]["numDropped"]; + let new_tail = msg["orderedNodesUpdate"]["newTail"]; // drop_dead_cells for (i = 0; i < num_dropped; i++) { body.lastElementChild.remove();} Object.keys(cell_updates).forEach(function (node_id) { - var update = cell_updates[node_id]; - var tag = update["tag"] - var contents = update["contents"] + let update = cell_updates[node_id]; + let tag = update["tag"] + let contents = update["contents"] if (tag == "Create") { var cell = document.createElement("div"); cells[node_id] = cell; @@ -190,23 +166,22 @@ function processUpdate(msg) { // append_new_cells new_tail.forEach(function (node_id) { - cell = cells[node_id]; + let cell = cells[node_id]; body.appendChild(cell); }) - // add hovertips - new_tail.forEach(function (node_id) { - cell = cells[node_id]; - var update = cell_updates[node_id]; - if (update["tag"] == "Create") { - var source = update["contents"][0]; - var blockId = source["jdBlockId"]; - var astInfo = source["jdASTInfo"]; - var lexemeList = source["jdLexemeList"]; - var astLims = source["jdASTLims"]; - cellCtx = [cell, blockId, astInfo, astLims]; + Object.keys(cell_updates).forEach(function (node_id) { + let update = cell_updates[node_id]; + let tag = update["tag"] + let cell = cells[node_id]; + if (tag == "Create" || tag == "Update") { + let source = update["contents"][0]; + let blockId = source["jdBlockId"]; + let lexemeList = source["jdLexemeList"]; + let focusMap = source["jdFocusMap"]; + let highlightMap = source["jdHighlightMap"]; + cellCtx = [cell, blockId, focusMap, highlightMap]; lexemeList.map(function (lexemeId) {attachHovertip(cellCtx, lexemeId)}) } }); } - From a97a641ed83ba49303ea2b370f091ec997d2b8d4 Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 29 Nov 2023 10:44:20 -0500 Subject: [PATCH 26/41] Plumbing for adding textual information on hover. --- src/lib/Live/Eval.hs | 12 ++++++++++++ static/dynamic.html | 6 ++++-- static/index.js | 23 ++++++++++++++++++++--- static/style.css | 38 ++++++++++---------------------------- 4 files changed, 46 insertions(+), 33 deletions(-) diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index fadbd02aa..730c60928 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -17,6 +17,7 @@ import Control.Monad.Writer.Strict import qualified Data.Map.Strict as M import Data.Aeson (ToJSON, ToJSONKey, toJSON, Value) import Data.Functor ((<&>)) +import Data.Foldable (toList) import Data.Maybe (fromJust) import Data.Text (Text) import Prelude hiding (span) @@ -319,6 +320,7 @@ data SourceBlockJSONData = SourceBlockJSONData , jdLexemeList :: [SrcId] , jdFocusMap :: FocusMap , jdHighlightMap :: HighlightMap + , jdHoverInfoMap :: HoverInfoMap , jdHTML :: String } deriving (Generic) instance ToJSON SourceBlockJSONData @@ -330,6 +332,7 @@ instance ToJSON SourceBlockWithId where , jdLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b' , jdFocusMap = computeFocus b' , jdHighlightMap = computeHighlights b' + , jdHoverInfoMap = computeHoverInfo b' , jdHTML = pprintHtml b } instance ToJSON Result where toJSON = toJSONViaHtml @@ -337,6 +340,15 @@ instance ToJSON Result where toJSON = toJSONViaHtml toJSONViaHtml :: ToMarkup a => a -> Value toJSONViaHtml x = toJSON $ pprintHtml x +-- === textual information on hover === + +type HoverInfo = String +newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) + +computeHoverInfo :: SourceBlock -> HoverInfoMap +computeHoverInfo sb = HoverInfoMap $ + M.fromList $ toList (lexemeList (sbLexemeInfo sb)) <&> \srcId -> (srcId, show srcId) + -- === highlighting on hover === -- TODO: put this somewhere else, like RenderHtml or something diff --git a/static/dynamic.html b/static/dynamic.html index 5e636424a..eb0111d13 100644 --- a/static/dynamic.html +++ b/static/dynamic.html @@ -21,8 +21,10 @@ -
- +
+ (hover over code for more information) +
+
diff --git a/static/index.js b/static/index.js index 725cb7a86..842acc812 100644 --- a/static/index.js +++ b/static/index.js @@ -44,6 +44,7 @@ var RENDER_MODE = Object.freeze({ // mapping from server-provided NodeID to HTML id var cells = {}; var body = document.getElementById("main-output"); +var hoverInfoDiv = document.getElementById("hover-info"); /** * Renders the webpage. @@ -71,8 +72,23 @@ function render(renderMode) { function attachHovertip(cellCtx, srcId) { let span = selectSpan(cellCtx, srcId); - span.addEventListener("mouseover", (event) => toggleSpan(event, cellCtx, srcId)); - span.addEventListener("mouseout" , (event) => toggleSpan(event, cellCtx, srcId));} + span.addEventListener("mouseover", function (event) { + addHoverInfo(cellCtx, srcId); + toggleSpan(event, cellCtx, srcId);}) + span.addEventListener("mouseout" , function(event) { + removeHoverInfo(); + toggleSpan(event, cellCtx, srcId)}); +} + +function addHoverInfo(cellCtx, srcId) { + let [ , , , , hoverInfoMap] = cellCtx + s = hoverInfoMap[srcId.toString()] + hoverInfoDiv.innerHTML = s; +} + +function removeHoverInfo() { + hoverInfoDiv.innerHTML = ""; +} function selectSpan(cellCtx, srcId) { let [cell, blockId, , ] = cellCtx @@ -180,7 +196,8 @@ function processUpdate(msg) { let lexemeList = source["jdLexemeList"]; let focusMap = source["jdFocusMap"]; let highlightMap = source["jdHighlightMap"]; - cellCtx = [cell, blockId, focusMap, highlightMap]; + let hoverInfoMap = source["jdHoverInfoMap"]; + cellCtx = [cell, blockId, focusMap, highlightMap, hoverInfoMap]; lexemeList.map(function (lexemeId) {attachHovertip(cellCtx, lexemeId)}) } }); diff --git a/static/style.css b/static/style.css index 6132484d0..5c83bb2c1 100644 --- a/static/style.css +++ b/static/style.css @@ -11,42 +11,25 @@ body { display: flex; justify-content: space-between; overflow-x: hidden; - --main-width: 50rem; - --nav-width: 20rem; + padding-bottom:50vw; } -@media (max-width: 70rem) { - /*For narrow screens hide nav and enable horizontal scrolling */ - nav {display: none;} - body {overflow-x: auto;} -} - -nav {/* this actually just holds space for #navbar, which is fixed */ - min-width: var(--nav-width); - max-width: var(--nav-width); -} -#navbar { +#hover-info { position: fixed; - height: 100vh; - width: var(--nav-width); - overflow-y: scroll; - border-right: 1px solid firebrick; -} -#navbar:before { - content: "Contents"; - font-weight: bold; -} -nav ol { - list-style-type:none; - padding-left: 1rem; + height: 10rem; + bottom: 0em; + width: 100vw; + overflow: hidden; + background-color: white; + border-top: 1px solid firebrick; + font-family: monospace; + white-space: pre; } #main-output { - max-width: var(--main-width); margin: auto; } - .code-block { } @@ -128,7 +111,6 @@ code { text-align: right; } - .waiting-cell { border-left: 6px solid #AAAAFF; } From 6f1ab77296de6e31c7e2a48c9cae6cb54af5a70c Mon Sep 17 00:00:00 2001 From: Dougal Date: Wed, 29 Nov 2023 14:41:33 -0500 Subject: [PATCH 27/41] Freeze highlighting and hover on click. --- static/index.js | 158 +++++++++++++++++++++++++++--------------------- 1 file changed, 88 insertions(+), 70 deletions(-) diff --git a/static/index.js b/static/index.js index 842acc812..0f33463e3 100644 --- a/static/index.js +++ b/static/index.js @@ -41,11 +41,68 @@ var RENDER_MODE = Object.freeze({ DYNAMIC: "dynamic", }) -// mapping from server-provided NodeID to HTML id -var cells = {}; -var body = document.getElementById("main-output"); +var body = document.getElementById("main-output"); var hoverInfoDiv = document.getElementById("hover-info"); +// State of the system beyond the HTML +var cells = {} +var frozenHover = false; +var curHighlights = []; // HTML elements currently highlighted +var focusMap = {} +var highlightMap = {} +var hoverInfoMap = {} + +function removeHover() { + if (frozenHover) return; + hoverInfoDiv.innerHTML = "" + curHighlights.map(function (element) { + element.classList.remove("highlighted", "highlighted-leaf")}) + curHighlights = []; +} +function lookupSrcMap(m, cellId, srcId) { + let blockMap = m[cellId] + if (blockMap == null) { + return null + } else { + return blockMap[srcId]} +} +function applyHover(cellId, srcId) { + if (frozenHover) return; + applyHoverInfo(cellId, srcId) + applyHoverHighlights(cellId, srcId) +} +function applyHoverInfo(cellId, srcId) { + let hoverInfo = lookupSrcMap(hoverInfoMap, cellId, srcId) + hoverInfoDiv.innerHTML = srcId.toString() // hoverInfo +} +function applyHoverHighlights(cellId, srcId) { + let focus = lookupSrcMap(focusMap, cellId, srcId) + if (focus == null) return + let highlights = lookupSrcMap(highlightMap, cellId, focus) + highlights.map(function (highlight) { + let [highlightType, [l, r]] = highlight + let spans = spansBetween(selectSpan(cellId, l), selectSpan(cellId, r)); + let highlightClass = getHighlightClass(highlightType) + spans.map(function (span) { + span.classList.add(highlightClass) + curHighlights.push(span)})}) +} +function toggleFrozenHover() { + if (frozenHover) { + frozenHover = false + removeHover() + } else { + frozenHover = true} +} +function attachHovertip(cellId, srcId) { + let span = selectSpan(cellId, srcId) + span.addEventListener("mouseover", function (event) { + event.stopPropagation() + applyHover(cellId, srcId)}) + span.addEventListener("mouseout" , function (event) { + event.stopPropagation() + removeHover()})} + /** * Renders the webpage. * @param {RENDER_MODE} renderMode The render mode, either static or dynamic. @@ -60,40 +117,23 @@ function render(renderMode) { source.onmessage = function(event) { var msg = JSON.parse(event.data); if (msg == "start") { - body.innerHTML = ""; + body.innerHTML = "" + body.addEventListener("click", function (event) { + event.stopPropagation() + toggleFrozenHover()}) cells = {} return } else { - processUpdate(msg); - } - }; - } + processUpdate(msg)}};} } -function attachHovertip(cellCtx, srcId) { - let span = selectSpan(cellCtx, srcId); - span.addEventListener("mouseover", function (event) { - addHoverInfo(cellCtx, srcId); - toggleSpan(event, cellCtx, srcId);}) - span.addEventListener("mouseout" , function(event) { - removeHoverInfo(); - toggleSpan(event, cellCtx, srcId)}); -} -function addHoverInfo(cellCtx, srcId) { - let [ , , , , hoverInfoMap] = cellCtx - s = hoverInfoMap[srcId.toString()] - hoverInfoDiv.innerHTML = s; +function selectSpan(cellId, srcId) { + return cells[cellId].querySelector("#span_".concat(cellId, "_", srcId)) } - -function removeHoverInfo() { - hoverInfoDiv.innerHTML = ""; +function selectCell(cellId) { + return cells[cellId] } - -function selectSpan(cellCtx, srcId) { - let [cell, blockId, , ] = cellCtx - return cell.querySelector("#span_".concat(blockId.toString(), "_", srcId.toString()));} - function getHighlightClass(highlightType) { if (highlightType == "HighlightGroup") { return "highlighted"; @@ -103,21 +143,6 @@ function getHighlightClass(highlightType) { throw new Error("Unrecognized highlight type"); } } - -function toggleSpan(event, cellCtx, srcId) { - event.stopPropagation(); - let [ , , focusMap, highlightMap] = cellCtx; - let focus = focusMap[srcId.toString()]; - if (focus == null) return; - let highlights = highlightMap[focus.toString()]; - highlights.map(function (highlight) { - let [highlightType, [l, r]] = highlight; - let spans = spansBetween(selectSpan(cellCtx, l), selectSpan(cellCtx, r)); - let highlightClass = getHighlightClass(highlightType); - spans.map(function (span) { - span.classList.toggle(highlightClass); - })})} - function spansBetween(l, r) { let spans = [] while (l !== null && !(Object.is(l, r))) { @@ -125,8 +150,8 @@ function spansBetween(l, r) { l = l.nextSibling; } spans.push(r) - return spans;} - + return spans +} function setCellContents(cell, contents) { let source = contents[0]; let results = contents[1]; @@ -152,7 +177,6 @@ function setCellContents(cell, contents) { } renderLaTeX(cell); } - function processUpdate(msg) { let cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; let num_dropped = msg["orderedNodesUpdate"]["numDropped"]; @@ -162,43 +186,37 @@ function processUpdate(msg) { for (i = 0; i < num_dropped; i++) { body.lastElementChild.remove();} - Object.keys(cell_updates).forEach(function (node_id) { - let update = cell_updates[node_id]; + Object.keys(cell_updates).forEach(function (cellId) { + let update = cell_updates[cellId]; let tag = update["tag"] let contents = update["contents"] if (tag == "Create") { var cell = document.createElement("div"); - cells[node_id] = cell; + cells[cellId] = cell; setCellContents(cell, contents) } else if (tag == "Update") { - var cell = cells[node_id]; + var cell = cells[cellId]; setCellContents(cell, contents); } else if (tag == "Delete") { - delete cells[node_id] + delete cells[cellId] } else { console.error(tag); - } - }); + }}); // append_new_cells - new_tail.forEach(function (node_id) { - let cell = cells[node_id]; - body.appendChild(cell); - }) + new_tail.forEach(function (cellId) { + let cell = selectCell(cellId); + body.appendChild(cell);}) - Object.keys(cell_updates).forEach(function (node_id) { - let update = cell_updates[node_id]; + Object.keys(cell_updates).forEach(function (cellId) { + let update = cell_updates[cellId] let tag = update["tag"] - let cell = cells[node_id]; if (tag == "Create" || tag == "Update") { + let update = cell_updates[cellId]; let source = update["contents"][0]; - let blockId = source["jdBlockId"]; - let lexemeList = source["jdLexemeList"]; - let focusMap = source["jdFocusMap"]; - let highlightMap = source["jdHighlightMap"]; - let hoverInfoMap = source["jdHoverInfoMap"]; - cellCtx = [cell, blockId, focusMap, highlightMap, hoverInfoMap]; - lexemeList.map(function (lexemeId) {attachHovertip(cellCtx, lexemeId)}) - } - }); + focusMap[cellId] = source["jdFocusMap"] + highlightMap[cellId] = source["jdHighlightMap"] + hoverInfoMap[cellId] = source["jsHoverInfoMap"] + let lexemeList = source["jdLexemeList"]; + lexemeList.map(function (lexemeId) {attachHovertip(cellId, lexemeId.toString())})}}); } From b0bd94ac50e4c9ff02d375f43f6b57b99923bcd3 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 30 Nov 2023 11:23:14 -0500 Subject: [PATCH 28/41] Make hover-info updates incremental so that passes can report as they complete --- src/lib/Actor.hs | 8 +- src/lib/IncState.hs | 74 ++++++++++----- src/lib/Live/Eval.hs | 212 ++++++++++++++++++++++++------------------ src/lib/Live/Web.hs | 2 +- src/lib/Types/Misc.hs | 9 ++ static/index.js | 52 +++++++---- 6 files changed, 218 insertions(+), 139 deletions(-) diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index 18a835bc4..1da61268e 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -10,7 +10,7 @@ module Actor ( ActorM, Actor (..), launchActor, send, selfMailbox, messageLoop, sliceMailbox, SubscribeMsg (..), IncServer, IncServerT, FileWatcher, StateServer, flushDiffs, handleSubscribeMsg, subscribe, subscribeIO, sendSync, - runIncServerT, launchFileWatcher + runIncServerT, launchFileWatcher, Mailbox ) where import Control.Concurrent @@ -188,14 +188,14 @@ launchClock intervalMicroseconds mailbox = -- === File watcher === type SourceFileContents = Text -type FileWatcher = StateServer SourceFileContents (Overwrite SourceFileContents) +type FileWatcher = StateServer (Overwritable SourceFileContents) (Overwrite SourceFileContents) readFileContents :: MonadIO m => FilePath -> m Text readFileContents path = liftIO $ T.decodeUtf8 <$> BS.readFile path data FileWatcherMsg = ClockSignal_FW () - | Subscribe_FW (SubscribeMsg Text (Overwrite Text)) + | Subscribe_FW (SubscribeMsg (Overwritable Text) (Overwrite Text)) deriving (Show) launchFileWatcher :: MonadIO m => FilePath -> m FileWatcher @@ -207,7 +207,7 @@ fileWatcherImpl path = do t0 <- liftIO $ getModificationTime path launchClock 100000 =<< selfMailbox ClockSignal_FW modTimeRef <- newRef t0 - runIncServerT initContents $ messageLoop \case + runIncServerT (Overwritable initContents) $ messageLoop \case Subscribe_FW msg -> handleSubscribeMsg msg ClockSignal_FW () -> do tOld <- readRef modTimeRef diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs index 3c8f90d77..43f3ef044 100644 --- a/src/lib/IncState.hs +++ b/src/lib/IncState.hs @@ -8,7 +8,8 @@ module IncState ( IncState (..), MapEltUpdate (..), MapUpdate (..), - Overwrite (..), TailUpdate (..), mapUpdateMapWithKey) where + Overwrite (..), TailUpdate (..), Unchanging (..), Overwritable (..), + mapUpdateMapWithKey) where import qualified Data.Map.Strict as M import GHC.Generics @@ -20,53 +21,69 @@ class Monoid d => IncState s d where -- === Diff utils === -data MapEltUpdate v = - Create v - | Update v +data MapEltUpdate s d = + Create s + | Replace s -- TODO: should we merge Create/Replace? + | Update d | Delete deriving (Functor, Show, Generic) -data MapUpdate k v = MapUpdate { mapUpdates :: M.Map k (MapEltUpdate v) } +data MapUpdate k s d = MapUpdate { mapUpdates :: M.Map k (MapEltUpdate s d) } deriving (Functor, Show, Generic) -mapUpdateMapWithKey :: MapUpdate k v -> (k -> v -> v') -> MapUpdate k v' -mapUpdateMapWithKey (MapUpdate m) f = MapUpdate $ M.mapWithKey (\k v -> fmap (f k) v) m +mapUpdateMapWithKey :: MapUpdate k s d -> (k -> s -> s') -> (k -> d -> d') -> MapUpdate k s' d' +mapUpdateMapWithKey (MapUpdate m) fs fd = + MapUpdate $ flip M.mapWithKey m \k v -> case v of + Create s -> Create $ fs k s + Replace s -> Replace $ fs k s + Update d -> Update $ fd k d + Delete -> Delete -instance Ord k => Monoid (MapUpdate k v) where +instance (IncState s d, Ord k) => Monoid (MapUpdate k s d) where mempty = MapUpdate mempty -instance Ord k => Semigroup (MapUpdate k v) where +instance (IncState s d, Ord k) => Semigroup (MapUpdate k s d) where MapUpdate m1 <> MapUpdate m2 = MapUpdate $ M.mapMaybe id (M.intersectionWith combineElts m1 m2) <> M.difference m1 m2 <> M.difference m2 m1 where combineElts e1 e2 = case e1 of - Create _ -> case e2 of + Create s -> case e2 of Create _ -> error "shouldn't be creating a node that already exists" - Update v -> Just $ Create v + Replace s' -> Just $ Create s' + Update d -> Just $ Create (applyDiff s d) Delete -> Nothing - Update _ -> case e2 of + Replace s -> case e2 of Create _ -> error "shouldn't be creating a node that already exists" - Update v -> Just $ Update v + Replace s' -> Just $ Replace s' + Update d -> Just $ Replace (applyDiff s d) + Delete -> Nothing + Update d -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Replace s -> Just $ Replace s + Update d' -> Just $ Update (d <> d') Delete -> Just $ Delete Delete -> case e2 of - Create v -> Just $ Update v - Update _ -> error "shouldn't be updating a node that doesn't exist" - Delete -> error "shouldn't be deleting a node that doesn't exist" + Create s -> Just $ Replace s + Replace _ -> error "shouldn't be replacing a node that doesn't exist" + Update _ -> error "shouldn't be updating a node that doesn't exist" + Delete -> error "shouldn't be deleting a node that doesn't exist" -instance Ord k => IncState (M.Map k v) (MapUpdate k v) where +instance (IncState s d, Ord k) => IncState (M.Map k s) (MapUpdate k s d) where applyDiff m (MapUpdate updates) = M.mapMaybe id (M.intersectionWith applyEltUpdate m updates) <> M.difference m updates <> M.mapMaybe applyEltCreation (M.difference updates m) - where applyEltUpdate _ = \case + where applyEltUpdate s = \case Create _ -> error "key already exists" - Update v -> Just v + Replace s' -> Just s' + Update d -> Just $ applyDiff s d Delete -> Nothing applyEltCreation = \case - Create v -> Just v - Update _ -> error "key doesn't exist yet" - Delete -> error "key doesn't exist yet" + Create s -> Just s + Replace _ -> error "key doesn't exist yet" + Update _ -> error "key doesn't exist yet" + Delete -> error "key doesn't exist yet" data TailUpdate a = TailUpdate { numDropped :: Int @@ -87,7 +104,8 @@ instance IncState [a] (TailUpdate a) where applyDiff xs (TailUpdate numDrop ys) = take (length xs - numDrop) xs <> ys -- Trivial diff that works for any type - just replace the old value with a completely new one. -data Overwrite a = NoChange | OverwriteWith a deriving (Show) +data Overwrite a = NoChange | OverwriteWith a deriving (Show, Generic) +newtype Overwritable a = Overwritable { fromOverwritable :: a } deriving (Show, Eq, Ord) instance Semigroup (Overwrite a) where l <> r = case r of @@ -97,8 +115,14 @@ instance Semigroup (Overwrite a) where instance Monoid (Overwrite a) where mempty = NoChange -instance IncState a (Overwrite a) where +instance IncState (Overwritable a) (Overwrite a) where applyDiff s = \case NoChange -> s - OverwriteWith s' -> s' + OverwriteWith s' -> Overwritable s' + + +-- Trivial diff that works for any type - just replace the old value with a completely new one. +newtype Unchanging a = Unchanging { fromUnchanging :: a } deriving (Show, Eq, Ord) +instance IncState (Unchanging a) () where + applyDiff s () = s diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 730c60928..82f4bcee5 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -8,7 +8,7 @@ {-# OPTIONS_GHC -Wno-orphans #-} module Live.Eval ( - watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, dagAsUpdate, addSourceBlockIds) where + watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, nodeListAsUpdate, addSourceBlockIds) where import Control.Concurrent import Control.Monad @@ -41,14 +41,24 @@ watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO ResultsServer watchAndEvalFile fname opts env = do watcher <- launchFileWatcher fname parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source - launchDagEvaluator parser env (evalSourceBlockIO opts) + launchDagEvaluator parser env (evalSourceBlockIO' opts) -addSourceBlockIds :: NodeListUpdate (NodeState SourceBlock o) -> NodeListUpdate (NodeState SourceBlockWithId o) +addSourceBlockIds :: CellsUpdate SourceBlock o -> CellsUpdate SourceBlockWithId o addSourceBlockIds (NodeListUpdate listUpdate mapUpdate) = NodeListUpdate listUpdate mapUpdate' - where mapUpdate' = mapUpdateMapWithKey mapUpdate \k (NodeState i o) -> NodeState (SourceBlockWithId k i) o + where mapUpdate' = mapUpdateMapWithKey mapUpdate + (\k (CellState b s o) -> CellState (SourceBlockWithId k b) s o) + (\_ d -> d) + +-- shim to pretend that evalSourceBlockIO streams its results. TODO: make it actually do that. +evalSourceBlockIO' + :: EvalConfig -> Mailbox Result -> TopStateEx -> SourceBlock -> IO TopStateEx +evalSourceBlockIO' cfg resultChan env block = do + (result, env') <- evalSourceBlockIO cfg env block + send resultChan result + return env' type ResultsServer = Evaluator SourceBlock Result -type ResultsUpdate = EvalStatusUpdate SourceBlock Result +type ResultsUpdate = CellsUpdate SourceBlock Result -- === DAG diff state === @@ -63,26 +73,26 @@ data NodeList a = NodeList , nodeMap :: M.Map NodeId a } deriving (Show, Generic) -data NodeListUpdate a = NodeListUpdate +data NodeListUpdate s d = NodeListUpdate { orderedNodesUpdate :: TailUpdate NodeId - , nodeMapUpdate :: MapUpdate NodeId a } - deriving (Show, Functor, Generic) + , nodeMapUpdate :: MapUpdate NodeId s d } + deriving (Show, Generic) -instance Semigroup (NodeListUpdate a) where +instance IncState s d => Semigroup (NodeListUpdate s d) where NodeListUpdate x1 y1 <> NodeListUpdate x2 y2 = NodeListUpdate (x1<>x2) (y1<>y2) -instance Monoid (NodeListUpdate a) where +instance IncState s d => Monoid (NodeListUpdate s d) where mempty = NodeListUpdate mempty mempty -instance IncState (NodeList a) (NodeListUpdate a) where +instance IncState s d => IncState (NodeList s) (NodeListUpdate s d) where applyDiff (NodeList m xs) (NodeListUpdate dm dxs) = NodeList (applyDiff m dm) (applyDiff xs dxs) -type Dag = NodeList -type DagUpdate = NodeListUpdate +type Dag a = NodeList (Unchanging a) +type DagUpdate a = NodeListUpdate (Unchanging a) () -dagAsUpdate :: Dag a -> DagUpdate a -dagAsUpdate (NodeList xs m)= NodeListUpdate (TailUpdate 0 xs) (MapUpdate $ fmap Create m) +nodeListAsUpdate :: NodeList s -> NodeListUpdate s d +nodeListAsUpdate (NodeList xs m)= NodeListUpdate (TailUpdate 0 xs) (MapUpdate $ fmap Create m) emptyNodeList :: NodeList a emptyNodeList = NodeList [] mempty @@ -101,7 +111,7 @@ commonPrefixLength _ _ = 0 nodeListVals :: NodeList a -> [a] nodeListVals nodes = orderedNodes nodes <&> \k -> fromJust $ M.lookup k (nodeMap nodes) -computeNodeListUpdate :: (Eq a, FreshNames NodeId m) => NodeList a -> [a] -> m (NodeListUpdate a) +computeNodeListUpdate :: (Eq s, FreshNames NodeId m) => NodeList s -> [s] -> m (NodeListUpdate s d) computeNodeListUpdate nodes newVals = do let prefixLength = commonPrefixLength (nodeListVals nodes) newVals let oldTail = drop prefixLength $ orderedNodes nodes @@ -127,13 +137,13 @@ launchCellParser fileWatcher parseCells = cellParserImpl :: Eq a => FileWatcher -> (Text -> [a]) -> ActorM (CellParserMsg a) () cellParserImpl fileWatcher parseCells = runFreshNameT do - initContents <- subscribe Update_CP fileWatcher - initNodeList <- buildNodeList $ parseCells initContents + Overwritable initContents <- subscribe Update_CP fileWatcher + initNodeList <- buildNodeList $ fmap Unchanging $ parseCells initContents runIncServerT initNodeList $ messageLoop \case Subscribe_CP msg -> handleSubscribeMsg msg Update_CP NoChange -> return () Update_CP (OverwriteWith newContents) -> do - let newCells = parseCells newContents + let newCells = fmap Unchanging $ parseCells newContents curNodeList <- getl It update =<< computeNodeListUpdate curNodeList newCells flushDiffs @@ -143,17 +153,27 @@ cellParserImpl fileWatcher parseCells = runFreshNameT do -- This is where we track the state of evaluation and decide what we needs to be -- run and what needs to be killed. -type Evaluator i o = StateServer (EvalStatus i o) (EvalStatusUpdate i o) +type Evaluator i o = StateServer (CellsState i o) (CellsUpdate i o) newtype EvaluatorM s i o a = EvaluatorM { runEvaluatorM' :: - IncServerT (EvalStatus i o) (EvalStatusUpdate i o) + IncServerT (CellsState i o) (CellsUpdate i o) (StateT (EvaluatorState s i o) (ActorM (EvaluatorMsg s i o))) a } deriving (Functor, Applicative, Monad, MonadIO, - Actor (EvaluatorMsg s i o), - IncServer (EvalStatus i o) (EvalStatusUpdate i o)) + Actor (EvaluatorMsg s i o)) +deriving instance Monoid o => IncServer (CellsState i o) (CellsUpdate i o) (EvaluatorM s i o) + +instance Monoid o => Semigroup (CellUpdate o) where + CellUpdate s o <> CellUpdate s' o' = CellUpdate (s<>s') (o<>o') + +instance Monoid o => Monoid (CellUpdate o) where + mempty = CellUpdate mempty mempty + +instance Monoid o => IncState (CellState i o) (CellUpdate o) where + applyDiff (CellState source status result) (CellUpdate status' result') = + CellState source (fromOverwritable (applyDiff (Overwritable status) status')) (result <> result') -instance DefuncState (EvaluatorMUpdate s i o) (EvaluatorM s i o) where +instance Monoid o => DefuncState (EvaluatorMUpdate s i o) (EvaluatorM s i o) where update = \case UpdateDagEU dag -> EvaluatorM $ update dag UpdateCurJob status -> EvaluatorM $ lift $ modify \s -> s { curRunningJob = status } @@ -161,12 +181,10 @@ instance DefuncState (EvaluatorMUpdate s i o) (EvaluatorM s i o) where AppendEnv env -> do envs <- getl PrevEnvs update $ UpdateEnvs $ envs ++ [env] - UpdateJobStatus nodeId status -> do - NodeState i _ <- fromJust <$> getl (NodeInfo nodeId) - let newState = NodeState i status - update $ UpdateDagEU $ NodeListUpdate mempty $ MapUpdate $ M.singleton nodeId (Update newState) + UpdateCellState nodeId cellUpdate -> update $ UpdateDagEU $ NodeListUpdate mempty $ + MapUpdate $ M.singleton nodeId $ Update cellUpdate -instance LabelReader (EvaluatorMLabel s i o) (EvaluatorM s i o) where +instance Monoid o => LabelReader (EvaluatorMLabel s i o) (EvaluatorM s i o) where getl l = case l of NodeListEM -> EvaluatorM $ orderedNodes <$> getl It NodeInfo nodeId -> EvaluatorM $ M.lookup nodeId <$> nodeMap <$> getl It @@ -175,53 +193,57 @@ instance LabelReader (EvaluatorMLabel s i o) (EvaluatorM s i o) where EvalFun -> EvaluatorM $ lift $ evalFun <$> get data EvaluatorMUpdate s i o = - UpdateDagEU (NodeListUpdate (NodeState i o)) - | UpdateJobStatus NodeId (NodeEvalStatus o) + UpdateDagEU (NodeListUpdate (CellState i o) (CellUpdate o)) + | UpdateCellState NodeId (CellUpdate o) | UpdateCurJob CurJobStatus | UpdateEnvs [s] | AppendEnv s data EvaluatorMLabel s i o a where NodeListEM :: EvaluatorMLabel s i o [NodeId] - NodeInfo :: NodeId -> EvaluatorMLabel s i o (Maybe (NodeState i o)) + NodeInfo :: NodeId -> EvaluatorMLabel s i o (Maybe (CellState i o)) PrevEnvs :: EvaluatorMLabel s i o [s] CurRunningJob :: EvaluatorMLabel s i o (CurJobStatus) EvalFun :: EvaluatorMLabel s i o (EvalFun s i o) --- The envs after each cell evaluated so far -type EvalFun s i o = s -> i -> IO (o, s) -type CurJobStatus = Maybe (ThreadId, NodeId, CellIndex) +-- `s` is the persistent state (i.e. TopEnvEx the environment) +-- `i` is the type of input cell (e.g. SourceBlock) +-- `o` is the (monoidal) type of updates, e.g. `Result` +type EvalFun s i o = Mailbox o -> s -> i -> IO s +-- It's redundant to have both NodeId and TheadId but it defends against +-- possible GHC reuse of ThreadId (I don't know if that can actually happen) +type JobId = (ThreadId, NodeId) +type CurJobStatus = Maybe (JobId, CellIndex) data EvaluatorState s i o = EvaluatorState { prevEnvs :: [s] , evalFun :: EvalFun s i o , curRunningJob :: CurJobStatus } -data NodeEvalStatus o = - Waiting - | Running - | Complete o - deriving (Show, Generic) +data CellStatus = Waiting | Running | Complete deriving (Show, Generic) -data NodeState i o = NodeState i (NodeEvalStatus o) deriving (Show, Generic) +data CellState i o = CellState i CellStatus o deriving (Show, Generic) +data CellUpdate o = CellUpdate (Overwrite CellStatus) o deriving (Show, Generic) type Show3 s i o = (Show s, Show i, Show o) -type EvalStatus i o = NodeList (NodeState i o) -type EvalStatusUpdate i o = NodeListUpdate (NodeState i o) +type CellsState i o = NodeList (CellState i o) +type CellsUpdate i o = NodeListUpdate (CellState i o) (CellUpdate o) type CellIndex = Int -- index in the list of cells, not the NodeId +data JobUpdate o s = PartialJobUpdate o | JobComplete s deriving (Show) + data EvaluatorMsg s i o = SourceUpdate (DagUpdate i) - | JobComplete ThreadId s o - | Subscribe_E (SubscribeMsg (EvalStatus i o) (EvalStatusUpdate i o)) + | JobUpdate JobId (JobUpdate o s) + | Subscribe_E (SubscribeMsg (CellsState i o) (CellsUpdate i o)) deriving (Show) initEvaluatorState :: s -> EvalFun s i o -> EvaluatorState s i o initEvaluatorState s evalCell = EvaluatorState [s] evalCell Nothing -launchDagEvaluator :: (Show3 s i o, MonadIO m) => CellParser i -> s -> EvalFun s i o -> m (Evaluator i o) +launchDagEvaluator :: (Show3 s i o, Monoid o, MonadIO m) => CellParser i -> s -> EvalFun s i o -> m (Evaluator i o) launchDagEvaluator cellParser env evalCell = do mailbox <- launchActor do let s = initEvaluatorState env evalCell @@ -229,73 +251,82 @@ launchDagEvaluator cellParser env evalCell = do dagEvaluatorImpl cellParser return $ sliceMailbox Subscribe_E mailbox -dagEvaluatorImpl :: (Show3 s i o) => CellParser i -> EvaluatorM s i o () +dagEvaluatorImpl :: (Show3 s i o, Monoid o) => CellParser i -> EvaluatorM s i o () dagEvaluatorImpl cellParser = do initDag <- subscribe SourceUpdate cellParser - processDagUpdate (dagAsUpdate initDag) >> flushDiffs + processDagUpdate (nodeListAsUpdate initDag) >> flushDiffs launchNextJob messageLoop \case Subscribe_E msg -> handleSubscribeMsg msg SourceUpdate dagUpdate -> do processDagUpdate dagUpdate flushDiffs - JobComplete threadId env result -> do - processJobComplete threadId env result + JobUpdate jobId jobUpdate -> do + processJobUpdate jobId jobUpdate flushDiffs -processJobComplete :: (Show3 s i o) => ThreadId -> s -> o -> EvaluatorM s i o () -processJobComplete threadId newEnv result = do +processJobUpdate :: (Show3 s i o, Monoid o) => JobId -> JobUpdate o s -> EvaluatorM s i o () +processJobUpdate jobId jobUpdate = do getl CurRunningJob >>= \case - Just (expectedThreadId, nodeId, _) -> do - when (threadId == expectedThreadId) do -- otherwise it's a zombie - update $ UpdateJobStatus nodeId (Complete result) - update $ UpdateCurJob Nothing - update $ AppendEnv newEnv - launchNextJob + Just (jobId', _) -> when (jobId == jobId') do + let nodeId = snd jobId + case jobUpdate of + JobComplete newEnv -> do + update $ UpdateCellState nodeId $ CellUpdate (OverwriteWith Complete) mempty + update $ UpdateCurJob Nothing + update $ AppendEnv newEnv + launchNextJob + flushDiffs + PartialJobUpdate result -> update $ UpdateCellState nodeId $ CellUpdate NoChange result Nothing -> return () -- this job is a zombie -nextJobIndex :: EvaluatorM s i o Int -nextJobIndex = do +nextCellIndex :: Monoid o => EvaluatorM s i o Int +nextCellIndex = do envs <- getl PrevEnvs return $ length envs - 1 -launchNextJob :: (Show3 s i o) => EvaluatorM s i o () +launchNextJob :: (Show3 s i o, Monoid o) => EvaluatorM s i o () launchNextJob = do - jobIndex <- nextJobIndex + cellIndex <- nextCellIndex nodeList <- getl NodeListEM - when (jobIndex < length nodeList) do -- otherwise we're all done - curEnv <- (!! jobIndex) <$> getl PrevEnvs - let nodeId = nodeList !! jobIndex - launchJob jobIndex nodeId curEnv + when (cellIndex < length nodeList) do -- otherwise we're all done + curEnv <- (!! cellIndex) <$> getl PrevEnvs + let nodeId = nodeList !! cellIndex + launchJob cellIndex nodeId curEnv -launchJob :: (Show3 s i o) => CellIndex -> NodeId -> s -> EvaluatorM s i o () -launchJob jobIndex nodeId env = do +launchJob :: (Show3 s i o, Monoid o) => CellIndex -> NodeId -> s -> EvaluatorM s i o () +launchJob cellIndex nodeId env = do jobAction <- getl EvalFun - NodeState source _ <- fromJust <$> getl (NodeInfo nodeId) - resultMailbox <- selfMailbox id + CellState source _ _ <- fromJust <$> getl (NodeInfo nodeId) + mailbox <- selfMailbox id + update $ UpdateCellState nodeId $ CellUpdate (OverwriteWith Running) mempty threadId <- liftIO $ forkIO do threadId <- myThreadId - (result, finalEnv) <- jobAction env source - send resultMailbox $ JobComplete threadId finalEnv result - update $ UpdateJobStatus nodeId Running - update $ UpdateCurJob (Just (threadId, nodeId, jobIndex)) - -computeNumValidCells :: DagUpdate i -> EvaluatorM s i o Int -computeNumValidCells dagUpdate = do - let nDropped = numDropped $ orderedNodesUpdate dagUpdate + let jobId = (threadId, nodeId) + let resultsMailbox = sliceMailbox (JobUpdate jobId . PartialJobUpdate) mailbox + finalEnv <- jobAction resultsMailbox env source + send mailbox $ JobUpdate jobId $ JobComplete finalEnv + let jobId = (threadId, nodeId) + update $ UpdateCurJob (Just (jobId, cellIndex)) + +computeNumValidCells :: Monoid o => TailUpdate NodeId -> EvaluatorM s i o Int +computeNumValidCells tailUpdate = do + let nDropped = numDropped tailUpdate nTotal <- length <$> getl NodeListEM return $ nTotal - nDropped -processDagUpdate :: (Show3 s i o) => DagUpdate i -> EvaluatorM s i o () -processDagUpdate dagUpdate = do - nValid <- computeNumValidCells dagUpdate +processDagUpdate :: (Show3 s i o, Monoid o) => DagUpdate i -> EvaluatorM s i o () +processDagUpdate (NodeListUpdate tailUpdate mapUpdate) = do + nValid <- computeNumValidCells tailUpdate envs <- getl PrevEnvs update $ UpdateEnvs $ take (nValid + 1) envs - update $ UpdateDagEU $ fmap (\i -> NodeState i Waiting) dagUpdate + update $ UpdateDagEU $ NodeListUpdate tailUpdate $ mapUpdateMapWithKey mapUpdate + (\_ (Unchanging i) -> CellState i Waiting mempty) + (\_ () -> mempty) getl CurRunningJob >>= \case Nothing -> launchNextJob - Just (threadId, _, jobIndex) - | (jobIndex >= nValid) -> do + Just ((threadId, _), cellIndex) + | (cellIndex >= nValid) -> do -- Current job is no longer valid. Kill it and restart. liftIO $ killThread threadId update $ UpdateCurJob Nothing @@ -304,15 +335,18 @@ processDagUpdate dagUpdate = do -- === instances === -instance (ToJSON i, ToJSON o) => ToJSON (NodeListUpdate (NodeState i o)) where -instance (ToJSON a, ToJSONKey k) => ToJSON (MapUpdate k a) +instance (ToJSON i, ToJSON o) => ToJSON (NodeListUpdate (CellState i o) o) where +instance (ToJSON s, ToJSON d, ToJSONKey k) => ToJSON (MapUpdate k s d) instance ToJSON a => ToJSON (TailUpdate a) -instance ToJSON a => ToJSON (MapEltUpdate a) -instance ToJSON o => ToJSON (NodeEvalStatus o) +instance (ToJSON s, ToJSON d) => ToJSON (MapEltUpdate s d) instance ToJSON SrcId deriving instance ToJSONKey SrcId instance ToJSON LexemeType -instance (ToJSON i, ToJSON o) => ToJSON (NodeState i o) +instance (ToJSON i, ToJSON o) => ToJSON (CellState i o) +instance (ToJSON i, ToJSON o) => ToJSON (CellsUpdate i o) +instance ToJSON o => ToJSON (CellUpdate o) +instance ToJSON a => ToJSON (Overwrite a) +instance ToJSON CellStatus data SourceBlockJSONData = SourceBlockJSONData { jdLine :: Int diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index 1ab7394ca..e36a472b9 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -52,7 +52,7 @@ resultStream :: ResultsServer -> StreamingBody resultStream resultsServer write flush = do write (fromByteString $ encodeResults ("start"::String)) >> flush (initResult, resultsChan) <- subscribeIO resultsServer - sendUpdate $ dagAsUpdate initResult + sendUpdate $ nodeListAsUpdate initResult forever $ readChan resultsChan >>= sendUpdate where sendUpdate :: ResultsUpdate -> IO () diff --git a/src/lib/Types/Misc.hs b/src/lib/Types/Misc.hs index 71416eead..eb7bfb288 100644 --- a/src/lib/Types/Misc.hs +++ b/src/lib/Types/Misc.hs @@ -34,3 +34,12 @@ data Output = type PassLogger = FilteredLogger PassName [Output] data OptLevel = NoOptimize | Optimize + +instance Semigroup Result where + Result outs err <> Result outs' err' = Result (outs <> outs') err'' + where err'' = case err' of + Success () -> err + Failure _ -> err' + +instance Monoid Result where + mempty = Result mempty (Success ()) diff --git a/static/index.js b/static/index.js index 0f33463e3..1ef4dc708 100644 --- a/static/index.js +++ b/static/index.js @@ -143,6 +143,17 @@ function getHighlightClass(highlightType) { throw new Error("Unrecognized highlight type"); } } +function getStatusClass(status) { + if (status == "Waiting") { + return "waiting-cell"; + } else if (status == "Running") { + return "running-cell"; + } else if (status == "Complete") { + return "complete-cell"; + } else { + throw new Error("Unrecognized status type"); + } +} function spansBetween(l, r) { let spans = [] while (l !== null && !(Object.is(l, r))) { @@ -152,31 +163,32 @@ function spansBetween(l, r) { spans.push(r) return spans } +function setCellStatus(cell, status) { + cell.className = "class" + cell.classList.add(getStatusClass(status)) +} + function setCellContents(cell, contents) { - let source = contents[0]; - let results = contents[1]; + let [source, status, result] = contents; let lineNum = source["jdLine"]; let sourceText = source["jdHTML"]; let lineNumDiv = document.createElement("div"); lineNumDiv.innerHTML = lineNum.toString(); lineNumDiv.className = "line-num"; cell.innerHTML = "" - cell.appendChild(lineNumDiv); + cell.appendChild(lineNumDiv) + setCellStatus(cell, status) cell.innerHTML += sourceText - - tag = results["tag"] - if (tag == "Waiting") { - cell.className = "cell waiting-cell"; - } else if (tag == "Running") { - cell.className = "cell running-cell"; - } else if (tag == "Complete") { - cell.className = "cell complete-cell"; - cell.innerHTML += results["contents"] - } else { - console.error(tag); - } + cell.innerHTML += result renderLaTeX(cell); } +function updateCellContents(cell, contents) { + let [statusUpdate, resultsUpdate] = contents; + if (statusUpdate["tag"] == "OverwriteWith") { + setCellStatus(cell, statusUpdate["contents"])} + if (resultsUpdate !== "") { + cell.innerHTML += resultsUpdate} +} function processUpdate(msg) { let cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; let num_dropped = msg["orderedNodesUpdate"]["numDropped"]; @@ -190,13 +202,13 @@ function processUpdate(msg) { let update = cell_updates[cellId]; let tag = update["tag"] let contents = update["contents"] - if (tag == "Create") { - var cell = document.createElement("div"); + if (tag == "Create" || tag == "Replace") { + let cell = document.createElement("div"); cells[cellId] = cell; setCellContents(cell, contents) } else if (tag == "Update") { - var cell = cells[cellId]; - setCellContents(cell, contents); + let cell = cells[cellId]; + updateCellContents(cell, contents); } else if (tag == "Delete") { delete cells[cellId] } else { @@ -211,7 +223,7 @@ function processUpdate(msg) { Object.keys(cell_updates).forEach(function (cellId) { let update = cell_updates[cellId] let tag = update["tag"] - if (tag == "Create" || tag == "Update") { + if (tag == "Create" || tag == "Replace") { let update = cell_updates[cellId]; let source = update["contents"][0]; focusMap[cellId] = source["jdFocusMap"] From 1c9d613073457b169453318668e8e4d6111e35e8 Mon Sep 17 00:00:00 2001 From: Dougal Date: Thu, 30 Nov 2023 20:57:31 -0500 Subject: [PATCH 29/41] Make hover-info updates even more incremental --- dex.cabal | 1 - src/dex.hs | 1 - src/lib/ConcreteSyntax.hs | 6 +- src/lib/ImpToLLVM.hs | 1 - src/lib/IncState.hs | 9 ++- src/lib/LLVM/CUDA.hs | 2 +- src/lib/LLVM/Compile.hs | 1 - src/lib/Live/Eval.hs | 110 +++++--------------------------- src/lib/Live/Web.hs | 31 +++++---- src/lib/PPrint.hs | 2 +- src/lib/RenderHtml.hs | 120 ++++++++++++++++++++++++++++------- src/lib/Runtime.hs | 2 +- src/lib/SourceIdTraversal.hs | 18 +++--- src/lib/TopLevel.hs | 4 +- src/lib/Types/Misc.hs | 45 ------------- src/lib/Types/Source.hs | 56 ++++++++++++++-- src/lib/Util.hs | 10 +-- static/index.js | 64 +++++++++---------- 18 files changed, 243 insertions(+), 240 deletions(-) delete mode 100644 src/lib/Types/Misc.hs diff --git a/dex.cabal b/dex.cabal index 6ce74e218..cf31f3149 100644 --- a/dex.cabal +++ b/dex.cabal @@ -93,7 +93,6 @@ library , Transpose , Types.Core , Types.Imp - , Types.Misc , Types.Primitives , Types.OpNames , Types.Source diff --git a/src/dex.hs b/src/dex.hs index 2874649e2..7d26b39ff 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -36,7 +36,6 @@ import Live.Web (runWeb) import Core import Types.Core import Types.Imp -import Types.Misc import Types.Source data ErrorHandling = HaltOnErr | ContinueOnErr diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index f0442b361..8695e8d24 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -31,7 +31,6 @@ import Lexing import Types.Core import Types.Source import Types.Primitives -import SourceIdTraversal import qualified Types.OpNames as P import Util @@ -60,7 +59,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty Nothing (Misc $ ImportModule Prelude) +preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -99,8 +98,7 @@ sourceBlock = do b <- sourceBlock' return (level, b) let lexInfo' = lexInfo { lexemeInfo = lexemeInfo lexInfo <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} - let groupTree = getGroupTree b - return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' (Just groupTree) b + return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' b recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') recover e = do diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 396aa980e..c628486e0 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -55,7 +55,6 @@ import PPrint import RawName qualified as R import Types.Core import Types.Imp -import Types.Misc import Types.Primitives import Types.Source import Util (IsBool (..), bindM2, enumerate) diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs index 43f3ef044..b5eb01b24 100644 --- a/src/lib/IncState.hs +++ b/src/lib/IncState.hs @@ -11,6 +11,7 @@ module IncState ( Overwrite (..), TailUpdate (..), Unchanging (..), Overwritable (..), mapUpdateMapWithKey) where +import Data.Aeson (ToJSON, ToJSONKey) import qualified Data.Map.Strict as M import GHC.Generics @@ -104,7 +105,8 @@ instance IncState [a] (TailUpdate a) where applyDiff xs (TailUpdate numDrop ys) = take (length xs - numDrop) xs <> ys -- Trivial diff that works for any type - just replace the old value with a completely new one. -data Overwrite a = NoChange | OverwriteWith a deriving (Show, Generic) +data Overwrite a = NoChange | OverwriteWith a + deriving (Show, Eq, Generic, Functor, Foldable, Traversable) newtype Overwritable a = Overwritable { fromOverwritable :: a } deriving (Show, Eq, Ord) instance Semigroup (Overwrite a) where @@ -126,3 +128,8 @@ newtype Unchanging a = Unchanging { fromUnchanging :: a } deriving (Show, Eq, Or instance IncState (Unchanging a) () where applyDiff s () = s + +instance ToJSON a => ToJSON (Overwrite a) +instance (ToJSON s, ToJSON d, ToJSONKey k) => ToJSON (MapUpdate k s d) +instance ToJSON a => ToJSON (TailUpdate a) +instance (ToJSON s, ToJSON d) => ToJSON (MapEltUpdate s d) diff --git a/src/lib/LLVM/CUDA.hs b/src/lib/LLVM/CUDA.hs index f0a6edca8..646fe59ec 100644 --- a/src/lib/LLVM/CUDA.hs +++ b/src/lib/LLVM/CUDA.hs @@ -40,7 +40,7 @@ import qualified Data.Set as S import LLVM.Compile import Types.Imp -import Types.Misc +import Types.Source data LLVMKernel = LLVMKernel L.Module diff --git a/src/lib/LLVM/Compile.hs b/src/lib/LLVM/Compile.hs index c3d690ad9..f7caa7637 100644 --- a/src/lib/LLVM/Compile.hs +++ b/src/lib/LLVM/Compile.hs @@ -33,7 +33,6 @@ import Control.Monad import Logging import PPrint () import Paths_dex (getDataFileName) -import Types.Misc -- The only reason this module depends on Types.Source is that we pass in the logger, -- in order to optionally print out the IRs. LLVM mutates its IRs in-place, so -- we can't just expose a functional API for each stage without taking a diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 82f4bcee5..00f5beb28 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -5,19 +5,18 @@ -- https://developers.google.com/open-source/licenses/bsd {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -Wno-orphans #-} module Live.Eval ( - watchAndEvalFile, ResultsServer, ResultsUpdate, subscribeIO, nodeListAsUpdate, addSourceBlockIds) where + watchAndEvalFile, EvalServer, EvalUpdate, CellsUpdate, fmapCellsUpdate, + NodeList (..), NodeListUpdate (..), subscribeIO, nodeListAsUpdate) where import Control.Concurrent import Control.Monad import Control.Monad.State.Strict import Control.Monad.Writer.Strict import qualified Data.Map.Strict as M -import Data.Aeson (ToJSON, ToJSONKey, toJSON, Value) +import Data.Aeson (ToJSON) import Data.Functor ((<&>)) -import Data.Foldable (toList) import Data.Maybe (fromJust) import Data.Text (Text) import Prelude hiding (span) @@ -25,30 +24,25 @@ import GHC.Generics import Actor import IncState -import Types.Misc import Types.Source import TopLevel import ConcreteSyntax -import RenderHtml (ToMarkup, pprintHtml) import MonadUtil -import Util (unsnoc) -- === Top-level interface === +type EvalServer = StateServer EvalState EvalUpdate +type EvalState = CellsState SourceBlock Result +type EvalUpdate = CellsUpdate SourceBlock Result + -- `watchAndEvalFile` returns the channel by which a client may -- subscribe by sending a write-only view of its input channel. -watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO ResultsServer +watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO EvalServer watchAndEvalFile fname opts env = do watcher <- launchFileWatcher fname parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source launchDagEvaluator parser env (evalSourceBlockIO' opts) -addSourceBlockIds :: CellsUpdate SourceBlock o -> CellsUpdate SourceBlockWithId o -addSourceBlockIds (NodeListUpdate listUpdate mapUpdate) = NodeListUpdate listUpdate mapUpdate' - where mapUpdate' = mapUpdateMapWithKey mapUpdate - (\k (CellState b s o) -> CellState (SourceBlockWithId k b) s o) - (\_ d -> d) - -- shim to pretend that evalSourceBlockIO streams its results. TODO: make it actually do that. evalSourceBlockIO' :: EvalConfig -> Mailbox Result -> TopStateEx -> SourceBlock -> IO TopStateEx @@ -57,8 +51,11 @@ evalSourceBlockIO' cfg resultChan env block = do send resultChan result return env' -type ResultsServer = Evaluator SourceBlock Result -type ResultsUpdate = CellsUpdate SourceBlock Result +fmapCellsUpdate :: CellsUpdate i o -> (NodeId -> i -> i') -> (NodeId -> o -> o') -> CellsUpdate i' o' +fmapCellsUpdate (NodeListUpdate t m) fi fo = NodeListUpdate t m' where + m' = mapUpdateMapWithKey m + (\k (CellState i s o) -> CellState (fi k i) s (fo k o)) + (\k (CellUpdate s o) -> CellUpdate s (fo k o)) -- === DAG diff state === @@ -335,84 +332,7 @@ processDagUpdate (NodeListUpdate tailUpdate mapUpdate) = do -- === instances === -instance (ToJSON i, ToJSON o) => ToJSON (NodeListUpdate (CellState i o) o) where -instance (ToJSON s, ToJSON d, ToJSONKey k) => ToJSON (MapUpdate k s d) -instance ToJSON a => ToJSON (TailUpdate a) -instance (ToJSON s, ToJSON d) => ToJSON (MapEltUpdate s d) -instance ToJSON SrcId -deriving instance ToJSONKey SrcId -instance ToJSON LexemeType +instance ToJSON CellStatus instance (ToJSON i, ToJSON o) => ToJSON (CellState i o) -instance (ToJSON i, ToJSON o) => ToJSON (CellsUpdate i o) instance ToJSON o => ToJSON (CellUpdate o) -instance ToJSON a => ToJSON (Overwrite a) -instance ToJSON CellStatus - -data SourceBlockJSONData = SourceBlockJSONData - { jdLine :: Int - , jdBlockId :: Int - , jdLexemeList :: [SrcId] - , jdFocusMap :: FocusMap - , jdHighlightMap :: HighlightMap - , jdHoverInfoMap :: HoverInfoMap - , jdHTML :: String } deriving (Generic) - -instance ToJSON SourceBlockJSONData - -instance ToJSON SourceBlockWithId where - toJSON b@(SourceBlockWithId blockId b') = toJSON $ SourceBlockJSONData - { jdLine = sbLine b' - , jdBlockId = blockId - , jdLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b' - , jdFocusMap = computeFocus b' - , jdHighlightMap = computeHighlights b' - , jdHoverInfoMap = computeHoverInfo b' - , jdHTML = pprintHtml b - } -instance ToJSON Result where toJSON = toJSONViaHtml - -toJSONViaHtml :: ToMarkup a => a -> Value -toJSONViaHtml x = toJSON $ pprintHtml x - --- === textual information on hover === - -type HoverInfo = String -newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) - -computeHoverInfo :: SourceBlock -> HoverInfoMap -computeHoverInfo sb = HoverInfoMap $ - M.fromList $ toList (lexemeList (sbLexemeInfo sb)) <&> \srcId -> (srcId, show srcId) - --- === highlighting on hover === --- TODO: put this somewhere else, like RenderHtml or something - -newtype FocusMap = FocusMap (M.Map LexemeId SrcId) deriving (ToJSON, Semigroup, Monoid) -newtype HighlightMap = HighlightMap (M.Map SrcId Highlights) deriving (ToJSON, Semigroup, Monoid) -type Highlights = [(HighlightType, LexemeSpan)] -data HighlightType = HighlightGroup | HighlightLeaf deriving Generic - -instance ToJSON HighlightType - -computeFocus :: SourceBlock -> FocusMap -computeFocus sb = execWriter $ mapM go $ sbGroupTree sb where - go :: GroupTree -> Writer FocusMap () - go t = forM_ (gtChildren t) \child-> do - go child - tell $ FocusMap $ M.singleton (gtSrcId child) (gtSrcId t) - -computeHighlights :: SourceBlock -> HighlightMap -computeHighlights sb = execWriter $ mapM go $ sbGroupTree sb where - go :: GroupTree -> Writer HighlightMap () - go t = do - spans <- forM (gtChildren t) \child -> do - go child - return (getHighlightType (gtSrcId child), gtSpan child) - tell $ HighlightMap $ M.singleton (gtSrcId t) spans - - getHighlightType :: SrcId -> HighlightType - getHighlightType sid = case M.lookup sid (lexemeInfo $ sbLexemeInfo sb) of - Nothing -> HighlightGroup -- not a lexeme - Just (lexemeTy, _) -> case lexemeTy of - Symbol -> HighlightLeaf - Keyword -> HighlightLeaf - _ -> HighlightGroup +instance (ToJSON s, ToJSON d) => ToJSON (NodeListUpdate s d) diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index e36a472b9..372bbf2d0 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -21,7 +21,9 @@ import qualified Data.ByteString as BS -- import Paths_dex (getDataFileName) import Live.Eval +import RenderHtml import TopLevel +import Types.Source runWeb :: FilePath -> EvalConfig -> TopStateEx -> IO () runWeb fname opts env = do @@ -29,7 +31,7 @@ runWeb fname opts env = do putStrLn "Streaming output to http://localhost:8000/" run 8000 $ serveResults resultsChan -serveResults :: ResultsServer -> Application +serveResults :: EvalServer -> Application serveResults resultsSubscribe request respond = do print (pathInfo request) case pathInfo request of @@ -48,18 +50,23 @@ serveResults resultsSubscribe request respond = do -- fname <- getDataFileName dataFname respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing -resultStream :: ResultsServer -> StreamingBody +resultStream :: EvalServer -> StreamingBody resultStream resultsServer write flush = do - write (fromByteString $ encodeResults ("start"::String)) >> flush + sendUpdate ("start"::String) (initResult, resultsChan) <- subscribeIO resultsServer - sendUpdate $ nodeListAsUpdate initResult - forever $ readChan resultsChan >>= sendUpdate + sendUpdate $ renderEvalUpdate $ nodeListAsUpdate initResult + forever do + nextUpdate <- readChan resultsChan + sendUpdate $ renderEvalUpdate nextUpdate where - sendUpdate :: ResultsUpdate -> IO () - sendUpdate update = do - let s = encodeResults $ addSourceBlockIds update - write (fromByteString s) >> flush + sendUpdate :: ToJSON a => a -> IO () + sendUpdate x = write (fromByteString $ encodePacket x) >> flush - encodeResults :: ToJSON a => a -> BS.ByteString - encodeResults = toStrict . wrap . encode - where wrap s = "data:" <> s <> "\n\n" +encodePacket :: ToJSON a => a -> BS.ByteString +encodePacket = toStrict . wrap . encode + where wrap s = "data:" <> s <> "\n\n" + +renderEvalUpdate :: CellsUpdate SourceBlock Result -> CellsUpdate RenderedSourceBlock RenderedResult +renderEvalUpdate cellsUpdate = fmapCellsUpdate cellsUpdate + (\k b -> renderSourceBlock k b) + (\_ r -> renderResult r) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 9add7af52..361393989 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -39,7 +39,6 @@ import Occurrence (Count (Bounded), UsageInfo (..)) import Occurrence qualified as Occ import Types.Core import Types.Imp -import Types.Misc import Types.Primitives import Types.Source import QueryTypePure @@ -484,6 +483,7 @@ prettyDuration d = p (showFFloat (Just 3) (d * mult) "") <+> unit instance Pretty Output where pretty (TextOut s) = pretty s + pretty (SourceInfo _) = "hello" pretty (HtmlOut _) = "" -- pretty (ExportedFun _ _) = "" pretty (BenchResult name compileTime runTime stats) = diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index bc4809a41..92ae92851 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -7,27 +7,121 @@ {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} -module RenderHtml (pprintHtml, progHtml, ToMarkup) where +module RenderHtml ( + progHtml, ToMarkup, renderSourceBlock, renderResult, + RenderedSourceBlock, RenderedResult) where import Text.Blaze.Internal (MarkupM) -import Text.Blaze.Html5 as H hiding (map) +import Text.Blaze.Html5 as H hiding (map, b) import Text.Blaze.Html5.Attributes as At import Text.Blaze.Html.Renderer.String +import Data.Aeson (ToJSON) import qualified Data.Map.Strict as M import Control.Monad.State.Strict +import Control.Monad.Writer.Strict +import Data.Functor ((<&>)) import Data.Maybe (fromJust) import Data.String (fromString) import Data.Text qualified as T import Data.Text.IO qualified as T import CMark (commonmarkToHtml) import System.IO.Unsafe - +import GHC.Generics import Err import Paths_dex (getDataFileName) import PPrint () -import Types.Misc import Types.Source +import Util (unsnoc, foldJusts) + +-- === rendering results === + +-- RenderedResult, RenderedSourceBlock aren't 100% HTML themselves but the idea +-- is that they should be trivially convertable to JSON and sent over to the +-- client which can do the final rendering without much code or runtime work. + +type BlockId = Int +data RenderedSourceBlock = RenderedSourceBlock + { rsbLine :: Int + , rsbBlockId :: BlockId + , rsbLexemeList :: [SrcId] + , rsbHtml :: String } + deriving (Generic) + +data RenderedResult = RenderedResult + { rrHtml :: String + , rrHighlightMap :: HighlightMap + , rrHoverInfoMap :: HoverInfoMap } + deriving (Generic) + +renderResult :: Result -> RenderedResult +renderResult r = RenderedResult + { rrHtml = pprintHtml r + , rrHighlightMap = computeHighlights r + , rrHoverInfoMap = computeHoverInfo r } + +renderSourceBlock :: BlockId -> SourceBlock -> RenderedSourceBlock +renderSourceBlock n b = RenderedSourceBlock + { rsbLine = sbLine b + , rsbBlockId = n + , rsbLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b + , rsbHtml = renderHtml case sbContents b of + Misc (ProseBlock s) -> cdiv "prose-block" $ mdToHtml s + _ -> renderSpans n (sbLexemeInfo b) (sbText b) + } + +instance ToMarkup Result where + toMarkup (Result outs err) = foldMap toMarkup outs <> err' + where err' = case err of + Failure e -> cdiv "err-block" $ toHtml $ pprint e + Success () -> mempty + +instance ToMarkup Output where + toMarkup out = case out of + HtmlOut s -> preEscapedString s + SourceInfo _ -> mempty + _ -> cdiv "result-block" $ toHtml $ pprint out + +instance ToJSON RenderedResult +instance ToJSON RenderedSourceBlock + +-- === textual information on hover === + +type HoverInfo = String +newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) + +computeHoverInfo :: Result -> HoverInfoMap +computeHoverInfo _ = mempty + +-- === highlighting on hover === + +newtype FocusMap = FocusMap (M.Map LexemeId SrcId) deriving (ToJSON, Semigroup, Monoid) +newtype HighlightMap = HighlightMap (M.Map SrcId Highlights) deriving (ToJSON, Semigroup, Monoid) +type Highlights = [(HighlightType, LexemeSpan)] +data HighlightType = HighlightGroup | HighlightLeaf deriving Generic + +instance ToJSON HighlightType + +computeHighlights :: Result -> HighlightMap +computeHighlights result = do + execWriter $ mapM go $ foldJusts (resultOutputs result) \case + SourceInfo (SIGroupTree t) -> Just t + _ -> Nothing + where + go :: GroupTree -> Writer HighlightMap () + go t = do + let children = gtChildren t + let highlights = children <&> \child -> + (getHighlightType (gtIsAtomicLexeme child), gtSpan child) + forM_ children \child-> do + tell $ HighlightMap $ M.singleton (gtSrcId child) highlights + go child + + getHighlightType :: Bool -> HighlightType + getHighlightType True = HighlightLeaf + getHighlightType False = HighlightGroup + +-- ----------------- cssSource :: T.Text cssSource = unsafePerformIO $ @@ -62,30 +156,12 @@ wrapBody blocks = docTypeHtml $ do inner = foldMap (cdiv "cell") blocks jsSource = textValue $ javascriptSource <> "render(RENDER_MODE.STATIC);" -instance ToMarkup Result where - toMarkup (Result outs err) = foldMap toMarkup outs <> err' - where err' = case err of - Failure e -> cdiv "err-block" $ toHtml $ pprint e - Success () -> mempty - -instance ToMarkup Output where - toMarkup out = case out of - HtmlOut s -> preEscapedString s - _ -> cdiv "result-block" $ toHtml $ pprint out - -instance ToMarkup SourceBlockWithId where - toMarkup (SourceBlockWithId blockId block) = case sbContents block of - Misc (ProseBlock s) -> cdiv "prose-block" $ mdToHtml s - _ -> renderSpans blockId (sbLexemeInfo block) (sbText block) - mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s cdiv :: String -> Html -> Html cdiv c inner = H.div inner ! class_ (stringValue c) -type BlockId = Int - renderSpans :: BlockId -> LexemeInfo -> T.Text -> Markup renderSpans blockId lexInfo sourceText = cdiv "code-block" do runTextWalkerT sourceText do diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 1ea5dad66..de00c78c7 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -33,8 +33,8 @@ import PPrint () import CUDA (synchronizeCUDA) import Types.Core hiding (DexDestructor) +import Types.Source hiding (CInt) import Types.Primitives -import Types.Misc -- === One-shot evaluation === diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 1900cafbc..19ca2f8ca 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -13,25 +13,25 @@ import Types.Source import Types.Primitives getGroupTree :: SourceBlock' -> GroupTree -getGroupTree b = mkGroupTree rootSrcId $ runTreeM $ visit b +getGroupTree b = mkGroupTree False rootSrcId $ runTreeM $ visit b type TreeM = Writer [GroupTree] -mkGroupTree :: SrcId -> [GroupTree] -> GroupTree -mkGroupTree sid = \case - [] -> GroupTree sid (sid,sid) [] -- no children - must be a lexeme - subtrees -> GroupTree sid (l,r) subtrees - where l = minimum $ subtrees <&> \(GroupTree _ (l',_) _) -> l' - r = maximum $ subtrees <&> \(GroupTree _ (_,r') _) -> r' +mkGroupTree :: Bool -> SrcId -> [GroupTree] -> GroupTree +mkGroupTree isAtomic sid = \case + [] -> GroupTree sid (sid,sid) [] isAtomic -- no children - must be a lexeme + subtrees -> GroupTree sid (l,r) subtrees isAtomic + where l = minimum $ subtrees <&> (fst . gtSpan) + r = maximum $ subtrees <&> (snd . gtSpan) runTreeM :: TreeM () -> [GroupTree] runTreeM cont = snd $ runWriter $ cont enterNode :: SrcId -> TreeM () -> TreeM () -enterNode sid cont = tell [mkGroupTree sid (runTreeM cont)] +enterNode sid cont = tell [mkGroupTree False sid (runTreeM cont)] emitLexeme :: SrcId -> TreeM () -emitLexeme lexemeId = tell [mkGroupTree lexemeId []] +emitLexeme lexemeId = tell [mkGroupTree True lexemeId []] class IsTree a where visit :: a -> TreeM () diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 89021559b..b13d81293 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -54,6 +54,7 @@ import Err import IRVariants import Imp import ImpToLLVM +import IncState import Inference import Inline import Logging @@ -70,9 +71,9 @@ import Runtime import Serialize (takePtrSnapshot, restorePtrSnapshot) import Simplify import SourceRename +import SourceIdTraversal import Types.Core import Types.Imp -import Types.Misc import Types.Primitives import Types.Source import Util ( Tree (..), measureSeconds, File (..), readFileWithHash) @@ -229,6 +230,7 @@ evalSourceBlock evalSourceBlock mname block = do result <- withCompileTime do (maybeErr, logs) <- catchLogsAndErrs do + logTop $ SourceInfo $ SIGroupTree $ OverwriteWith $ getGroupTree $ sbContents block benchReq <- getBenchRequirement block withPassCtx (PassCtx benchReq (passLogFilter $ sbLogLevel block)) $ evalSourceBlock' mname block diff --git a/src/lib/Types/Misc.hs b/src/lib/Types/Misc.hs deleted file mode 100644 index eb7bfb288..000000000 --- a/src/lib/Types/Misc.hs +++ /dev/null @@ -1,45 +0,0 @@ --- Copyright 2022 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Types.Misc where - -import GHC.Generics (Generic (..)) - -import Err -import Logging -import Types.Source - -type LitProg = [(SourceBlock, Result)] - -data Result = Result - { resultOutputs :: [Output] - , resultErrs :: Except () } - deriving (Show, Eq) - -type BenchStats = (Int, Double) -- number of runs, total benchmarking time -data Output = - TextOut String - | HtmlOut String - | PassInfo PassName String - | EvalTime Double (Maybe BenchStats) - | TotalTime Double - | BenchResult String Double Double (Maybe BenchStats) -- name, compile time, eval time - | MiscLog String - -- Used to have | ExportedFun String Atom - deriving (Show, Eq, Generic) - -type PassLogger = FilteredLogger PassName [Output] - -data OptLevel = NoOptimize | Optimize - -instance Semigroup Result where - Result outs err <> Result outs' err' = Result (outs <> outs') err'' - where err'' = case err' of - Success () -> err - Failure _ -> err' - -instance Monoid Result where - mempty = Result mempty (Success ()) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 3c39d603a..f83f14a18 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -20,6 +20,7 @@ module Types.Source where +import Data.Aeson (ToJSON, ToJSONKey) import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M @@ -31,10 +32,13 @@ import Data.Word import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Err +import Logging import Name import qualified Types.OpNames as P import IRVariants import Util (File (..), SnocList) +import IncState import Types.Primitives @@ -85,14 +89,55 @@ type LexemeSpan = (LexemeId, LexemeId) data GroupTree = GroupTree { gtSrcId :: SrcId , gtSpan :: LexemeSpan - , gtChildren :: [GroupTree] } - deriving (Show, Generic) + , gtChildren :: [GroupTree] + , gtIsAtomicLexeme :: Bool } + deriving (Show, Eq, Generic) instance Semigroup LexemeInfo where LexemeInfo a b <> LexemeInfo a' b' = LexemeInfo (a <> a') (b <> b') instance Monoid LexemeInfo where mempty = LexemeInfo mempty mempty +-- === Results === + +type LitProg = [(SourceBlock, Result)] + +data Result = Result + { resultOutputs :: [Output] + , resultErrs :: Except () } + deriving (Show, Eq) + +type BenchStats = (Int, Double) -- number of runs, total benchmarking time + +data SourceInfo = + SIGroupTree (Overwrite GroupTree) + deriving (Show, Eq, Generic) + +data Output = + TextOut String + | HtmlOut String + | SourceInfo SourceInfo -- for hovertips etc + | PassInfo PassName String + | EvalTime Double (Maybe BenchStats) + | TotalTime Double + | BenchResult String Double Double (Maybe BenchStats) -- name, compile time, eval time + | MiscLog String + -- Used to have | ExportedFun String Atom + deriving (Show, Eq, Generic) + +type PassLogger = FilteredLogger PassName [Output] + +data OptLevel = NoOptimize | Optimize + +instance Semigroup Result where + Result outs err <> Result outs' err' = Result (outs <> outs') err'' + where err'' = case err' of + Success () -> err + Failure _ -> err' + +instance Monoid Result where + mempty = Result mempty (Success ()) + -- === Concrete syntax === -- The grouping-level syntax of the source language @@ -512,15 +557,12 @@ data UModule = UModule -- === top-level blocks === -data SourceBlockWithId = SourceBlockWithId Int SourceBlock - data SourceBlock = SourceBlock { sbLine :: Int , sbOffset :: Int , sbLogLevel :: LogLevel , sbText :: Text , sbLexemeInfo :: LexemeInfo - , sbGroupTree :: Maybe GroupTree , sbContents :: SourceBlock' } deriving (Show, Generic) @@ -817,3 +859,7 @@ deriving instance Ord (UEffect n) deriving instance Show (UEffectRow n) deriving instance Eq (UEffectRow n) deriving instance Ord (UEffectRow n) + +instance ToJSON SrcId +deriving instance ToJSONKey SrcId +instance ToJSON LexemeType diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 4c257f95d..8a44e7234 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -19,7 +19,7 @@ import GHC.Base (getTag) import GHC.Exts ((==#), tagToEnum#) import Crypto.Hash import Data.Functor.Identity (Identity(..)) -import Data.Maybe (catMaybes) +import Data.Maybe (catMaybes, mapMaybe) import Data.List (sort) import Data.Hashable (Hashable) import Data.Store (Store) @@ -134,12 +134,8 @@ mapFst f zs = [(f x, y) | (x, y) <- zs] mapSnd :: (a -> b) -> [(c, a)] -> [(c, b)] mapSnd f zs = [(x, f y) | (x, y) <- zs] -mapMaybe :: (a -> Maybe b) -> [a] -> [b] -mapMaybe _ [] = [] -mapMaybe f (x:xs) = let rest = mapMaybe f xs - in case f x of - Just y -> y : rest - Nothing -> rest +foldJusts :: Monoid b => [a] -> (a -> Maybe b) -> b +foldJusts xs f = fold $ mapMaybe f xs forMFilter :: Monad m => [a] -> (a -> m (Maybe b)) -> m [b] forMFilter xs f = catMaybes <$> mapM f xs diff --git a/static/index.js b/static/index.js index 1ef4dc708..5a0e0693f 100644 --- a/static/index.js +++ b/static/index.js @@ -76,9 +76,8 @@ function applyHoverInfo(cellId, srcId) { hoverInfoDiv.innerHTML = srcId.toString() // hoverInfo } function applyHoverHighlights(cellId, srcId) { - let focus = lookupSrcMap(focusMap, cellId, srcId) - if (focus == null) return - let highlights = lookupSrcMap(highlightMap, cellId, focus) + let highlights = lookupSrcMap(highlightMap, cellId, srcId) + if (highlights == null) return highlights.map(function (highlight) { let [highlightType, [l, r]] = highlight let spans = spansBetween(selectSpan(cellId, l), selectSpan(cellId, r)); @@ -126,8 +125,6 @@ function render(renderMode) { } else { processUpdate(msg)}};} } - - function selectSpan(cellId, srcId) { return cells[cellId].querySelector("#span_".concat(cellId, "_", srcId)) } @@ -158,8 +155,7 @@ function spansBetween(l, r) { let spans = [] while (l !== null && !(Object.is(l, r))) { spans.push(l); - l = l.nextSibling; - } + l = l.nextSibling;} spans.push(r) return spans } @@ -168,10 +164,10 @@ function setCellStatus(cell, status) { cell.classList.add(getStatusClass(status)) } -function setCellContents(cell, contents) { +function setCellContents(cellId, cell, contents) { let [source, status, result] = contents; - let lineNum = source["jdLine"]; - let sourceText = source["jdHTML"]; + let lineNum = source["rsbLine"]; + let sourceText = source["rsbHtml"]; let lineNumDiv = document.createElement("div"); lineNumDiv.innerHTML = lineNum.toString(); lineNumDiv.className = "line-num"; @@ -179,36 +175,43 @@ function setCellContents(cell, contents) { cell.appendChild(lineNumDiv) setCellStatus(cell, status) cell.innerHTML += sourceText - cell.innerHTML += result - renderLaTeX(cell); + renderLaTeX(cell) + extendCellResult(cellId, cell, result) +} +function extendCellResult(cellId, cell, result) { + let resultText = result["rrHtml"] + if (resultText !== "") { + cell.innerHTML += resultText + } + highlightMap[cellId] = result["rrHighlightMap"] + hoverInfoMap[cellId] = result["rrHoverInfoMap"] } -function updateCellContents(cell, contents) { - let [statusUpdate, resultsUpdate] = contents; +function updateCellContents(cellId, cell, contents) { + let [statusUpdate, result] = contents; if (statusUpdate["tag"] == "OverwriteWith") { setCellStatus(cell, statusUpdate["contents"])} - if (resultsUpdate !== "") { - cell.innerHTML += resultsUpdate} + extendCellResult(cellId, cell, result) } function processUpdate(msg) { - let cell_updates = msg["nodeMapUpdate"]["mapUpdates"]; - let num_dropped = msg["orderedNodesUpdate"]["numDropped"]; - let new_tail = msg["orderedNodesUpdate"]["newTail"]; + let cellUpdates = msg["nodeMapUpdate"]["mapUpdates"]; + let numDropped = msg["orderedNodesUpdate"]["numDropped"]; + let newTail = msg["orderedNodesUpdate"]["newTail"]; // drop_dead_cells - for (i = 0; i < num_dropped; i++) { + for (i = 0; i < numDropped; i++) { body.lastElementChild.remove();} - Object.keys(cell_updates).forEach(function (cellId) { - let update = cell_updates[cellId]; + Object.keys(cellUpdates).forEach(function (cellId) { + let update = cellUpdates[cellId]; let tag = update["tag"] let contents = update["contents"] if (tag == "Create" || tag == "Replace") { let cell = document.createElement("div"); cells[cellId] = cell; - setCellContents(cell, contents) + setCellContents(cellId, cell, contents) } else if (tag == "Update") { let cell = cells[cellId]; - updateCellContents(cell, contents); + updateCellContents(cellId, cell, contents); } else if (tag == "Delete") { delete cells[cellId] } else { @@ -216,19 +219,16 @@ function processUpdate(msg) { }}); // append_new_cells - new_tail.forEach(function (cellId) { + newTail.forEach(function (cellId) { let cell = selectCell(cellId); body.appendChild(cell);}) - Object.keys(cell_updates).forEach(function (cellId) { - let update = cell_updates[cellId] + Object.keys(cellUpdates).forEach(function (cellId) { + let update = cellUpdates[cellId] let tag = update["tag"] if (tag == "Create" || tag == "Replace") { - let update = cell_updates[cellId]; + let update = cellUpdates[cellId]; let source = update["contents"][0]; - focusMap[cellId] = source["jdFocusMap"] - highlightMap[cellId] = source["jdHighlightMap"] - hoverInfoMap[cellId] = source["jsHoverInfoMap"] - let lexemeList = source["jdLexemeList"]; + let lexemeList = source["rsbLexemeList"]; lexemeList.map(function (lexemeId) {attachHovertip(cellId, lexemeId.toString())})}}); } From 9624d29d47086ab23d8a87dc2e0388d2bc642734 Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 1 Dec 2023 09:03:13 -0500 Subject: [PATCH 30/41] Add placeholder for types on hover. Need to add the logging logic next. --- src/lib/RenderHtml.hs | 6 +++++- src/lib/Types/Source.hs | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 92ae92851..04e35a1bd 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -91,7 +91,11 @@ type HoverInfo = String newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) computeHoverInfo :: Result -> HoverInfoMap -computeHoverInfo _ = mempty +computeHoverInfo result = do + let typeInfo = foldJusts (resultOutputs result) \case + SourceInfo (SITypeInfo m) -> Just m + _ -> Nothing + HoverInfoMap $ fromTypeInfo typeInfo -- === highlighting on hover === diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index f83f14a18..6b432474e 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -98,6 +98,11 @@ instance Semigroup LexemeInfo where instance Monoid LexemeInfo where mempty = LexemeInfo mempty mempty +-- === Type Info === + +newtype TypeInfo = TypeInfo { fromTypeInfo :: M.Map SrcId String } + deriving (Semigroup, Monoid, ToJSON, Show, Eq) + -- === Results === type LitProg = [(SourceBlock, Result)] @@ -111,6 +116,7 @@ type BenchStats = (Int, Double) -- number of runs, total benchmarking time data SourceInfo = SIGroupTree (Overwrite GroupTree) + | SITypeInfo TypeInfo deriving (Show, Eq, Generic) data Output = From 7042a479b33d26557de5a61be8bac8604d4e451b Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 1 Dec 2023 16:29:21 -0500 Subject: [PATCH 31/41] Simplify logging and allow realtime updates from runtime prints Regarding simplification: the previous system is really complicated with lots of logging options that could be specified at runtime on a block-by-block basis. It was nice to have but I decided it wasn't worth the implementation complexity. Now we just have two log levels - ordinary logging (outputs, errors) and debug logging (dumps from passes). --- dex.cabal | 1 - src/dex.hs | 85 ++++---------- src/lib/ConcreteSyntax.hs | 41 +------ src/lib/ImpToLLVM.hs | 5 +- src/lib/LLVM/Compile.hs | 12 +- src/lib/Live/Eval.hs | 19 ++- src/lib/Live/Web.hs | 4 +- src/lib/Logging.hs | 84 ------------- src/lib/MonadUtil.hs | 47 +++++++- src/lib/PPrint.hs | 73 +++--------- src/lib/RenderHtml.hs | 36 +++--- src/lib/Runtime.hs | 52 ++------ src/lib/TopLevel.hs | 241 ++++++++++++-------------------------- src/lib/Types/Source.hs | 19 +-- 14 files changed, 216 insertions(+), 503 deletions(-) delete mode 100644 src/lib/Logging.hs diff --git a/dex.cabal b/dex.cabal index cf31f3149..74a4eacf9 100644 --- a/dex.cabal +++ b/dex.cabal @@ -71,7 +71,6 @@ library , LLVM.Shims , Lexing , Linearize - , Logging , Lower , MonadUtil , MTL1 diff --git a/src/dex.hs b/src/dex.hs index 7d26b39ff..6d88c86db 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -22,23 +22,21 @@ import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Map.Strict as M -import PPrint (resultAsJSON, printResult) +import PPrint (printOutput) import TopLevel import Err import Name import AbstractSyntax (parseTopDeclRepl) import ConcreteSyntax (keyWordStrs, preludeImportBlock) -#ifdef DEX_LIVE import RenderHtml -- import Live.Terminal (runTerminal) import Live.Web (runWeb) -#endif import Core import Types.Core import Types.Imp import Types.Source +import MonadUtil -data ErrorHandling = HaltOnErr | ContinueOnErr data DocFmt = ResultOnly | TextDoc | JSONDoc @@ -57,35 +55,25 @@ data EvalMode = ReplMode String data CmdOpts = CmdOpts EvalMode EvalConfig runMode :: EvalMode -> EvalConfig -> IO () -runMode evalMode opts = case evalMode of +runMode evalMode cfg = case evalMode of ScriptMode fname fmt onErr -> do env <- loadCache - (litProg, finalEnv) <- runTopperM opts env do + ((), finalEnv) <- runTopperM cfg env do source <- liftIO $ T.decodeUtf8 <$> BS.readFile fname - evalSourceText source (printIncrementalSource fmt) \result@(Result _ errs) -> do - printIncrementalResult fmt result - return case (onErr, errs) of (HaltOnErr, Failure _) -> False; _ -> True - printFinal fmt litProg + evalSourceText source $ printIncrementalSource fmt storeCache finalEnv ReplMode prompt -> do env <- loadCache - void $ runTopperM opts env do - evalBlockRepl preludeImportBlock + void $ runTopperM cfg env do + evalSourceBlockRepl preludeImportBlock forever do block <- readSourceBlock prompt - evalBlockRepl block - where - evalBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n () - evalBlockRepl block = do - result <- evalSourceBlockRepl block - case result of - Result [] (Success ()) -> return () - _ -> liftIO $ putStrLn $ pprint result + evalSourceBlockRepl block ClearCache -> clearCache #ifdef DEX_LIVE WebMode fname -> do env <- loadCache - runWeb fname opts env + runWeb fname cfg env WatchMode _ -> error "not implemented" #endif @@ -98,26 +86,6 @@ printIncrementalSource fmt sb = case fmt of HTMLDoc -> return () #endif -printIncrementalResult :: DocFmt -> Result -> IO () -printIncrementalResult fmt result = case fmt of - ResultOnly -> case pprint result of [] -> return (); msg -> putStrLn msg - JSONDoc -> case resultAsJSON result of "{}" -> return (); s -> putStrLn s - TextDoc -> do - isatty <- queryTerminal stdOutput - putStr $ printResult isatty result -#ifdef DEX_LIVE - HTMLDoc -> return () -#endif - -printFinal :: DocFmt -> [(SourceBlock, Result)] -> IO () -printFinal fmt prog = case fmt of - ResultOnly -> return () - TextDoc -> return () - JSONDoc -> return () -#ifdef DEX_LIVE - HTMLDoc -> undefined -- putStr $ progHtml prog -#endif - readSourceBlock :: (MonadIO (m n), EnvReader m) => String -> m n SourceBlock readSourceBlock prompt = do sourceMap <- withEnv $ envSourceMap . moduleEnv @@ -205,24 +173,25 @@ parseEvalOpts = EvalConfig <*> (option pathOption $ long "lib-path" <> value [LibBuiltinPath] <> metavar "PATH" <> help "Library path") <*> optional (strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") - <*> optional (strOption $ long "logto" - <> metavar "FILE" - <> help "File to log to" <> showDefault) - <*> pure Nothing <*> flag NoOptimize Optimize (short 'O' <> help "Optimize generated code") <*> enumOption "print" "Print backend" PrintCodegen printBackends + <*> flag ContinueOnErr HaltOnErr ( + long "stop-on-error" + <> help "Stop program evaluation when an error occurs (type or runtime)") + <*> enumOption "loglevel" "Log level" NormalLogLevel logLevels + <*> pure stdOutLogger where printBackends = [ ("haskell", PrintHaskell) , ("dex" , PrintCodegen) ] - backends = [ ("llvm", LLVM) - , ("llvm-mc", LLVMMC) -#ifdef DEX_CUDA - , ("llvm-cuda", LLVMCUDA) -#endif -#if DEX_LLVM_VERSION == HEAD - , ("mlir", MLIR) -#endif - , ("interpreter", Interpreter)] + backends = [ ("llvm" , LLVM ) + , ("llvm-mc", LLVMMC) ] + logLevels = [ ("normal", NormalLogLevel) + , ("debug" , DebugLogLevel ) ] + +stdOutLogger :: Outputs -> IO () +stdOutLogger (Outputs outs) = do + isatty <- queryTerminal stdOutput + forM_ outs \out -> putStr $ printOutput isatty out pathOption :: ReadM [LibPath] pathOption = splitPaths [] <$> str @@ -237,13 +206,7 @@ pathOption = splitPaths [] <$> str "BUILTIN_LIBRARIES" -> LibBuiltinPath path -> LibDirectory path -openLogFile :: EvalConfig -> IO EvalConfig -openLogFile EvalConfig {..} = do - logFile <- forM logFileName (`openFile` WriteMode) - return $ EvalConfig {..} - main :: IO () main = do CmdOpts evalMode opts <- execParser parseOpts - opts' <- openLogFile opts - runMode evalMode opts' + runMode evalMode opts diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 8695e8d24..838089e7c 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -59,7 +59,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" mempty (Misc $ ImportModule Prelude) +preludeImportBlock = SourceBlock 0 0 "" mempty (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -93,21 +93,18 @@ sourceBlock :: Parser SourceBlock sourceBlock = do offset <- getOffset pos <- getSourcePos - (src, (lexInfo, (level, b))) <- withSource $ withLexemeInfo $ withRecovery recover do - level <- logLevel <|> logTime <|> logBench <|> return LogNothing - b <- sourceBlock' - return (level, b) + (src, (lexInfo, b)) <- withSource $ withLexemeInfo $ withRecovery recover $ sourceBlock' let lexInfo' = lexInfo { lexemeInfo = lexemeInfo lexInfo <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} - return $ SourceBlock (unPos (sourceLine pos)) offset level src lexInfo' b + return $ SourceBlock (unPos (sourceLine pos)) offset src lexInfo' b -recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') +recover :: ParseError Text Void -> Parser SourceBlock' recover e = do pos <- liftM statePosState getParserState reachedEOF <- try (mayBreak sc >> eof >> return True) <|> return False consumeTillBreak let errmsg = errorBundlePretty (ParseErrorBundle (e :| []) pos) - return (LogNothing, UnParseable reachedEOF errmsg) + return $ UnParseable reachedEOF errmsg importModule :: Parser SourceBlock' importModule = Misc . ImportModule . OrdinaryModule <$> do @@ -138,34 +135,6 @@ declareCustomLinearization = do consumeTillBreak :: Parser () consumeTillBreak = void $ manyTill anySingle $ eof <|> void (try (eol >> eol)) -logLevel :: Parser LogLevel -logLevel = do - void $ try $ lexeme MiscLexeme $ char '%' >> string "passes" - passes <- many passName - eol - case passes of - [] -> return LogAll - _ -> return $ LogPasses passes - -logTime :: Parser LogLevel -logTime = do - void $ try $ lexeme MiscLexeme $ char '%' >> string "time" - eol - return PrintEvalTime - -logBench :: Parser LogLevel -logBench = do - void $ try $ lexeme MiscLexeme $ char '%' >> string "bench" - WithSrc _ benchName <- strLit - eol - return $ PrintBench benchName - -passName :: Parser PassName -passName = choice [thisNameString s $> x | (s, x) <- passNames] - -passNames :: [(Text, PassName)] -passNames = [(T.pack $ show x, x) | x <- [minBound..maxBound]] - sourceBlock' :: Parser SourceBlock' sourceBlock' = proseBlock diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index c628486e0..556f927bd 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -48,7 +48,6 @@ import Core import Err import Imp import LLVM.CUDA (LLVMKernel (..), compileCUDAKernel, ptxDataLayout, ptxTargetTriple) -import Logging import Subst import Name import PPrint @@ -109,7 +108,7 @@ instance Compiler CompileM -- === Imp to LLVM === impToLLVM :: (EnvReader m, MonadIO1 m) - => FilteredLogger PassName [Output] -> NameHint + => PassLogger -> NameHint -> ClosedImpFunction n -> m n (WithCNameInterface L.Module) impToLLVM logger fNameHint (ClosedImpFunction funBinders ptrBinders impFun) = do @@ -185,7 +184,7 @@ impToLLVM logger fNameHint (ClosedImpFunction funBinders ptrBinders impFun) = do compileFunction :: (EnvReader m, MonadIO1 m) - => FilteredLogger PassName [Output] -> L.Name + => PassLogger -> L.Name -> OperandEnv i o -> ImpFunction i -> m o ([L.Definition], S.Set ExternFunSpec, [L.Name]) compileFunction logger fName env fun@(ImpFunction (IFunType cc argTys retTys) diff --git a/src/lib/LLVM/Compile.hs b/src/lib/LLVM/Compile.hs index f7caa7637..b4248bbf3 100644 --- a/src/lib/LLVM/Compile.hs +++ b/src/lib/LLVM/Compile.hs @@ -30,14 +30,10 @@ import System.IO.Unsafe import Control.Monad -import Logging import PPrint () import Paths_dex (getDataFileName) --- The only reason this module depends on Types.Source is that we pass in the logger, --- in order to optionally print out the IRs. LLVM mutates its IRs in-place, so --- we can't just expose a functional API for each stage without taking a --- performance hit. But maybe the performance hit isn't so bad? import Types.Source +import MonadUtil data LLVMOptLevel = OptALittle -- -O1 @@ -109,7 +105,11 @@ standardCompilationPipeline opt logger exports tm m = do {-# SCC showAssembly #-} logPass AsmPass $ showAsm tm m where logPass :: PassName -> IO String -> IO () - logPass passName cont = logFiltered logger passName $ cont >>= \s -> return [PassInfo passName s] + logPass passName showIt = case ioLogLevel logger of + DebugLogLevel -> do + s <- showIt + ioLogAction logger $ Outputs [PassInfo passName s] + NormalLogLevel -> return () {-# SCC standardCompilationPipeline #-} internalize :: [String] -> Mod.Module -> IO () diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 00f5beb28..89fcd3e3b 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -32,8 +32,8 @@ import MonadUtil -- === Top-level interface === type EvalServer = StateServer EvalState EvalUpdate -type EvalState = CellsState SourceBlock Result -type EvalUpdate = CellsUpdate SourceBlock Result +type EvalState = CellsState SourceBlock Outputs +type EvalUpdate = CellsUpdate SourceBlock Outputs -- `watchAndEvalFile` returns the channel by which a client may -- subscribe by sending a write-only view of its input channel. @@ -41,15 +41,12 @@ watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO EvalServer watchAndEvalFile fname opts env = do watcher <- launchFileWatcher fname parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source - launchDagEvaluator parser env (evalSourceBlockIO' opts) - --- shim to pretend that evalSourceBlockIO streams its results. TODO: make it actually do that. -evalSourceBlockIO' - :: EvalConfig -> Mailbox Result -> TopStateEx -> SourceBlock -> IO TopStateEx -evalSourceBlockIO' cfg resultChan env block = do - (result, env') <- evalSourceBlockIO cfg env block - send resultChan result - return env' + launchDagEvaluator parser env (sourceBlockEvalFun opts) + +sourceBlockEvalFun :: EvalConfig -> Mailbox Outputs -> TopStateEx -> SourceBlock -> IO TopStateEx +sourceBlockEvalFun cfg resultChan env block = do + let cfg' = cfg { cfgLogAction = send resultChan } + evalSourceBlockIO cfg' env block fmapCellsUpdate :: CellsUpdate i o -> (NodeId -> i -> i') -> (NodeId -> o -> o') -> CellsUpdate i' o' fmapCellsUpdate (NodeListUpdate t m) fi fo = NodeListUpdate t m' where diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index 372bbf2d0..4e23d805c 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -66,7 +66,7 @@ encodePacket :: ToJSON a => a -> BS.ByteString encodePacket = toStrict . wrap . encode where wrap s = "data:" <> s <> "\n\n" -renderEvalUpdate :: CellsUpdate SourceBlock Result -> CellsUpdate RenderedSourceBlock RenderedResult +renderEvalUpdate :: CellsUpdate SourceBlock Outputs -> CellsUpdate RenderedSourceBlock RenderedOutputs renderEvalUpdate cellsUpdate = fmapCellsUpdate cellsUpdate (\k b -> renderSourceBlock k b) - (\_ r -> renderResult r) + (\_ r -> renderOutputs r) diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs deleted file mode 100644 index 1c3f0eef1..000000000 --- a/src/lib/Logging.hs +++ /dev/null @@ -1,84 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} - -module Logging (Logger, LoggerT (..), MonadLogger (..), logIO, runLoggerT, - FilteredLogger (..), logFiltered, logSkippingFilter, - MonadLogger1, MonadLogger2, - runLogger, execLogger, logThis, readLog, ) where - -import Control.Monad -import Control.Monad.Reader -import Data.Text.Prettyprint.Doc -import Control.Concurrent.MVar -import Prelude hiding (log) -import System.IO - -import Err -import Name - -data Logger l = Logger (MVar l) (Maybe Handle) - -data FilteredLogger k l = FilteredLogger (k -> Bool) (Logger l) - -runLogger :: (Monoid l, MonadIO m) => Maybe Handle -> (Logger l -> m a) -> m (a, l) -runLogger logFile m = do - log <- liftIO $ newMVar mempty - ans <- m $ Logger log logFile - logged <- liftIO $ readMVar log - return (ans, logged) - -execLogger :: (Monoid l, MonadIO m) => Maybe Handle -> (Logger l -> m a) -> m a -execLogger logFile m = fst <$> runLogger logFile m - -logThis :: (Pretty l, Monoid l, MonadIO m) => Logger l -> l -> m () -logThis (Logger log maybeLogHandle) x = liftIO $ do - forM_ maybeLogHandle \h -> do - hPutStrLn h $ pprint x - hFlush h - modifyMVar_ log \cur -> return (cur <> x) - -logFiltered :: (Monoid l, MonadIO m, Pretty l) => FilteredLogger k l -> k -> m l -> m () -logFiltered (FilteredLogger shouldLog logger) k m = - when (shouldLog k) $ m >>= logThis logger - -logSkippingFilter :: (Monoid l, MonadIO m, Pretty l) => FilteredLogger k l -> l -> m () -logSkippingFilter (FilteredLogger _ logger) = logThis logger - -readLog :: MonadIO m => Logger l -> m l -readLog (Logger log _) = liftIO $ readMVar log - --- === monadic interface === - -newtype LoggerT l m a = LoggerT { runLoggerT' :: ReaderT (Logger l) m a } - deriving (Functor, Applicative, Monad, MonadTrans, - MonadIO, MonadFail, Fallible, Catchable) - -class (Pretty l, Monoid l, Monad m) => MonadLogger l m | m -> l where - getLogger :: m (Logger l) - withLogger :: Logger l -> m a -> m a - -instance (MonadIO m, Pretty l, Monoid l) => MonadLogger l (LoggerT l m) where - getLogger = LoggerT ask - withLogger l m = LoggerT $ local (const l) $ runLoggerT' m - -type MonadLogger1 l (m :: MonadKind1) = forall (n::S) . MonadLogger l (m n) -type MonadLogger2 l (m :: MonadKind2) = forall (n1::S) (n2::S) . MonadLogger l (m n1 n2) - -logIO :: MonadIO m => MonadLogger l m => l -> m () -logIO val = do - logger <- getLogger - liftIO $ logThis logger val - -runLoggerT :: Monoid l => Logger l -> LoggerT l m a -> m a -runLoggerT l (LoggerT m) = runReaderT m l - --- === more instances === - -instance MonadLogger l m => MonadLogger l (ReaderT r m) where - getLogger = lift getLogger - withLogger l cont = ReaderT \r -> withLogger l $ runReaderT cont r diff --git a/src/lib/MonadUtil.hs b/src/lib/MonadUtil.hs index 6d75e2377..17a21bd95 100644 --- a/src/lib/MonadUtil.hs +++ b/src/lib/MonadUtil.hs @@ -8,10 +8,14 @@ module MonadUtil ( DefuncState (..), LabelReader (..), SingletonLabel (..), FreshNames (..), - runFreshNameT, FreshNameT (..)) where + runFreshNameT, FreshNameT (..), Logger (..), LogLevel (..), getIOLogger, + IOLoggerT (..), runIOLoggerT, LoggerT (..), runLoggerT, IOLogger (..), HasIOLogger (..)) where import Control.Monad.Reader import Control.Monad.State.Strict +import Control.Monad.Writer.Strict + +import Err -- === Defunctionalized state === -- Interface for state whose allowable updates are specified by a data type. @@ -49,3 +53,44 @@ instance FreshNames a m => FreshNames a (ReaderT r m) where runFreshNameT :: MonadIO m => FreshNameT m a -> m a runFreshNameT cont = evalStateT (runFreshNameT' cont) 0 + +-- === Logging monad === + +data IOLogger w = IOLogger { ioLogLevel :: LogLevel + , ioLogAction :: w -> IO () } +data LogLevel = NormalLogLevel | DebugLogLevel + +class (Monoid w, Monad m) => Logger w m | m -> w where + emitLog :: w -> m () + getLogLevel :: m LogLevel + +newtype IOLoggerT w m a = IOLoggerT { runIOLoggerT' :: ReaderT (IOLogger w) m a } + deriving (Functor, Applicative, Monad, MonadIO, Fallible, MonadFail, Catchable) + +class Monad m => HasIOLogger w m | m -> w where + getIOLogAction :: Monad m => m (w -> IO ()) + +instance (Monoid w, MonadIO m) => HasIOLogger w (IOLoggerT w m) where + getIOLogAction = IOLoggerT $ asks ioLogAction + +instance (Monoid w, MonadIO m) => Logger w (IOLoggerT w m) where + emitLog w = do + logger <- getIOLogAction + liftIO $ logger w + getLogLevel = IOLoggerT $ asks ioLogLevel + +getIOLogger :: (HasIOLogger w m, Logger w m) => m (IOLogger w) +getIOLogger = IOLogger <$> getLogLevel <*> getIOLogAction + +runIOLoggerT :: (Monoid w, MonadIO m) => LogLevel -> (w -> IO ()) -> IOLoggerT w m a -> m a +runIOLoggerT logLevel write cont = runReaderT (runIOLoggerT' cont) (IOLogger logLevel write) + +newtype LoggerT w m a = LoggerT { runLoggerT' :: WriterT w m a } + deriving (Functor, Applicative, Monad, MonadIO) + +instance (Monoid w, Monad m) => Logger w (LoggerT w m) where + emitLog w = LoggerT $ tell w + getLogLevel = return NormalLogLevel + +runLoggerT :: (Monoid w, Monad m) => LoggerT w m a -> m (a, w) +runLoggerT cont = runWriterT (runLoggerT' cont) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 361393989..0344bd861 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -9,16 +9,13 @@ {-# OPTIONS_GHC -Wno-orphans #-} module PPrint ( - pprint, pprintCanonicalized, pprintList, asStr , atPrec, resultAsJSON, - PrettyPrec(..), PrecedenceLevel (..), prettyBlock, printLitBlock, - printResult, prettyFromPrettyPrec) where + pprint, pprintCanonicalized, pprintList, asStr , atPrec, + PrettyPrec(..), PrecedenceLevel (..), prettyBlock, + printOutput, prettyFromPrettyPrec) where -import Data.Aeson hiding (Result, Null, Value, Success) -import Data.Aeson.Encoding (encodingToLazyByteString, value) import GHC.Exts (Constraint) import GHC.Float import Data.Foldable (toList, fold) -import qualified Data.ByteString.Lazy.Char8 as B import qualified Data.Map.Strict as M import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc @@ -474,39 +471,20 @@ instance Pretty SourceBlock where Just (_, '\n') -> t _ -> t `snoc` '\n' -prettyDuration :: Double -> Doc ann -prettyDuration d = p (showFFloat (Just 3) (d * mult) "") <+> unit - where (mult, unit) = if d >= 1 then (1 , "s") - else if d >= 1e-3 then (1e3, "ms") - else if d >= 1e-6 then (1e6, "us") - else (1e9, "ns") - instance Pretty Output where - pretty (TextOut s) = pretty s - pretty (SourceInfo _) = "hello" - pretty (HtmlOut _) = "" - -- pretty (ExportedFun _ _) = "" - pretty (BenchResult name compileTime runTime stats) = - benchName <> hardline <> - "Compile time: " <> prettyDuration compileTime <> hardline <> - "Run time: " <> prettyDuration runTime <+> - (case stats of - Just (runs, _) -> - "\t" <> parens ("based on" <+> p runs <+> plural "run" "runs" runs) - Nothing -> "") - where benchName = case name of "" -> "" - _ -> "\n" <> p name - pretty (PassInfo _ s) = p s - pretty (EvalTime t _) = "Eval (s): " <+> p t - pretty (TotalTime t) = "Total (s): " <+> p t <+> " (eval + compile)" - pretty (MiscLog s) = p s - + pretty = \case + TextOut s -> pretty s + HtmlOut _ -> "" + SourceInfo _ -> "" + PassInfo _ s -> p s + MiscLog s -> p s + Error e -> p e instance Pretty PassName where pretty x = p $ show x instance Pretty Result where - pretty (Result outs r) = vcat (map pretty outs) <> maybeErr + pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr where maybeErr = case r of Failure err -> p err Success () -> mempty @@ -992,17 +970,10 @@ instance Pretty RWS where Writer -> "Accum" State -> "State" -printLitBlock :: Pretty block => Bool -> block -> Result -> String -printLitBlock isatty block result = pprint block ++ printResult isatty result - -printResult :: Bool -> Result -> String -printResult isatty (Result outs errs) = - concat (map printOutput outs) ++ case errs of - Success () -> "" - Failure err -> addColor isatty Red $ addPrefix ">" $ pprint err - where - printOutput :: Output -> String - printOutput out = addPrefix (addColor isatty Cyan ">") $ pprint $ out +printOutput :: Bool -> Output -> String +printOutput isatty out = case out of + Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out + _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out addPrefix :: String -> String -> String addPrefix prefix str = unlines $ map prefixLine $ lines str @@ -1016,20 +987,6 @@ addColor True c s = setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] ++ s ++ setSGRCode [Reset] -resultAsJSON :: Result -> String -resultAsJSON (Result outs err) = - B.unpack $ encodingToLazyByteString $ value $ object (outMaps <> errMaps) - where - errMaps = case err of - Failure e -> ["error" .= String (fromString $ pprint e)] - Success () -> [] - outMaps = flip foldMap outs $ \case - BenchResult name compileTime runTime _ -> - [ "bench_name" .= toJSON name - , "compile_time" .= toJSON compileTime - , "run_time" .= toJSON runTime ] - out -> ["result" .= String (fromString $ pprint out)] - -- === Concrete syntax rendering === instance Pretty SourceBlock' where diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 04e35a1bd..3cbc25c86 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -8,8 +8,8 @@ {-# OPTIONS_GHC -Wno-incomplete-patterns #-} module RenderHtml ( - progHtml, ToMarkup, renderSourceBlock, renderResult, - RenderedSourceBlock, RenderedResult) where + progHtml, ToMarkup, renderSourceBlock, renderOutputs, + RenderedSourceBlock, RenderedOutputs) where import Text.Blaze.Internal (MarkupM) import Text.Blaze.Html5 as H hiding (map, b) @@ -36,7 +36,7 @@ import Util (unsnoc, foldJusts) -- === rendering results === --- RenderedResult, RenderedSourceBlock aren't 100% HTML themselves but the idea +-- RenderedOutputs, RenderedSourceBlock aren't 100% HTML themselves but the idea -- is that they should be trivially convertable to JSON and sent over to the -- client which can do the final rendering without much code or runtime work. @@ -48,14 +48,14 @@ data RenderedSourceBlock = RenderedSourceBlock , rsbHtml :: String } deriving (Generic) -data RenderedResult = RenderedResult +data RenderedOutputs = RenderedOutputs { rrHtml :: String , rrHighlightMap :: HighlightMap , rrHoverInfoMap :: HoverInfoMap } deriving (Generic) -renderResult :: Result -> RenderedResult -renderResult r = RenderedResult +renderOutputs :: Outputs -> RenderedOutputs +renderOutputs r = RenderedOutputs { rrHtml = pprintHtml r , rrHighlightMap = computeHighlights r , rrHoverInfoMap = computeHoverInfo r } @@ -70,19 +70,17 @@ renderSourceBlock n b = RenderedSourceBlock _ -> renderSpans n (sbLexemeInfo b) (sbText b) } -instance ToMarkup Result where - toMarkup (Result outs err) = foldMap toMarkup outs <> err' - where err' = case err of - Failure e -> cdiv "err-block" $ toHtml $ pprint e - Success () -> mempty +instance ToMarkup Outputs where + toMarkup (Outputs outs) = foldMap toMarkup outs instance ToMarkup Output where toMarkup out = case out of HtmlOut s -> preEscapedString s SourceInfo _ -> mempty + Error _ -> cdiv "err-block" $ toHtml $ pprint out _ -> cdiv "result-block" $ toHtml $ pprint out -instance ToJSON RenderedResult +instance ToJSON RenderedOutputs instance ToJSON RenderedSourceBlock -- === textual information on hover === @@ -90,9 +88,9 @@ instance ToJSON RenderedSourceBlock type HoverInfo = String newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) -computeHoverInfo :: Result -> HoverInfoMap -computeHoverInfo result = do - let typeInfo = foldJusts (resultOutputs result) \case +computeHoverInfo :: Outputs -> HoverInfoMap +computeHoverInfo (Outputs outputs) = do + let typeInfo = foldJusts outputs \case SourceInfo (SITypeInfo m) -> Just m _ -> Nothing HoverInfoMap $ fromTypeInfo typeInfo @@ -106,9 +104,9 @@ data HighlightType = HighlightGroup | HighlightLeaf deriving Generic instance ToJSON HighlightType -computeHighlights :: Result -> HighlightMap -computeHighlights result = do - execWriter $ mapM go $ foldJusts (resultOutputs result) \case +computeHighlights :: Outputs -> HighlightMap +computeHighlights (Outputs outputs) = do + execWriter $ mapM go $ foldJusts outputs \case SourceInfo (SIGroupTree t) -> Just t _ -> Nothing where @@ -142,7 +140,7 @@ pprintHtml x = renderHtml $ toMarkup x progHtml :: (ToMarkup a, ToMarkup b) => [(a, b)] -> String progHtml blocks = renderHtml $ wrapBody $ map toHtmlBlock blocks - where toHtmlBlock (block,result) = toMarkup block <> toMarkup result + where toHtmlBlock (block,outputs) = toMarkup block <> toMarkup outputs wrapBody :: [Html] -> Html wrapBody blocks = docTypeHtml $ do diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index de00c78c7..1bac0c11c 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -24,13 +24,10 @@ import Control.Monad import Control.Concurrent import Control.Exception hiding (throw) import qualified Control.Exception as E -import qualified System.Environment as E import Err -import Logging -import Util (measureSeconds) +import MonadUtil import PPrint () -import CUDA (synchronizeCUDA) import Types.Core hiding (DexDestructor) import Types.Source hiding (CInt) @@ -55,75 +52,44 @@ data BenchRequirement = data LLVMCallable = LLVMCallable { nativeFun :: NativeFunction - , benchRequired :: BenchRequirement , logger :: PassLogger - , resultTypes :: [BaseType] - } + , resultTypes :: [BaseType] } -- The NativeFunction needs to have been compiled with EntryFunCC. callEntryFun :: LLVMCallable -> [LitVal] -> IO [LitVal] -callEntryFun LLVMCallable{nativeFun, benchRequired, logger, resultTypes} args = do +callEntryFun LLVMCallable{nativeFun, logger, resultTypes} args = do withPipeToLogger logger \fd -> allocaCells (length args) \argsPtr -> allocaCells (length resultTypes) \resultPtr -> do storeLitVals argsPtr args let fPtr = castFunPtr $ nativeFunPtr nativeFun - evalTime <- checkedCallFunPtr fd argsPtr resultPtr fPtr + checkedCallFunPtr fd argsPtr resultPtr fPtr results <- loadLitVals resultPtr resultTypes - case benchRequired of - NoBench -> logSkippingFilter logger [EvalTime evalTime Nothing] - DoBench shouldSyncCUDA -> do - let sync = when shouldSyncCUDA $ synchronizeCUDA - (avgTime, benchRuns, totalTime) <- runBench do - let (CInt fd') = fdFD fd - exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throw RuntimeErr "" - freeLitVals resultPtr resultTypes - sync - logSkippingFilter logger [EvalTime avgTime (Just (benchRuns, totalTime + evalTime))] return results {-# SCC callEntryFun #-} -checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO Double +checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO () checkedCallFunPtr fd argsPtr resultPtr fPtr = do let (CInt fd') = fdFD fd - (exitCode, duration) <- measureSeconds $ do - exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - return exitCode + exitCode <- callFunPtr fPtr fd' argsPtr resultPtr unless (exitCode == 0) $ throw RuntimeErr "" - return duration withPipeToLogger :: PassLogger -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do result <- snd <$> withPipe - (\h -> readStream h \s -> logSkippingFilter logger [TextOut s]) + (\h -> readStream h \s -> ioLogAction logger $ Outputs [TextOut s]) (\h -> handleToFd h >>= writeAction) case result of Left e -> E.throw e Right ans -> return ans -runBench :: IO () -> IO (Double, Int, Double) -runBench run = do - exampleDuration <- snd <$> measureSeconds run - test_mode <- (Just "t" ==) <$> E.lookupEnv "DEX_TEST_MODE" - let timeBudget = (2 - exampleDuration) `max` 0 -- seconds - let benchRuns = if test_mode - then 0 - else (ceiling $ timeBudget / exampleDuration) :: Int - totalTime' <- liftM snd $ measureSeconds $ do - forM_ [1..benchRuns] $ const run - let totalTime = totalTime' + exampleDuration - avgTime = totalTime / (fromIntegral $ benchRuns + 1) - - return (avgTime, benchRuns + 1, totalTime) - -- === serializing scalars === loadLitVals :: MonadIO m => Ptr () -> [BaseType] -> m [LitVal] loadLitVals p types = zipWithM loadLitVal (ptrArray p) types -freeLitVals :: MonadIO m => Ptr () -> [BaseType] -> m () -freeLitVals p types = zipWithM_ freeLitVal (ptrArray p) types +_freeLitVals :: MonadIO m => Ptr () -> [BaseType] -> m () +_freeLitVals p types = zipWithM_ freeLitVal (ptrArray p) types storeLitVals :: MonadIO m => Ptr () -> [LitVal] -> m () storeLitVals p xs = zipWithM_ storeLitVal (ptrArray p) xs diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index b13d81293..78dafb62b 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -8,12 +8,12 @@ module TopLevel ( EvalConfig (..), Topper, TopperM, runTopperM, - evalSourceBlock, evalSourceBlockRepl, OptLevel (..), + evalSourceBlockRepl, OptLevel (..), evalSourceText, TopStateEx (..), LibPath (..), evalSourceBlockIO, initTopState, loadCache, storeCache, clearCache, ensureModuleLoaded, importModule, printCodegen, loadObject, toCFunction, packageLLVMCallable, - simpOptimizations, loweredOptimizations, compileTopLevelFun) where + simpOptimizations, loweredOptimizations, compileTopLevelFun, ErrorHandling (..)) where import Data.Functor import Data.Maybe (catMaybes) @@ -26,7 +26,6 @@ import Data.Text (Text) import Data.Text.Prettyprint.Doc import Data.Store (encode, decode) import Data.String (fromString) -import Data.List (partition) import qualified Data.Map.Strict as M import qualified Data.Set as S import Foreign.Ptr @@ -34,7 +33,7 @@ import Foreign.C.String import GHC.Generics (Generic (..)) import System.FilePath import System.Directory -import System.IO (stderr, hPutStrLn, Handle) +import System.IO (stderr, hPutStrLn) import System.IO.Error (isDoesNotExistError) import LLVM.Link @@ -57,8 +56,8 @@ import ImpToLLVM import IncState import Inference import Inline -import Logging import Lower +import MonadUtil import MTL1 import Subst import Name @@ -76,60 +75,49 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source -import Util ( Tree (..), measureSeconds, File (..), readFileWithHash) +import Util ( Tree (..), File (..), readFileWithHash) import Vectorize -- === top-level monad === data LibPath = LibDirectory FilePath | LibBuiltinPath +data ErrorHandling = HaltOnErr | ContinueOnErr data EvalConfig = EvalConfig { backendName :: Backend , libPaths :: [LibPath] , preludeFile :: Maybe FilePath - , logFileName :: Maybe FilePath - , logFile :: Maybe Handle , optLevel :: OptLevel - , printBackend :: PrintBackend } + , printBackend :: PrintBackend + , errorHandling :: ErrorHandling + , cfgLogLevel :: LogLevel + , cfgLogAction :: Outputs -> IO ()} class Monad m => ConfigReader m where getConfig :: m EvalConfig -data PassCtx = PassCtx - { requiresBench :: BenchRequirement - , shouldLogPass :: PassName -> Bool - } - -initPassCtx :: PassCtx -initPassCtx = PassCtx NoBench (const True) - -class Monad m => PassCtxReader m where - getPassCtx :: m PassCtx - withPassCtx :: PassCtx -> m a -> m a - class Monad m => RuntimeEnvReader m where getRuntimeEnv :: m RuntimeEnv -type TopLogger m = (MonadIO m, MonadLogger [Output] m) +type TopLogger m = (MonadIO m, Logger Outputs m) class ( forall n. Fallible (m n) - , forall n. MonadLogger [Output] (m n) + , forall n. Logger Outputs (m n) + , forall n. HasIOLogger Outputs (m n) , forall n. Catchable (m n) , forall n. ConfigReader (m n) - , forall n. PassCtxReader (m n) , forall n. RuntimeEnvReader (m n) , forall n. MonadIO (m n) -- TODO: something more restricted here , TopBuilder m ) => Topper m data TopperReaderData = TopperReaderData - { topperPassCtx :: PassCtx - , topperEvalConfig :: EvalConfig + { topperEvalConfig :: EvalConfig , topperRuntimeEnv :: RuntimeEnv } newtype TopperM (n::S) a = TopperM { runTopperM' - :: TopBuilderT (ReaderT TopperReaderData (LoggerT [Output] IO)) n a } + :: TopBuilderT (ReaderT TopperReaderData IO) n a } deriving ( Functor, Applicative, Monad, MonadIO, MonadFail , Fallible, EnvReader, ScopeReader, Catchable) @@ -147,9 +135,8 @@ runTopperM -> (forall n. Mut n => TopperM n a) -> IO (a, TopStateEx) runTopperM opts (TopStateEx env rtEnv) cont = do - let maybeLogFile = logFile opts - (Abs frag (LiftE result), _) <- runLogger maybeLogFile \l -> runLoggerT l $ - flip runReaderT (TopperReaderData initPassCtx opts rtEnv) $ + Abs frag (LiftE result) <- + flip runReaderT (TopperReaderData opts rtEnv) $ runTopBuilderT env $ runTopperM' do localTopBuilder $ LiftE <$> cont return (result, extendTopEnv env rtEnv frag) @@ -172,45 +159,42 @@ allocateDynamicVarKeyPtrs = do -- ====== evalSourceBlockIO - :: EvalConfig -> TopStateEx -> SourceBlock -> IO (Result, TopStateEx) + :: EvalConfig -> TopStateEx -> SourceBlock -> IO TopStateEx evalSourceBlockIO opts env block = - runTopperM opts env $ evalSourceBlockRepl block + liftM snd $ runTopperM opts env $ evalSourceBlockRepl block -- Used for the top-level source file (rather than imported modules) -evalSourceText - :: (Topper m, Mut n) - => Text -> (SourceBlock -> IO ()) -> (Result -> IO Bool) - -> m n [(SourceBlock, Result)] -evalSourceText source beginCallback endCallback = do - let (UModule mname deps sbs) = parseUModule Main source +evalSourceText :: (Topper m, Mut n) => Text -> (SourceBlock -> IO ()) -> m n () +evalSourceText source logSourceBlock = do + let UModule mname deps sbs = parseUModule Main source mapM_ ensureModuleLoaded deps evalSourceBlocks mname sbs where evalSourceBlocks mname = \case - [] -> return [] - (sb:rest) -> do - liftIO $ beginCallback sb - result <- evalSourceBlock mname sb - liftIO (endCallback result) >>= \case - False -> return [(sb, result)] - True -> ((sb, result):) <$> evalSourceBlocks mname rest - -catchLogsAndErrs :: (Topper m, Mut n) => m n a -> m n (Except a, [Output]) -catchLogsAndErrs m = do - maybeLogFile <- logFile <$> getConfig - runLogger maybeLogFile \l -> withLogger l $ - catchErrExcept m + [] -> return () + sb:rest -> do + liftIO $ logSourceBlock sb + evalSourceBlock mname sb >>= \case + Success () -> return () + Failure e -> do + logTop $ Error e + (errorHandling <$> getConfig) >>= \case + HaltOnErr -> return () + ContinueOnErr -> evalSourceBlocks mname rest -- Module imports have to be handled differently in the repl because we don't -- know ahead of time which modules will be needed. -evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n Result +evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n () evalSourceBlockRepl block = do case sbContents block of Misc (ImportModule name) -> do -- TODO: clear source map and synth candidates before calling this ensureModuleLoaded name _ -> return () - evalSourceBlock Main block + maybeErr <- evalSourceBlock Main block + case maybeErr of + Success () -> return () + Failure e -> logTop $ Error e -- XXX: This ensures that a module and its transitive dependencies are loaded, -- (which will require evaluating them if they're not in the cache) but it @@ -226,24 +210,18 @@ ensureModuleLoaded moduleSourceName = do {-# SCC ensureModuleLoaded #-} evalSourceBlock - :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n Result + :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n (Except ()) evalSourceBlock mname block = do - result <- withCompileTime do - (maybeErr, logs) <- catchLogsAndErrs do - logTop $ SourceInfo $ SIGroupTree $ OverwriteWith $ getGroupTree $ sbContents block - benchReq <- getBenchRequirement block - withPassCtx (PassCtx benchReq (passLogFilter $ sbLogLevel block)) $ - evalSourceBlock' mname block - return $ Result logs maybeErr - case resultErrs result of - Failure _ -> case sbContents block of - TopDecl decl -> do - case parseDecl decl of - Success decl' -> emitSourceMap $ uDeclErrSourceMap mname decl' - Failure _ -> return () - _ -> return () + maybeErr <- catchErrExcept do + logTop $ SourceInfo $ SIGroupTree $ OverwriteWith $ getGroupTree $ sbContents block + evalSourceBlock' mname block + case (maybeErr, sbContents block) of + (Failure _, TopDecl decl) -> do + case parseDecl decl of + Success decl' -> emitSourceMap $ uDeclErrSourceMap mname decl' + Failure _ -> return () _ -> return () - return $ filterLogs block result + return maybeErr evalSourceBlock' :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n () @@ -367,12 +345,6 @@ runEnvQuery query = do return $ pprint val ++ "\n(type constructor and data constructor share the same name)" logTop $ TextOut $ "Binding:\n" ++ info -filterLogs :: SourceBlock -> Result -> Result -filterLogs block (Result outs err) = let - (logOuts, requiredOuts) = partition isLogInfo outs - outs' = requiredOuts ++ processLogs (sbLogLevel block) logOuts - in Result outs' err - -- returns a toposorted list of the module's transitive dependencies (including -- the module itself) excluding those provided in the set of already known -- modules. @@ -470,58 +442,15 @@ importModule name = do emitLocalModuleEnv $ mempty { envImportStatus = importStatus } {-# SCC importModule #-} -passLogFilter :: LogLevel -> PassName -> Bool -passLogFilter = \case - LogAll -> const True - LogNothing -> const False - LogPasses passes -> (`elem` passes) - PrintEvalTime -> const False - PrintBench _ -> const False - -processLogs :: LogLevel -> [Output] -> [Output] -processLogs logLevel logs = case logLevel of - LogAll -> logs - LogNothing -> [] - LogPasses passes -> flip filter logs \case - PassInfo pass _ | pass `elem` passes -> True - | otherwise -> False - _ -> False - PrintEvalTime -> [BenchResult "" compileTime runTime benchStats] - where (compileTime, runTime, benchStats) = timesFromLogs logs - PrintBench benchName -> [BenchResult benchName compileTime runTime benchStats] - where (compileTime, runTime, benchStats) = timesFromLogs logs - -timesFromLogs :: [Output] -> (Double, Double, Maybe BenchStats) -timesFromLogs logs = (totalTime - totalEvalTime, singleEvalTime, benchStats) - where - (totalEvalTime, singleEvalTime, benchStats) = - case [(t, stats) | EvalTime t stats <- logs] of - [] -> (0.0 , 0.0, Nothing) - [(t, stats)] -> (total, t , stats) - where total = maybe t snd stats - _ -> error "Expect at most one result" - totalTime = case [tTotal | TotalTime tTotal <- logs] of - [] -> 0.0 - [t] -> t - _ -> error "Expect at most one result" - -isLogInfo :: Output -> Bool -isLogInfo out = case out of - PassInfo _ _ -> True - MiscLog _ -> True - EvalTime _ _ -> True - TotalTime _ -> True - _ -> False - evalUType :: (Topper m, Mut n) => UType VoidS -> m n (CType n) evalUType ty = do - logTop $ PassInfo Parse $ pprint ty + logDebug $ return $ PassInfo Parse $ pprint ty renamed <- logPass RenamePass $ renameSourceNamesUExpr ty checkPass TypePass $ checkTopUType renamed evalUExpr :: (Topper m, Mut n) => UExpr VoidS -> m n (CAtom n) evalUExpr expr = do - logTop $ PassInfo Parse $ pprint expr + logDebug $ return $ PassInfo Parse $ pprint expr renamed <- logPass RenamePass $ renameSourceNamesUExpr expr typed <- checkPass TypePass $ inferTopUExpr renamed evalBlock typed @@ -564,8 +493,7 @@ loweredOptimizations lowered = do (dceTop >=> hoistLoopInvariant) whenOpt lopt \lo -> do (vo, errs) <- vectorizeLoops 64 lo - l <- getFilteredLogger - logFiltered l VectPass $ return [TextOut $ pprint errs] + logTop $ TextOut $ pprint errs checkPass VectPass $ return vo loweredOptimizationsNoDest :: Topper m => STopLam n -> m n (STopLam n) @@ -607,7 +535,7 @@ evalDictSpecializations ds = do execUDecl :: (Topper m, Mut n) => ModuleSourceName -> UTopDecl VoidS VoidS -> m n () execUDecl mname decl = do - logTop $ PassInfo Parse $ pprint decl + logDebug $ return $ PassInfo Parse $ pprint decl Abs renamedDecl sourceMap <- logPass RenamePass $ renameSourceNamesTopUDecl mname decl inferenceResult <- checkPass TypePass $ inferTopUDecl renamedDecl sourceMap @@ -679,7 +607,7 @@ linkFunObjCode objCode dyvarStores (LinktimeVals funVals ptrVals) = do toCFunction :: (Topper m, Mut n) => NameHint -> ImpFunction n -> m n (CFunction n) toCFunction nameHint impFun = do - logger <- getFilteredLogger + logger <- getIOLogger (closedImpFun, reqFuns, reqPtrNames) <- abstractLinktimeObjects impFun obj <- impToLLVM logger nameHint closedImpFun >>= compileToObjCode reqObjNames <- mapM funNameToObj reqFuns @@ -699,14 +627,13 @@ packageLLVMCallable :: forall n m. (Topper m, Mut n) => ImpFunction n -> m n LLVMCallable packageLLVMCallable impFun = do nativeFun <- toCFunction "main" impFun >>= loadObjectContent - benchRequired <- requiresBench <$> getPassCtx - logger <- getFilteredLogger + logger <- getIOLogger let IFunType _ _ resultTypes = impFunType impFun return LLVMCallable{..} compileToObjCode :: Topper m => WithCNameInterface LLVM.AST.Module -> m n FunObjCode compileToObjCode astWithNames = forM astWithNames \ast -> do - logger <- getFilteredLogger + logger <- getIOLogger opt <- getLLVMOptLevel <$> getConfig liftIO $ compileLLVM logger opt ast (cniMainFunName astWithNames) @@ -717,11 +644,6 @@ funNameToObj v = do TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ topFunObjCode impl b -> error $ "couldn't find object cache entry for " ++ pprint v ++ "\ngot:\n" ++ pprint b -withCompileTime :: MonadIO m => m Result -> m Result -withCompileTime m = do - (Result outs err, t) <- measureSeconds m - return $ Result (outs ++ [TotalTime t]) err - checkPass :: (Topper m, Pretty (e n), CheckableE r e) => PassName -> m n (e n) -> m n (e n) checkPass name cont = do @@ -729,24 +651,30 @@ checkPass name cont = do result <- cont return result #ifdef DEX_DEBUG - logTop $ MiscLog $ "Running checks" + logDebug $ return $ MiscLog $ "Running checks" checkTypes result - logTop $ MiscLog $ "Checks passed" + logDebug $ return $ MiscLog $ "Checks passed" #else - logTop $ MiscLog $ "Checks skipped (not a debug build)" + logDebug $ return $ MiscLog $ "Checks skipped (not a debug build)" #endif return result logTop :: TopLogger m => Output -> m () -logTop x = logIO [x] +logTop x = emitLog $ Outputs [x] + +logDebug :: TopLogger m => m Output -> m () +logDebug m = getLogLevel >>= \case + NormalLogLevel -> return () + DebugLogLevel -> do + x <- m + emitLog $ Outputs [x] logPass :: Topper m => Pretty a => PassName -> m n a -> m n a logPass passName cont = do - logTop $ PassInfo passName $ "=== " <> pprint passName <> " ===" - logTop $ MiscLog $ "Starting "++ pprint passName + logDebug $ return $ PassInfo passName $ "=== " <> pprint passName <> " ===" + logDebug $ return $ MiscLog $ "Starting "++ pprint passName result <- cont - {-# SCC logPassPrinting #-} logTop $ PassInfo passName - $ "=== Result ===\n" <> pprint result + logDebug $ return $ PassInfo passName $ "=== Result ===\n" <> pprint result return result loadModuleSource @@ -777,15 +705,6 @@ loadModuleSource config moduleName = do LibDirectory dir -> return dir {-# SCC loadModuleSource #-} -getBenchRequirement :: Topper m => SourceBlock -> m n BenchRequirement -getBenchRequirement block = case sbLogLevel block of - PrintBench _ -> do - backend <- backendName <$> getConfig - let needsSync = case backend of LLVMCUDA -> True - _ -> False - return $ DoBench needsSync - _ -> return NoBench - getDexString :: (MonadIO1 m, EnvReader m, Fallible1 m) => Val CoreIR n -> m n String getDexString val = do -- TODO: use a `ByteString` instead of `String` @@ -868,20 +787,8 @@ restorePtrSnapshots s = traverseBindingsTopStateEx s \case PtrBinding ty p -> liftIO $ PtrBinding ty <$> restorePtrSnapshot p b -> return b -getFilteredLogger :: Topper m => m n PassLogger -getFilteredLogger = do - shouldLog <- shouldLogPass <$> getPassCtx - logger <- getLogger - return $ FilteredLogger shouldLog logger - -- === instances === -instance PassCtxReader (TopperM n) where - getPassCtx = TopperM $ asks topperPassCtx - withPassCtx ctx cont = TopperM $ - liftTopBuilderTWith (local \r -> r {topperPassCtx = ctx}) $ - runTopperM' cont - instance RuntimeEnvReader (TopperM n) where getRuntimeEnv = TopperM $ asks topperRuntimeEnv @@ -904,10 +811,14 @@ instance TopBuilder TopperM where emitNamelessEnv env = TopperM $ emitNamelessEnv env localTopBuilder cont = TopperM $ localTopBuilder $ runTopperM' cont -instance MonadLogger [Output] (TopperM n) where - getLogger = TopperM $ lift1 $ lift $ getLogger - withLogger l cont = - TopperM $ liftTopBuilderTWith (withLogger l) (runTopperM' cont) +instance Logger Outputs (TopperM n) where + emitLog x = do + logger <- getIOLogAction + liftIO $ logger x + getLogLevel = cfgLogLevel <$> getConfig + +instance HasIOLogger Outputs (TopperM n) where + getIOLogAction = cfgLogAction <$> getConfig instance Generic TopStateEx where type Rep TopStateEx = Rep (Env UnsafeS, RuntimeEnv) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 6b432474e..4249b9a92 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -33,10 +33,10 @@ import GHC.Generics (Generic (..)) import Data.Store (Store (..)) import Err -import Logging import Name import qualified Types.OpNames as P import IRVariants +import MonadUtil import Util (File (..), SnocList) import IncState @@ -108,7 +108,7 @@ newtype TypeInfo = TypeInfo { fromTypeInfo :: M.Map SrcId String } type LitProg = [(SourceBlock, Result)] data Result = Result - { resultOutputs :: [Output] + { resultOutputs :: Outputs , resultErrs :: Except () } deriving (Show, Eq) @@ -124,14 +124,13 @@ data Output = | HtmlOut String | SourceInfo SourceInfo -- for hovertips etc | PassInfo PassName String - | EvalTime Double (Maybe BenchStats) - | TotalTime Double - | BenchResult String Double Double (Maybe BenchStats) -- name, compile time, eval time | MiscLog String - -- Used to have | ExportedFun String Atom + | Error Err deriving (Show, Eq, Generic) +newtype Outputs = Outputs { fromOutputs :: [Output] } + deriving (Show, Eq, Generic, Semigroup, Monoid) -type PassLogger = FilteredLogger PassName [Output] +type PassLogger = IOLogger Outputs data OptLevel = NoOptimize | Optimize @@ -566,7 +565,6 @@ data UModule = UModule data SourceBlock = SourceBlock { sbLine :: Int , sbOffset :: Int - , sbLogLevel :: LogLevel , sbText :: Text , sbLexemeInfo :: LexemeInfo , sbContents :: SourceBlock' } @@ -598,10 +596,6 @@ data SourceBlockMisc data CmdName = GetType | EvalExpr OutFormat | ExportFun String deriving (Show, Generic) -data LogLevel = LogNothing | PrintEvalTime | PrintBench String - | LogPasses [PassName] | LogAll - deriving (Show, Generic) - data PrintBackend = PrintCodegen -- Soon-to-be default path based on `PrintAny` | PrintHaskell -- Backup path for debugging in case the codegen path breaks. @@ -836,7 +830,6 @@ instance Ord SourceBlock where compare x y = compare (sbText x) (sbText y) instance Store SymbolicZeros -instance Store LogLevel instance Store PassName instance Store ModuleSourceName instance Store (UVar n) From 81007ebfb6f42a7b5a6015601ea0065b4c404ed7 Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 1 Dec 2023 22:11:39 -0500 Subject: [PATCH 32/41] Add an ExceptT monad transformer --- dex.cabal | 1 - src/lib/Cat.hs | 178 ------------------------------------------------- src/lib/Err.hs | 60 ++++++++++++++++- 3 files changed, 59 insertions(+), 180 deletions(-) delete mode 100644 src/lib/Cat.hs diff --git a/dex.cabal b/dex.cabal index 74a4eacf9..faa4e372e 100644 --- a/dex.cabal +++ b/dex.cabal @@ -48,7 +48,6 @@ library , Algebra , Builder , CUDA - , Cat , CheapReduction , CheckType , ConcreteSyntax diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs deleted file mode 100644 index 7c1e49132..000000000 --- a/src/lib/Cat.hs +++ /dev/null @@ -1,178 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} - -module Cat (CatT, MonadCat, runCatT, look, extend, scoped, looks, extendLocal, - extendR, captureW, asFst, asSnd, capture, asCat, evalCatT, evalCat, - Cat, runCat, newCatT, catTraverse, evalScoped, execCat, execCatT, - catFold, catFoldM, catMap, catMapM) where - --- Monad for tracking monoidal state - -import Control.Applicative -import Control.Monad.State.Strict -import Control.Monad.Reader -import Control.Monad.Writer -import Control.Monad.Identity -import Control.Monad.Except hiding (Except) - -import Err - -newtype CatT env m a = CatT (StateT (env, env) m a) - deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFail, Alternative, - Fallible) - -type Cat env = CatT env Identity - -class (Monoid env, Monad m) => MonadCat env m | m -> env where - look :: m env - extend :: env -> m () - scoped :: m a -> m (a, env) - -instance (Monoid env, Monad m) => MonadCat env (CatT env m) where - look = CatT $ gets fst - extend x = CatT $ do - (fullState, localState) <- get - put (fullState <> x, localState <> x) - scoped (CatT m) = CatT $ do - originalState <- get - put (fst originalState, mempty) - ans <- m - newLocalState <- gets snd - put originalState - return (ans, newLocalState) - -instance MonadCat env m => MonadCat env (StateT s m) where - look = lift look - extend x = lift $ extend x - scoped m = StateT \s -> do - ((ans, s'), env) <- scoped $ runStateT m s - return $ ((ans, env), s') - -instance MonadCat env m => MonadCat env (ReaderT r m) where - look = lift look - extend x = lift $ extend x - scoped m = do r <- ask - lift $ scoped $ runReaderT m r - -instance (Monoid w, MonadCat env m) => MonadCat env (WriterT w m) where - look = lift look - extend x = lift $ extend x - scoped m = do ((x, w), env) <- lift $ scoped $ runWriterT m - tell w - return (x, env) - -instance MonadCat env m => MonadCat env (ExceptT e m) where - look = lift look - extend x = lift $ extend x - scoped m = do (xerr, env) <- lift $ scoped $ runExceptT m - case xerr of - Left err -> throwError err - Right x -> return (x, env) - -instance (Monoid env, MonadReader r m) => MonadReader r (CatT env m) where - ask = lift ask - local f m = do - env <- look - (ans, env') <- lift $ local f $ runCatT m env - extend env' - return ans - -runCatT :: (Monoid env, Monad m) => CatT env m a -> env -> m (a, env) -runCatT (CatT m) initEnv = do - (ans, (_, newEnv)) <- runStateT m (initEnv, mempty) - return (ans, newEnv) - -evalCatT :: (Monoid env, Monad m) => CatT env m a -> m a -evalCatT m = fst <$> runCatT m mempty - -execCatT :: (Monoid env, Monad m) => CatT env m a -> m env -execCatT m = snd <$> runCatT m mempty - -newCatT :: (Monoid env, Monad m) => (env -> m (a, env)) -> CatT env m a -newCatT f = do - env <- look - (ans, env') <- lift $ f env - extend env' - return ans - -runCat :: Monoid env => Cat env a -> env -> (a, env) -runCat m env = runIdentity $ runCatT m env - -evalCat :: Monoid env => Cat env a -> a -evalCat m = runIdentity $ evalCatT m - -execCat :: Monoid env => Cat env a -> env -execCat m = runIdentity $ execCatT m - -looks :: (Monoid env, MonadCat env m) => (env -> a) -> m a -looks getter = liftM getter look - -evalScoped :: Monoid env => Cat env a -> Cat env a -evalScoped m = fst <$> scoped m - -capture :: (Monoid env, MonadCat env m) => m a -> m (a, env) -capture m = do - (x, env) <- scoped m - extend env - return (x, env) - -extendLocal :: (Monoid env, MonadCat env m) => env -> m a -> m a -extendLocal x m = do - ((ans, env), _) <- scoped $ do extend x - scoped m - extend env - return ans - --- Not part of the cat monad, but related utils for monoidal envs - -catTraverse :: (Monoid menv, MonadReader env m, Traversable f) - => (a -> m (b, menv)) -> (menv -> env) -> f a -> menv -> m (f b, menv) -catTraverse f inj xs env = runCatT (traverse (asCat f inj) xs) env - -catFoldM :: (Monoid env, Traversable t, Monad m) - => (env -> a -> m env) -> env -> t a -> m env -catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs \x -> do - cur <- look - new <- lift $ f cur x - extend new - -catFold :: (Monoid env, Traversable t) - => (env -> a -> env) -> env -> t a -> env -catFold f env xs = runIdentity $ catFoldM (\e x -> Identity $ f e x) env xs - -catMapM :: (Monoid env, Traversable t, Monad m) - => (env -> a -> m (b, env)) -> env -> t a -> m (t b, env) -catMapM f env xs = flip runCatT env $ forM xs \x -> do - cur <- look - (y, new) <- lift $ f cur x - extend new - return y - -catMap :: (Monoid env, Traversable t) - => (env -> a -> (b, env)) -> env -> t a -> (t b, env) -catMap f env xs = runIdentity $ catMapM (\e x -> Identity $ f e x) env xs - -asCat :: (Monoid menv, MonadReader env m) - => (a -> m (b, menv)) -> (menv -> env) -> a -> CatT menv m b -asCat f inj x = do - env' <- look - (x', env'') <- lift $ local (const $ inj env') (f x) - extend env'' - return x' - -extendR :: (Monoid env, MonadReader env m) => env -> m a -> m a -extendR x m = local (<> x) m - -asFst :: Monoid b => a -> (a, b) -asFst x = (x, mempty) - -asSnd :: Monoid a => b -> (a, b) -asSnd y = (mempty, y) - -captureW :: MonadWriter w m => m a -> m (a, w) -captureW m = censor (const mempty) (listen m) diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 47e279ee2..d0ad6c9da 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -4,12 +4,15 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd +{-# LANGUAGE UndecidableInstances #-} + module Err (Err (..), ErrType (..), Except (..), Fallible (..), Catchable (..), catchErrExcept, HardFailM (..), runHardFail, throw, catchIOExcept, liftExcept, liftExceptAlt, assertEq, ignoreExcept, - pprint, docAsStr, getCurrentCallStack, printCurrentCallStack + pprint, docAsStr, getCurrentCallStack, printCurrentCallStack, + ExceptT (..) ) where import Control.Exception hiding (throw) @@ -80,6 +83,61 @@ instance Catchable IO where Success result -> return result Failure errs -> handler errs +-- === ExceptT type === + +newtype ExceptT m a = ExceptT { runExceptT :: m (Except a) } + +instance Monad m => Functor (ExceptT m) where + fmap = liftM + {-# INLINE fmap #-} + +instance Monad m => Applicative (ExceptT m) where + pure = return + {-# INLINE pure #-} + liftA2 = liftM2 + {-# INLINE liftA2 #-} + +instance Monad m => Monad (ExceptT m) where + return x = ExceptT $ return (Success x) + {-# INLINE return #-} + m >>= f = ExceptT $ runExceptT m >>= \case + Failure errs -> return $ Failure errs + Success x -> runExceptT $ f x + {-# INLINE (>>=) #-} + +instance Monad m => MonadFail (ExceptT m) where + fail s = ExceptT $ return $ Failure $ Err SearchFailure s + {-# INLINE fail #-} + +instance Monad m => Fallible (ExceptT m) where + throwErr errs = ExceptT $ return $ Failure errs + {-# INLINE throwErr #-} + +instance Monad m => Alternative (ExceptT m) where + empty = throw SearchFailure "" + {-# INLINE empty #-} + m1 <|> m2 = do + catchSearchFailure m1 >>= \case + Nothing -> m2 + Just x -> return x + {-# INLINE (<|>) #-} + +instance Monad m => Catchable (ExceptT m) where + m `catchErr` f = ExceptT $ runExceptT m >>= \case + Failure errs -> runExceptT $ f errs + Success x -> return $ Success x + {-# INLINE catchErr #-} + +instance MonadState s m => MonadState s (ExceptT m) where + get = lift get + {-# INLINE get #-} + put x = lift $ put x + {-# INLINE put #-} + +instance MonadTrans ExceptT where + lift m = ExceptT $ Success <$> m + {-# INLINE lift #-} + -- === Except type === -- Except is isomorphic to `Either Err` but having a distinct type makes it From c23ef4faa90354931aa05fc38c0f8aede873347a Mon Sep 17 00:00:00 2001 From: Dougal Date: Fri, 1 Dec 2023 22:11:51 -0500 Subject: [PATCH 33/41] Show types on hover! --- src/lib/Inference.hs | 40 ++++++++++++++++++++++++++++------------ static/index.js | 4 +++- static/style.css | 2 +- 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index bd1964d15..f6fb4d926 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -34,6 +34,7 @@ import CheckType import Core import Err import IRVariants +import MonadUtil import MTL1 import Name import Subst @@ -47,11 +48,11 @@ import Util hiding (group) -- === Top-level interface === -checkTopUType :: (Fallible1 m, EnvReader m) => UType n -> m n (CType n) +checkTopUType :: (Fallible1 m, TopLogger m, EnvReader m) => UType n -> m n (CType n) checkTopUType ty = liftInfererM $ checkUType ty {-# SCC checkTopUType #-} -inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) +inferTopUExpr :: (Fallible1 m, TopLogger m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) inferTopUExpr e = fst <$> (asTopBlock =<< liftInfererM (buildBlock $ bottomUp e)) {-# SCC inferTopUExpr #-} @@ -60,7 +61,9 @@ data UDeclInferenceResult e n = | UDeclResultBindName LetAnn (TopBlock CoreIR n) (Abs (UBinder (AtomNameC CoreIR)) e n) | UDeclResultBindPattern NameHint (TopBlock CoreIR n) (ReconAbs CoreIR e n) -inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, HasNamesE e) +type TopLogger (m::MonadKind1) = forall n. Logger Outputs (m n) + +inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, HasNamesE e, TopLogger m) => UTopDecl n l -> e l -> m n (UDeclInferenceResult e n) inferTopUDecl (UStructDecl tc def) result = do tc' <- emitBinding (getNameHint tc) $ TyConBinding Nothing (DotMethods mempty) @@ -153,7 +156,7 @@ data InfState (n::S) = InfState , infEffects :: EffectRow CoreIR n } newtype InfererM (i::S) (o::S) (a:: *) = InfererM - { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR Except)) i o a } + { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR (ExceptT (State TypeInfo)))) i o a } deriving (Functor, Applicative, Monad, MonadFail, Alternative, Builder CoreIR, EnvExtender, ScopableBuilder CoreIR, ScopeReader, EnvReader, Fallible, Catchable, SubstReader Name) @@ -161,14 +164,21 @@ newtype InfererM (i::S) (o::S) (a:: *) = InfererM type InfererCPSB b i o a = (forall o'. DExt o o' => b o o' -> InfererM i o' a) -> InfererM i o a type InfererCPSB2 b i i' o a = (forall o'. DExt o o' => b o o' -> InfererM i' o' a) -> InfererM i o a -liftInfererM :: (EnvReader m, Fallible1 m) => InfererM n n a -> m n a +liftInfererM :: (EnvReader m, TopLogger m, Fallible1 m) => InfererM n n a -> m n a liftInfererM cont = do + (ansExcept, typeInfo) <- liftInfererMPure cont + emitLog $ Outputs [SourceInfo $ SITypeInfo typeInfo] + liftExcept ansExcept +{-# INLINE liftInfererM #-} + +liftInfererMPure :: EnvReader m => InfererM n n a -> m n (Except a, TypeInfo) +liftInfererMPure cont = do ansM <- liftBuilderT $ runReaderT1 emptyInfState $ runSubstReaderT idSubst $ runInfererM' cont - liftExcept ansM + return $ runState (runExceptT ansM) mempty where emptyInfState :: InfState n emptyInfState = InfState (Givens HM.empty) Pure -{-# INLINE liftInfererM #-} +{-# INLINE liftInfererMPure #-} -- === Solver monad === @@ -346,6 +356,11 @@ withAllowedEffects :: EffectRow CoreIR o -> InfererM i o a -> InfererM i o a withAllowedEffects effs cont = withInfState (\(InfState g _) -> InfState g effs) cont {-# INLINE withAllowedEffects #-} +emitTypeInfo :: SrcId -> String -> InfererM i o () +emitTypeInfo sid ty = do + InfererM $ liftSubstReaderT $ lift11 $ lift1 $ lift do + modify \(TypeInfo m) -> TypeInfo $ M.insert sid ty m + -- === actual inference pass === data RequiredTy (n::S) = @@ -461,10 +476,11 @@ bottomUp expr = bottomUpExplicit expr >>= instantiateSigma Infer -- Doesn't instantiate implicit args bottomUpExplicit :: Emits o => UExpr i -> InfererM i o (SigmaAtom o) -bottomUpExplicit (WithSrcE _ expr) = case expr of +bottomUpExplicit (WithSrcE sid expr) = case expr of UVar ~(InternalName _ sn v) -> do v' <- renameM v ty <- getUVarType v' + emitTypeInfo sid $ pprint sn ++ " : " ++ pprint ty return $ SigmaUVar sn ty v' ULit l -> return $ SigmaAtom Nothing $ Con $ Lit l UFieldAccess x (WithSrc _ field) -> do @@ -2012,7 +2028,7 @@ makeStructRepVal tyConName args = do -- shortcut of just generalizing the data parameters. generalizeDict :: EnvReader m => CType n -> CDict n -> m n (CDict n) generalizeDict ty dict = do - result <- liftEnvReaderT $ liftInfererM $ generalizeDictRec ty dict + result <- liftEnvReaderM $ liftM fst $ liftInfererMPure $ generalizeDictRec ty dict case result of Failure e -> error $ "Failed to generalize " ++ pprint dict ++ " to " ++ show ty ++ " because " ++ pprint e @@ -2349,10 +2365,10 @@ checkScalarOrPairType ty = throw TypeErr $ pprint ty instance DiffStateE SolverSubst SolverDiff where updateDiffStateE :: forall n. Distinct n => Env n -> SolverSubst n -> SolverDiff n -> SolverSubst n - updateDiffStateE _ initState (SolverDiff (RListE diffs)) = foldl update initState (unsnoc diffs) + updateDiffStateE _ initState (SolverDiff (RListE diffs)) = foldl update' initState (unsnoc diffs) where - update :: Distinct n => SolverSubst n -> Solution n -> SolverSubst n - update (SolverSubst subst) (PairE v x) = SolverSubst $ M.insert v x subst + update' :: Distinct n => SolverSubst n -> Solution n -> SolverSubst n + update' (SolverSubst subst) (PairE v x) = SolverSubst $ M.insert v x subst instance SinkableE InfState where sinkingProofE _ = todoSinkableProof diff --git a/static/index.js b/static/index.js index 5a0e0693f..4a9044db1 100644 --- a/static/index.js +++ b/static/index.js @@ -73,7 +73,9 @@ function applyHover(cellId, srcId) { } function applyHoverInfo(cellId, srcId) { let hoverInfo = lookupSrcMap(hoverInfoMap, cellId, srcId) - hoverInfoDiv.innerHTML = srcId.toString() // hoverInfo + if (hoverInfo !== undefined) { + hoverInfoDiv.innerHTML = hoverInfo + } } function applyHoverHighlights(cellId, srcId) { let highlights = lookupSrcMap(highlightMap, cellId, srcId) diff --git a/static/style.css b/static/style.css index 5c83bb2c1..e9f78a6d9 100644 --- a/static/style.css +++ b/static/style.css @@ -16,7 +16,7 @@ body { #hover-info { position: fixed; - height: 10rem; + height: 5rem; bottom: 0em; width: 100vw; overflow: hidden; From 6f62fb4cb19fa9327eaeee4a592afcabcb7f45f2 Mon Sep 17 00:00:00 2001 From: Dougal Date: Sat, 2 Dec 2023 16:56:30 -0500 Subject: [PATCH 34/41] Update prelude to use sugarfree versions of RangeX types --- lib/prelude.dx | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index badf5e39e..be288660a 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -333,17 +333,13 @@ def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = y' = nat_to_rep y rep_to_nat %isub(x', y') --- `(i..)` parses as `RangeFrom(i)` -- TODO: need to a way to indicate constructor as private struct RangeFrom(i:q) given (q:Type) = val : Nat --- `(i<..)` parses as `RangeFromExc i` struct RangeFromExc(i:q) given (q:Type) = val : Nat --- `(..i)` parses as `RangeTo i` struct RangeTo(i:q) given (q:Type) = val : Nat --- `(..a) given (a|Add, n|Ix) instance Sub(n=>a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables +instance Add((i:n) => RangeFrom i => a) given (a|Add, n|Ix) -- Upper triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables +instance Sub((i:n) => RangeFrom i => a) given (a|Sub, n|Ix) -- Upper triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables +instance Add((i:n) => RangeTo i => a) given (a|Add, n|Ix) -- Lower triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables +instance Sub((i:n) => RangeTo i => a) given (a|Sub, n|Ix) -- Lower triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (.. a) given (a|Add, n|Ix) +instance Add((i:n) => RangeToExc i => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (.. a) given (a|Sub, n|Ix) +instance Sub((i:n) => RangeToExc i => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix) +instance Add((i:n) => RangeFromExc i => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix) +instance Sub((i:n) => RangeFromExc i => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] instance Mul(n=>a) given (a|Mul, n|Ix) @@ -478,16 +474,16 @@ instance VSpace((a, b)) given (a|VSpace, b|VSpace) (x, y) = pair (s .* x, s .* y) -instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeTo i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeFrom i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (.. a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeToExc i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeFromExc i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace(()) @@ -1840,16 +1836,16 @@ instance Arbitrary(Nat) instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(.. a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeToExc i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeTo i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeFrom i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeFromExc i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) From 75c41841a321f0984ebae6ee9326f8377e5262ca Mon Sep 17 00:00:00 2001 From: Dougal Date: Sat, 2 Dec 2023 20:12:07 -0500 Subject: [PATCH 35/41] Move Pretty instances to where the data types are defined. This avoids circular import issues and orphan instances. Also move top-level data types out of Types.Core to make the file size more reasonable. --- dex.cabal | 3 +- src/dex.hs | 22 +- src/lib/Builder.hs | 3 +- src/lib/CheapReduction.hs | 1 + src/lib/CheckType.hs | 1 + src/lib/ConcreteSyntax.hs | 98 --- src/lib/Core.hs | 1 + src/lib/Err.hs | 20 +- src/lib/Export.hs | 1 + src/lib/Generalize.hs | 1 + src/lib/Imp.hs | 4 +- src/lib/ImpToLLVM.hs | 1 + src/lib/Inference.hs | 1 + src/lib/Inline.hs | 1 + src/lib/JAX/ToSimp.hs | 1 + src/lib/Linearize.hs | 1 + src/lib/Lower.hs | 1 + src/lib/MTL1.hs | 2 +- src/lib/Name.hs | 20 + src/lib/OccAnalysis.hs | 3 +- src/lib/Occurrence.hs | 13 + src/lib/Optimize.hs | 1 + src/lib/PPrint.hs | 976 +------------------------- src/lib/QueryType.hs | 1 + src/lib/QueryTypePure.hs | 1 + src/lib/Runtime.hs | 2 +- src/lib/Simplify.hs | 1 + src/lib/Simplify.hs-boot | 1 + src/lib/SourceRename.hs | 2 +- src/lib/Subst.hs | 1 + src/lib/TopLevel.hs | 2 +- src/lib/Transpose.hs | 5 +- src/lib/Types/Core.hs | 1294 ++++++++++------------------------- src/lib/Types/Imp.hs | 95 ++- src/lib/Types/OpNames.hs | 7 + src/lib/Types/Primitives.hs | 76 +- src/lib/Types/Source.hs | 364 +++++++++- src/lib/Types/Top.hs | 1046 ++++++++++++++++++++++++++++ src/lib/Util.hs | 6 + src/lib/Vectorize.hs | 3 +- 40 files changed, 2077 insertions(+), 2006 deletions(-) create mode 100644 src/lib/Types/Top.hs diff --git a/dex.cabal b/dex.cabal index faa4e372e..50703eeb6 100644 --- a/dex.cabal +++ b/dex.cabal @@ -94,6 +94,7 @@ library , Types.Primitives , Types.OpNames , Types.Source + , Types.Top , QueryType , QueryTypePure , Util @@ -124,7 +125,6 @@ library , prettyprinter , text -- Portable system utilities - , ansi-terminal , directory , filepath , haskeline @@ -234,6 +234,7 @@ executable dex main-is: dex.hs build-depends: dex , ansi-wl-pprint + , ansi-terminal , base , bytestring , containers diff --git a/src/dex.hs b/src/dex.hs index 6d88c86db..623c2bdec 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -21,8 +21,9 @@ import Data.List import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Map.Strict as M +import qualified System.Console.ANSI as ANSI +import System.Console.ANSI hiding (Color) -import PPrint (printOutput) import TopLevel import Err import Name @@ -35,6 +36,7 @@ import Core import Types.Core import Types.Imp import Types.Source +import Types.Top import MonadUtil data DocFmt = ResultOnly @@ -193,6 +195,24 @@ stdOutLogger (Outputs outs) = do isatty <- queryTerminal stdOutput forM_ outs \out -> putStr $ printOutput isatty out +printOutput :: Bool -> Output -> String +printOutput isatty out = case out of + Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out + _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out + +addPrefix :: String -> String -> String +addPrefix prefix str = unlines $ map prefixLine $ lines str + where prefixLine :: String -> String + prefixLine s = case s of "" -> prefix + _ -> prefix ++ " " ++ s + +addColor :: Bool -> ANSI.Color -> String -> String +addColor False _ s = s +addColor True c s = + setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] + ++ s ++ setSGRCode [Reset] + + pathOption :: ReadM [LibPath] pathOption = splitPaths [] <$> str where diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index a12b5c8b3..7415cec74 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -35,6 +35,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util (enumerate, transitiveClosureM, bindM2, toSnocList) -- === Ordinary (local) builder class === @@ -281,7 +282,7 @@ emitTopLet hint letAnn expr = do v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr) return $ AtomVar v ty -emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n) +emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> TopLam SimpIR n -> m n (TopFunName n) emitTopFunBinding hint def f = do emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 6fdd7280f..8df743f24 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -32,6 +32,7 @@ import Name import PPrint () import QueryTypePure import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Util diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index f808e7153..9bcbf029d 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -27,6 +27,7 @@ import QueryType import Types.Core import Types.Primitives import Types.Source +import Types.Top -- === top-level API === diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 838089e7c..70bc67b9e 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -17,12 +17,10 @@ import Data.Char import Data.Either import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import Data.Map qualified as M import Data.String (fromString) import Data.Text (Text) import Data.Text qualified as T import Data.Text.Encoding qualified as T -import Data.Tuple import Data.Void import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) @@ -31,7 +29,6 @@ import Lexing import Types.Core import Types.Source import Types.Primitives -import qualified Types.OpNames as P import Util -- TODO: implement this more efficiently rather than just parsing the whole @@ -697,101 +694,6 @@ withSrcs p = do (sids, result) <- collectAtomicLexemeIds p return $ WithSrcs sid sids result --- === primitive constructors and operators === - -strToPrimName :: String -> Maybe PrimName -strToPrimName s = M.lookup s primNames - -primNameToStr :: PrimName -> String -primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of - Just s -> s - Nothing -> show prim - -showPrimName :: PrimName -> String -showPrimName prim = primNameToStr prim -{-# NOINLINE showPrimName #-} - -primNames :: M.Map String PrimName -primNames = M.fromList - [ ("ask" , UMAsk), ("mextend", UMExtend) - , ("get" , UMGet), ("put" , UMPut) - , ("while" , UWhile) - , ("linearize", ULinearize), ("linearTranspose", UTranspose) - , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) - , ("runIO" , URunIO ), ("catchException" , UCatchException) - , ("iadd" , binary IAdd), ("isub" , binary ISub) - , ("imul" , binary IMul), ("fdiv" , binary FDiv) - , ("fadd" , binary FAdd), ("fsub" , binary FSub) - , ("fmul" , binary FMul), ("idiv" , binary IDiv) - , ("irem" , binary IRem) - , ("fpow" , binary FPow) - , ("and" , binary BAnd), ("or" , binary BOr ) - , ("not" , unary BNot), ("xor" , binary BXor) - , ("shl" , binary BShL), ("shr" , binary BShR) - , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) - , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) - , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) - , ("fneg" , unary FNeg) - , ("exp" , unary Exp), ("exp2" , unary Exp2) - , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) - , ("sin" , unary Sin), ("cos" , unary Cos) - , ("tan" , unary Tan), ("sqrt" , unary Sqrt) - , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) - , ("log1p", unary Log1p), ("lgamma", unary LGamma) - , ("erf" , unary Erf), ("erfc" , unary Erfc) - , ("TyKind" , UPrimTC $ P.TypeKind) - , ("Float64" , baseTy $ Scalar Float64Type) - , ("Float32" , baseTy $ Scalar Float32Type) - , ("Int64" , baseTy $ Scalar Int64Type) - , ("Int32" , baseTy $ Scalar Int32Type) - , ("Word8" , baseTy $ Scalar Word8Type) - , ("Word32" , baseTy $ Scalar Word32Type) - , ("Word64" , baseTy $ Scalar Word64Type) - , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) - , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) - , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) - , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) - , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) - , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) - , ("Nat" , UNat) - , ("Fin" , UFin) - , ("EffKind" , UEffectRowKind) - , ("NatCon" , UNatCon) - , ("Ref" , UPrimTC $ P.RefType) - , ("HeapType" , UPrimTC $ P.HeapType) - , ("indexRef" , UIndexRef) - , ("alloc" , memOp $ P.IOAlloc) - , ("free" , memOp $ P.IOFree) - , ("ptrOffset", memOp $ P.PtrOffset) - , ("ptrLoad" , memOp $ P.PtrLoad) - , ("ptrStore" , memOp $ P.PtrStore) - , ("throwError" , miscOp $ P.ThrowError) - , ("throwException", miscOp $ P.ThrowException) - , ("dataConTag" , miscOp $ P.SumTag) - , ("toEnum" , miscOp $ P.ToEnum) - , ("outputStream" , miscOp $ P.OutputStream) - , ("cast" , miscOp $ P.CastOp) - , ("bitcast" , miscOp $ P.BitcastOp) - , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) - , ("garbageVal" , miscOp $ P.GarbageVal) - , ("select" , miscOp $ P.Select) - , ("showAny" , miscOp $ P.ShowAny) - , ("showScalar" , miscOp $ P.ShowScalar) - , ("projNewtype" , UProjNewtype) - , ("applyMethod0" , UApplyMethod 0) - , ("applyMethod1" , UApplyMethod 1) - , ("applyMethod2" , UApplyMethod 2) - , ("explicitApply", UExplicitApply) - , ("monoLit", UMonoLiteral) - ] - where - binary op = UBinOp op - baseTy b = UBaseType b - memOp op = UMemOp op - unary op = UUnOp op - ptrTy ty = PtrType (CPU, ty) - miscOp op = UMiscOp op - -- === notes === -- note [if-syntax] diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 2c60f846e..e420b50a7 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -37,6 +37,7 @@ import Err import IRVariants import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Types.Source diff --git a/src/lib/Err.hs b/src/lib/Err.hs index d0ad6c9da..51b34eb1f 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -24,12 +24,10 @@ import Control.Monad.State.Strict import Control.Monad.Reader import Data.Coerce import Data.Foldable (fold) -import Data.Text qualified as T -import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc import GHC.Stack -import System.Environment -import System.IO.Unsafe + +import PPrint -- === core API === @@ -285,20 +283,6 @@ instance Fallible Maybe where throwErr _ = Nothing {-# INLINE throwErr #-} --- === small pretty-printing utils === --- These are here instead of in PPrint.hs for import cycle reasons - -pprint :: Pretty a => a -> String -pprint x = docAsStr $ pretty x -{-# SCC pprint #-} - -docAsStr :: Doc ann -> String -docAsStr doc = T.unpack $ renderStrict $ layoutPretty layout $ doc - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> lookupEnv "DEX_PPRINT_UNBOUNDED" - -- === instances === instance Fallible Except where diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 7983f52cb..1108507fb 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -29,6 +29,7 @@ import Subst hiding (Rename) import TopLevel import Types.Core import Types.Imp +import Types.Top import Types.Primitives hiding (sizeOf) type ExportAtomNameC = AtomNameC CoreIR diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 945552a7d..7ace599c4 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -18,6 +18,7 @@ import QueryType import Name import Subst import Types.Primitives +import Types.Top type RolePiBinder = WithAttrB RoleExpl CBinder type RolePiBinders = Nest RolePiBinder diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 7ab6c865c..07fada480 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -42,10 +42,10 @@ import QueryType import Types.Core import Types.Imp import Types.Primitives +import Types.Top import Util (forMFilter, Tree (..), zipTrees, enumerate) -toImpFunction :: EnvReader m - => CallingConvention -> STopLam n -> m n (ImpFunction n) +toImpFunction :: EnvReader m => CallingConvention -> STopLam n -> m n (ImpFunction n) toImpFunction cc (TopLam True destTy lam) = do LamExpr bsAndRefB body <- return lam PairB bs destB <- case popNest bsAndRefB of diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 556f927bd..f333c1f08 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -56,6 +56,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util (IsBool (..), bindM2, enumerate) -- === Compile monad === diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index f6fb4d926..bea45b654 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -43,6 +43,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import qualified Types.OpNames as P import Util hiding (group) diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index f3ada792d..f72f24bef 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -17,6 +17,7 @@ import Occurrence hiding (Var) import PeepholeOptimize import Types.Core import Types.Primitives +import Types.Top -- === External API === diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index cdf25d73b..7466d237b 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -17,6 +17,7 @@ import JAX.Concrete import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives qualified as P newtype JaxSimpM (i::S) (o::S) a = JaxSimpM diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 98bbb7d39..ee61d8437 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -27,6 +27,7 @@ import PPrint import QueryType import Types.Core import Types.Primitives +import Types.Top import Util (enumerate) -- === linearization monad === diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index cf28b0667..db7b83fad 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -26,6 +26,7 @@ import Name import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index 2011fa64d..47fe8b8c1 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -17,7 +17,7 @@ import Data.Foldable (toList) import Name import Err -import Types.Core (Env) +import Types.Top (Env) import Core (EnvReader (..), EnvExtender (..)) import Util (SnocList (..), snoc, emptySnocList) diff --git a/src/lib/Name.hs b/src/lib/Name.hs index dc36f6c38..fd23def5e 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -44,6 +44,7 @@ import RawName ( RawNameMap, RawName, NameHint, HasNameHint (..) , freshRawName, rawNameFromHint, rawNames, noHint) import qualified RawName as R import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..), unsnoc ) +import PPrint import Err import IRVariants @@ -445,6 +446,9 @@ type OrdE e = (forall (n::S) . Ord (e n )) :: Constraint type OrdV v = (forall (c::C) (n::S). Ord (v c n)) :: Constraint type OrdB b = (forall (n::S) (l::S). Ord (b n l)) :: Constraint +type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint +type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint + type HashableE (e::E) = forall n. Hashable (e n) data UnitE (n::S) = UnitE @@ -2164,6 +2168,8 @@ instance PrettyE e => Pretty (ListE e n) where instance PrettyE e => Pretty (RListE e n) where pretty (RListE e) = pretty $ unsnoc e +deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) + instance ( Generic (b UnsafeS UnsafeS) , Generic (body UnsafeS) ) => Generic (Abs b body n) where @@ -2746,6 +2752,9 @@ canonicalizeForPrinting e cont = do ClosedWithScope scope e' -> cont $ renameE (scope, newSubst id) e' +pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String +pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' + liftHoistExcept :: Fallible m => HoistExcept a -> m a liftHoistExcept (HoistSuccess x) = return x liftHoistExcept (HoistFailure vs) = throw EscapedNameErr (pprint vs) @@ -2887,6 +2896,10 @@ abstractFreeVarsNoAnn vs e = Abs bs e' -> Abs bs' e' where bs' = fmapNest (\(b:>UnitE) -> b) bs +unsafeFromNest :: Nest b n l -> [b UnsafeS UnsafeS] +unsafeFromNest Empty = [] +unsafeFromNest (Nest b rest) = unsafeCoerceB b : unsafeFromNest rest + instance Color c => HoistableB (NameBinder c) where freeVarsB _ = mempty @@ -3389,6 +3402,13 @@ hoistNameMap b = ignoreHoistFailure . hoistNameMapE b unsafeCoerceIRE :: forall (r'::IR) (r::IR) (e::IR->E) (n::S). e r n -> e r' n unsafeCoerceIRE = TrulyUnsafe.unsafeCoerce +-- === Pretty instances === + +instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty + +instance PrettyE ann => Pretty (BinderP c ann n l) + where pretty (b:>ty) = pretty b <> ":" <> pretty ty + -- === notes === {- diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index fcf04cdf2..0e75165be 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -20,6 +20,7 @@ import Occurrence hiding (Var) import Occurrence qualified as Occ import Types.Core import Types.Primitives +import Types.Top import QueryType -- === External API === @@ -28,7 +29,7 @@ import QueryType -- annotation holding a summary of how that binding is used. It also eliminates -- unused pure bindings as it goes, since it has all the needed information. -analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) +analyzeOccurrences :: EnvReader m => TopLam SimpIR n -> m n (TopLam SimpIR n) analyzeOccurrences lam = liftLamExpr lam \e -> liftOCCM $ occ accessOnce e {-# INLINE analyzeOccurrences #-} diff --git a/src/lib/Occurrence.hs b/src/lib/Occurrence.hs index 5e024e854..ea8248de8 100644 --- a/src/lib/Occurrence.hs +++ b/src/lib/Occurrence.hs @@ -19,6 +19,7 @@ import Data.List (foldl') import Data.Store (Store (..)) import GHC.Generics (Generic (..)) +import PPrint import IRVariants import Name @@ -888,3 +889,15 @@ instance RenameE AccessInfo instance Hashable UsageInfo instance Store UsageInfo + +-- === instances === + +instance Pretty UsageInfo where + pretty (UsageInfo static (ixDepth, ct)) = + "occurs in" <+> pretty static <+> "places, read" + <+> pretty ct <+> "times, to depth" <+> pretty (show ixDepth) + +instance Pretty Count where + pretty = \case + Bounded ct -> "<=" <+> pretty ct + Unbounded -> "many" diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 425291cd4..1ed73ff23 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -15,6 +15,7 @@ import Control.Monad.State.Strict import Types.Core import Types.Primitives +import Types.Top import MTL1 import Name import Subst diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 0344bd861..b16559fa5 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -6,40 +6,34 @@ {-# LANGUAGE IncoherentInstances #-} -- due to `ConRef` {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -Wno-orphans #-} module PPrint ( - pprint, pprintCanonicalized, pprintList, asStr , atPrec, - PrettyPrec(..), PrecedenceLevel (..), prettyBlock, - printOutput, prettyFromPrettyPrec) where + Pretty (..), Doc, DocPrec, (<+>), pprint, pprintList, asStr , atPrec, + pAppArg, pApp, pArg, hardline, PrettyPrec(..), PrecedenceLevel (..), + docAsStr, parensSep, prettyLines, sep, pLowest, prettyFromPrettyPrec, + indented, commaSep, spaced, spaceIfColinear, encloseSep) where -import GHC.Exts (Constraint) -import GHC.Float import Data.Foldable (toList, fold) -import qualified Data.Map.Strict as M import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc -import Data.Text (Text, snoc, uncons, unsnoc, unpack) -import qualified Data.Set as S -import Data.String (fromString) -import qualified System.Console.ANSI as ANSI -import System.Console.ANSI hiding (Color) +import Data.Text (unpack) import System.IO.Unsafe import qualified System.Environment as E -import Numeric -import ConcreteSyntax -import Err -import IRVariants -import Name -import Occurrence (Count (Bounded), UsageInfo (..)) -import Occurrence qualified as Occ -import Types.Core -import Types.Imp -import Types.Primitives -import Types.Source -import QueryTypePure -import Util (Tree (..)) +-- === small pretty-printing utils === + +pprint :: Pretty a => a -> String +pprint x = docAsStr $ pretty x +{-# SCC pprint #-} + +docAsStr :: Doc ann -> String +docAsStr doc = unpack $ renderStrict $ layoutPretty layout $ doc + +layout :: LayoutOptions +layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions + where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" + +-- === DocPrec === -- A DocPrec is a slightly context-aware Doc, specifically one that -- knows the precedence level of the immediately enclosing operation, @@ -93,31 +87,12 @@ prettyFromPrettyPrec = pArg pAppArg :: (PrettyPrec a, Foldable f) => Doc ann -> f a -> Doc ann pAppArg name as = align $ name <> group (nest 2 $ foldMap (\a -> line <> pArg a) as) -fromInfix :: Text -> Maybe Text -fromInfix t = do - ('(', t') <- uncons t - (t'', ')') <- unsnoc t' - return t'' - -type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint -type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint - -pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String -pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' - pprintList :: Pretty a => [a] -> String -pprintList xs = asStr $ vsep $ punctuate "," (map p xs) - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" +pprintList xs = asStr $ vsep $ punctuate "," (map pretty xs) asStr :: Doc ann -> String asStr doc = unpack $ renderStrict $ layoutPretty layout $ doc -p :: Pretty a => a -> Doc ann -p = pretty - pLowest :: PrettyPrec a => a -> Doc ann pLowest a = prettyPrec a LowestPrec @@ -127,17 +102,8 @@ pApp a = prettyPrec a AppPrec pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec -prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann -prettyBlock Empty expr = group $ line <> pLowest expr -prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -fromNest :: Nest b n l -> [b UnsafeS UnsafeS] -fromNest Empty = [] -fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest - prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann -prettyLines xs = foldMap (\d -> hardline <> p d) $ toList xs +prettyLines xs = foldMap (\d -> hardline <> pretty d) $ toList xs parensSep :: Doc ann -> [Doc ann] -> Doc ann parensSep separator items = encloseSep "(" ")" separator items @@ -148,907 +114,13 @@ spaceIfColinear = flatAlt "" space instance PrettyPrec a => PrettyPrec [a] where prettyPrec xs = atPrec ArgPrec $ hsep $ map pLowest xs -instance PrettyE ann => Pretty (BinderP c ann n l) - where pretty (b:>ty) = p b <> ":" <> p ty - -instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Expr r n) where - prettyPrec = \case - Atom x -> prettyPrec x - Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body - App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) - TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) - TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) - Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs - TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es - PrimOp op -> prettyPrec op - ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs - Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x - Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x - -prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann -prettyPrecCase name e alts effs = atPrec LowestPrec $ - name <+> pApp e <+> "of" <> - nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts - <> effectLine effs) - where - effectLine :: IRRep r => EffectRow r n -> Doc ann - effectLine Pure = "" - effectLine row = hardline <> "case annotated with effects" <+> p row - -prettyAlt :: IRRep r => Alt r n -> Doc ann -prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (p body) - -prettyBinderNoAnn :: Binder r n l -> Doc ann -prettyBinderNoAnn (b:>_) = p b - -instance (IRRep r, PrettyPrecE e) => Pretty (Abs (Binder r) e n) where pretty = prettyFromPrettyPrec -instance (IRRep r, PrettyPrecE e) => PrettyPrec (Abs (Binder r) e n) where - prettyPrec (Abs binder body) = atPrec LowestPrec $ "\\" <> p binder <> "." <> pLowest body - -instance IRRep r => Pretty (DeclBinding r n) where - pretty (DeclBinding ann expr) = "Decl" <> p ann <+> p expr - -instance IRRep r => Pretty (Decl r n l) where - pretty (Let b (DeclBinding ann rhs)) = - align $ annDoc <> p (b:>getType rhs) <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann - -instance IRRep r => Pretty (PiType r n) where - pretty (PiType bs (EffTy effs resultTy)) = - (spaced $ fromNest $ bs) <+> "->" <+> "{" <> p effs <> "}" <+> p resultTy - -instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (LamExpr r n) where - prettyPrec (LamExpr bs body) = - atPrec LowestPrec $ prettyLam (p bs <> ".") body - -instance IRRep r => Pretty (IxType r n) where - pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict - -instance IRRep r => Pretty (Dict r n) where - pretty = \case - DictCon con -> pretty con - StuckDict _ stuck -> pretty stuck - -instance IRRep r => Pretty (DictCon r n) where - pretty = \case - InstanceDict _ name args -> "Instance" <+> p name <+> p args - IxFin n -> "Ix (Fin" <+> p n <> ")" - DataData a -> "Data " <+> p a - IxRawFin n -> "Ix (RawFin " <> p n <> ")" - IxSpecialized d xs -> p d <+> p xs - -instance Pretty (DictType n) where - pretty = \case - DictType classSourceName _ params -> p classSourceName <+> spaced params - IxDictType ty -> "Ix" <+> p ty - DataDictType ty -> "Data" <+> p ty - -instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DepPairType r n) where - prettyPrec (DepPairType _ b rhs) = - atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [p b, p rhs] - -instance Pretty (CoreLamExpr n) where - pretty (CoreLamExpr _ lam) = p lam - -instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Atom r n) where - prettyPrec atom = case atom of - Con e -> prettyPrec e - Stuck _ e -> prettyPrec e - -instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Type r n) where - prettyPrec = \case - TyCon e -> prettyPrec e - StuckTy _ e -> prettyPrec e - -instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Stuck r n) where - prettyPrec = \case - Var v -> atPrec ArgPrec $ p v - StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v - StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs - StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v - InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) - SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i - PtrVar _ v -> atPrec ArgPrec $ p v - RepValAtom x -> atPrec LowestPrec $ pretty x - ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts - LiftSimp ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" - LiftSimpFun ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" - TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam - -instance Pretty (RepVal n) where - pretty (RepVal ty tree) = " p tree <+> ":" <+> p ty <> ">" - -instance Pretty a => Pretty (Tree a) where - pretty = \case - Leaf x -> pretty x - Branch xs -> pretty xs - -instance Pretty Projection where - pretty = \case - UnwrapNewtype -> "u" - ProjectProduct i -> p i - -forStr :: ForAnn -> Doc ann -forStr Fwd = "for" -forStr Rev = "rof" - -instance Pretty (CorePiType n) where - pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = - prettyBindersWithExpl expls bs <+> p appExpl <> prettyEff <> p resultTy - where - prettyEff = case eff of - Pure -> space - _ -> space <> pretty eff <> space - -prettyBindersWithExpl :: forall b n l ann. PrettyB b - => [Explicitness] -> Nest b n l -> Doc ann -prettyBindersWithExpl expls bs = do - let groups = groupByExpl $ zip expls (fromNest bs) - let groups' = case groups of [] -> [(Explicit, [])] - _ -> groups - mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] - -groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] -groupByExpl [] = [] -groupByExpl ((expl, b):bs) = do - let (matches, rest) = span (\(expl', _) -> expl == expl') bs - let matches' = map snd matches - (expl, b:matches') : groupByExpl rest - -withExplParens :: Explicitness -> Doc ann -> Doc ann -withExplParens Explicit x = parens x -withExplParens (Inferred _ Unify) x = braces $ x -withExplParens (Inferred _ (Synth _)) x = brackets x - -instance IRRep r => Pretty (TabPiType r n) where - pretty (TabPiType dict (b:>ty) body) = let - prettyBody = case body of - TyCon (Pi subpi) -> pretty subpi - _ -> pLowest body - prettyBinder = prettyBinderHelper (b:>ty) body - in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) - --- A helper to let us turn dict printing on and off. We mostly want it off to --- reduce clutter in prints and error messages, but when debugging synthesis we --- want it on. -prettyIxDict :: IRRep r => IxDict r n -> Doc ann -prettyIxDict dict = if False then " " <> p dict else mempty - -prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann -prettyBinderHelper (b:>ty) body = - if binderName b `isFreeIn` body - then parens $ p (b:>ty) - else p ty - -prettyLam :: Pretty a => Doc ann -> a -> Doc ann -prettyLam binders body = - group $ group (nest 4 $ binders) <> group (nest 2 $ p body) - -instance IRRep r => Pretty (EffectRow r n) where - pretty (EffectRow effs t) = - braces $ hsep (punctuate "," (map p (eSetToList effs))) <> p t - -instance IRRep r => Pretty (EffectRowTail r n) where - pretty = \case - NoTail -> mempty - EffectRowTail v -> "|" <> p v - -instance IRRep r => Pretty (Effect r n) where - pretty eff = case eff of - RWSEffect rws h -> p rws <+> p h - ExceptionEffect -> "Except" - IOEffect -> "IO" - InitEffect -> "Init" - -instance Pretty (UEffect n) where - pretty eff = case eff of - URWSEffect rws h -> p rws <+> p h - UExceptionEffect -> "Except" - UIOEffect -> "IO" - -instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty - -instance PrettyPrec (AtomVar r n) where - prettyPrec (AtomVar v _) = prettyPrec v -instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec - -instance IRRep r => Pretty (AtomBinding r n) where - pretty binding = case binding of - LetBound b -> p b - MiscBound t -> p t - SolverBound b -> p b - FFIFunBound s _ -> p s - NoinlineFun ty _ -> "Top function with type: " <+> p ty - TopDataBound (RepVal ty _) -> "Top data with type: " <+> p ty - -instance Pretty (SpecializationSpec n) where - pretty (AppSpecialization f (Abs bs (ListE args))) = - "Specialization" <+> p f <+> p bs <+> p args - -instance Pretty IxMethod where - pretty method = p $ show method - -instance Pretty (SolverBinding n) where - pretty (InfVarBound ty) = "Inference variable of type:" <+> p ty - pretty (SkolemBound ty) = "Skolem variable of type:" <+> p ty - pretty (DictBound ty) = "Dictionary variable of type:" <+> p ty - -instance Pretty (Binding c n) where - pretty b = case b of - -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` - -- TODO: can we avoid printing needing IRRep? Presumably it's related to - -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. - AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) - TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef - DataConBinding tyConName idx -> "Data constructor:" <+> - pretty tyConName <+> "Constructor index:" <+> pretty idx - ClassBinding classDef -> pretty classDef - InstanceBinding instanceDef _ -> pretty instanceDef - MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className - TopFunBinding f -> pretty f - FunObjCodeBinding _ -> "" - ModuleBinding _ -> "" - PtrBinding _ _ -> "" - SpecializedDictBinding _ -> "" - ImpNameBinding ty -> "Imp name of type: " <+> p ty - -instance Pretty (Module n) where - pretty m = prettyRecord - [ ("moduleSourceName" , p $ moduleSourceName m) - , ("moduleDirectDeps" , p $ S.toList $ moduleDirectDeps m) - , ("moduleTransDeps" , p $ S.toList $ moduleTransDeps m) - , ("moduleExports" , p $ moduleExports m) - , ("moduleSynthCandidates", p $ moduleSynthCandidates m) ] - -instance Pretty (TyConParams n) where - pretty (TyConParams _ _) = undefined - -instance Pretty (TyConDef n) where - pretty (TyConDef name _ bs cons) = "data" <+> p name <+> p bs <> pretty cons - -instance Pretty (DataConDefs n) where - pretty = undefined - -instance Pretty (DataConDef n) where - pretty (DataConDef name _ repTy _) = - p name <+> ":" <+> p repTy - -instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = - "Class:" <+> pretty classSourceName <+> pretty methodNames - <> indented ( - line <> "parameter binders:" <+> pretty params <> - line <> "superclasses:" <+> pretty superclasses <> - line <> "methods:" <+> pretty methodTys) - -instance Pretty ParamRole where - pretty r = p (show r) - -instance Pretty (InstanceDef n) where - pretty (InstanceDef className _ bs params _) = - "Instance" <+> p className <+> pretty bs <+> p params - -deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) - -instance Pretty (TopEnv n) where - pretty (TopEnv defs rules cache _ _) = - prettyRecord [ ("Defs" , p defs) - , ("Rules" , p rules) - , ("Cache" , p cache) ] - -instance Pretty (CustomRules n) where - pretty _ = "TODO: Rule printing" - -instance Pretty (ImportStatus n) where - pretty imports = pretty $ S.toList $ directImports imports - -instance Pretty (ModuleEnv n) where - pretty (ModuleEnv imports sm sc) = - prettyRecord [ ("Imports" , p imports) - , ("Source map" , p sm) - , ("Synth candidates", p sc) ] - -instance Pretty (Env n) where - pretty (Env env1 env2) = - prettyRecord [ ("Top env" , p env1) - , ("Module env", p env2)] - -prettyRecord :: [(String, Doc ann)] -> Doc ann -prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs - -instance Pretty SourceBlock where - pretty block = pretty $ ensureNewline (sbText block) where - -- Force the SourceBlock to end in a newline for echoing, even if - -- it was terminated with EOF in the original program. - ensureNewline t = case unsnoc t of - Nothing -> t - Just (_, '\n') -> t - _ -> t `snoc` '\n' - -instance Pretty Output where - pretty = \case - TextOut s -> pretty s - HtmlOut _ -> "" - SourceInfo _ -> "" - PassInfo _ s -> p s - MiscLog s -> p s - Error e -> p e - -instance Pretty PassName where - pretty x = p $ show x - -instance Pretty Result where - pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr - where maybeErr = case r of Failure err -> p err - Success () -> mempty - -instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UBinder' c n l) where - prettyPrec b = atPrec ArgPrec case b of - UBindSource v -> p v - UIgnore -> "_" - UBind v _ -> p v - -instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = p x -instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x - -instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = p x -instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x - -instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = p x -instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x - -instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = p x -instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x - -instance PrettyE e => Pretty (SourceNameOr e n) where - pretty (SourceName _ v) = p v - pretty (InternalName _ v _) = p v - -instance Pretty (SourceOrInternalName c n) where - pretty (SourceOrInternalName sn) = p sn - -instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (ULamExpr n) where - prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ - "\\" <> p bs <+> "." <+> indented (p body) - -instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPiExpr n) where - prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> pLowest ty - prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> p eff <+> pLowest ty - -instance Pretty Explicitness where - pretty expl = p (show expl) - -instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UTabPiExpr n) where - prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ - p pat <+> "=>" <+> pLowest ty - -instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UDepPairType n) where - -- TODO: print explicitness info - prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ - p pat <+> "&>" <+> pLowest ty - -instance Pretty (UBlock' n) where - pretty (UBlock decls result) = - prettyLines (fromNest decls) <> hardline <> pLowest result - -instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UExpr' n) where - prettyPrec expr = case expr of - ULit l -> prettyPrec l - UVar v -> atPrec ArgPrec $ p v - ULam lam -> prettyPrec lam - UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named - UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x - UFor dir (UForExpr binder body) -> - atPrec LowestPrec $ kw <+> p binder <> "." - <+> nest 2 (p body) - where kw = case dir of Fwd -> "for" - Rev -> "rof" - UPi piType -> prettyPrec piType - UTabPi piType -> prettyPrec piType - UDepPairTy depPairType -> prettyPrec depPairType - UDepPair lhs rhs -> atPrec ArgPrec $ parens $ - p lhs <+> ",>" <+> p rhs - UHole -> atPrec ArgPrec "_" - UTypeAnn v ty -> atPrec LowestPrec $ - group $ pApp v <> line <> ":" <+> pApp ty - UTabCon xs -> atPrec ArgPrec $ p xs - UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs - UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> - nest 2 (prettyLines alts) - UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f - UNatLit v -> atPrec ArgPrec $ p v - UIntLit v -> atPrec ArgPrec $ p v - UFloatLit v -> atPrec ArgPrec $ p v - UDo block -> atPrec LowestPrec $ p block - -instance Pretty FieldName' where - pretty = \case - FieldName s -> pretty s - FieldNum n -> pretty n - -instance Pretty (UAlt n) where - pretty (UAlt pat body) = p pat <+> "->" <+> p body - -instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons) = - "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm bs fields defs)) = - "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines fields <> prettyLines defs) - pretty (UInterface params methodTys interfaceName methodNames) = - "interface" <+> p params <+> p interfaceName - <> hardline <> foldMap (<>hardline) methods - where - methods = [ p b <> ":" <> p (unsafeCoerceE ty) - | (b, ty) <- zip (toList $ fromNest methodNames) methodTys] - pretty (UInstance className bs params methods (RightB UnitB) _) = - "instance" <+> p bs <+> p className <+> spaced params <+> - prettyLines methods - pretty (UInstance className bs params methods (LeftB v) _) = - "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params - <> prettyLines methods - pretty (ULocalDecl decl) = p decl - -instance Pretty (UDecl' n l) where - pretty (ULet ann b _ rhs) = - align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - pretty (UExprDecl expr) = p expr - pretty UPass = "pass" - -instance Pretty (UEffectRow n) where - pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (p <$> toList x) - pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (p <$> toList x)) <+> "|" <+> p y <> "}" - -prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann -prettyBinderNest bs = nest 6 $ line' <> (sep $ map p $ fromNest bs) - -instance Pretty (UDataDefTrail n) where - pretty (UDataDefTrail bs) = p $ fromNest bs - -instance Pretty (UAnnBinder n l) where - pretty (UAnnBinder _ b ty _) = p b <> ":" <> p ty - -instance Pretty (UAnn n) where - pretty (UAnn ty) = ":" <> p ty - pretty UNoAnn = mempty - -instance Pretty (UMethodDef' n) where - pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs - -instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPat' n l) where - prettyPrec pat = case pat of - UPatBinder x -> atPrec ArgPrec $ p x - UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (fromNest xs) - UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y - UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (fromNest pats) - UPatTable pats -> atPrec ArgPrec $ p pats +instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty spaced :: (Foldable f, Pretty a) => f a -> Doc ann -spaced xs = hsep $ map p $ toList xs +spaced xs = hsep $ map pretty $ toList xs commaSep :: (Foldable f, Pretty a) => f a -> Doc ann -commaSep xs = fold $ punctuate "," $ map p $ toList xs - -instance Pretty (EnvFrag n l) where - pretty (EnvFrag bindings) = p bindings - -instance Pretty (Cache n) where - pretty (Cache _ _ _ _ _ _) = "" -- TODO - -instance Pretty (SynthCandidates n) where - pretty scs = "instance dicts:" <+> p (M.toList $ instanceDicts scs) - -instance Pretty (LoadedModules n) where - pretty _ = "" +commaSep xs = fold $ punctuate "," $ map pretty $ toList xs indented :: Doc ann -> Doc ann indented doc = nest 2 (hardline <> doc) <> hardline - --- ==== Imp IR === - -instance Pretty (IExpr n) where - pretty (ILit v) = p v - pretty (IVar v _) = p v - pretty (IPtrVar v _) = p v - -instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty (ImpDecl n l) where - pretty (ImpLet Empty instr) = p instr - pretty (ImpLet (Nest b Empty) instr) = p b <+> "=" <+> p instr - pretty (ImpLet bs instr) = p bs <+> "=" <+> p instr - -instance Pretty IFunType where - pretty (IFunType cc argTys retTys) = - "Fun" <+> p cc <+> p argTys <+> "->" <+> p retTys - -instance Pretty (TopFunDef n) where - pretty = \case - Specialization s -> p s - LinearizationPrimal _ -> "" - LinearizationTangent _ -> "" - -instance Pretty (TopFun n) where - pretty = \case - DexTopFun def lam lowering -> - "Top-level Function" - <> hardline <+> "definition:" <+> pretty def - <> hardline <+> "lambda:" <+> pretty lam - <> hardline <+> "lowering:" <+> pretty lowering - FFITopFun f _ -> p f - -instance IRRep r => Pretty (TopLam r n) where - pretty (TopLam _ _ lam) = pretty lam - -instance Pretty a => Pretty (EvalStatus a) where - pretty = \case - Waiting -> "" - Running -> "" - Finished a -> pretty a - -instance Pretty (ImpFunction n) where - pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = - "impfun" <+> p cc <+> prettyBinderNest bs - <> nest 2 (hardline <> p body) <> hardline - -instance Pretty (ImpBlock n) where - pretty (ImpBlock Empty []) = mempty - pretty (ImpBlock Empty expr) = group $ line <> pLowest expr - pretty (ImpBlock decls []) = prettyLines $ fromNest decls - pretty (ImpBlock decls expr) = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -instance Pretty (IBinder n l) where - pretty (IBinder b ty) = p b <+> ":" <+> p ty - -instance Pretty (ImpInstr n) where - pretty = \case - IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> - nest 4 (p block) - IWhile body -> "while" <+> nest 2 (p body) - ICond predicate cons alt -> - "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> - hardline <> "else" <> nest 2 (p alt) - IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s - ILaunch f size args -> - "launch" <+> p f <+> p size <+> spaced args - ICastOp t x -> "cast" <+> p x <+> "to" <+> p t - IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t - Store dest val -> "store" <+> p dest <+> p val - Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" - StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" - MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel - InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel - GetAllocSize ptr -> "getAllocSize" <+> p ptr - Free ptr -> "free" <+> p ptr - ISyncWorkgroup -> "syncWorkgroup" - IThrowError -> "throwError" - ICall f args -> "call" <+> p f <+> p args - IVectorBroadcast v _ -> "vbroadcast" <+> p v - IVectorIota _ -> "viota" - DebugPrint s x -> "debug_print" <+> p (show s) <+> p x - IPtrLoad ptr -> "load" <+> p ptr - IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx - IBinOp op x y -> opDefault (UBinOp op) [x, y] - IUnOp op x -> opDefault (UUnOp op) [x] - ISelect x y z -> "select" <+> p x <+> p y <+> p z - IOutputStream -> "outputStream" - IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x - where opDefault name xs = prettyOpDefault name xs $ AppPrec - -sizeStr :: IExpr n -> Doc ann -sizeStr s = case s of - ILit (Word32Lit x) -> p x -- print in decimal because it's more readable - _ -> p s - -instance Pretty BaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec BaseType where - prettyPrec b = case b of - Scalar sb -> prettyPrec sb - Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (p <$> shape) ++ [p sb] - PtrType ty -> atPrec AppPrec $ "Ptr" <+> p ty - -instance Pretty AddressSpace where pretty d = p (show d) - -instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec ScalarBaseType where - prettyPrec sb = atPrec ArgPrec $ case sb of - Int64Type -> "Int64" - Int32Type -> "Int32" - Float64Type -> "Float64" - Float32Type -> "Float32" - Word8Type -> "Word8" - Word32Type -> "Word32" - Word64Type -> "Word64" - -instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (TyCon r n) where - prettyPrec con = case con of - BaseType b -> prettyPrec b - ProdType [] -> atPrec ArgPrec $ "()" - ProdType as -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pApp as - SumType cs -> atPrec ArgPrec $ align $ group $ - encloseSep "(|" "|)" " | " $ fmap pApp cs - RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a - TypeKind -> atPrec ArgPrec "Type" - HeapType -> atPrec ArgPrec "Heap" - Pi piType -> atPrec LowestPrec $ align $ p piType - TabPi piType -> atPrec LowestPrec $ align $ p piType - DepPairTy ty -> prettyPrec ty - DictTy t -> atPrec LowestPrec $ p t - NewtypeTyCon con' -> prettyPrec con' - -prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann -prettyPrecNewtype con x = case (con, x) of - (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n - (_, x') -> prettyPrec x' - -instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (NewtypeTyCon n) where - prettyPrec = \case - Nat -> atPrec ArgPrec $ "Nat" - Fin n -> atPrec AppPrec $ "Fin" <+> pArg n - EffectRowKind -> atPrec ArgPrec "EffKind" - UserADTType "RangeTo" _ (TyConParams _ [i]) -> atPrec LowestPrec $ ".." <> pApp i - UserADTType "RangeToExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ "..<" <> pApp i - UserADTType "RangeFrom" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> ".." - UserADTType "RangeFromExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> "<.." - UserADTType name _ (TyConParams infs params) -> case (infs, params) of - ([], []) -> atPrec ArgPrec $ p name - ([Explicit, Explicit], [l, r]) - | Just sym <- fromInfix (fromString $ pprint name) -> - atPrec ArgPrec $ align $ group $ - parens $ flatAlt " " "" <> pApp l <> line <> p sym <+> pApp r - _ -> atPrec LowestPrec $ pAppArg (p name) $ ignoreSynthParams (TyConParams infs params) - -instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Con r n) where - prettyPrec = \case - Lit l -> prettyPrec l - ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" - ProdCon xs -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pLowest xs - SumCon _ tag payload -> atPrec ArgPrec $ - "(" <> p tag <> "|" <+> pApp payload <+> "|)" - HeapVal -> atPrec ArgPrec "HeapValue" - Lam lam -> atPrec LowestPrec $ p lam - DepPair x y _ -> atPrec ArgPrec $ align $ group $ - parens $ p x <+> ",>" <+> p y - Eff e -> atPrec ArgPrec $ p e - DictConAtom d -> atPrec LowestPrec $ p d - NewtypeCon con x -> prettyPrecNewtype con x - TyConAtom ty -> prettyPrec ty - -instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (PrimOp r n) where - prettyPrec = \case - MemOp op -> prettyPrec op - VectorOp op -> prettyPrec op - DAMOp op -> prettyPrec op - Hof (TypedHof _ hof) -> prettyPrec hof - RefOp ref eff -> atPrec LowestPrec case eff of - MAsk -> "ask" <+> pApp ref - MExtend _ x -> "extend" <+> pApp ref <+> pApp x - MGet -> "get" <+> pApp ref - MPut x -> pApp ref <+> ":=" <+> pApp x - IndexRef _ i -> pApp ref <+> "!" <+> pApp i - ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i - UnOp op x -> prettyOpDefault (UUnOp op) [x] - BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] - MiscOp op -> prettyOpGeneric op - -instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (MemOp r n) where - prettyPrec = \case - PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx - PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] - op -> prettyOpGeneric op - -instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (VectorOp r n) where - prettyPrec = \case - VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty - VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty - VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty - VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i - -prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann -prettyOpDefault name args = - case length args of - 0 -> atPrec ArgPrec primName - _ -> atPrec AppPrec $ pAppArg primName args - where primName = p name - -prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann -prettyOpGeneric op = case fromEGenericOpRep op of - GenericOpRep op' [] [] [] -> atPrec ArgPrec (p $ show op') - GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (p (show op')) xs <+> p ts <+> p lams - -instance Pretty PrimName where - pretty primName = p $ "%" ++ showPrimName primName - -instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Hof r n) where - prettyPrec hof = atPrec LowestPrec case hof of - For _ _ lam -> "for" <+> pLowest lam - While body -> "while" <+> pArg body - RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) - RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) - RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) - RunIO body -> "runIO" <+> pArg body - RunInit body -> "runInit" <+> pArg body - CatchException _ body -> "catchException" <+> pArg body - Linearize body x -> "linearize" <+> pArg body <+> pArg x - Transpose body x -> "transpose" <+> pArg body <+> pArg x - -instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DAMOp r n) where - prettyPrec op = atPrec LowestPrec case op of - Seq _ ann _ c lamExpr -> case lamExpr of - UnaryLamExpr b body -> do - "seq" <+> pApp ann <+> pApp c <+> prettyLam (p b <> ".") body - _ -> p (show op) -- shouldn't happen, but crashing pretty printers make debugging hard - RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y - Place r v -> pApp r <+> "r:=" <+> pApp v - Freeze r -> "freeze" <+> pApp r - AllocDest ty -> "alloc" <+> pApp ty - -instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (BaseMonoid r n) where - prettyPrec (BaseMonoid x f) = - atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) - -instance PrettyPrec Direction where - prettyPrec d = atPrec ArgPrec $ case d of - Fwd -> "fwd" - Rev -> "rev" - -printDouble :: Double -> Doc ann -printDouble x = p (double2Float x) - -printFloat :: Float -> Doc ann -printFloat x = p $ reverse $ dropWhile (=='0') $ reverse $ - showFFloat (Just 6) x "" - -instance Pretty LitVal where pretty = prettyFromPrettyPrec -instance PrettyPrec LitVal where - prettyPrec (Int64Lit x) = atPrec ArgPrec $ p x - prettyPrec (Int32Lit x) = atPrec ArgPrec $ p x - prettyPrec (Float64Lit x) = atPrec ArgPrec $ printDouble x - prettyPrec (Float32Lit x) = atPrec ArgPrec $ printFloat x - prettyPrec (Word8Lit x) = atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x - prettyPrec (Word32Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (Word64Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (PtrLit ty (PtrLitVal x)) = - atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) - prettyPrec (PtrLit _ NullPtr) = atPrec ArgPrec $ "NullPtr" - prettyPrec (PtrLit _ (PtrSnapshot _)) = atPrec ArgPrec "" - -instance Pretty CallingConvention where - pretty = p . show - -instance Pretty LetAnn where - pretty ann = case ann of - PlainLet -> "" - InlineLet -> "%inline" - NoInlineLet -> "%noinline" - LinearLet -> "%linear" - OccInfoPure u -> p u <> line - OccInfoImpure u -> p u <> ", impure" <> line - -instance Pretty UsageInfo where - pretty (UsageInfo static (ixDepth, ct)) = - "occurs in" <+> p static <+> "places, read" - <+> p ct <+> "times, to depth" <+> p (show ixDepth) - -instance Pretty Count where - pretty (Bounded ct) = "<=" <+> pretty ct - pretty Occ.Unbounded = "many" - -instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty RWS where - pretty eff = case eff of - Reader -> "Read" - Writer -> "Accum" - State -> "State" - -printOutput :: Bool -> Output -> String -printOutput isatty out = case out of - Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out - _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out - -addPrefix :: String -> String -> String -addPrefix prefix str = unlines $ map prefixLine $ lines str - where prefixLine :: String -> String - prefixLine s = case s of "" -> prefix - _ -> prefix ++ " " ++ s - -addColor :: Bool -> ANSI.Color -> String -> String -addColor False _ s = s -addColor True c s = - setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] - ++ s ++ setSGRCode [Reset] - --- === Concrete syntax rendering === - -instance Pretty SourceBlock' where - pretty (TopDecl decl) = p decl - pretty d = fromString $ show d - -instance Pretty CTopDecl where - pretty (CSDecl ann decl) = annDoc <> p decl - where annDoc = case ann of - PlainLet -> mempty - _ -> p ann <> " " - pretty d = fromString $ show d - -instance Pretty CSDecl where - pretty = undefined - -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk - -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk - -- pretty (CDefDecl (CDef name args maybeAnn blk)) = - -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc - -- <> nest 2 (hardline <> p blk) - -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty - -- Nothing -> mempty - -- pretty (CInstance header givens methods name) = - -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where - -- name' = case name of - -- Nothing -> "instance " - -- (Just n) -> "named-instance " <> p n <> " " - -- pretty (CExpr e) = p e - -instance Pretty AppExplicitness where - pretty ExplicitApp = "->" - pretty ImplicitApp = "->>" - -instance Pretty CSBlock where - pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls - pretty (ExprBlock g) = pArg g - -instance Pretty Group where pretty = prettyFromPrettyPrec -instance PrettyPrec Group where - prettyPrec = undefined - -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n - -- prettyPrec (CPrim prim args) = prettyOpDefault prim args - -- prettyPrec (CParens blk) = - -- atPrec ArgPrec $ "(" <> p blk <> ")" - -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g - -- prettyPrec (CBin op lhs rhs) = - -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs - -- prettyPrec (CLambda args body) = - -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body - -- prettyPrec (CCase scrut alts) = - -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts - -- prettyPrec g = atPrec ArgPrec $ fromString $ show g - -instance Pretty Bin where - pretty (EvalBinOp name) = pretty name - pretty DepAmpersand = "&>" - pretty Dot = "." - pretty DepComma = ",>" - pretty Colon = ":" - pretty DoubleColon = "::" - pretty Dollar = "$" - pretty ImplicitArrow = "->>" - pretty FatArrow = "=>" - pretty Pipe = "|" - pretty CSEqual = "=" diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 79214c57a..a6e3b4b52 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -15,6 +15,7 @@ import Data.Functor ((<&>)) import Types.Primitives import Types.Core import Types.Source +import Types.Top import Types.Imp import IRVariants import Core diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 45a080f79..153ed5449 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -8,6 +8,7 @@ module QueryTypePure where import Types.Primitives import Types.Core +import Types.Top import IRVariants import Name diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 1bac0c11c..20730e332 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -29,7 +29,7 @@ import Err import MonadUtil import PPrint () -import Types.Core hiding (DexDestructor) +import Types.Top hiding (DexDestructor) import Types.Source hiding (CInt) import Types.Primitives diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 129039bdb..789ccb59c 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -31,6 +31,7 @@ import RuntimePrint import Transpose import Types.Core import Types.Source +import Types.Top import Types.Primitives import Util (enumerate) diff --git a/src/lib/Simplify.hs-boot b/src/lib/Simplify.hs-boot index c14ae648a..8e1499c3d 100644 --- a/src/lib/Simplify.hs-boot +++ b/src/lib/Simplify.hs-boot @@ -9,5 +9,6 @@ module Simplify (linearizeTopFun) where import Name import Builder import Types.Core +import Types.Top linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index ee2b6f9f4..c6b68d82d 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -23,7 +23,7 @@ import PPrint () import IRVariants import Types.Source import Types.Primitives -import Types.Core (Env (..), ModuleEnv (..)) +import Types.Top (Env (..), ModuleEnv (..)) renameSourceNamesTopUDecl :: (Fallible1 m, EnvReader m) diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 06265b78a..b8124d360 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -19,6 +19,7 @@ import Name import MTL1 import IRVariants import Types.Core +import Types.Top import Core import qualified RawName as R import QueryTypePure diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 78dafb62b..69dc417dd 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -63,7 +63,6 @@ import Subst import Name import OccAnalysis import Optimize -import PPrint (pprintCanonicalized) import Paths_dex (getDataFileName) import QueryType import Runtime @@ -75,6 +74,7 @@ import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top import Util ( Tree (..), File (..), readFileWithHash) import Vectorize diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 10c87d377..e35305bc6 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -20,6 +20,7 @@ import Name import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) @@ -36,9 +37,7 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a runTransposeM cont = runSubstReaderT idSubst $ cont -transposeTopFun - :: (MonadFail1 m, EnvReader m) - => STopLam n -> m n (STopLam n) +transposeTopFun :: (MonadFail1 m, EnvReader m) => STopLam n -> m n (STopLam n) transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 0475f6ac8..daee75118 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -4,20 +4,8 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE StrictData #-} -- Core data types for CoreIR and its variations. @@ -25,19 +13,20 @@ module Types.Core (module Types.Core, SymbolicZeros (..)) where import Data.Word import Data.Maybe (fromJust) -import Data.Functor +import Data.Foldable (toList) import Data.Hashable -import Data.Text.Prettyprint.Doc hiding (nest) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc +import Data.Text (Text, unsnoc, uncons) import qualified Data.Map.Strict as M -import qualified Data.Set as S import GHC.Generics (Generic (..)) import Data.Store (Store (..)) -import Foreign.Ptr import Name -import Util (FileHash, SnocList (..), Tree (..)) +import Util (Tree (..)) import IRVariants +import PPrint import qualified Types.OpNames as P import Types.Primitives @@ -141,6 +130,9 @@ data BaseMonoid r n = , baseCombine :: LamExpr r n } deriving (Show, Generic) +data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) + deriving (Show, Generic) + data DeclBinding r n = DeclBinding LetAnn (Expr r n) deriving (Show, Generic) data Decl (r::IR) (n::S) (l::S) = Let (AtomNameBinder r n l) (DeclBinding r n) @@ -204,11 +196,6 @@ data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] type WithDecls (r::IR) = Abs (Decls r) :: E -> E type Block (r::IR) = WithDecls r (Expr r) :: E -type TopBlock = TopLam -- used for nullary lambda -type IsDestLam = Bool -data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) - deriving (Show, Generic) - data LamExpr (r::IR) (n::S) where LamExpr :: Nest (Binder r) n l -> Expr r l -> LamExpr r n @@ -282,9 +269,6 @@ instance ToBindersAbs TyConDef DataConDefs CoreIR where instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where toAbs (ClassDef _ _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) -instance ToBindersAbs (TopLam r) (Expr r) r where - toAbs (TopLam _ _ lam) = toAbs lam - -- === GenericOp class === class GenericOp (e::IR->E) where @@ -434,7 +418,6 @@ type CDecl = Decl CoreIR type CDecls = Decls CoreIR type CAtomName = AtomName CoreIR type CAtomVar = AtomVar CoreIR -type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR @@ -449,7 +432,6 @@ type SAtomName = AtomName SimpIR type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR type SLam = LamExpr SimpIR -type STopLam = TopLam SimpIR -- === newtypes === @@ -522,174 +504,6 @@ data DictCon (r::IR) (n::S) where IxRawFin :: Atom r n -> DictCon r n IxSpecialized :: SpecDictName n -> [SAtom n] -> DictCon SimpIR n --- TODO: Use an IntMap -newtype CustomRules (n::S) = - CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } - deriving (Semigroup, Monoid, Store) -data AtomRules (n::S) = - -- number of implicit args, number of explicit args, linearization function - CustomLinearize Int Int SymbolicZeros (CAtom n) - deriving (Generic) - --- === Runtime representations === - -data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) - deriving (Show, Generic) - --- === envs and modules === - --- `ModuleEnv` contains data that only makes sense in the context of evaluating --- a particular module. `TopEnv` contains everything that makes sense "between" --- evaluating modules. -data Env n = Env - { topEnv :: {-# UNPACK #-} TopEnv n - , moduleEnv :: {-# UNPACK #-} ModuleEnv n } - deriving (Generic) - -data TopEnv (n::S) = TopEnv - { envDefs :: RecSubst Binding n - , envCustomRules :: CustomRules n - , envCache :: Cache n - , envLoadedModules :: LoadedModules n - , envLoadedObjects :: LoadedObjects n } - deriving (Generic) - -data SerializedEnv n = SerializedEnv - { serializedEnvDefs :: RecSubst Binding n - , serializedEnvCustomRules :: CustomRules n - , serializedEnvCache :: Cache n } - deriving (Generic) - --- TODO: consider splitting this further into `ModuleEnv` (the env that's --- relevant between top-level decls) and `LocalEnv` (the additional parts of the --- env that's relevant under a lambda binder). Unlike the Top/Module --- distinction, there's some overlap. For example, instances can be defined at --- both the module-level and local level. Similarly, if we start allowing --- top-level effects in `Main` then we'll have module-level effects and local --- effects. -data ModuleEnv (n::S) = ModuleEnv - { envImportStatus :: ImportStatus n - , envSourceMap :: SourceMap n - , envSynthCandidates :: SynthCandidates n } - deriving (Generic) - -data Module (n::S) = Module - { moduleSourceName :: ModuleSourceName - , moduleDirectDeps :: S.Set (ModuleName n) - , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself - , moduleExports :: SourceMap n - -- these are just the synth candidates required by this - -- module by itself. We'll usually also need those required by the module's - -- (transitive) dependencies, which must be looked up separately. - , moduleSynthCandidates :: SynthCandidates n } - deriving (Show, Generic) - -data LoadedModules (n::S) = LoadedModules - { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} - deriving (Show, Generic) - -emptyModuleEnv :: ModuleEnv n -emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty - -emptyLoadedModules :: LoadedModules n -emptyLoadedModules = LoadedModules mempty - -data LoadedObjects (n::S) = LoadedObjects - -- the pointer points to the actual runtime function - { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} - deriving (Show, Generic) - -emptyLoadedObjects :: LoadedObjects n -emptyLoadedObjects = LoadedObjects mempty - -data ImportStatus (n::S) = ImportStatus - { directImports :: S.Set (ModuleName n) - -- XXX: This are cached for efficiency. It's derivable from `directImports`. - , transImports :: S.Set (ModuleName n) } - deriving (Show, Generic) - -data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) - -data TopEnvUpdate n = - ExtendCache (Cache n) - | AddCustomRule (CAtomName n) (AtomRules n) - | UpdateLoadedModules ModuleSourceName (ModuleName n) - | UpdateLoadedObjects (FunObjCodeName n) NativeFunction - | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) - | UpdateInstanceDef (InstanceName n) (InstanceDef n) - | UpdateTyConDef (TyConName n) (TyConDef n) - | UpdateFieldDef (TyConName n) SourceName (CAtomName n) - --- TODO: we could add a lot more structure for querying by dict type, caching, etc. -data SynthCandidates n = SynthCandidates - { instanceDicts :: M.Map (ClassName n) [InstanceName n] - , ixInstances :: [InstanceName n] } - deriving (Show, Generic) - -emptyImportStatus :: ImportStatus n -emptyImportStatus = ImportStatus mempty mempty - --- TODO: figure out the additional top-level context we need -- backend, other --- compiler flags etc. We can have a map from those to this. - -data Cache (n::S) = Cache - { specializationCache :: EMap SpecializationSpec TopFunName n - , ixDictCache :: EMap AbsDict SpecDictName n - , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n - , transpositionCache :: EMap TopFunName TopFunName n - -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we - -- only want to store one entry per module name as a simple cache eviction - -- policy, so we store it keyed on the module name, with the text hash for - -- the validity check. - , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) - , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) - } deriving (Show, Generic) - --- === runtime function and variable representations === - -type RuntimeEnv = DynamicVarKeyPtrs - -type DexDestructor = FunPtr (IO ()) - -data NativeFunction = NativeFunction - { nativeFunPtr :: FunPtr () - , nativeFunTeardown :: IO () } - -instance Show NativeFunction where - show _ = "" - --- Holds pointers to thread-local storage used to simulate dynamically scoped --- variables, such as the output stream file descriptor. -type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] - -data DynamicVar = OutStreamDyvar -- TODO: add others as needed - deriving (Enum, Bounded) - -dynamicVarCName :: DynamicVar -> String -dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" - -dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] -dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) - --- === bindings - static information we carry about a lexical scope === - --- TODO: consider making this an open union via a typeable-like class -data Binding (c::C) (n::S) where - AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n - TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n - DataConBinding :: TyConName n -> Int -> Binding DataConNameC n - ClassBinding :: ClassDef n -> Binding ClassNameC n - InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n - MethodBinding :: ClassName n -> Int -> Binding MethodNameC n - TopFunBinding :: TopFun n -> Binding TopFunNameC n - FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n - ModuleBinding :: Module n -> Binding ModuleNameC n - -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` - PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n - SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n - ImpNameBinding :: BaseType -> Binding ImpNameC n data EffectOpDef (n::S) where EffectOpDef :: EffectName n -- name of associated effect @@ -748,108 +562,6 @@ instance RenameE EffectOpType deriving instance Show (EffectOpType n) deriving via WrapE EffectOpType n instance Generic (EffectOpType n) -instance GenericE SpecializedDictDef where - type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) - fromE (SpecializedDict ab methods) = ab `PairE` methods' - where methods' = case methods of Just xs -> LeftE (ListE xs) - Nothing -> RightE UnitE - {-# INLINE fromE #-} - toE (ab `PairE` methods) = SpecializedDict ab methods' - where methods' = case methods of LeftE (ListE xs) -> Just xs - RightE UnitE -> Nothing - {-# INLINE toE #-} - -instance SinkableE SpecializedDictDef -instance HoistableE SpecializedDictDef -instance AlphaEqE SpecializedDictDef -instance AlphaHashableE SpecializedDictDef -instance RenameE SpecializedDictDef - -data EvalStatus a = Waiting | Running | Finished a - deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) -type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) - -data TopFun (n::S) = - DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) - | FFITopFun String IFunType - deriving (Show, Generic) - -data TopFunDef (n::S) = - Specialization (SpecializationSpec n) - | LinearizationPrimal (LinearizationSpec n) - -- Tangent functions all take some number of nonlinear args, then a *single* - -- linear arg. This is so that transposition can be an involution - you apply - -- it twice and you get back to the original function. - | LinearizationTangent (LinearizationSpec n) - deriving (Show, Generic) - -newtype TopFunLowerings (n::S) = TopFunLowerings - { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed - deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) - -data AtomBinding (r::IR) (n::S) where - LetBound :: DeclBinding r n -> AtomBinding r n - MiscBound :: Type r n -> AtomBinding r n - TopDataBound :: RepVal n -> AtomBinding SimpIR n - SolverBound :: SolverBinding n -> AtomBinding CoreIR n - NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n - FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n - -deriving instance IRRep r => Show (AtomBinding r n) -deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) - --- name of function, name of arg -type InferenceArgDesc = (String, String) -data InfVarDesc = - ImplicitArgInfVar InferenceArgDesc - | AnnotationInfVar String -- name of binder - | TypeInstantiationInfVar String -- name of type - | MiscInfVar - deriving (Show, Generic, Eq, Ord) - -data SolverBinding (n::S) = - InfVarBound (CType n) - | SkolemBound (CType n) - | DictBound (CType n) - deriving (Show, Generic) - -newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) - deriving (OutFrag) - -instance HasScope Env where - toScope = toScope . envDefs . topEnv - -instance OutMap Env where - emptyOutMap = - Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) - emptyModuleEnv - {-# INLINE emptyOutMap #-} - -instance ExtOutMap Env (RecSubstFrag Binding) where - -- TODO: We might want to reorganize this struct to make this - -- do less explicit sinking etc. It's a hot operation! - extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = - withExtEvidence frag $ Env - (TopEnv - (defs `extendRecSubst` frag) - (sink rules) - (sink cache) - (sink loadedM) - (sink loadedO)) - (sink moduleEnv) - {-# INLINE extendOutMap #-} - -instance ExtOutMap Env EnvFrag where - extendOutMap = extendEnv - {-# INLINE extendOutMap #-} - -extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l -extendEnv env (EnvFrag newEnv) = do - case extendOutMap env newEnv of - Env envTop (ModuleEnv imports sm scs) -> do - Env envTop (ModuleEnv imports sm scs) -{-# NOINLINE [1] extendEnv #-} - -- === effects === data Effect (r::IR) (n::S) = @@ -906,31 +618,6 @@ instance IRRep r => Store (EffectRowTail r n) instance IRRep r => Store (EffectRow r n) instance IRRep r => Store (Effect r n) --- === Specialization and generalization === - -type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) -type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e -type AbsDict = Abstracted CoreIR (Dict CoreIR) - -data SpecializedDictDef n = - SpecializedDict - (AbsDict n) - -- Methods (thunked if nullary), if they're available. - -- We create specialized dict names during simplification, but we don't - -- actually simplify/lower them until we return to TopLevel - (Maybe [TopLam SimpIR n]) - deriving (Show, Generic) - --- TODO: extend with AD-oriented specializations, backend-specific specializations etc. -data SpecializationSpec (n::S) = - AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) - deriving (Show, Generic) - -type Active = Bool -data LinearizationSpec (n::S) = - LinearizationSpec (TopFunName n) [Active] - deriving (Show, Generic) - -- === Binder utils === binderType :: Binder r n l -> Type r n @@ -946,39 +633,6 @@ bindersVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : bindersVars bs --- === ToBinding === - -atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n -atomBindingToBinding b = AtomNameBinding b - -bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n -bindingToAtomBinding (AtomNameBinding b) = b - -class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where - toBinding :: e n -> Binding c n - -instance Color c => ToBinding (Binding c) c where - toBinding = id - -instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where - toBinding = atomBindingToBinding - -instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where - toBinding = toBinding . LetBound - -instance IRRep r => ToBinding (Type r) (AtomNameC r) where - toBinding = toBinding . MiscBound - -instance ToBinding SolverBinding (AtomNameC CoreIR) where - toBinding = toBinding . SolverBound - -instance IRRep r => ToBinding (IxType r) (AtomNameC r) where - toBinding (IxType t _) = toBinding t - -instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where - toBinding (LeftE e) = toBinding e - toBinding (RightE e) = toBinding e - -- === ToAtom === class ToAtom (e::E) (r::IR) | e -> r where @@ -1168,15 +822,6 @@ pattern TrueAtom = Con (Lit (Word8Lit 1)) -- === Typeclass instances for Name and other Haskell libraries === -instance GenericE AtomRules where - type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom - fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a - toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a -instance SinkableE AtomRules -instance HoistableE AtomRules -instance AlphaEqE AtomRules -instance RenameE AtomRules - instance GenericE RepVal where type RepE RepVal= PairE SType (ComposeE Tree IExpr) fromE (RepVal ty tree) = ty `PairE` ComposeE tree @@ -1188,15 +833,6 @@ instance HoistableE RepVal instance AlphaHashableE RepVal instance AlphaEqE RepVal -instance GenericE CustomRules where - type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) - fromE (CustomRules m) = ListE $ toPairE <$> M.toList m - toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l -instance SinkableE CustomRules -instance HoistableE CustomRules -instance AlphaEqE CustomRules -instance RenameE CustomRules - instance GenericE TyConParams where type RepE TyConParams = PairE (LiftE [Explicitness]) (ListE CAtom) fromE (TyConParams infs xs) = PairE (LiftE infs) (ListE xs) @@ -2014,45 +1650,6 @@ instance IRRep r => AlphaEqE (DictCon r) instance IRRep r => AlphaHashableE (DictCon r) instance IRRep r => RenameE (DictCon r) -instance GenericE Cache where - type RepE Cache = - EMap SpecializationSpec TopFunName - `PairE` EMap AbsDict SpecDictName - `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) - `PairE` EMap TopFunName TopFunName - `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) - `PairE` ListE ( LiftE ModuleSourceName - `PairE` LiftE FileHash - `PairE` ListE ModuleName - `PairE` ModuleName) - fromE (Cache x y z w parseCache evalCache) = - x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` - ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] - {-# INLINE fromE #-} - toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = - Cache x y z w parseCache - (M.fromList - [(sourceName, ((hashVal, deps), result)) - | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - <- evalCache]) - {-# INLINE toE #-} - -instance SinkableE Cache -instance HoistableE Cache -instance AlphaEqE Cache -instance RenameE Cache -instance Store (Cache n) - -instance Monoid (Cache n) where - mempty = Cache mempty mempty mempty mempty mempty mempty - mappend = (<>) - -instance Semigroup (Cache n) where - -- right-biased instead of left-biased - Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = - Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) - instance GenericE (LamExpr r) where type RepE (LamExpr r) = Abs (Nest (Binder r)) (Expr r) fromE (LamExpr b block) = Abs b block @@ -2167,228 +1764,6 @@ instance IRRep r => RenameE (DepPairType r) deriving instance IRRep r => Show (DepPairType r n) deriving via WrapE (DepPairType r) n instance IRRep r => Generic (DepPairType r n) -instance GenericE SynthCandidates where - type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) - `PairE` ListE InstanceName - fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys - where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) - {-# INLINE fromE #-} - toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys - where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs - {-# INLINE toE #-} - -instance SinkableE SynthCandidates -instance HoistableE SynthCandidates -instance AlphaEqE SynthCandidates -instance AlphaHashableE SynthCandidates -instance RenameE SynthCandidates - -instance IRRep r => GenericE (AtomBinding r) where - type RepE (AtomBinding r) = - EitherE2 (EitherE3 - (DeclBinding r) -- LetBound - (Type r) -- MiscBound - (WhenCore r SolverBinding) -- SolverBound - ) (EitherE3 - (WhenCore r (PairE CType CAtom)) -- NoinlineFun - (WhenSimp r RepVal) -- TopDataBound - (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound - ) - - fromE = \case - LetBound x -> Case0 $ Case0 x - MiscBound x -> Case0 $ Case1 x - SolverBound x -> Case0 $ Case2 $ WhenIRE x - NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x - TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal - FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v - {-# INLINE fromE #-} - - toE = \case - Case0 x' -> case x' of - Case0 x -> LetBound x - Case1 x -> MiscBound x - Case2 (WhenIRE x) -> SolverBound x - _ -> error "impossible" - Case1 x' -> case x' of - Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x - Case1 (WhenIRE repVal) -> TopDataBound repVal - Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v - _ -> error "impossible" - _ -> error "impossible" - {-# INLINE toE #-} - - -instance IRRep r => SinkableE (AtomBinding r) -instance IRRep r => HoistableE (AtomBinding r) -instance IRRep r => RenameE (AtomBinding r) -instance IRRep r => AlphaEqE (AtomBinding r) -instance IRRep r => AlphaHashableE (AtomBinding r) - -instance GenericE TopFunDef where - type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec - fromE = \case - Specialization s -> Case0 s - LinearizationPrimal s -> Case1 s - LinearizationTangent s -> Case2 s - {-# INLINE fromE #-} - toE = \case - Case0 s -> Specialization s - Case1 s -> LinearizationPrimal s - Case2 s -> LinearizationTangent s - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE TopFunDef -instance HoistableE TopFunDef -instance RenameE TopFunDef -instance AlphaEqE TopFunDef -instance AlphaHashableE TopFunDef - -instance IRRep r => GenericE (TopLam r) where - type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r - fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y - {-# INLINE fromE #-} - toE (LiftE d `PairE` x `PairE` y) = TopLam d x y - {-# INLINE toE #-} - -instance IRRep r => SinkableE (TopLam r) -instance IRRep r => HoistableE (TopLam r) -instance IRRep r => RenameE (TopLam r) -instance IRRep r => AlphaEqE (TopLam r) -instance IRRep r => AlphaHashableE (TopLam r) - -instance GenericE TopFun where - type RepE TopFun = EitherE - (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) - (LiftE (String, IFunType)) - fromE = \case - DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) - FFITopFun name ty -> RightE (LiftE (name, ty)) - {-# INLINE fromE #-} - toE = \case - LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status - RightE (LiftE (name, ty)) -> FFITopFun name ty - {-# INLINE toE #-} - -instance SinkableE TopFun -instance HoistableE TopFun -instance RenameE TopFun -instance AlphaEqE TopFun -instance AlphaHashableE TopFun - -instance GenericE SpecializationSpec where - type RepE SpecializationSpec = - PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) - fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) - {-# INLINE fromE #-} - toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) - {-# INLINE toE #-} - -instance HasNameHint (SpecializationSpec n) where - getNameHint (AppSpecialization f _) = getNameHint f - -instance SinkableE SpecializationSpec -instance HoistableE SpecializationSpec -instance RenameE SpecializationSpec -instance AlphaEqE SpecializationSpec -instance AlphaHashableE SpecializationSpec - -instance GenericE LinearizationSpec where - type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) - fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) - {-# INLINE fromE #-} - toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives - {-# INLINE toE #-} - -instance SinkableE LinearizationSpec -instance HoistableE LinearizationSpec -instance RenameE LinearizationSpec -instance AlphaEqE LinearizationSpec -instance AlphaHashableE LinearizationSpec - -instance GenericE SolverBinding where - type RepE SolverBinding = EitherE3 - CType - CType - CType - fromE = \case - InfVarBound ty -> Case0 ty - SkolemBound ty -> Case1 ty - DictBound ty -> Case2 ty - {-# INLINE fromE #-} - - toE = \case - Case0 ty -> InfVarBound ty - Case1 ty -> SkolemBound ty - Case2 ty -> DictBound ty - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE SolverBinding -instance HoistableE SolverBinding -instance RenameE SolverBinding -instance AlphaEqE SolverBinding -instance AlphaHashableE SolverBinding - -instance GenericE (Binding c) where - type RepE (Binding c) = - EitherE3 - (EitherE6 - (WhenAtomName c AtomBinding) - (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) - (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) - (WhenC ClassNameC c (ClassDef)) - (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) - (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) - (EitherE4 - (WhenC TopFunNameC c (TopFun)) - (WhenC FunObjCodeNameC c (CFunction)) - (WhenC ModuleNameC c (Module)) - (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) - (EitherE2 - (WhenC SpecializedDictNameC c (SpecializedDictDef)) - (WhenC ImpNameC c (LiftE BaseType))) - - fromE = \case - AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding - TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods - DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx - ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef - InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty - MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx - TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun - FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun - ModuleBinding m -> Case1 $ Case2 $ WhenC $ m - PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) - SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def - ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty - {-# INLINE fromE #-} - - toE = \case - Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding - Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods - Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx - Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef - Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty - Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i - Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun - Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f - Case1 (Case2 (WhenC (m))) -> ModuleBinding m - Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p - Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def - Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty - _ -> error "impossible" - {-# INLINE toE #-} - -deriving via WrapE (Binding c) n instance Generic (Binding c n) -instance SinkableV Binding -instance HoistableV Binding -instance RenameV Binding -instance Color c => SinkableE (Binding c) -instance Color c => HoistableE (Binding c) -instance Color c => RenameE (Binding c) - instance GenericE DotMethods where type RepE DotMethods = ListE (LiftE SourceName `PairE` CAtomName) fromE (DotMethods xys) = ListE $ [LiftE x `PairE` y | (x, y) <- M.toList xys] @@ -2506,277 +1881,9 @@ instance IRRep r => BindsOneName (Decl r) (AtomNameC r) where binderName (Let b _) = binderName b {-# INLINE binderName #-} -instance Semigroup (SynthCandidates n) where - SynthCandidates xs ys <> SynthCandidates xs' ys' = - SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') - -instance Monoid (SynthCandidates n) where - mempty = SynthCandidates mempty mempty - -instance GenericB EnvFrag where - type RepB EnvFrag = RecSubstFrag Binding - fromB (EnvFrag frag) = frag - toB frag = EnvFrag frag - -instance SinkableB EnvFrag -instance HoistableB EnvFrag -instance ProvesExt EnvFrag -instance BindsNames EnvFrag -instance RenameB EnvFrag - -instance GenericE TopEnvUpdate where - type RepE TopEnvUpdate = EitherE2 ( - EitherE4 - {- ExtendCache -} Cache - {- AddCustomRule -} (CAtomName `PairE` AtomRules) - {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) - {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) - ) ( EitherE6 - {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) - {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) - {- UpdateTyConDef -} (TyConName `PairE` TyConDef) - {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) - ) - fromE = \case - ExtendCache x -> Case0 $ Case0 x - AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) - UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) - UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) - FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) - LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) - UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) - UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) - UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) - UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) - - toE = \case - Case0 e -> case e of - Case0 x -> ExtendCache x - Case1 (x `PairE` y) -> AddCustomRule x y - Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y - Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y - _ -> error "impossible" - Case1 e -> case e of - Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y - Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y - Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y - Case3 (x `PairE` y) -> UpdateInstanceDef x y - Case4 (x `PairE` y) -> UpdateTyConDef x y - Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z - _ -> error "impossible" - _ -> error "impossible" - -instance SinkableE TopEnvUpdate -instance HoistableE TopEnvUpdate -instance RenameE TopEnvUpdate - -instance GenericB TopEnvFrag where - type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) - fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) - toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) - -instance RenameB TopEnvFrag -instance HoistableB TopEnvFrag -instance SinkableB TopEnvFrag -instance ProvesExt TopEnvFrag -instance BindsNames TopEnvFrag - -instance OutFrag TopEnvFrag where - emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty - {-# INLINE emptyOutFrag #-} - catOutFrags (TopEnvFrag frag1 env1 partial1) - (TopEnvFrag frag2 env2 partial2) = - withExtEvidence frag2 $ - TopEnvFrag - (catOutFrags frag1 frag2) - (sink env1 <> env2) - (sinkSnocList partial1 <> partial2) - {-# INLINE catOutFrags #-} - --- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't --- extend the synthesis candidates based on the annotated let-bound names. It --- only extends synth candidates when they're supplied explicitly. -instance ExtOutMap Env TopEnvFrag where - extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do - let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates - Env newerTopEnv newModuleEnv - where - Env (TopEnv defs rules cache loadedM loadedO) mEnv = env - - newTopEnv = withExtEvidence frag $ TopEnv - (defs `extendRecSubst` frag) - (sink rules) (sink cache) (sink loadedM) (sink loadedO) - - newModuleEnv = - ModuleEnv - (imports <> imports') - (sm <> sm' <> newImportedSM) - (scs <> scs' <> newImportedSC) - where - ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv - ModuleEnv imports' sm' scs' = mEnv' - newDirectImports = S.difference (directImports imports') (directImports imports) - newTransImports = S.difference (transImports imports') (transImports imports) - newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure - newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure - - lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m - -applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n -applyUpdate e = \case - ExtendCache cache -> e { envCache = envCache e <> cache} - AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} - UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} - UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} - FinishDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName - case oldMethods of - Nothing -> do - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - Just _ -> error "shouldn't be adding methods if we already have them" - LowerDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - UpdateTopFunEvalStatus f s -> do - case lookupEnvPure e f of - TopFunBinding (DexTopFun def lam _) -> - updateEnv f (TopFunBinding $ DexTopFun def lam s) e - _ -> error "can't update ffi function impl" - UpdateInstanceDef name def -> do - case lookupEnvPure e name of - InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e - UpdateTyConDef name def -> do - let TyConBinding _ methods = lookupEnvPure e name - updateEnv name (TyConBinding (Just def) methods) e - UpdateFieldDef name sn x -> do - let TyConBinding def methods = lookupEnvPure e name - updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e - -updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n -updateEnv v rhs env = - env { envDefs = RecSubst $ updateSubstFrag v rhs bs } - where (RecSubst bs) = envDefs env - -lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n -lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v - -instance GenericE Module where - type RepE Module = LiftE ModuleSourceName - `PairE` ListE ModuleName - `PairE` ListE ModuleName - `PairE` SourceMap - `PairE` SynthCandidates - - fromE (Module name deps transDeps sm sc) = - LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) - `PairE` sm `PairE` sc - {-# INLINE fromE #-} - - toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps - `PairE` sm `PairE` sc) = - Module name (S.fromList deps) (S.fromList transDeps) sm sc - {-# INLINE toE #-} - -instance SinkableE Module -instance HoistableE Module -instance AlphaEqE Module -instance AlphaHashableE Module -instance RenameE Module - -instance GenericE ImportStatus where - type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName - fromE (ImportStatus direct trans) = ListE (S.toList direct) - `PairE` ListE (S.toList trans) - {-# INLINE fromE #-} - toE (ListE direct `PairE` ListE trans) = - ImportStatus (S.fromList direct) (S.fromList trans) - {-# INLINE toE #-} - -instance SinkableE ImportStatus -instance HoistableE ImportStatus -instance AlphaEqE ImportStatus -instance AlphaHashableE ImportStatus -instance RenameE ImportStatus - -instance Semigroup (ImportStatus n) where - ImportStatus direct trans <> ImportStatus direct' trans' = - ImportStatus (direct <> direct') (trans <> trans') - -instance Monoid (ImportStatus n) where - mappend = (<>) - mempty = ImportStatus mempty mempty - -instance GenericE LoadedModules where - type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) - fromE (LoadedModules m) = - ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) - {-# INLINE toE #-} - -instance SinkableE LoadedModules -instance HoistableE LoadedModules -instance AlphaEqE LoadedModules -instance AlphaHashableE LoadedModules -instance RenameE LoadedModules - -instance GenericE LoadedObjects where - type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) - fromE (LoadedObjects m) = - ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) - {-# INLINE toE #-} - -instance SinkableE LoadedObjects -instance HoistableE LoadedObjects -instance RenameE LoadedObjects - -instance GenericE ModuleEnv where - type RepE ModuleEnv = ImportStatus - `PairE` SourceMap - `PairE` SynthCandidates - fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc - {-# INLINE fromE #-} - toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc - {-# INLINE toE #-} - -instance SinkableE ModuleEnv -instance HoistableE ModuleEnv -instance AlphaEqE ModuleEnv -instance AlphaHashableE ModuleEnv -instance RenameE ModuleEnv - -instance Semigroup (ModuleEnv n) where - ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = - ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) - -instance Monoid (ModuleEnv n) where - mempty = ModuleEnv mempty mempty mempty - -instance Semigroup (LoadedModules n) where - LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) - -instance Monoid (LoadedModules n) where - mempty = LoadedModules mempty - -instance Semigroup (LoadedObjects n) where - LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) - -instance Monoid (LoadedObjects n) where - mempty = LoadedObjects mempty - -instance Hashable InfVarDesc instance Hashable IxMethod instance Hashable ParamRole instance Hashable BuiltinClassName -instance Hashable a => Hashable (EvalStatus a) instance IRRep r => Store (MiscOp r n) instance IRRep r => Store (VectorOp r n) @@ -2791,24 +1898,18 @@ instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) instance IRRep r => Store (AtomVar r n) instance IRRep r => Store (Expr r n) -instance Store (SolverBinding n) -instance IRRep r => Store (AtomBinding r n) -instance Store (SpecializationSpec n) -instance Store (LinearizationSpec n) instance IRRep r => Store (DeclBinding r n) instance IRRep r => Store (Decl r n l) instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) instance Store (DataConDef n) -instance IRRep r => Store (TopLam r n) instance IRRep r => Store (LamExpr r n) instance IRRep r => Store (IxType r n) instance Store (CorePiType n) instance Store (CoreLamExpr n) instance IRRep r => Store (TabPiType r n) instance IRRep r => Store (DepPairType r n) -instance Store (AtomRules n) instance Store BuiltinClassName instance Store (ClassDef n) instance Store (InstanceDef n) @@ -2819,21 +1920,9 @@ instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (EffectOpType n) instance Store (EffectOpIdx) -instance Store (SynthCandidates n) -instance Store (Module n) -instance Store (ImportStatus n) -instance Store (TopFunLowerings n) -instance Store a => Store (EvalStatus a) -instance Store (TopFun n) -instance Store (TopFunDef n) -instance Color c => Store (Binding c n) -instance Store (ModuleEnv n) -instance Store (SerializedEnv n) instance Store (ann n) => Store (NonDepNest r ann n l) -instance Store InfVarDesc instance Store IxMethod instance Store ParamRole -instance Store (SpecializedDictDef n) instance IRRep r => Store (Dict r n) instance IRRep r => Store (TypedHof r n) instance IRRep r => Store (Hof r n) @@ -2843,3 +1932,366 @@ instance IRRep r => Store (DAMOp r n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) + +-- === Pretty instances === + +instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Hof r n) where + prettyPrec hof = atPrec LowestPrec case hof of + For _ _ lam -> "for" <+> pLowest lam + While body -> "while" <+> pArg body + RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) + RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) + RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) + RunIO body -> "runIO" <+> pArg body + RunInit body -> "runInit" <+> pArg body + CatchException _ body -> "catchException" <+> pArg body + Linearize body x -> "linearize" <+> pArg body <+> pArg x + Transpose body x -> "transpose" <+> pArg body <+> pArg x + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DAMOp r n) where + prettyPrec op = atPrec LowestPrec case op of + Seq _ ann _ c lamExpr -> case lamExpr of + UnaryLamExpr b body -> do + "seq" <+> pApp ann <+> pApp c <+> prettyLam (pretty b <> ".") body + _ -> pretty (show op) -- shouldn't happen, but crashing pretty printers make debugging hard + RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y + Place r v -> pApp r <+> "r:=" <+> pApp v + Freeze r -> "freeze" <+> pApp r + AllocDest ty -> "alloc" <+> pApp ty + +instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (TyCon r n) where + prettyPrec con = case con of + BaseType b -> prettyPrec b + ProdType [] -> atPrec ArgPrec $ "()" + ProdType as -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pApp as + SumType cs -> atPrec ArgPrec $ align $ group $ + encloseSep "(|" "|)" " | " $ fmap pApp cs + RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a + TypeKind -> atPrec ArgPrec "Type" + HeapType -> atPrec ArgPrec "Heap" + Pi piType -> atPrec LowestPrec $ align $ p piType + TabPi piType -> atPrec LowestPrec $ align $ p piType + DepPairTy ty -> prettyPrec ty + DictTy t -> atPrec LowestPrec $ p t + NewtypeTyCon con' -> prettyPrec con' + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann +prettyPrecNewtype con x = case (con, x) of + (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n + (_, x') -> prettyPrec x' + +instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (NewtypeTyCon n) where + prettyPrec = \case + Nat -> atPrec ArgPrec $ "Nat" + Fin n -> atPrec AppPrec $ "Fin" <+> pArg n + EffectRowKind -> atPrec ArgPrec "EffKind" + UserADTType name _ (TyConParams infs params) -> case (infs, params) of + ([], []) -> atPrec ArgPrec $ pretty name + ([Explicit, Explicit], [l, r]) + | Just sym <- fromInfix (fromString $ pprint name) -> + atPrec ArgPrec $ align $ group $ + parens $ flatAlt " " "" <> pApp l <> line <> pretty sym <+> pApp r + _ -> atPrec LowestPrec $ pAppArg (pretty name) $ ignoreSynthParams (TyConParams infs params) + where + fromInfix :: Text -> Maybe Text + fromInfix t = do + ('(', t') <- uncons t + (t'', ')') <- unsnoc t' + return t'' + +instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Con r n) where + prettyPrec = \case + Lit l -> prettyPrec l + ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" + ProdCon xs -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pLowest xs + SumCon _ tag payload -> atPrec ArgPrec $ + "(" <> p tag <> "|" <+> pApp payload <+> "|)" + HeapVal -> atPrec ArgPrec "HeapValue" + Lam lam -> atPrec LowestPrec $ p lam + DepPair x y _ -> atPrec ArgPrec $ align $ group $ + parens $ p x <+> ",>" <+> p y + Eff e -> atPrec ArgPrec $ p e + DictConAtom d -> atPrec LowestPrec $ p d + NewtypeCon con x -> prettyPrecNewtype con x + TyConAtom ty -> prettyPrec ty + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (PrimOp r n) where + prettyPrec = \case + MemOp op -> prettyPrec op + VectorOp op -> prettyPrec op + DAMOp op -> prettyPrec op + Hof (TypedHof _ hof) -> prettyPrec hof + RefOp ref eff -> atPrec LowestPrec case eff of + MAsk -> "ask" <+> pApp ref + MExtend _ x -> "extend" <+> pApp ref <+> pApp x + MGet -> "get" <+> pApp ref + MPut x -> pApp ref <+> ":=" <+> pApp x + IndexRef _ i -> pApp ref <+> "!" <+> pApp i + ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i + UnOp op x -> prettyOpDefault (UUnOp op) [x] + BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] + MiscOp op -> prettyOpGeneric op + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (MemOp r n) where + prettyPrec = \case + PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx + PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] + op -> prettyOpGeneric op + +instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (VectorOp r n) where + prettyPrec = \case + VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty + VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty + VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty + VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i + +prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann +prettyOpGeneric op = case fromEGenericOpRep op of + GenericOpRep op' [] [] [] -> atPrec ArgPrec (pretty $ show op') + GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (pretty (show op')) xs <+> pretty ts <+> pretty lams + +instance Pretty IxMethod where + pretty method = pretty $ show method + +instance Pretty (TyConParams n) where + pretty (TyConParams _ _) = undefined + +instance Pretty (TyConDef n) where + pretty (TyConDef name _ bs cons) = "data" <+> pretty name <+> pretty bs <> pretty cons + +instance Pretty (DataConDefs n) where + pretty = undefined + +instance Pretty (DataConDef n) where + pretty (DataConDef name _ repTy _) = pretty name <+> ":" <+> pretty repTy + +instance Pretty (ClassDef n) where + pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = + "Class:" <+> pretty classSourceName <+> pretty methodNames + <> indented ( + line <> "parameter binders:" <+> pretty params <> + line <> "superclasses:" <+> pretty superclasses <> + line <> "methods:" <+> pretty methodTys) + +instance Pretty ParamRole where + pretty r = pretty (show r) + +instance Pretty (InstanceDef n) where + pretty (InstanceDef className _ bs params _) = + "Instance" <+> pretty className <+> pretty bs <+> pretty params + +instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Expr r n) where + prettyPrec = \case + Atom x -> prettyPrec x + Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body + App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) + Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs + TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es + PrimOp op -> prettyPrec op + ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x + Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann +prettyPrecCase name e alts effs = atPrec LowestPrec $ + name <+> pApp e <+> "of" <> + nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts + <> effectLine effs) + where + effectLine :: IRRep r => EffectRow r n -> Doc ann + effectLine Pure = "" + effectLine row = hardline <> "case annotated with effects" <+> pretty row + +prettyAlt :: IRRep r => Alt r n -> Doc ann +prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (pretty body) + +prettyBinderNoAnn :: Binder r n l -> Doc ann +prettyBinderNoAnn (b:>_) = pretty b + +instance IRRep r => Pretty (DeclBinding r n) where + pretty (DeclBinding ann expr) = "Decl" <> pretty ann <+> pretty expr + +instance IRRep r => Pretty (Decl r n l) where + pretty (Let b (DeclBinding ann rhs)) = + align $ annDoc <> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann + +instance IRRep r => Pretty (PiType r n) where + pretty (PiType bs (EffTy effs resultTy)) = + (spaced $ unsafeFromNest $ bs) <+> "->" <+> "{" <> pretty effs <> "}" <+> pretty resultTy + +instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (LamExpr r n) where + prettyPrec (LamExpr bs body) = atPrec LowestPrec $ prettyLam (pretty bs <> ".") body + +instance IRRep r => Pretty (IxType r n) where + pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict + +instance IRRep r => Pretty (Dict r n) where + pretty = \case + DictCon con -> pretty con + StuckDict _ stuck -> pretty stuck + +instance IRRep r => Pretty (DictCon r n) where + pretty = \case + InstanceDict _ name args -> "Instance" <+> pretty name <+> pretty args + IxFin n -> "Ix (Fin" <+> pretty n <> ")" + DataData a -> "Data " <+> pretty a + IxRawFin n -> "Ix (RawFin " <> pretty n <> ")" + IxSpecialized d xs -> pretty d <+> pretty xs + +instance Pretty (DictType n) where + pretty = \case + DictType classSourceName _ params -> pretty classSourceName <+> spaced params + IxDictType ty -> "Ix" <+> pretty ty + DataDictType ty -> "Data" <+> pretty ty + +instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DepPairType r n) where + prettyPrec (DepPairType _ b rhs) = + atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [pretty b, pretty rhs] + +instance Pretty (CoreLamExpr n) where + pretty (CoreLamExpr _ lam) = pretty lam + +instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Atom r n) where + prettyPrec atom = case atom of + Con e -> prettyPrec e + Stuck _ e -> prettyPrec e + +instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Type r n) where + prettyPrec = \case + TyCon e -> prettyPrec e + StuckTy _ e -> prettyPrec e + +instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Stuck r n) where + prettyPrec = \case + Var v -> atPrec ArgPrec $ p v + StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs + StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v + InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) + SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i + PtrVar _ v -> atPrec ArgPrec $ p v + RepValAtom x -> atPrec LowestPrec $ pretty x + ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts + LiftSimp ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" + LiftSimpFun ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" + TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance PrettyPrec (AtomVar r n) where + prettyPrec (AtomVar v _) = prettyPrec v +instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec + +instance IRRep r => Pretty (EffectRow r n) where + pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map pretty (eSetToList effs))) <> pretty t + +instance IRRep r => Pretty (EffectRowTail r n) where + pretty = \case + NoTail -> mempty + EffectRowTail v -> "|" <> pretty v + +instance IRRep r => Pretty (Effect r n) where + pretty eff = case eff of + RWSEffect rws h -> pretty rws <+> pretty h + ExceptionEffect -> "Except" + IOEffect -> "IO" + InitEffect -> "Init" + +prettyLam :: Pretty a => Doc ann -> a -> Doc ann +prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ pretty body) + +instance IRRep r => Pretty (TabPiType r n) where + pretty (TabPiType dict (b:>ty) body) = let + prettyBody = case body of + TyCon (Pi subpi) -> pretty subpi + _ -> pLowest body + prettyBinder = prettyBinderHelper (b:>ty) body + in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) + +-- A helper to let us turn dict printing on and off. We mostly want it off to +-- reduce clutter in prints and error messages, but when debugging synthesis we +-- want it on. +prettyIxDict :: IRRep r => IxDict r n -> Doc ann +prettyIxDict dict = if False then " " <> pretty dict else mempty + +prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann +prettyBinderHelper (b:>ty) body = + if binderName b `isFreeIn` body + then parens $ pretty (b:>ty) + else pretty ty + +instance Pretty (CorePiType n) where + pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = + prettyBindersWithExpl expls bs <+> pretty appExpl <> prettyEff <> pretty resultTy + where + prettyEff = case eff of + Pure -> space + _ -> space <> pretty eff <> space + +prettyBindersWithExpl :: forall b n l ann. PrettyB b + => [Explicitness] -> Nest b n l -> Doc ann +prettyBindersWithExpl expls bs = do + let groups = groupByExpl $ zip expls (unsafeFromNest bs) + let groups' = case groups of [] -> [(Explicit, [])] + _ -> groups + mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] + +groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] +groupByExpl [] = [] +groupByExpl ((expl, b):bs) = do + let (matches, rest) = span (\(expl', _) -> expl == expl') bs + let matches' = map snd matches + (expl, b:matches') : groupByExpl rest + +withExplParens :: Explicitness -> Doc ann -> Doc ann +withExplParens Explicit x = parens x +withExplParens (Inferred _ Unify) x = braces $ x +withExplParens (Inferred _ (Synth _)) x = brackets x + +instance Pretty (RepVal n) where + pretty (RepVal ty tree) = " pretty tree <+> ":" <+> pretty ty <> ">" + +prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann +prettyBlock Empty expr = group $ line <> pLowest expr +prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr + where decls' = unsafeFromNest decls + +instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (BaseMonoid r n) where + prettyPrec (BaseMonoid x f) = + atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) diff --git a/src/lib/Types/Imp.hs b/src/lib/Types/Imp.hs index d99d66c4a..9006745ce 100644 --- a/src/lib/Types/Imp.hs +++ b/src/lib/Types/Imp.hs @@ -27,11 +27,16 @@ import qualified Data.ByteString as BS import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc (line', nest, group) import Name +import PPrint import Util (IsBool (..)) - import Types.Primitives +import Types.Source + +-- === data types === type ImpName = Name ImpNameC @@ -480,3 +485,91 @@ instance Store LinktimeVals instance Hashable IsCUDARequired instance Hashable CallingConvention instance Hashable IFunType + +instance Pretty CallingConvention where pretty = fromString . show + +instance Pretty (ImpFunction n) where + pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = + "impfun" <+> pretty cc <+> prettyBinderNest bs + <> nest 2 (hardline <> pretty body) <> hardline + +instance Pretty (ImpBlock n) where + pretty = \case + ImpBlock Empty [] -> mempty + ImpBlock Empty expr -> group $ hardline <> pLowest expr + ImpBlock decls [] -> prettyLines $ fromNest decls + ImpBlock decls expr -> prettyLines decls' <> hardline <> pLowest expr + where decls' = fromNest decls + +instance Pretty (IBinder n l) where + pretty (IBinder b ty) = pretty b <+> ":" <+> pretty ty + +instance Pretty (ImpInstr n) where + pretty = \case + IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> + nest 4 (p block) + IWhile body -> "while" <+> nest 2 (p body) + ICond predicate cons alt -> + "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> + hardline <> "else" <> nest 2 (p alt) + IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s + ILaunch f s args -> "launch" <+> p f <+> p s <+> spaced args + ICastOp t x -> "cast" <+> p x <+> "to" <+> p t + IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t + Store dest val -> "store" <+> p dest <+> p val + Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" + StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" + MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel + InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel + GetAllocSize ptr -> "getAllocSize" <+> p ptr + Free ptr -> "free" <+> p ptr + ISyncWorkgroup -> "syncWorkgroup" + IThrowError -> "throwError" + ICall f args -> "call" <+> p f <+> p args + IVectorBroadcast v _ -> "vbroadcast" <+> p v + IVectorIota _ -> "viota" + DebugPrint s x -> "debug_print" <+> p (show s) <+> p x + IPtrLoad ptr -> "load" <+> p ptr + IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx + IBinOp op x y -> opDefault (UBinOp op) [x, y] + IUnOp op x -> opDefault (UUnOp op) [x] + ISelect x y z -> "select" <+> p x <+> p y <+> p z + IOutputStream -> "outputStream" + IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x + where opDefault name xs = prettyOpDefault name xs $ AppPrec + p :: Pretty a => a -> Doc ann + p = pretty + forStr :: ForAnn -> Doc ann + forStr = \case + Fwd -> "for" + Rev -> "rof" + +sizeStr :: IExpr n -> Doc ann +sizeStr s = case s of + ILit (Word32Lit x) -> pretty x -- print in decimal because it's more readable + _ -> pretty s + +instance Pretty (IExpr n) where + pretty = \case + ILit v -> pretty v + IVar v _ -> pretty v + IPtrVar v _ -> pretty v + +instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty + +instance Pretty (ImpDecl n l) where + pretty = \case + ImpLet Empty instr -> pretty instr + ImpLet (Nest b Empty) instr -> pretty b <+> "=" <+> pretty instr + ImpLet bs instr -> pretty bs <+> "=" <+> pretty instr + +instance Pretty IFunType where + pretty (IFunType cc argTys retTys) = + "Fun" <+> pretty cc <+> pretty argTys <+> "->" <+> pretty retTys + +prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann +prettyBinderNest bs = nest 6 $ line' <> (sep $ map pretty $ fromNest bs) + +fromNest :: Nest b n l -> [b UnsafeS UnsafeS] +fromNest Empty = [] +fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest diff --git a/src/lib/Types/OpNames.hs b/src/lib/Types/OpNames.hs index 178936ec7..344329ac6 100644 --- a/src/lib/Types/OpNames.hs +++ b/src/lib/Types/OpNames.hs @@ -14,6 +14,8 @@ import Data.Hashable import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import PPrint + data TC = ProdType | SumType | RefType | TypeKind | HeapType data Con = ProdCon | SumCon Int | HeapVal @@ -117,3 +119,8 @@ deriving instance Eq (Hof r) deriving instance Eq DAMOp deriving instance Eq RefOp deriving instance Eq UserEffectOp + +instance Pretty Projection where + pretty = \case + UnwrapNewtype -> "u" + ProjectProduct i -> pretty i diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index 8096f7e6e..f449acba6 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -28,12 +28,14 @@ import Data.String (IsString (..)) import Data.Word import Data.Hashable import Data.Store (Store (..)) -import Data.Text.Prettyprint.Doc (Pretty (..)) import qualified Data.Store.Internal as SI import Foreign.Ptr +import Numeric +import GHC.Float import GHC.Generics (Generic (..)) +import PPrint import Occurrence import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) import Name @@ -222,3 +224,75 @@ instance Hashable AppExplicitness instance Hashable DepPairExplicitness instance Hashable InferenceMechanism instance Hashable RequiredMethodAccess + +-- === Pretty instances === + +instance Pretty AppExplicitness where + pretty ExplicitApp = "->" + pretty ImplicitApp = "->>" + +instance Pretty RWS where + pretty eff = case eff of + Reader -> "Read" + Writer -> "Accum" + State -> "State" + +instance Pretty LetAnn where + pretty ann = case ann of + PlainLet -> "" + InlineLet -> "%inline" + NoInlineLet -> "%noinline" + LinearLet -> "%linear" + OccInfoPure u -> pretty u <> hardline + OccInfoImpure u -> pretty u <> ", impure" <> hardline + +instance PrettyPrec Direction where + prettyPrec d = atPrec ArgPrec $ case d of + Fwd -> "fwd" + Rev -> "rev" + +printDouble :: Double -> Doc ann +printDouble x = pretty (double2Float x) + +printFloat :: Float -> Doc ann +printFloat x = pretty $ reverse $ dropWhile (=='0') $ reverse $ + showFFloat (Just 6) x "" + +instance Pretty LitVal where pretty = prettyFromPrettyPrec +instance PrettyPrec LitVal where + prettyPrec = \case + Int64Lit x -> atPrec ArgPrec $ p x + Int32Lit x -> atPrec ArgPrec $ p x + Float64Lit x -> atPrec ArgPrec $ printDouble x + Float32Lit x -> atPrec ArgPrec $ printFloat x + Word8Lit x -> atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x + Word32Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + Word64Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + PtrLit ty (PtrLitVal x) -> atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) + PtrLit _ NullPtr -> atPrec ArgPrec $ "NullPtr" + PtrLit _ (PtrSnapshot _) -> atPrec ArgPrec "" + where p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty Device where pretty = fromString . show + +instance Pretty BaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec BaseType where + prettyPrec b = case b of + Scalar sb -> prettyPrec sb + Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (pretty <$> shape) ++ [pretty sb] + PtrType ty -> atPrec AppPrec $ "Ptr" <+> pretty ty + +instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec ScalarBaseType where + prettyPrec sb = atPrec ArgPrec $ case sb of + Int64Type -> "Int64" + Int32Type -> "Int32" + Float64Type -> "Float64" + Float32Type -> "Float32" + Word8Type -> "Word8" + Word32Type -> "Word32" + Word64Type -> "Word64" + +instance Pretty Explicitness where + pretty expl = pretty (show expl) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 4249b9a92..b43b81a4c 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -26,13 +26,17 @@ import Data.Foldable import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Text (Text) -import Data.Text.Prettyprint.Doc (Pretty (..), hardline, (<+>)) import Data.Word +import Data.Text.Prettyprint.Doc (vcat, line, group, parens, nest, align, punctuate, hsep) +import Data.Text (snoc, unsnoc) +import Data.Tuple (swap) import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) import Err +import PPrint import Name import qualified Types.OpNames as P import IRVariants @@ -650,6 +654,102 @@ data PrimName = | UTuple -- overloaded for type constructor and data constructor, resolved in inference deriving (Show, Eq, Generic) +-- === primitive constructors and operators === + +strToPrimName :: String -> Maybe PrimName +strToPrimName s = M.lookup s primNames + +primNameToStr :: PrimName -> String +primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of + Just s -> s + Nothing -> show prim + +showPrimName :: PrimName -> String +showPrimName prim = primNameToStr prim +{-# NOINLINE showPrimName #-} + +primNames :: M.Map String PrimName +primNames = M.fromList + [ ("ask" , UMAsk), ("mextend", UMExtend) + , ("get" , UMGet), ("put" , UMPut) + , ("while" , UWhile) + , ("linearize", ULinearize), ("linearTranspose", UTranspose) + , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) + , ("runIO" , URunIO ), ("catchException" , UCatchException) + , ("iadd" , binary IAdd), ("isub" , binary ISub) + , ("imul" , binary IMul), ("fdiv" , binary FDiv) + , ("fadd" , binary FAdd), ("fsub" , binary FSub) + , ("fmul" , binary FMul), ("idiv" , binary IDiv) + , ("irem" , binary IRem) + , ("fpow" , binary FPow) + , ("and" , binary BAnd), ("or" , binary BOr ) + , ("not" , unary BNot), ("xor" , binary BXor) + , ("shl" , binary BShL), ("shr" , binary BShR) + , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) + , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) + , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) + , ("fneg" , unary FNeg) + , ("exp" , unary Exp), ("exp2" , unary Exp2) + , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) + , ("sin" , unary Sin), ("cos" , unary Cos) + , ("tan" , unary Tan), ("sqrt" , unary Sqrt) + , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) + , ("log1p", unary Log1p), ("lgamma", unary LGamma) + , ("erf" , unary Erf), ("erfc" , unary Erfc) + , ("TyKind" , UPrimTC $ P.TypeKind) + , ("Float64" , baseTy $ Scalar Float64Type) + , ("Float32" , baseTy $ Scalar Float32Type) + , ("Int64" , baseTy $ Scalar Int64Type) + , ("Int32" , baseTy $ Scalar Int32Type) + , ("Word8" , baseTy $ Scalar Word8Type) + , ("Word32" , baseTy $ Scalar Word32Type) + , ("Word64" , baseTy $ Scalar Word64Type) + , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) + , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) + , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) + , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) + , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) + , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) + , ("Nat" , UNat) + , ("Fin" , UFin) + , ("EffKind" , UEffectRowKind) + , ("NatCon" , UNatCon) + , ("Ref" , UPrimTC $ P.RefType) + , ("HeapType" , UPrimTC $ P.HeapType) + , ("indexRef" , UIndexRef) + , ("alloc" , memOp $ P.IOAlloc) + , ("free" , memOp $ P.IOFree) + , ("ptrOffset", memOp $ P.PtrOffset) + , ("ptrLoad" , memOp $ P.PtrLoad) + , ("ptrStore" , memOp $ P.PtrStore) + , ("throwError" , miscOp $ P.ThrowError) + , ("throwException", miscOp $ P.ThrowException) + , ("dataConTag" , miscOp $ P.SumTag) + , ("toEnum" , miscOp $ P.ToEnum) + , ("outputStream" , miscOp $ P.OutputStream) + , ("cast" , miscOp $ P.CastOp) + , ("bitcast" , miscOp $ P.BitcastOp) + , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) + , ("garbageVal" , miscOp $ P.GarbageVal) + , ("select" , miscOp $ P.Select) + , ("showAny" , miscOp $ P.ShowAny) + , ("showScalar" , miscOp $ P.ShowScalar) + , ("projNewtype" , UProjNewtype) + , ("applyMethod0" , UApplyMethod 0) + , ("applyMethod1" , UApplyMethod 1) + , ("applyMethod2" , UApplyMethod 2) + , ("explicitApply", UExplicitApply) + , ("monoLit", UMonoLiteral) + ] + where + binary op = UBinOp op + baseTy b = UBaseType b + memOp op = UMemOp op + unary op = UUnOp op + ptrTy ty = PtrType (CPU, ty) + miscOp op = UMiscOp op + + -- === instances === instance Semigroup (SourceMap n) where @@ -862,3 +962,265 @@ deriving instance Ord (UEffectRow n) instance ToJSON SrcId deriving instance ToJSONKey SrcId instance ToJSON LexemeType + +-- === Pretty instances === + + + +instance Pretty CSBlock where + pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls + pretty (ExprBlock g) = pArg g + +instance Pretty Group where pretty = prettyFromPrettyPrec +instance PrettyPrec Group where + prettyPrec = undefined + -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n + -- prettyPrec (CPrim prim args) = prettyOpDefault prim args + -- prettyPrec (CParens blk) = + -- atPrec ArgPrec $ "(" <> p blk <> ")" + -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g + -- prettyPrec (CBin op lhs rhs) = + -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs + -- prettyPrec (CLambda args body) = + -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body + -- prettyPrec (CCase scrut alts) = + -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts + -- prettyPrec g = atPrec ArgPrec $ fromString $ show g + +instance Pretty Bin where + pretty = \case + EvalBinOp name -> pretty name + DepAmpersand -> "&>" + Dot -> "." + DepComma -> ",>" + Colon -> ":" + DoubleColon -> "::" + Dollar -> "$" + ImplicitArrow -> "->>" + FatArrow -> "->>" + Pipe -> "|" + CSEqual -> "=" + +instance Pretty SourceBlock' where + pretty (TopDecl decl) = pretty decl + pretty d = fromString $ show d + +instance Pretty CTopDecl where + pretty (CSDecl ann decl) = annDoc <> pretty decl + where annDoc = case ann of + PlainLet -> mempty + _ -> pretty ann <> " " + pretty d = fromString $ show d + +instance Pretty CSDecl where + pretty = undefined + -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk + -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk + -- pretty (CDefDecl (CDef name args maybeAnn blk)) = + -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc + -- <> nest 2 (hardline <> p blk) + -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty + -- Nothing -> mempty + -- pretty (CInstance header givens methods name) = + -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where + -- name' = case name of + -- Nothing -> "instance " + -- (Just n) -> "named-instance " <> p n <> " " + -- pretty (CExpr e) = p e + +instance Pretty PrimName where + pretty primName = pretty $ "%" ++ showPrimName primName + +instance Pretty (UDataDefTrail n) where + pretty (UDataDefTrail bs) = pretty $ unsafeFromNest bs + +instance Pretty (UAnnBinder n l) where + pretty (UAnnBinder _ b ty _) = pretty b <> ":" <> pretty ty + +instance Pretty (UAnn n) where + pretty (UAnn ty) = ":" <> pretty ty + pretty UNoAnn = mempty + +instance Pretty (UMethodDef' n) where + pretty (UMethodDef b rhs) = pretty b <+> "=" <+> pretty rhs + +instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPat' n l) where + prettyPrec pat = case pat of + UPatBinder x -> atPrec ArgPrec $ p x + UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (unsafeFromNest xs) + UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y + UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (unsafeFromNest pats) + UPatTable pats -> atPrec ArgPrec $ p pats + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UAlt n) where + pretty (UAlt pat body) = pretty pat <+> "->" <+> pretty body + +instance Pretty (UTopDecl n l) where + pretty = \case + UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons -> + "data" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines (zip (toList $ unsafeFromNest bDataCons) dataCons)) + UStructDecl bTyCon (UStructDef nm bs fields defs) -> + "struct" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines fields <> prettyLines defs) + UInterface params methodTys interfaceName methodNames -> + "interface" <+> p params <+> p interfaceName + <> hardline <> foldMap (<>hardline) methods + where + methods = [ p b <> ":" <> p (unsafeCoerceE ty) + | (b, ty) <- zip (toList $ unsafeFromNest methodNames) methodTys] + UInstance className bs params methods (RightB UnitB) _ -> + "instance" <+> p bs <+> p className <+> spaced params <+> + prettyLines methods + UInstance className bs params methods (LeftB v) _ -> + "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params + <> prettyLines methods + ULocalDecl decl -> p decl + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UDecl' n l) where + pretty = \case + ULet ann b _ rhs -> align $ pretty ann <+> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + UExprDecl expr -> pretty expr + UPass -> "pass" + +instance Pretty (UEffectRow n) where + pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (pretty <$> toList x) + pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (pretty <$> toList x)) <+> "|" <+> pretty y <> "}" + +instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x + +instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x + +instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = pretty x +instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x + +instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = pretty x +instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x + +instance PrettyE e => Pretty (SourceNameOr e n) where + pretty (SourceName _ v) = pretty v + pretty (InternalName _ v _) = pretty v + +instance Pretty (SourceOrInternalName c n) where + pretty (SourceOrInternalName sn) = pretty sn + +instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (ULamExpr n) where + prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ + "\\" <> pretty bs <+> "." <+> indented (pretty body) + +instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPiExpr n) where + prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pLowest ty + prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pretty eff <+> pLowest ty + +instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UTabPiExpr n) where + prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "=>" <+> pLowest ty + +instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UDepPairType n) where + -- TODO: print explicitness info + prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "&>" <+> pLowest ty + +instance Pretty (UBlock' n) where + pretty (UBlock decls result) = + prettyLines (unsafeFromNest decls) <> hardline <> pLowest result + +instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UExpr' n) where + prettyPrec expr = case expr of + ULit l -> prettyPrec l + UVar v -> atPrec ArgPrec $ p v + ULam lam -> prettyPrec lam + UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named + UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x + UFor dir (UForExpr binder body) -> + atPrec LowestPrec $ kw <+> p binder <> "." + <+> nest 2 (p body) + where kw = case dir of Fwd -> "for" + Rev -> "rof" + UPi piType -> prettyPrec piType + UTabPi piType -> prettyPrec piType + UDepPairTy depPairType -> prettyPrec depPairType + UDepPair lhs rhs -> atPrec ArgPrec $ parens $ + p lhs <+> ",>" <+> p rhs + UHole -> atPrec ArgPrec "_" + UTypeAnn v ty -> atPrec LowestPrec $ + group $ pApp v <> line <> ":" <+> pApp ty + UTabCon xs -> atPrec ArgPrec $ p xs + UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs + UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> + nest 2 (prettyLines alts) + UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f + UNatLit v -> atPrec ArgPrec $ p v + UIntLit v -> atPrec ArgPrec $ p v + UFloatLit v -> atPrec ArgPrec $ p v + UDo block -> atPrec LowestPrec $ p block + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty SourceBlock where + pretty block = pretty $ ensureNewline (sbText block) where + -- Force the SourceBlock to end in a newline for echoing, even if + -- it was terminated with EOF in the original program. + ensureNewline t = case unsnoc t of + Nothing -> t + Just (_, '\n') -> t + _ -> t `snoc` '\n' + +instance Pretty Output where + pretty = \case + TextOut s -> pretty s + HtmlOut _ -> "" + SourceInfo _ -> "" + PassInfo _ s -> pretty s + MiscLog s -> pretty s + Error e -> pretty e + +instance Pretty PassName where + pretty x = pretty $ show x + +instance Pretty Result where + pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr + where maybeErr = case r of Failure err -> pretty err + Success () -> mempty + +instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UBinder' c n l) where + prettyPrec b = atPrec ArgPrec case b of + UBindSource v -> pretty v + UIgnore -> "_" + UBind v _ -> pretty v + +instance Pretty FieldName' where + pretty = \case + FieldName s -> pretty s + FieldNum n -> pretty n + +instance Pretty (UEffect n) where + pretty eff = case eff of + URWSEffect rws h -> pretty rws <+> pretty h + UExceptionEffect -> "Except" + UIOEffect -> "IO" + +prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann +prettyOpDefault name args = + case length args of + 0 -> atPrec ArgPrec primName + _ -> atPrec AppPrec $ pAppArg primName args + where primName = pretty name diff --git a/src/lib/Types/Top.hs b/src/lib/Types/Top.hs new file mode 100644 index 000000000..fba64b0e1 --- /dev/null +++ b/src/lib/Types/Top.hs @@ -0,0 +1,1046 @@ +-- Copyright 2022 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE StrictData #-} + +-- Top-level data types + +module Types.Top where + +import Data.Functor ((<&>)) +import Data.Hashable +import Data.Text.Prettyprint.Doc +import qualified Data.Map.Strict as M +import qualified Data.Set as S + +import GHC.Generics (Generic (..)) +import Data.Store (Store (..)) +import Foreign.Ptr + +import Name +import Util (FileHash, SnocList (..)) +import IRVariants +import PPrint + +import Types.Primitives +import Types.Core +import Types.Source +import Types.Imp + +type TopBlock = TopLam -- used for nullary lambda +type IsDestLam = Bool +data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) + deriving (Show, Generic) +type STopLam = TopLam SimpIR +type CTopLam = TopLam CoreIR + +data EvalStatus a = Waiting | Running | Finished a + deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) +type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) + +data TopFun (n::S) = + DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) + | FFITopFun String IFunType + deriving (Show, Generic) + +data TopFunDef (n::S) = + Specialization (SpecializationSpec n) + | LinearizationPrimal (LinearizationSpec n) + -- Tangent functions all take some number of nonlinear args, then a *single* + -- linear arg. This is so that transposition can be an involution - you apply + -- it twice and you get back to the original function. + | LinearizationTangent (LinearizationSpec n) + deriving (Show, Generic) + +newtype TopFunLowerings (n::S) = TopFunLowerings + { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed + deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) + +data AtomBinding (r::IR) (n::S) where + LetBound :: DeclBinding r n -> AtomBinding r n + MiscBound :: Type r n -> AtomBinding r n + TopDataBound :: RepVal n -> AtomBinding SimpIR n + SolverBound :: SolverBinding n -> AtomBinding CoreIR n + NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n + FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n + +deriving instance IRRep r => Show (AtomBinding r n) +deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) + +-- name of function, name of arg +type InferenceArgDesc = (String, String) +data InfVarDesc = + ImplicitArgInfVar InferenceArgDesc + | AnnotationInfVar String -- name of binder + | TypeInstantiationInfVar String -- name of type + | MiscInfVar + deriving (Show, Generic, Eq, Ord) + +data SolverBinding (n::S) = + InfVarBound (CType n) + | SkolemBound (CType n) + | DictBound (CType n) + deriving (Show, Generic) + +-- TODO: Use an IntMap +newtype CustomRules (n::S) = + CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } + deriving (Semigroup, Monoid, Store) +data AtomRules (n::S) = + -- number of implicit args, number of explicit args, linearization function + CustomLinearize Int Int SymbolicZeros (CAtom n) + deriving (Generic) + +-- === envs and modules === + +-- `ModuleEnv` contains data that only makes sense in the context of evaluating +-- a particular module. `TopEnv` contains everything that makes sense "between" +-- evaluating modules. +data Env n = Env + { topEnv :: {-# UNPACK #-} TopEnv n + , moduleEnv :: {-# UNPACK #-} ModuleEnv n } + deriving (Generic) + +newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) + deriving (OutFrag) + +data TopEnv (n::S) = TopEnv + { envDefs :: RecSubst Binding n + , envCustomRules :: CustomRules n + , envCache :: Cache n + , envLoadedModules :: LoadedModules n + , envLoadedObjects :: LoadedObjects n } + deriving (Generic) + +data SerializedEnv n = SerializedEnv + { serializedEnvDefs :: RecSubst Binding n + , serializedEnvCustomRules :: CustomRules n + , serializedEnvCache :: Cache n } + deriving (Generic) + +-- TODO: consider splitting this further into `ModuleEnv` (the env that's +-- relevant between top-level decls) and `LocalEnv` (the additional parts of the +-- env that's relevant under a lambda binder). Unlike the Top/Module +-- distinction, there's some overlap. For example, instances can be defined at +-- both the module-level and local level. Similarly, if we start allowing +-- top-level effects in `Main` then we'll have module-level effects and local +-- effects. +data ModuleEnv (n::S) = ModuleEnv + { envImportStatus :: ImportStatus n + , envSourceMap :: SourceMap n + , envSynthCandidates :: SynthCandidates n } + deriving (Generic) + +data Module (n::S) = Module + { moduleSourceName :: ModuleSourceName + , moduleDirectDeps :: S.Set (ModuleName n) + , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself + , moduleExports :: SourceMap n + -- these are just the synth candidates required by this + -- module by itself. We'll usually also need those required by the module's + -- (transitive) dependencies, which must be looked up separately. + , moduleSynthCandidates :: SynthCandidates n } + deriving (Show, Generic) + +data LoadedModules (n::S) = LoadedModules + { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} + deriving (Show, Generic) + +emptyModuleEnv :: ModuleEnv n +emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty + +emptyLoadedModules :: LoadedModules n +emptyLoadedModules = LoadedModules mempty + +data LoadedObjects (n::S) = LoadedObjects + -- the pointer points to the actual runtime function + { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} + deriving (Show, Generic) + +emptyLoadedObjects :: LoadedObjects n +emptyLoadedObjects = LoadedObjects mempty + +data ImportStatus (n::S) = ImportStatus + { directImports :: S.Set (ModuleName n) + -- XXX: This are cached for efficiency. It's derivable from `directImports`. + , transImports :: S.Set (ModuleName n) } + deriving (Show, Generic) + +data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) + +data TopEnvUpdate n = + ExtendCache (Cache n) + | AddCustomRule (CAtomName n) (AtomRules n) + | UpdateLoadedModules ModuleSourceName (ModuleName n) + | UpdateLoadedObjects (FunObjCodeName n) NativeFunction + | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) + | UpdateInstanceDef (InstanceName n) (InstanceDef n) + | UpdateTyConDef (TyConName n) (TyConDef n) + | UpdateFieldDef (TyConName n) SourceName (CAtomName n) + +-- TODO: we could add a lot more structure for querying by dict type, caching, etc. +data SynthCandidates n = SynthCandidates + { instanceDicts :: M.Map (ClassName n) [InstanceName n] + , ixInstances :: [InstanceName n] } + deriving (Show, Generic) + +emptyImportStatus :: ImportStatus n +emptyImportStatus = ImportStatus mempty mempty + +-- TODO: figure out the additional top-level context we need -- backend, other +-- compiler flags etc. We can have a map from those to this. + +data Cache (n::S) = Cache + { specializationCache :: EMap SpecializationSpec TopFunName n + , ixDictCache :: EMap AbsDict SpecDictName n + , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n + , transpositionCache :: EMap TopFunName TopFunName n + -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we + -- only want to store one entry per module name as a simple cache eviction + -- policy, so we store it keyed on the module name, with the text hash for + -- the validity check. + , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) + , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) + } deriving (Show, Generic) + +-- === runtime function and variable representations === + +type RuntimeEnv = DynamicVarKeyPtrs + +type DexDestructor = FunPtr (IO ()) + +data NativeFunction = NativeFunction + { nativeFunPtr :: FunPtr () + , nativeFunTeardown :: IO () } + +instance Show NativeFunction where + show _ = "" + +-- Holds pointers to thread-local storage used to simulate dynamically scoped +-- variables, such as the output stream file descriptor. +type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] + +data DynamicVar = OutStreamDyvar -- TODO: add others as needed + deriving (Enum, Bounded) + +dynamicVarCName :: DynamicVar -> String +dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" + +dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] +dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) + +-- === Specialization and generalization === + +type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) +type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e +type AbsDict = Abstracted CoreIR (Dict CoreIR) + +data SpecializedDictDef n = + SpecializedDict + (AbsDict n) + -- Methods (thunked if nullary), if they're available. + -- We create specialized dict names during simplification, but we don't + -- actually simplify/lower them until we return to TopLevel + (Maybe [TopLam SimpIR n]) + deriving (Show, Generic) + +-- TODO: extend with AD-oriented specializations, backend-specific specializations etc. +data SpecializationSpec (n::S) = + AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) + deriving (Show, Generic) + +type Active = Bool +data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] + deriving (Show, Generic) + +-- === bindings - static information we carry about a lexical scope === + +-- TODO: consider making this an open union via a typeable-like class +data Binding (c::C) (n::S) where + AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n + TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n + DataConBinding :: TyConName n -> Int -> Binding DataConNameC n + ClassBinding :: ClassDef n -> Binding ClassNameC n + InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n + MethodBinding :: ClassName n -> Int -> Binding MethodNameC n + TopFunBinding :: TopFun n -> Binding TopFunNameC n + FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n + ModuleBinding :: Module n -> Binding ModuleNameC n + -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` + PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n + SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n + ImpNameBinding :: BaseType -> Binding ImpNameC n + +-- === ToBinding === + +atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n +atomBindingToBinding b = AtomNameBinding b + +bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n +bindingToAtomBinding (AtomNameBinding b) = b + +class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where + toBinding :: e n -> Binding c n + +instance Color c => ToBinding (Binding c) c where + toBinding = id + +instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where + toBinding = atomBindingToBinding + +instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where + toBinding = toBinding . LetBound + +instance IRRep r => ToBinding (Type r) (AtomNameC r) where + toBinding = toBinding . MiscBound + +instance ToBinding SolverBinding (AtomNameC CoreIR) where + toBinding = toBinding . SolverBound + +instance IRRep r => ToBinding (IxType r) (AtomNameC r) where + toBinding (IxType t _) = toBinding t + +instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where + toBinding (LeftE e) = toBinding e + toBinding (RightE e) = toBinding e + +instance ToBindersAbs (TopLam r) (Expr r) r where + toAbs (TopLam _ _ lam) = toAbs lam + +-- === GenericE, GenericB === + +instance GenericE SpecializedDictDef where + type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) + fromE (SpecializedDict ab methods) = ab `PairE` methods' + where methods' = case methods of Just xs -> LeftE (ListE xs) + Nothing -> RightE UnitE + {-# INLINE fromE #-} + toE (ab `PairE` methods) = SpecializedDict ab methods' + where methods' = case methods of LeftE (ListE xs) -> Just xs + RightE UnitE -> Nothing + {-# INLINE toE #-} + +instance SinkableE SpecializedDictDef +instance HoistableE SpecializedDictDef +instance AlphaEqE SpecializedDictDef +instance AlphaHashableE SpecializedDictDef +instance RenameE SpecializedDictDef + +instance HasScope Env where + toScope = toScope . envDefs . topEnv + +instance OutMap Env where + emptyOutMap = + Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) + emptyModuleEnv + {-# INLINE emptyOutMap #-} + +instance ExtOutMap Env (RecSubstFrag Binding) where + -- TODO: We might want to reorganize this struct to make this + -- do less explicit sinking etc. It's a hot operation! + extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = + withExtEvidence frag $ Env + (TopEnv + (defs `extendRecSubst` frag) + (sink rules) + (sink cache) + (sink loadedM) + (sink loadedO)) + (sink moduleEnv) + {-# INLINE extendOutMap #-} + +instance ExtOutMap Env EnvFrag where + extendOutMap = extendEnv + {-# INLINE extendOutMap #-} + +extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l +extendEnv env (EnvFrag newEnv) = do + case extendOutMap env newEnv of + Env envTop (ModuleEnv imports sm scs) -> do + Env envTop (ModuleEnv imports sm scs) +{-# NOINLINE [1] extendEnv #-} + + +instance GenericE AtomRules where + type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom + fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a + toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a +instance SinkableE AtomRules +instance HoistableE AtomRules +instance AlphaEqE AtomRules +instance RenameE AtomRules + +instance GenericE CustomRules where + type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) + fromE (CustomRules m) = ListE $ toPairE <$> M.toList m + toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l +instance SinkableE CustomRules +instance HoistableE CustomRules +instance AlphaEqE CustomRules +instance RenameE CustomRules + +instance GenericE Cache where + type RepE Cache = + EMap SpecializationSpec TopFunName + `PairE` EMap AbsDict SpecDictName + `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) + `PairE` EMap TopFunName TopFunName + `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) + `PairE` ListE ( LiftE ModuleSourceName + `PairE` LiftE FileHash + `PairE` ListE ModuleName + `PairE` ModuleName) + fromE (Cache x y z w parseCache evalCache) = + x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` + ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] + {-# INLINE fromE #-} + toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = + Cache x y z w parseCache + (M.fromList + [(sourceName, ((hashVal, deps), result)) + | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + <- evalCache]) + {-# INLINE toE #-} + +instance SinkableE Cache +instance HoistableE Cache +instance AlphaEqE Cache +instance RenameE Cache +instance Store (Cache n) + +instance Monoid (Cache n) where + mempty = Cache mempty mempty mempty mempty mempty mempty + mappend = (<>) + +instance Semigroup (Cache n) where + -- right-biased instead of left-biased + Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = + Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) + + +instance GenericE SynthCandidates where + type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) + `PairE` ListE InstanceName + fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys + where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) + {-# INLINE fromE #-} + toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys + where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs + {-# INLINE toE #-} + +instance SinkableE SynthCandidates +instance HoistableE SynthCandidates +instance AlphaEqE SynthCandidates +instance AlphaHashableE SynthCandidates +instance RenameE SynthCandidates + +instance IRRep r => GenericE (AtomBinding r) where + type RepE (AtomBinding r) = + EitherE2 (EitherE3 + (DeclBinding r) -- LetBound + (Type r) -- MiscBound + (WhenCore r SolverBinding) -- SolverBound + ) (EitherE3 + (WhenCore r (PairE CType CAtom)) -- NoinlineFun + (WhenSimp r RepVal) -- TopDataBound + (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound + ) + + fromE = \case + LetBound x -> Case0 $ Case0 x + MiscBound x -> Case0 $ Case1 x + SolverBound x -> Case0 $ Case2 $ WhenIRE x + NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x + TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal + FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v + {-# INLINE fromE #-} + + toE = \case + Case0 x' -> case x' of + Case0 x -> LetBound x + Case1 x -> MiscBound x + Case2 (WhenIRE x) -> SolverBound x + _ -> error "impossible" + Case1 x' -> case x' of + Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x + Case1 (WhenIRE repVal) -> TopDataBound repVal + Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} + + +instance IRRep r => SinkableE (AtomBinding r) +instance IRRep r => HoistableE (AtomBinding r) +instance IRRep r => RenameE (AtomBinding r) +instance IRRep r => AlphaEqE (AtomBinding r) +instance IRRep r => AlphaHashableE (AtomBinding r) + +instance GenericE TopFunDef where + type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec + fromE = \case + Specialization s -> Case0 s + LinearizationPrimal s -> Case1 s + LinearizationTangent s -> Case2 s + {-# INLINE fromE #-} + toE = \case + Case0 s -> Specialization s + Case1 s -> LinearizationPrimal s + Case2 s -> LinearizationTangent s + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE TopFunDef +instance HoistableE TopFunDef +instance RenameE TopFunDef +instance AlphaEqE TopFunDef +instance AlphaHashableE TopFunDef + +instance IRRep r => GenericE (TopLam r) where + type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r + fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y + {-# INLINE fromE #-} + toE (LiftE d `PairE` x `PairE` y) = TopLam d x y + {-# INLINE toE #-} + +instance IRRep r => SinkableE (TopLam r) +instance IRRep r => HoistableE (TopLam r) +instance IRRep r => RenameE (TopLam r) +instance IRRep r => AlphaEqE (TopLam r) +instance IRRep r => AlphaHashableE (TopLam r) + +instance GenericE TopFun where + type RepE TopFun = EitherE + (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) + (LiftE (String, IFunType)) + fromE = \case + DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) + FFITopFun name ty -> RightE (LiftE (name, ty)) + {-# INLINE fromE #-} + toE = \case + LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status + RightE (LiftE (name, ty)) -> FFITopFun name ty + {-# INLINE toE #-} + +instance SinkableE TopFun +instance HoistableE TopFun +instance RenameE TopFun +instance AlphaEqE TopFun +instance AlphaHashableE TopFun + +instance GenericE SpecializationSpec where + type RepE SpecializationSpec = + PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) + fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) + {-# INLINE fromE #-} + toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) + {-# INLINE toE #-} + +instance HasNameHint (SpecializationSpec n) where + getNameHint (AppSpecialization f _) = getNameHint f + +instance SinkableE SpecializationSpec +instance HoistableE SpecializationSpec +instance RenameE SpecializationSpec +instance AlphaEqE SpecializationSpec +instance AlphaHashableE SpecializationSpec + +instance GenericE LinearizationSpec where + type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) + fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) + {-# INLINE fromE #-} + toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives + {-# INLINE toE #-} + +instance SinkableE LinearizationSpec +instance HoistableE LinearizationSpec +instance RenameE LinearizationSpec +instance AlphaEqE LinearizationSpec +instance AlphaHashableE LinearizationSpec + +instance GenericE SolverBinding where + type RepE SolverBinding = EitherE3 + CType + CType + CType + fromE = \case + InfVarBound ty -> Case0 ty + SkolemBound ty -> Case1 ty + DictBound ty -> Case2 ty + {-# INLINE fromE #-} + + toE = \case + Case0 ty -> InfVarBound ty + Case1 ty -> SkolemBound ty + Case2 ty -> DictBound ty + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE SolverBinding +instance HoistableE SolverBinding +instance RenameE SolverBinding +instance AlphaEqE SolverBinding +instance AlphaHashableE SolverBinding + +instance GenericE (Binding c) where + type RepE (Binding c) = + EitherE3 + (EitherE6 + (WhenAtomName c AtomBinding) + (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) + (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) + (WhenC ClassNameC c (ClassDef)) + (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) + (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) + (EitherE4 + (WhenC TopFunNameC c (TopFun)) + (WhenC FunObjCodeNameC c (CFunction)) + (WhenC ModuleNameC c (Module)) + (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) + (EitherE2 + (WhenC SpecializedDictNameC c (SpecializedDictDef)) + (WhenC ImpNameC c (LiftE BaseType))) + + fromE = \case + AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding + TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods + DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx + ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef + InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty + MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx + TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun + FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun + ModuleBinding m -> Case1 $ Case2 $ WhenC $ m + PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) + SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def + ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty + {-# INLINE fromE #-} + + toE = \case + Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding + Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods + Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx + Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef + Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty + Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i + Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun + Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f + Case1 (Case2 (WhenC (m))) -> ModuleBinding m + Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p + Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def + Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty + _ -> error "impossible" + {-# INLINE toE #-} + +deriving via WrapE (Binding c) n instance Generic (Binding c n) +instance SinkableV Binding +instance HoistableV Binding +instance RenameV Binding +instance Color c => SinkableE (Binding c) +instance Color c => HoistableE (Binding c) +instance Color c => RenameE (Binding c) + +instance Semigroup (SynthCandidates n) where + SynthCandidates xs ys <> SynthCandidates xs' ys' = + SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') + +instance Monoid (SynthCandidates n) where + mempty = SynthCandidates mempty mempty + +instance GenericB EnvFrag where + type RepB EnvFrag = RecSubstFrag Binding + fromB (EnvFrag frag) = frag + toB frag = EnvFrag frag + +instance SinkableB EnvFrag +instance HoistableB EnvFrag +instance ProvesExt EnvFrag +instance BindsNames EnvFrag +instance RenameB EnvFrag + +instance GenericE TopEnvUpdate where + type RepE TopEnvUpdate = EitherE2 ( + EitherE4 + {- ExtendCache -} Cache + {- AddCustomRule -} (CAtomName `PairE` AtomRules) + {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) + {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) + ) ( EitherE6 + {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) + {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) + {- UpdateTyConDef -} (TyConName `PairE` TyConDef) + {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) + ) + fromE = \case + ExtendCache x -> Case0 $ Case0 x + AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) + UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) + UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) + FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) + LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) + UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) + UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) + UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) + UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) + + toE = \case + Case0 e -> case e of + Case0 x -> ExtendCache x + Case1 (x `PairE` y) -> AddCustomRule x y + Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y + Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y + _ -> error "impossible" + Case1 e -> case e of + Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y + Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y + Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y + Case3 (x `PairE` y) -> UpdateInstanceDef x y + Case4 (x `PairE` y) -> UpdateTyConDef x y + Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z + _ -> error "impossible" + _ -> error "impossible" + +instance SinkableE TopEnvUpdate +instance HoistableE TopEnvUpdate +instance RenameE TopEnvUpdate + +instance GenericB TopEnvFrag where + type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) + fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) + toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) + +instance RenameB TopEnvFrag +instance HoistableB TopEnvFrag +instance SinkableB TopEnvFrag +instance ProvesExt TopEnvFrag +instance BindsNames TopEnvFrag + +instance OutFrag TopEnvFrag where + emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty + {-# INLINE emptyOutFrag #-} + catOutFrags (TopEnvFrag frag1 env1 partial1) + (TopEnvFrag frag2 env2 partial2) = + withExtEvidence frag2 $ + TopEnvFrag + (catOutFrags frag1 frag2) + (sink env1 <> env2) + (sinkSnocList partial1 <> partial2) + {-# INLINE catOutFrags #-} + +-- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't +-- extend the synthesis candidates based on the annotated let-bound names. It +-- only extends synth candidates when they're supplied explicitly. +instance ExtOutMap Env TopEnvFrag where + extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do + let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates + Env newerTopEnv newModuleEnv + where + Env (TopEnv defs rules cache loadedM loadedO) mEnv = env + + newTopEnv = withExtEvidence frag $ TopEnv + (defs `extendRecSubst` frag) + (sink rules) (sink cache) (sink loadedM) (sink loadedO) + + newModuleEnv = + ModuleEnv + (imports <> imports') + (sm <> sm' <> newImportedSM) + (scs <> scs' <> newImportedSC) + where + ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv + ModuleEnv imports' sm' scs' = mEnv' + newDirectImports = S.difference (directImports imports') (directImports imports) + newTransImports = S.difference (transImports imports') (transImports imports) + newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure + newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure + + lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m + +applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n +applyUpdate e = \case + ExtendCache cache -> e { envCache = envCache e <> cache} + AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} + UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} + UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} + FinishDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName + case oldMethods of + Nothing -> do + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + Just _ -> error "shouldn't be adding methods if we already have them" + LowerDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + UpdateTopFunEvalStatus f s -> do + case lookupEnvPure e f of + TopFunBinding (DexTopFun def lam _) -> + updateEnv f (TopFunBinding $ DexTopFun def lam s) e + _ -> error "can't update ffi function impl" + UpdateInstanceDef name def -> do + case lookupEnvPure e name of + InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e + UpdateTyConDef name def -> do + let TyConBinding _ methods = lookupEnvPure e name + updateEnv name (TyConBinding (Just def) methods) e + UpdateFieldDef name sn x -> do + let TyConBinding def methods = lookupEnvPure e name + updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e + +updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n +updateEnv v rhs env = + env { envDefs = RecSubst $ updateSubstFrag v rhs bs } + where (RecSubst bs) = envDefs env + +lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n +lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v + +instance GenericE Module where + type RepE Module = LiftE ModuleSourceName + `PairE` ListE ModuleName + `PairE` ListE ModuleName + `PairE` SourceMap + `PairE` SynthCandidates + + fromE (Module name deps transDeps sm sc) = + LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) + `PairE` sm `PairE` sc + {-# INLINE fromE #-} + + toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps + `PairE` sm `PairE` sc) = + Module name (S.fromList deps) (S.fromList transDeps) sm sc + {-# INLINE toE #-} + +instance SinkableE Module +instance HoistableE Module +instance AlphaEqE Module +instance AlphaHashableE Module +instance RenameE Module + +instance GenericE ImportStatus where + type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName + fromE (ImportStatus direct trans) = ListE (S.toList direct) + `PairE` ListE (S.toList trans) + {-# INLINE fromE #-} + toE (ListE direct `PairE` ListE trans) = + ImportStatus (S.fromList direct) (S.fromList trans) + {-# INLINE toE #-} + +instance SinkableE ImportStatus +instance HoistableE ImportStatus +instance AlphaEqE ImportStatus +instance AlphaHashableE ImportStatus +instance RenameE ImportStatus + +instance Semigroup (ImportStatus n) where + ImportStatus direct trans <> ImportStatus direct' trans' = + ImportStatus (direct <> direct') (trans <> trans') + +instance Monoid (ImportStatus n) where + mappend = (<>) + mempty = ImportStatus mempty mempty + +instance GenericE LoadedModules where + type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) + fromE (LoadedModules m) = + ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) + {-# INLINE toE #-} + +instance SinkableE LoadedModules +instance HoistableE LoadedModules +instance AlphaEqE LoadedModules +instance AlphaHashableE LoadedModules +instance RenameE LoadedModules + +instance GenericE LoadedObjects where + type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) + fromE (LoadedObjects m) = + ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) + {-# INLINE toE #-} + +instance SinkableE LoadedObjects +instance HoistableE LoadedObjects +instance RenameE LoadedObjects + +instance GenericE ModuleEnv where + type RepE ModuleEnv = ImportStatus + `PairE` SourceMap + `PairE` SynthCandidates + fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc + {-# INLINE fromE #-} + toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc + {-# INLINE toE #-} + +instance SinkableE ModuleEnv +instance HoistableE ModuleEnv +instance AlphaEqE ModuleEnv +instance AlphaHashableE ModuleEnv +instance RenameE ModuleEnv + +instance Semigroup (ModuleEnv n) where + ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = + ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) + +instance Monoid (ModuleEnv n) where + mempty = ModuleEnv mempty mempty mempty + +instance Semigroup (LoadedModules n) where + LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) + +instance Monoid (LoadedModules n) where + mempty = LoadedModules mempty + +instance Semigroup (LoadedObjects n) where + LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) + +instance Monoid (LoadedObjects n) where + mempty = LoadedObjects mempty + + +-- === instance === + +prettyRecord :: [(String, Doc ann)] -> Doc ann +prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs + +instance Pretty (TopEnv n) where + pretty (TopEnv defs rules cache _ _) = + prettyRecord [ ("Defs" , pretty defs) + , ("Rules" , pretty rules) + , ("Cache" , pretty cache) ] + +instance Pretty (CustomRules n) where + pretty _ = "TODO: Rule printing" + +instance Pretty (ImportStatus n) where + pretty imports = pretty $ S.toList $ directImports imports + +instance Pretty (ModuleEnv n) where + pretty (ModuleEnv imports sm sc) = + prettyRecord [ ("Imports" , pretty imports) + , ("Source map" , pretty sm) + , ("Synth candidates", pretty sc) ] + +instance Pretty (Env n) where + pretty (Env env1 env2) = + prettyRecord [ ("Top env" , pretty env1) + , ("Module env", pretty env2)] + +instance Pretty (SolverBinding n) where + pretty (InfVarBound ty) = "Inference variable of type:" <+> pretty ty + pretty (SkolemBound ty) = "Skolem variable of type:" <+> pretty ty + pretty (DictBound ty) = "Dictionary variable of type:" <+> pretty ty + +instance Pretty (Binding c n) where + pretty b = case b of + -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` + -- TODO: can we avoid printing needing IRRep? Presumably it's related to + -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. + AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) + TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef + DataConBinding tyConName idx -> "Data constructor:" <+> + pretty tyConName <+> "Constructor index:" <+> pretty idx + ClassBinding classDef -> pretty classDef + InstanceBinding instanceDef _ -> pretty instanceDef + MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className + TopFunBinding f -> pretty f + FunObjCodeBinding _ -> "" + ModuleBinding _ -> "" + PtrBinding _ _ -> "" + SpecializedDictBinding _ -> "" + ImpNameBinding ty -> "Imp name of type: " <+> pretty ty + +instance Pretty (Module n) where + pretty m = prettyRecord + [ ("moduleSourceName" , pretty $ moduleSourceName m) + , ("moduleDirectDeps" , pretty $ S.toList $ moduleDirectDeps m) + , ("moduleTransDeps" , pretty $ S.toList $ moduleTransDeps m) + , ("moduleExports" , pretty $ moduleExports m) + , ("moduleSynthCandidates", pretty $ moduleSynthCandidates m) ] + +instance Pretty a => Pretty (EvalStatus a) where + pretty = \case + Waiting -> "" + Running -> "" + Finished a -> pretty a + +instance Pretty (EnvFrag n l) where + pretty (EnvFrag bindings) = pretty bindings + +instance Pretty (Cache n) where + pretty (Cache _ _ _ _ _ _) = "" -- TODO + +instance Pretty (SynthCandidates n) where + pretty scs = "instance dicts:" <+> pretty (M.toList $ instanceDicts scs) + +instance Pretty (LoadedModules n) where + pretty _ = "" + +instance Pretty (TopFunDef n) where + pretty = \case + Specialization s -> pretty s + LinearizationPrimal _ -> "" + LinearizationTangent _ -> "" + +instance Pretty (TopFun n) where + pretty = \case + DexTopFun def lam lowering -> + "Top-level Function" + <> hardline <+> "definition:" <+> pretty def + <> hardline <+> "lambda:" <+> pretty lam + <> hardline <+> "lowering:" <+> pretty lowering + FFITopFun f _ -> pretty f + +instance IRRep r => Pretty (TopLam r n) where + pretty (TopLam _ _ lam) = pretty lam + +instance IRRep r => Pretty (AtomBinding r n) where + pretty binding = case binding of + LetBound b -> pretty b + MiscBound t -> pretty t + SolverBound b -> pretty b + FFIFunBound s _ -> pretty s + NoinlineFun ty _ -> "Top function with type: " <+> pretty ty + TopDataBound (RepVal ty _) -> "Top data with type: " <+> pretty ty + +instance Pretty (SpecializationSpec n) where + pretty (AppSpecialization f (Abs bs (ListE args))) = + "Specialization" <+> pretty f <+> pretty bs <+> pretty args + +instance Hashable InfVarDesc +instance Hashable a => Hashable (EvalStatus a) + +instance Store (SolverBinding n) +instance IRRep r => Store (AtomBinding r n) +instance IRRep r => Store (TopLam r n) +instance Store (SynthCandidates n) +instance Store (Module n) +instance Store (ImportStatus n) +instance Store (TopFunLowerings n) +instance Store a => Store (EvalStatus a) +instance Store (TopFun n) +instance Store (TopFunDef n) +instance Color c => Store (Binding c n) +instance Store (ModuleEnv n) +instance Store (SerializedEnv n) +instance Store InfVarDesc +instance Store (AtomRules n) +instance Store (LinearizationSpec n) +instance Store (SpecializedDictDef n) +instance Store (SpecializationSpec n) diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 8a44e7234..4dbc43edc 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -26,6 +26,7 @@ import Data.Store (Store) import qualified Data.List.NonEmpty as NE import qualified Data.ByteString as BS import Data.Foldable +import Data.Text.Prettyprint.Doc (Pretty (..), pretty) import Data.List.NonEmpty (NonEmpty (..)) import GHC.Generics (Generic) @@ -354,6 +355,11 @@ zipTrees (Leaf x) (Leaf y) = Leaf (x, y) zipTrees (Branch xs) (Branch ys) | length xs == length ys = Branch $ zipWith zipTrees xs ys zipTrees _ _ = error "zip error" +instance Pretty a => Pretty (Tree a) where + pretty = \case + Leaf x -> pretty x + Branch xs -> pretty xs + -- === bytestrings paired with their hash digest === -- TODO: use something other than a string to store the digest diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index d6fec397e..90e289df9 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -9,7 +9,7 @@ module Vectorize (vectorizeLoops) where import Prelude hiding ((.)) import Data.Word import Data.Functor -import Data.Text.Prettyprint.Doc (Pretty, pretty, viaShow, (<+>)) +import Data.Text.Prettyprint.Doc (viaShow) import Control.Category import Control.Monad.Reader import Control.Monad.State.Strict @@ -26,6 +26,7 @@ import Subst import PPrint import QueryType import Types.Core +import Types.Top import Types.OpNames qualified as P import Types.Primitives import Util (allM, zipWithZ) From 6595db0dbd23bb0409229f37e99d7f1709f56c62 Mon Sep 17 00:00:00 2001 From: Dougal Date: Sun, 3 Dec 2023 13:59:45 -0500 Subject: [PATCH 36/41] Make a separate ADT case for each user-facing error message. This is preparation for giving better source information in error messages. --- src/dex.hs | 1 + src/lib/AbstractSyntax.hs | 39 ++-- src/lib/Algebra.hs | 3 +- src/lib/Builder.hs | 18 +- src/lib/CheapReduction.hs | 2 +- src/lib/CheckType.hs | 30 +-- src/lib/ConcreteSyntax.hs | 1 + src/lib/Err.hs | 391 +++++++++++++++++++++++++++-------- src/lib/Export.hs | 13 +- src/lib/Generalize.hs | 1 + src/lib/Imp.hs | 1 + src/lib/ImpToLLVM.hs | 1 - src/lib/Inference.hs | 175 +++++++--------- src/lib/Lexing.hs | 3 +- src/lib/Name.hs | 9 +- src/lib/QueryType.hs | 57 +++-- src/lib/RenderHtml.hs | 2 +- src/lib/Runtime.hs | 2 +- src/lib/RuntimePrint.hs | 1 + src/lib/Simplify.hs | 2 +- src/lib/SourceIdTraversal.hs | 1 + src/lib/SourceRename.hs | 43 ++-- src/lib/TopLevel.hs | 40 ++-- src/lib/Transpose.hs | 2 +- src/lib/Types/Source.hs | 13 +- src/lib/Types/Top.hs | 11 - src/lib/Vectorize.hs | 4 +- 27 files changed, 506 insertions(+), 360 deletions(-) diff --git a/src/dex.hs b/src/dex.hs index 623c2bdec..f8bce6475 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -32,6 +32,7 @@ import ConcreteSyntax (keyWordStrs, preludeImportBlock) import RenderHtml -- import Live.Terminal (runTerminal) import Live.Web (runWeb) +import PPrint hiding (hardline) import Core import Types.Core import Types.Imp diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index a21062ff1..6a4b25840 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -58,7 +58,7 @@ import Data.Text (Text) import ConcreteSyntax import Err import Name -import PPrint () +import PPrint import Types.Primitives import Types.Source import qualified Types.OpNames as P @@ -139,7 +139,7 @@ decl ann (WithSrcs sid _ d) = WithSrcB sid <$> case d of CLet binder rhs -> do (p, ty) <- patOptAnn binder ULet ann p ty <$> asExpr <$> block rhs - CBind _ _ -> throw SyntaxErr "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope." + CBind _ _ -> throw TopLevelArrowBinder CDefDecl def -> do (name, lam) <- aDef def return $ ULet ann (fromSourceNameW name) Nothing (WithSrcE sid (ULam lam)) @@ -199,7 +199,7 @@ withTrailingConstraints g cont = case g of Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs <- withTrailingConstraints lhs cont s <- case b of UBindSource s -> return s - UIgnore -> throw SyntaxErr "Can't constrain anonymous binders" + UIgnore -> throw CantConstrainAnonBinders UBind _ _ -> error "Shouldn't have internal names until renaming pass" c' <- expr c return $ UnaryNest (UAnnBinder expl (WithSrcB sid b) ann (cs ++ [c'])) @@ -261,7 +261,7 @@ uBinder :: GroupW -> SyntaxM (UBinder c VoidS VoidS) uBinder (WithSrcs sid _ b) = case b of CLeaf (CIdentifier name) -> return $ fromSourceNameW $ WithSrc sid name CLeaf CHole -> return $ WithSrcB sid UIgnore - _ -> throw SyntaxErr "Binder must be an identifier or `_`" + _ -> throw UnexpectedBinder -- Type annotation with an optional binder pattern tyOptPat :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) @@ -300,8 +300,7 @@ pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of CLeaf (CIdentifier name) -> return $ UPatBinder $ fromSourceNameW $ WithSrc sid name CJuxtapose True lhs rhs -> do case lhs of - WithSrcs _ _ (CJuxtapose True _ _) -> - throw SyntaxErr "Only unary constructors can form patterns without parens" + WithSrcs _ _ (CJuxtapose True _ _) -> throw OnlyUnaryWithoutParens _ -> return () name <- identifier "pattern constructor name" lhs arg <- pat rhs @@ -313,11 +312,11 @@ pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of gs' <- mapM pat gs return $ UPatCon (fromSourceNameW name) (toNest gs') _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - _ -> throw SyntaxErr "Illegal pattern" + _ -> throw IllegalPattern tyOptBinder :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) tyOptBinder expl (WithSrcs sid sids grp) = case grp of - CBin (WithSrc _ Pipe) _ _ -> throw SyntaxErr "Unexpected constraint" + CBin (WithSrc _ Pipe) _ _ -> throw UnexpectedConstraint CBin (WithSrc _ Colon) name ty -> do b <- uBinder name ann <- UAnn <$> expr ty @@ -341,7 +340,7 @@ binderReqTy expl (WithSrcs _ _ (CBin (WithSrc _ Colon) name ty)) = do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] -binderReqTy _ _ = throw SyntaxErr $ "Expected an annotated binder" +binderReqTy _ _ = throw ExpectedAnnBinder argList :: [GroupW] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) argList gs = partitionEithers <$> mapM singleArg gs @@ -355,7 +354,7 @@ singleArg = \case identifier :: String -> GroupW -> SyntaxM SourceNameW identifier ctx (WithSrcs sid _ g) = case g of CLeaf (CIdentifier name) -> return $ WithSrc sid name - _ -> throw SyntaxErr $ "Expected " ++ ctx ++ " to be an identifier" + _ -> throw $ ExpectedIdentifier ctx aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS) aEffects (WithSrcs _ _ (effs, optEffTail)) = do @@ -375,7 +374,7 @@ effect (WithSrcs _ _ grp) = case grp of return $ URWSEffect State $ fromSourceNameW (WithSrc sid h) CLeaf (CIdentifier "Except") -> return UExceptionEffect CLeaf (CIdentifier "IO" ) -> return UIOEffect - _ -> throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." + _ -> throw UnexpectedEffectForm aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS)) aMethod (WithSrcs _ _ CPass) = return Nothing @@ -386,7 +385,7 @@ aMethod (WithSrcs src _ d) = Just . WithSrcE src <$> case d of CLet (WithSrcs sid _ (CLeaf (CIdentifier name))) rhs -> do rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs return $ UMethodDef (fromSourceNameW (WithSrc sid name)) rhs' - _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." + _ -> throw UnexpectedMethodDef asExpr :: UBlock VoidS -> UExpr VoidS asExpr (WithSrcE src b) = case b of @@ -403,7 +402,7 @@ blockDecls :: [CSDeclW] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) blockDecls [] = error "shouldn't have empty list of decls" blockDecls [WithSrcs _ _ d] = case d of CExpr g -> (Empty,) <$> expr g - _ -> throw SyntaxErr "Block must end in expression" + _ -> throw BlockWithoutFinalExpr blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs @@ -428,7 +427,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of -- Table constructors here. Other uses of square brackets -- should be detected upstream, before calling expr. CBrackets gs -> UTabCon <$> mapM expr gs - CGivens _ -> throw SyntaxErr $ "Unexpected `given` clause" + CGivens _ -> throw UnexpectedGivenClause CArrow lhs effs rhs -> do case lhs of WithSrcs _ _ (CParens gs) -> do @@ -436,7 +435,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy - _ -> throw SyntaxErr "Argument types should be in parentheses" + _ -> throw ArgsShouldHaveParens CDo b -> UDo <$> block b CJuxtapose hasSpace lhs rhs -> case hasSpace of True -> extendAppRight <$> expr lhs <*> expr rhs @@ -459,7 +458,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of name <- case rhs' of CLeaf (CIdentifier name) -> return $ FieldName name CLeaf (CNat i ) -> return $ FieldNum $ fromIntegral i - _ -> throw SyntaxErr "Field must be a name or an integer" + _ -> throw BadField return $ UFieldAccess lhs' (WithSrc src name) DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs EvalBinOp s -> evalOp s @@ -467,14 +466,14 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of lhs' <- tyOptPat lhs UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs DepComma -> UDepPair <$> (expr lhs) <*> expr rhs - CSEqual -> throw SyntaxErr "Equal sign must be used as a separator for labels or binders, not a standalone operator" - Colon -> throw SyntaxErr "Colon separates binders from their type annotations, is not a standalone operator.\nIf you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" + CSEqual -> throw BadEqualSign + Colon -> throw BadColon ImplicitArrow -> case lhs of WithSrcs _ _ (CParens gs) -> do bs <- aPiBinders gs resultTy <- expr rhs return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy - _ -> throw SyntaxErr "Argument types should be in parentheses" + _ -> throw ArgsShouldHaveParens FatArrow -> do lhs' <- tyOptPat lhs UTabPi . (UTabPiExpr lhs') <$> expr rhs @@ -496,7 +495,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of WithSrcE _ (UIntLit i) -> UIntLit (-i) WithSrcE _ (UFloatLit i) -> UFloatLit (-i) e -> unaryApp (mkUVar sid "neg") e - _ -> throw SyntaxErr $ "Prefix (" ++ pprint name ++ ") not legal as a bare expression" + _ -> throw $ BadPrefix $ pprint name CLambda params body -> do params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index 5ecc05f76..1175d1523 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -27,6 +27,7 @@ import MTL1 import Name import Subst import QueryType +import PPrint import Types.Core import Types.Imp import Types.Primitives @@ -55,7 +56,7 @@ sumUsingPolys lim (Abs i body) = do sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do exprAsPoly body' >>= \case Just poly' -> return $ Abs i' poly' - Nothing -> throw NotImplementedErr $ + Nothing -> throwInternal $ "Algebraic simplification failed to model index computations:\n" ++ "Trying to sum from 0 to " ++ pprint lim ++ " - 1, \\" ++ pprint i' ++ "." ++ pprint body' diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 7415cec74..f3f790f00 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -17,7 +17,6 @@ import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M import Data.Foldable (fold) import Data.Graph (graphFromEdges, topSort) -import Data.Text.Prettyprint.Doc (Pretty (..)) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe @@ -30,6 +29,7 @@ import MTL1 import Subst import Name import PeepholeOptimize +import PPrint import QueryType import Types.Core import Types.Imp @@ -103,18 +103,6 @@ buildScopedAssumeNoDecls cont = do _ -> error "Expected no decl emissions" {-# INLINE buildScopedAssumeNoDecls #-} -withReducibleEmissions - :: (ScopableBuilder r m, Builder r m, HasNamesE e, SubstE AtomSubstVal e) - => String - -> (forall o' . (Emits o', DExt o o') => m o' (e o')) - -> m o (e o) -withReducibleEmissions msg cont = do - withDecls <- buildScoped cont - reduceWithDecls withDecls >>= \case - Just t -> return t - _ -> throw TypeErr msg -{-# INLINE withReducibleEmissions #-} - -- === "Hoisting" top-level builder class === -- `emitHoistedEnv` lets you emit top env fragments, like cache entries or @@ -926,10 +914,10 @@ symbolicTangentTy elTy = lookupSourceMap "SymbolicTangent" >>= \case Just (UTyConVar symTanName) -> do return $ toType $ UserADTType "SymbolicTangent" symTanName $ TyConParams [Explicit] [toAtom elTy] - Nothing -> throw UnboundVarErr $ + Nothing -> throwInternal $ "Can't define a custom linearization with symbolic zeros: " ++ "the SymbolicTangent type is not in scope." - Just _ -> throw TypeErr "SymbolicTangent should name a `data` type" + Just _ -> throwInternal $ "SymbolicTangent should name a `data` type" symbolicTangentZero :: EnvReader m => SType n -> m n (SAtom n) symbolicTangentZero argTy = return $ toAtom $ SumCon [UnitTy, argTy] 0 UnitVal diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index 8df743f24..41bd2f1d4 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -29,7 +29,7 @@ import Core import Err import IRVariants import Name -import PPrint () +import PPrint import QueryTypePure import Types.Core import Types.Top diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 9bcbf029d..580039a84 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -22,7 +22,7 @@ import IRVariants import MTL1 import Name import Subst -import PPrint () +import PPrint import QueryType import Types.Core import Types.Primitives @@ -56,7 +56,7 @@ affineUsed name = TyperM $ do case lookupNameMapE name affines of Just (LiftE n) -> if n > 0 then - throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." + throwInternal $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." else put $ insertNameMapE name (LiftE $ n + 1) affines Nothing -> put $ insertNameMapE name (LiftE 1) affines @@ -90,7 +90,7 @@ checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case False -> {-# SCC typeNormalization #-} do alphaEq reqTy ty >>= \case True -> return () - False -> throw TypeErr $ pprint reqTy ++ " != " ++ pprint ty + False -> throwInternal $ pprint reqTy ++ " != " ++ pprint ty {-# INLINE checkTypesEq #-} class SinkableE e => CheckableE (r::IR) (e::E) | e -> r where @@ -407,7 +407,7 @@ instance IRRep r => CheckableE r (Con r) where ProdCon xs -> ProdCon <$> mapM checkE xs SumCon tys tag payload -> do tys' <- mapM checkE tys - unless (0 <= tag && tag < length tys') $ throw TypeErr "Invalid SumType tag" + unless (0 <= tag && tag < length tys') $ throwInternal "Invalid SumType tag" payload' <- payload |: (tys' !! tag) return $ SumCon tys' tag payload' HeapVal -> return HeapVal @@ -570,7 +570,7 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where case (destTy', sourceTy) of (BaseTy dbt@(Scalar _), BaseTy sbt@(Scalar _)) | sizeOf sbt == sizeOf dbt -> return $ BitcastOp destTy' e' - _ -> throw TypeErr $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy + _ -> throwInternal $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy UnsafeCoerce t e -> UnsafeCoerce <$> checkE t <*> renameM e GarbageVal t -> GarbageVal <$> checkE t SumTag x -> do @@ -616,14 +616,14 @@ instance IRRep r => CheckableE r (VectorOp r) where TabTy _ b (BaseTy (Scalar sbt)) <- return $ getType tbl' i' <- i |: binderType b ty'@(BaseTy (Vector _ sbt')) <- checkE ty - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + unless (sbt == sbt') $ throwInternal "Scalar type mismatch" return $ VectorIdx tbl' i' ty' VectorSubref ref i ty -> do ref' <- checkE ref RefTy _ (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref' i' <- i |: binderType b ty'@(BaseTy (Vector _ sbt')) <- checkE ty - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + unless (sbt == sbt') $ throwInternal "Scalar type mismatch" return $ VectorSubref ref' i' ty' checkHof :: IRRep r => EffTy r o -> Hof r i -> TyperM r i o (Hof r o) @@ -706,7 +706,7 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkExtends effs effAnn' ixTy' <- checkE ixTy (carry', carryTy') <- checkAndGetType carry - let badCarry = throw TypeErr $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' + let badCarry = throwInternal $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' case carryTy' of TyCon (ProdType refTys) -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry _ -> badCarry @@ -773,7 +773,7 @@ checkProject i x = case getType x of TyCon (DepPairTy t) | i == 1 -> do xFst <- reduceProj 0 x checkInstantiation t [xFst] - xTy -> throw TypeErr $ "Not a product type:" ++ pprint xTy + xTy -> throwInternal $ "Not a product type:" ++ pprint xTy checkTabApp :: (IRRep r) => Type r o -> Atom r o -> TyperM r i o (Type r o) checkTabApp ty i = do @@ -794,7 +794,7 @@ checkInstantiation abTop xsTop = do checkTypesEq (getType x) (binderType b) rest <- applySubst (b@>SubstVal x) (Abs bs body) go rest xs - go _ _ = throw ZipErr "Wrong number of args" + go _ _ = throwInternal "Wrong number of args" checkIntBaseType :: Fallible m => BaseType -> m () checkIntBaseType t = case t of @@ -809,7 +809,7 @@ checkIntBaseType t = case t of Word32Type -> return () Word64Type -> return () _ -> notInt - notInt = throw TypeErr $ + notInt = throwInternal $ "Expected a fixed-width scalar integer type, but found: " ++ pprint t checkFloatBaseType :: Fallible m => BaseType -> m () @@ -822,13 +822,13 @@ checkFloatBaseType t = case t of Float64Type -> return () Float32Type -> return () _ -> notFloat - notFloat = throw TypeErr $ + notFloat = throwInternal $ "Expected a fixed-width scalar floating-point type, but found: " ++ pprint t checkValidCast :: (Fallible1 m, IRRep r) => Type r n -> Type r n -> m n () checkValidCast (TyCon (BaseType l)) (TyCon (BaseType r)) = checkValidBaseCast l r checkValidCast sourceTy destTy = - throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy + throwInternal $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy checkValidBaseCast :: Fallible m => BaseType -> BaseType -> m () checkValidBaseCast (PtrType _) (PtrType _) = return () @@ -838,13 +838,13 @@ checkValidBaseCast (Scalar _) (Scalar _) = return () checkValidBaseCast sourceTy@(Vector sourceSizes _) destTy@(Vector destSizes _) = assertEq sourceSizes destSizes $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy checkValidBaseCast sourceTy destTy = - throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy + throwInternal $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy scalarOrVectorLike :: Fallible m => BaseType -> ScalarBaseType -> m BaseType scalarOrVectorLike x sbt = case x of Scalar _ -> return $ Scalar sbt Vector sizes _ -> return $ Vector sizes sbt - _ -> throw CompilerErr "only scalar or vector base types should occur here" + _ -> throwInternal $ "only scalar or vector base types should occur here" data ArgumentType = SomeFloatArg | SomeIntArg | SomeUIntArg diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 70bc67b9e..6104c4df1 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -25,6 +25,7 @@ import Data.Void import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) +import Err import Lexing import Types.Core import Types.Source diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 51b34eb1f..649525f39 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -6,14 +6,15 @@ {-# LANGUAGE UndecidableInstances #-} -module Err (Err (..), ErrType (..), Except (..), - Fallible (..), Catchable (..), catchErrExcept, - HardFailM (..), runHardFail, throw, - catchIOExcept, liftExcept, liftExceptAlt, - assertEq, ignoreExcept, - pprint, docAsStr, getCurrentCallStack, printCurrentCallStack, - ExceptT (..) - ) where +module Err ( + Err (..), Except (..), ToErr (..), PrintableErr (..), + ParseErr (..), SyntaxErr (..), NameErr (..), TypeErr (..), MiscErr (..), + Fallible (..), Catchable (..), catchErrExcept, + HardFailM (..), runHardFail, throw, + catchIOExcept, liftExcept, liftExceptAlt, + ignoreExcept, getCurrentCallStack, printCurrentCallStack, + ExceptT (..), rootSrcId, SrcId (..), assertEq, throwInternal, + InferenceArgDesc, InfVarDesc (..)) where import Control.Exception hiding (throw) import Control.Applicative @@ -22,40 +23,285 @@ import Control.Monad.Identity import Control.Monad.Writer.Strict import Control.Monad.State.Strict import Control.Monad.Reader +import Data.Aeson (ToJSON, ToJSONKey) import Data.Coerce +import Data.Hashable +import Data.List (sort) import Data.Foldable (fold) -import Data.Text.Prettyprint.Doc +import Data.Store (Store (..)) import GHC.Stack +import GHC.Generics import PPrint --- === core API === - -data Err = Err ErrType String deriving (Show, Eq) - -data ErrType = NoErr - | ParseErr - | SyntaxErr - | TypeErr - | KindErr - | LinErr - | VarDefErr - | UnboundVarErr - | AmbiguousVarErr - | RepeatedVarErr - | RepeatedPatVarErr - | InvalidPatternErr - | CompilerErr - | IRVariantErr - | NotImplementedErr - | DataIOErr - | MiscErr - | RuntimeErr - | ZipErr - | EscapedNameErr - | ModuleImportErr - | SearchFailure -- used as the identity for `Alternative` instances and for MonadFail - deriving (Show, Eq) +-- === source info === + +-- XXX: 0 is reserved for the root The IDs are generated from left to right in +-- parsing order, so IDs for lexemes are guaranteed to be sorted correctly. +newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) + +rootSrcId :: SrcId +rootSrcId = SrcId 0 + +-- === core errro type === + +data Err = + SearchFailure String -- used as the identity for `Alternative` instances and for MonadFail. + | InternalErr String + | ParseErr ParseErr + | SyntaxErr SyntaxErr + | NameErr NameErr + | TypeErr TypeErr + | RuntimeErr + | MiscErr MiscErr + deriving (Show, Eq) + +type MsgStr = String +type VarStr = String +type TypeStr = String + +data ParseErr = + MiscParseErr MsgStr + deriving (Show, Eq) + +data SyntaxErr = + MiscSyntaxErr MsgStr + | TopLevelArrowBinder + | CantConstrainAnonBinders + | UnexpectedBinder + | OnlyUnaryWithoutParens + | IllegalPattern + | UnexpectedConstraint + | ExpectedIdentifier String + | UnexpectedEffectForm + | UnexpectedMethodDef + | BlockWithoutFinalExpr + | UnexpectedGivenClause + | ArgsShouldHaveParens + | BadEqualSign + | BadColon + | ExpectedAnnBinder + | BadField + | BadPrefix VarStr + deriving (Show, Eq) + +data NameErr = + MiscNameErr MsgStr + | UnboundVarErr VarStr -- name of var + | EscapedNameErr [VarStr] -- names + | RepeatedPatVarErr VarStr + | RepeatedVarErr VarStr + | NotAnOrdinaryVar VarStr + | NotADataCon VarStr + | NotAClassName VarStr + | NotAMethodName VarStr + | AmbiguousVarErr VarStr [String] + | VarDefErr VarStr + deriving (Show, Eq) + +data TypeErr = + MiscTypeErr MsgStr + | CantSynthDict TypeStr + | CantSynthInfVars TypeStr + | NotASynthType TypeStr + | CantUnifySkolem + | OccursCheckFailure VarStr TypeStr + | UnificationFailure TypeStr TypeStr [VarStr] -- expected, actual, inference vars + | DisallowedEffects String String -- allowed, actual + | InferEmptyTable + | ArityErr Int Int -- expected, actual + | PatternArityErr Int Int -- expected, actual + | SumTypeCantFail + | PatTypeErr String String -- expected type constructor (from pattern), actual type (from rhs) + | EliminationErr String String -- expected type constructor, actual type + | IllFormedCasePattern + | NotAMethod VarStr VarStr + | DuplicateMethod VarStr + | MissingMethod VarStr + | WrongArrowErr String String + | AnnotationRequired + | NotAUnaryConstraint TypeStr + | InterfacesNoImplicitParams + | RepeatedOptionalArgs [VarStr] + | UnrecognizedOptionalArgs [VarStr] [VarStr] + | NoFields TypeStr + | TypeMismatch TypeStr TypeStr -- TODO: should we merege this with UnificationFailure + | InferHoleErr + | InferDepPairErr + | InferEmptyCaseEff + | UnexpectedTerm String TypeStr + | CantFindField VarStr TypeStr [VarStr] -- field name, field type, known fields + | TupleLengthMismatch Int Int + | CantReduceType TypeStr + | CantReduceDict + | CantReduceDependentArg + | AmbiguousInferenceVar VarStr TypeStr InfVarDesc + | FFIResultTyErr TypeStr + | FFIArgTyNotScalar TypeStr + deriving (Show, Eq) + +data MiscErr = + MiscMiscErr MsgStr + | ModuleImportErr VarStr + | CantFindModuleSource VarStr + deriving (Show, Eq) + +-- name of function, name of arg +type InferenceArgDesc = (String, String) +data InfVarDesc = + ImplicitArgInfVar InferenceArgDesc + | AnnotationInfVar String -- name of binder + | TypeInstantiationInfVar String -- name of type + | MiscInfVar + deriving (Show, Generic, Eq, Ord) + +-- === ToErr class === + +class ToErr a where + toErr :: a -> Err + +instance ToErr Err where toErr = id +instance ToErr ParseErr where toErr = ParseErr +instance ToErr SyntaxErr where toErr = SyntaxErr +instance ToErr NameErr where toErr = NameErr +instance ToErr TypeErr where toErr = TypeErr +instance ToErr MiscErr where toErr = MiscErr + +-- === Error messages === + +class PrintableErr a where + printErr :: a -> String + +instance PrintableErr Err where + printErr = \case + SearchFailure s -> "Internal search failure: " ++ s + InternalErr s -> "Internal compiler error: " ++ s ++ "\n" ++ + "Please report this at github.com/google-research/dex-lang/issues\n" + ParseErr e -> "Parse error: " ++ printErr e + SyntaxErr e -> "Syntax error: " ++ printErr e + NameErr e -> "Name error: " ++ printErr e + TypeErr e -> "Type error: " ++ printErr e + MiscErr e -> "Error: " ++ printErr e + RuntimeErr -> "Runtime error" + +instance PrintableErr ParseErr where + printErr = \case + MiscParseErr s -> s + +instance PrintableErr SyntaxErr where + printErr = \case + MiscSyntaxErr s -> s + TopLevelArrowBinder -> + "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope." + CantConstrainAnonBinders -> "can't constrain anonymous binders" + UnexpectedBinder -> "binder must be an identifier or `_`" + OnlyUnaryWithoutParens ->"only unary constructors can form patterns without parens" + IllegalPattern -> "illegal pattern" + UnexpectedConstraint -> "unexpected constraint" + ExpectedIdentifier ctx -> "expected " ++ ctx ++ " to be an identifier" + UnexpectedEffectForm -> + "unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, " + ++ "or the name of a user-defined effect." + UnexpectedMethodDef -> "unexpected method definition. Expected `def` or `x = ...`." + BlockWithoutFinalExpr -> "block must end in expression" + UnexpectedGivenClause -> "unexpected `given` clause" + ArgsShouldHaveParens -> "argument types should be in parentheses" + BadEqualSign -> "equal sign must be used as a separator for labels or binders, not a standalone operator" + BadColon -> + "colon separates binders from their type annotations, is not a standalone operator.\n" + ++ " If you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" + ExpectedAnnBinder -> "expected an annotated binder" + BadField -> "field must be a name or an integer" + BadPrefix name -> "prefix (" ++ name ++ ") not legal as a bare expression" + +instance PrintableErr NameErr where + printErr = \case + MiscNameErr s -> s + UnboundVarErr v -> "variable not in scope: " ++ v + EscapedNameErr vs -> "leaked local variables: " ++ unwords vs + RepeatedPatVarErr v -> "variable already defined within pattern: " ++ v + RepeatedVarErr v -> "variable already defined : " ++ v + NotAnOrdinaryVar v -> "not an ordinary variable: " ++ v + NotADataCon v -> "not a data constructor: " ++ v + NotAClassName v -> "not a class name: " ++ v + NotAMethodName v -> "not a method name: " ++ v + -- we sort the lines to make the result a bit more deterministic for quine tests + AmbiguousVarErr v defs -> + "ambiguous occurrence: " ++ v ++ " is defined:\n" + ++ unlines (sort defs) + -- TODO: we see this message a lot. We should improve it by including more information. + -- Ideally we'd provide a link to where the original error happened." + VarDefErr v -> "error in (earlier) definition of variable: " ++ v + +instance PrintableErr TypeErr where + printErr = \case + MiscTypeErr s -> s + FFIResultTyErr t -> "FFI result type should be scalar or pair. Got: " ++ t + FFIArgTyNotScalar t -> "FFI function arguments should be scalar. Got: " ++ t + CantSynthDict t -> "can't synthesize a class dictionary for: " ++ t + CantSynthInfVars t -> "can't synthesize a class dictionary for a type with inference vars: " ++ t + NotASynthType t -> "can't synthesize terms of type: " ++ t + CantUnifySkolem -> "can't unify with skolem vars" + OccursCheckFailure v t -> "occurs check failure: " ++ v ++ " occurs in " ++ t + DisallowedEffects r1 r2 -> "\nAllowed: " ++ pprint r1 ++ + "\nRequested: " ++ pprint r2 + UnificationFailure t1 t2 vs -> "\nExpected: " ++ t1 + ++ "\nActual: " ++ t2 ++ case vs of + [] -> "" + _ -> "\n(Solving for: " ++ unwords vs ++ ")" + InferEmptyTable -> "can't infer type of empty table" + ArityErr n1 n2 -> "wrong number of positional arguments provided. Expected " ++ show n1 ++ " but got " ++ show n2 + PatternArityErr n1 n2 -> "unexpected number of pattern binders. Expected " ++ show n1 ++ " but got " ++ show n2 + SumTypeCantFail -> "sum type constructor in can't-fail pattern" + PatTypeErr patTy rhsTy -> "pattern is for a " ++ patTy ++ "but we're matching against a " ++ rhsTy + EliminationErr expected ty -> "expected a " ++ expected ++ ". Got a: " ++ ty + IllFormedCasePattern -> "case patterns must start with a data constructor or variant pattern" + NotAMethod method className -> "unexpected method: " ++ method ++ " is not a method of " ++ className + DuplicateMethod method -> "duplicate method: " ++ method + MissingMethod method -> "missing method: " ++ method + WrongArrowErr expected actual -> "wrong arrow. Expected " ++ expected ++ " got " ++ actual + AnnotationRequired -> "type annotation or constraint required" + NotAUnaryConstraint ty -> "constraint should be a unary function. Got: " ++ ty + InterfacesNoImplicitParams -> "interfaces can't have implicit parameters" + RepeatedOptionalArgs vs -> "repeated names offered:" ++ unwords vs + UnrecognizedOptionalArgs vs accepted -> "unrecognized named arguments: " ++ unwords vs + ++ ". Should be one of: " ++ pprint accepted + NoFields ty -> "can't get fields for type " ++ pprint ty + TypeMismatch expected actual -> "\nExpected: " ++ expected ++ + "\nActual: " ++ actual + InferHoleErr -> "can't infer value of hole" + InferDepPairErr -> "can't infer the type of a dependent pair; please annotate its type" + InferEmptyCaseEff -> "can't infer empty case expressions" + UnexpectedTerm term ty -> "unexpected " ++ term ++ ". Expected: " ++ ty + CantFindField field fieldTy knownFields -> + "can't resolve field " ++ field ++ " of type " ++ fieldTy ++ + "\nKnown fields are: " ++ unwords knownFields + TupleLengthMismatch req actual -> do + "tuple length mismatch. Expected: " ++ show req ++ " but got " ++ show actual + CantReduceType ty -> "Can't reduce type expression: " ++ ty + CantReduceDict -> "Can't reduce dict" + CantReduceDependentArg -> + "dependent functions can only be applied to fully evaluated expressions. " ++ + "Bind the argument to a name before you apply the function." + AmbiguousInferenceVar infVar ty desc -> case desc of + AnnotationInfVar v -> + "couldn't infer type of unannotated binder " <> v + ImplicitArgInfVar (f, argName) -> + "couldn't infer implicit argument `" <> argName <> "` of " <> f + TypeInstantiationInfVar t -> + "couldn't infer instantiation of type " <> t + MiscInfVar -> + "ambiguous type variable: " ++ infVar ++ ": " ++ ty + +instance PrintableErr MiscErr where + printErr = \case + MiscMiscErr s -> s + ModuleImportErr v -> "couldn't import " ++ v + CantFindModuleSource v -> + "couldn't find a source file for module " ++ v ++ + "\nHint: Consider extending --lib-path" + +-- === monads and helpers === class MonadFail m => Fallible m where throwErr :: Err -> m a @@ -68,7 +314,7 @@ catchErrExcept m = catchErr (Success <$> m) (\e -> return $ Failure e) catchSearchFailure :: Catchable m => m a -> m (Maybe a) catchSearchFailure m = (Just <$> m) `catchErr` \case - Err SearchFailure _ -> return Nothing + SearchFailure _ -> return Nothing err -> throwErr err instance Fallible IO where @@ -104,7 +350,7 @@ instance Monad m => Monad (ExceptT m) where {-# INLINE (>>=) #-} instance Monad m => MonadFail (ExceptT m) where - fail s = ExceptT $ return $ Failure $ Err SearchFailure s + fail s = ExceptT $ return $ Failure $ SearchFailure s {-# INLINE fail #-} instance Monad m => Fallible (ExceptT m) where @@ -112,7 +358,7 @@ instance Monad m => Fallible (ExceptT m) where {-# INLINE throwErr #-} instance Monad m => Alternative (ExceptT m) where - empty = throw SearchFailure "" + empty = throwErr $ SearchFailure "" {-# INLINE empty #-} m1 <|> m2 = do catchSearchFailure m1 >>= \case @@ -164,7 +410,7 @@ instance Monad Except where {-# INLINE (>>=) #-} instance Alternative Except where - empty = throw SearchFailure "" + empty = throwErr $ SearchFailure "" {-# INLINE empty #-} m1 <|> m2 = do catchSearchFailure m1 >>= \case @@ -218,8 +464,8 @@ instance Fallible HardFailM where -- === convenience layer === -throw :: Fallible m => ErrType -> String -> m a -throw errTy s = throwErr $ Err errTy s +throw :: (ToErr e, Fallible m) => e -> m a +throw e = throwErr $ toErr e {-# INLINE throw #-} getCurrentCallStack :: () -> Maybe [String] @@ -240,12 +486,12 @@ printCurrentCallStack (Just frames) = fold frames catchIOExcept :: MonadIO m => IO a -> m (Except a) catchIOExcept m = liftIO $ (liftM Success m) `catches` [ Handler \(e::Err) -> return $ Failure e - , Handler \(e::IOError) -> return $ Failure $ Err DataIOErr $ show e + , Handler \(e::IOError) -> return $ Failure $ MiscErr $ MiscMiscErr $ show e -- Propagate asynchronous exceptions like ThreadKilled; they are -- part of normal operation (of the live evaluation modes), not -- compiler bugs. , Handler \(e::AsyncException) -> liftIO $ throwIO e - , Handler \(e::SomeException) -> return $ Failure $ Err CompilerErr $ show e + , Handler \(e::SomeException) -> return $ Failure $ InternalErr $ show e ] liftExcept :: Fallible m => Except a -> m a @@ -266,10 +512,12 @@ ignoreExcept (Success x) = x assertEq :: (HasCallStack, Fallible m, Show a, Pretty a, Eq a) => a -> a -> String -> m () assertEq x y s = if x == y then return () - else throw CompilerErr msg + else throwInternal msg where msg = "assertion failure (" ++ s ++ "):\n" - ++ pprint x ++ " != " ++ pprint y ++ "\n\n" - ++ prettyCallStack callStack ++ "\n" + ++ pprint x ++ " != " ++ pprint y + +throwInternal :: (HasCallStack, Fallible m) => String -> m a +throwInternal s = throwErr $ InternalErr $ s ++ "\n" ++ prettyCallStack callStack ++ "\n" instance (Monoid w, Fallible m) => Fallible (WriterT w m) where throwErr errs = lift $ throwErr errs @@ -290,49 +538,11 @@ instance Fallible Except where {-# INLINE throwErr #-} instance MonadFail Except where - fail s = Failure $ Err SearchFailure s + fail s = Failure $ SearchFailure s {-# INLINE fail #-} instance Exception Err -instance Pretty Err where - pretty (Err e s) = pretty e <> pretty s - -instance Pretty a => Pretty (Except a) where - pretty (Success x) = "Success:" <+> pretty x - pretty (Failure e) = "Failure:" <+> pretty e - -instance Pretty ErrType where - pretty e = case e of - -- NoErr tags a chunk of output that was promoted into the Err ADT - -- by appending Results. - NoErr -> "" - ParseErr -> "Parse error:" - SyntaxErr -> "Syntax error: " - TypeErr -> "Type error:" - KindErr -> "Kind error:" - LinErr -> "Linearity error: " - IRVariantErr -> "Internal IR validation error: " - VarDefErr -> "Error in (earlier) definition of variable: " - UnboundVarErr -> "Error: variable not in scope: " - AmbiguousVarErr -> "Error: ambiguous variable: " - RepeatedVarErr -> "Error: variable already defined: " - RepeatedPatVarErr -> "Error: variable already defined within pattern: " - InvalidPatternErr -> "Error: not a valid pattern: " - NotImplementedErr -> - "Not implemented:" <> line <> - "Please report this at github.com/google-research/dex-lang/issues\n" <> line - CompilerErr -> - "Compiler bug!" <> line <> - "Please report this at github.com/google-research/dex-lang/issues\n" <> line - DataIOErr -> "IO error: " - MiscErr -> "Error:" - RuntimeErr -> "Runtime error" - ZipErr -> "Zipping error" - EscapedNameErr -> "Leaked local variables:" - ModuleImportErr -> "Module import error: " - SearchFailure -> "Search error (internal error)" - instance Fallible m => Fallible (ReaderT r m) where throwErr errs = lift $ throwErr errs {-# INLINE throwErr #-} @@ -348,3 +558,12 @@ instance Fallible m => Fallible (StateT s m) where instance Catchable m => Catchable (StateT s m) where StateT f `catchErr` handler = StateT \s -> f s `catchErr` \e -> runStateT (handler e) s + +instance Pretty Err where + pretty e = pretty $ printErr e + +instance ToJSON SrcId +deriving instance ToJSONKey SrcId + +instance Hashable InfVarDesc +instance Store InfVarDesc diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 1108507fb..854595720 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -21,6 +21,7 @@ import Foreign.Ptr import Builder import Core import Err +import PPrint import IRVariants import Name import QueryType @@ -47,11 +48,11 @@ prepareFunctionForExport :: (Mut n, Topper m) prepareFunctionForExport cc f = do naryPi <- case getType f of TyCon (Pi piTy) -> return piTy - _ -> throw TypeErr "Only first-order functions can be exported" + _ -> throw $ MiscMiscErr "Only first-order functions can be exported" sig <- liftExportSigM $ corePiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throw $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (toAtom <$> xs) fSimp <- simplifyTopFunction $ coreLamToTopLam f' @@ -67,7 +68,7 @@ prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throw $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s fImp <- compileTopLevelFun cc f nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -104,7 +105,7 @@ corePiToExportSig :: CallingConvention corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw TypeErr "Only pure functions can be exported" + _ -> throw $ MiscMiscErr "Only pure functions can be exported" goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention @@ -112,7 +113,7 @@ simpPiToExportSig :: CallingConvention simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw TypeErr "Only pure functions can be exported" + _ -> throw $ MiscMiscErr "Only pure functions can be exported" bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy @@ -163,7 +164,7 @@ toExportType ty = case ty of Nothing -> unsupported Just ety -> return ety _ -> unsupported - where unsupported = throw TypeErr $ "Unsupported type of argument in exported function: " ++ pprint ty + where unsupported = throw $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty {-# INLINE toExportType #-} parseTabTy :: IRRep r => Type r i -> ExportSigM r i o (Maybe (ExportType o)) diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 7ace599c4..fc120af2e 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -11,6 +11,7 @@ import Data.Maybe (fromJust) import Core import Err +import PPrint import Types.Core import Inference import IRVariants diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 07fada480..18f3f7004 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -37,6 +37,7 @@ import Err import IRVariants import MTL1 import Name +import PPrint import Subst import QueryType import Types.Core diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index f333c1f08..e5736b35e 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -45,7 +45,6 @@ import qualified Data.Set as S import CUDA (getCudaArchitecture) import Core -import Err import Imp import LLVM.CUDA (LLVMKernel (..), compileCUDAKernel, ptxDataLayout, ptxTargetTriple) import Subst diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index bea45b654..67227f3dc 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -22,7 +22,6 @@ import Data.Foldable (toList, asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) -import Data.Text.Prettyprint.Doc (Pretty (..)) import Data.Word import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M @@ -38,6 +37,7 @@ import MonadUtil import MTL1 import Name import Subst +import PPrint import QueryType import Types.Core import Types.Imp @@ -214,17 +214,6 @@ applySolverSubst subst e = do return $ fmapNames env (lookupSolverSubst subst) e {-# INLINE applySolverSubst #-} -formatAmbiguousVarErr :: CAtomName n -> CType n' -> InfVarDesc -> String -formatAmbiguousVarErr infVar ty = \case - AnnotationInfVar v -> - "Couldn't infer type of unannotated binder " <> v - ImplicitArgInfVar (f, argName) -> - "Couldn't infer implicit argument `" <> argName <> "` of " <> f - TypeInstantiationInfVar t -> - "Couldn't infer instantiation of type " <> t - MiscInfVar -> - "Ambiguous type variable: " ++ pprint infVar ++ ": " ++ pprint ty - withFreshBinderInf :: NameHint -> Explicitness -> CType o -> InfererCPSB CBinder i o a withFreshBinderInf hint expl ty cont = withFreshBinder hint ty \b -> do @@ -290,7 +279,7 @@ withFreshUnificationVar desc k cont = do ans <- toAtomVar v >>= cont soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case Just soln -> return soln - Nothing -> throw TypeErr $ formatAmbiguousVarErr v k desc + Nothing -> throw $ AmbiguousInferenceVar (pprint v) (pprint k) desc return (ans, soln) {-# INLINE withFreshUnificationVar #-} @@ -362,6 +351,18 @@ emitTypeInfo sid ty = do InfererM $ liftSubstReaderT $ lift11 $ lift1 $ lift do modify \(TypeInfo m) -> TypeInfo $ M.insert sid ty m +withReducibleEmissions + :: (HasNamesE e, SubstE AtomSubstVal e, ToErr err) + => err + -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) + -> InfererM i o (e o) +withReducibleEmissions msg cont = do + withDecls <- buildScoped cont + reduceWithDecls withDecls >>= \case + Just t -> return t + _ -> throw msg +{-# INLINE withReducibleEmissions #-} + -- === actual inference pass === data RequiredTy (n::S) = @@ -420,13 +421,13 @@ topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CA topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of ULam lamExpr -> case reqTy of TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam lamExpr piTy - _ -> throw TypeErr $ "Unexpected lambda. Expected: " ++ pprint reqTy + _ -> throw $ UnexpectedTerm "lambda" (pprint reqTy) UFor dir uFor -> case reqTy of TyCon (TabPi tabPiTy) -> do lam@(UnaryLamExpr b' _) <- checkUForExpr uFor tabPiTy ixTy <- asIxType $ binderType b' emitHof $ For dir ixTy lam - _ -> throw TypeErr $ "Unexpected `for` expression. Expected: " ++ pprint reqTy + _ -> throw $ UnexpectedTerm "`for` expression" (pprint reqTy) UApp f posArgs namedArgs -> do f' <- bottomUpExplicit f checkOrInferApp f' posArgs namedArgs (Check reqTy) @@ -436,7 +437,7 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of rhsTy <- instantiate ty [lhs'] rhs' <- topDown rhsTy rhs return $ toAtom $ DepPair lhs' rhs' ty - _ -> throw TypeErr $ "Unexpected dependent pair. Expected: " ++ pprint reqTy + _ -> throw $ UnexpectedTerm "dependent pair" (pprint reqTy) UCase scrut alts -> do scrut' <- bottomUp scrut let scrutTy = getType scrut' @@ -446,15 +447,15 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of UTabCon xs -> do case reqTy of TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy xs - _ -> throw TypeErr $ "Unexpected table constructor. Expected: " ++ pprint reqTy + _ -> throw $ UnexpectedTerm "table constructor" (pprint reqTy) UNatLit x -> fromNatLit x reqTy UIntLit x -> fromIntLit x reqTy UPrim UTuple xs -> case reqTy of TyKind -> toAtom . ProdType <$> mapM checkUType xs TyCon (ProdType reqTys) -> do - when (length reqTys /= length xs) $ throw TypeErr "Tuple length mismatch" + when (length reqTys /= length xs) $ throw $ TupleLengthMismatch (length reqTys) (length xs) toAtom <$> ProdCon <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x - _ -> throw TypeErr $ "Unexpected tuple. Expected: " ++ pprint reqTy + _ -> throw $ UnexpectedTerm "tuple" (pprint reqTy) UFieldAccess _ _ -> infer UVar _ -> infer UTypeAnn _ _ -> infer @@ -465,7 +466,7 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of UPi _ -> infer UTabPi _ -> infer UDepPairTy _ -> infer - UHole -> throw TypeErr "Can't infer value of hole" + UHole -> throw InferHoleErr where infer :: InfererM i o (CAtom o) infer = do @@ -495,9 +496,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of method' <- toAtomVar method resultTy <- partialAppType (getType method') (params ++ [x']) return $ SigmaPartialApp resultTy (toAtom method') (params ++ [x']) - Nothing -> throw TypeErr $ - "Can't resolve field " ++ pprint field ++ " of type " ++ pprint ty ++ - "\nKnown fields are: " ++ pprint (M.keys fields) + Nothing -> throw $ CantFindField (pprint field) (pprint ty) (map pprint $ M.keys fields) ULam lamExpr -> SigmaAtom Nothing <$> toAtom <$> inferULam lamExpr UFor dir uFor -> do lam@(UnaryLamExpr b' _) <- inferUForExpr uFor @@ -525,8 +524,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of withUBinder b \(WithAttrB _ b') -> do rhs' <- checkUType rhs return $ SigmaAtom Nothing $ toAtom $ DepPairTy $ DepPairType expl b' rhs' - UDepPair _ _ -> throw TypeErr $ - "Can't infer the type of a dependent pair; please annotate its type" + UDepPair _ _ -> throw InferDepPairErr UCase scrut (alt:alts) -> do scrut' <- bottomUp scrut let scrutTy = getType scrut' @@ -536,7 +534,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of resultTy <- liftHoistExcept $ hoist b ty alts' <- mapM (checkCaseAlt (Check resultTy) scrutTy) alts SigmaAtom Nothing <$> buildSortedCase scrut' (alt':alts') resultTy - UCase _ [] -> throw TypeErr "Can't infer empty case expressions" + UCase _ [] -> throw InferEmptyCaseEff UDo block -> withBlockDecls block \result -> bottomUpExplicit result UTabCon xs -> liftM (SigmaAtom Nothing) $ inferTabCon xs UTypeAnn val ty -> do @@ -548,7 +546,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of UPrim UMonoLiteral [WithSrcE _ l] -> case l of UIntLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Int32Lit $ fromIntegral x UNatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Word32Lit $ fromIntegral x - _ -> throw MiscErr "argument to %monoLit must be a literal" + _ -> throwInternal "argument to %monoLit must be a literal" UPrim UExplicitApply (f:xs) -> do f' <- bottomUpExplicit f xs' <- mapM bottomUp xs @@ -567,13 +565,12 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit l NatTy UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit l (BaseTy $ Scalar Int32Type) UFloatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Float32Lit $ realToFrac x - UHole -> throw TypeErr "Can't infer value of hole" + UHole -> throw InferHoleErr expectEq :: (PrettyE e, AlphaEqE e) => e o -> e o -> InfererM i o () expectEq reqTy actualTy = alphaEq reqTy actualTy >>= \case True -> return () - False -> throw TypeErr $ "Expected: " ++ pprint reqTy ++ - "\nActual: " ++ pprint actualTy + False -> throw $ TypeMismatch (pprint reqTy) (pprint actualTy) {-# INLINE expectEq #-} fromIntLit :: Emits o => Int -> CType o -> InfererM i o (CAtom o) @@ -692,7 +689,7 @@ data FieldDef (n::S) = getFieldDefs :: CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) getFieldDefs ty = case ty of - StuckTy _ _ -> noFields "" + StuckTy _ _ -> noFields TyCon con -> case con of NewtypeTyCon (UserADTType _ tyName params) -> do TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName @@ -703,9 +700,9 @@ getFieldDefs ty = case ty of let methodFields = M.toList dotMethods <&> \(field, f) -> (FieldName field, FieldDotMethod f params) return $ M.fromList $ concat projFields ++ methodFields - ADTCons _ -> noFields "" + ADTCons _ -> noFields RefType _ valTy -> case valTy of - RefTy _ _ -> noFields "" + RefTy _ _ -> noFields _ -> do valFields <- getFieldDefs valTy return $ M.filter isProj valFields @@ -713,10 +710,9 @@ getFieldDefs ty = case ty of FieldProj _ -> True _ -> False ProdType ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) - TabPi _ -> noFields "\nArray indexing uses [] now." - _ -> noFields "" - where - noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s + TabPi _ -> noFields + _ -> noFields + where noFields = throw $ NoFields $ pprint ty projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of @@ -773,7 +769,7 @@ checkOrInferApp f' posArgs namedArgs reqTy = do args <- inferMixedArgs fDesc expls bsConstrained (posArgs, namedArgs) applySigmaAtom f args ImplicitApp -> error "should already have handled this case" - ty -> throw TypeErr $ "Expected a function type. Got: " ++ pprint ty + ty -> throw $ EliminationErr "function type" (pprint ty) where fDesc :: SourceName fDesc = getSourceName f' @@ -891,9 +887,7 @@ checkExplicitArity :: [Explicitness] -> [a] -> InfererM i o () checkExplicitArity expls args = do let arity = length [() | Explicit <- expls] let numArgs = length args - when (numArgs /= arity) do - throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ - pprint arity ++ " but got " ++ pprint numArgs + when (numArgs /= arity) $ throw $ ArityErr arity numArgs type MixedArgs arg = ([arg], [(SourceName, arg)]) -- positional args, named args data Constraint (n::S) = @@ -1022,12 +1016,10 @@ checkNamedArgValidity expls offeredNames = do Inferred v _ -> v let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames - when (not $ null duplicates) do - throw TypeErr $ "Repeated names offered" ++ pprint duplicates + when (not $ null duplicates) $ throw $ RepeatedOptionalArgs $ map pprint duplicates let unrecognizedNames = filter (not . (`elem` acceptedNames)) offeredNames when (not $ null unrecognizedNames) do - throw TypeErr $ "Unrecognized named arguments: " ++ pprint unrecognizedNames - ++ "\nShould be one of: " ++ pprint acceptedNames + throw $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do @@ -1035,7 +1027,7 @@ inferPrimArg x = do case getType xBlock of TyKind -> reduceExpr xBlock >>= \case Just reduced -> return reduced - _ -> throw CompilerErr "Type args to primops must be reducible" + _ -> throwInternal "Type args to primops must be reducible" _ -> emit xBlock matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) @@ -1077,7 +1069,7 @@ matchPrimApp = \case combiner' <- lam2 combiner f' <- lam2 f emitHof $ RunWriter Nothing (BaseMonoid idVal combiner') f' - p -> \case xs -> throw TypeErr $ "Bad primitive application: " ++ show (p, xs) + p -> \case xs -> throwInternal $ "Bad primitive application: " ++ show (p, xs) where lam2 :: Fallible m => CAtom n -> m (LamExpr CoreIR n) lam2 x = do @@ -1127,14 +1119,11 @@ inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of resultTy' <- applySubst (b @> SubstVal arg') resultTy rest' <- inferNaryTabAppArgs resultTy' rest return $ arg':rest' - _ -> throw TypeErr $ "Expected a table type but got: " ++ pprint tabTy + _ -> throw $ EliminationErr "table type" (pprint tabTy) checkSigmaDependent :: UExpr i -> PartialType o -> InfererM i o (CAtom o) -checkSigmaDependent e ty = withReducibleEmissions depFunErrMsg $ topDownPartial (sink ty) e - where - depFunErrMsg = - "Dependent functions can only be applied to fully evaluated expressions. " ++ - "Bind the argument to a name before you apply the function." +checkSigmaDependent e ty = withReducibleEmissions CantReduceDependentArg $ + topDownPartial (sink ty) e -- === sorting case alternatives === @@ -1285,7 +1274,7 @@ inferClassDef className methodNames paramBs methodTys = do PairB paramBs'' superclassBs <- partitionBinders (zipAttrs roleExpls paramBs') $ \b@(WithAttrB (_, expl) b') -> case expl of Explicit -> return $ LeftB b - Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" + Inferred _ Unify -> throw InterfacesNoImplicitParams Inferred _ (Synth _) -> return $ RightB b' let (roleExpls', paramBs''') = unzipAttrs paramBs'' builtinName <- case className of @@ -1354,16 +1343,15 @@ inferAnn ann cs = case ann of WithSrcE _ (UVar ~(InternalName _ _ v)):_ -> do renameM v >>= getUVarType >>= \case TyCon (Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _)) -> return ty - ty -> throw TypeErr $ "Constraint should be a unary function. Got: " ++ pprint ty - _ -> throw TypeErr "Type annotation or constraint required" + ty -> throw $ NotAUnaryConstraint $ pprint ty + _ -> throw AnnotationRequired checkULamPartial :: PartialPiType o -> ULamExpr i -> InfererM i o (CoreLamExpr o) checkULamPartial partialPiTy lamExpr = do PartialPiType piAppExpl expls piBs piEffs piReqTy <- return partialPiTy ULamExpr lamBs lamAppExpl lamEffs lamResultTy body <- return lamExpr checkExplicitArity expls (nestToList (const ()) lamBs) - when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " - ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl + when (piAppExpl /= lamAppExpl) $ throw $ WrongArrowErr (pprint piAppExpl) (pprint lamAppExpl) checkLamBinders expls piBs lamBs \lamBs' -> do PairE piEffs' piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) (PairE piEffs piReqTy) resultTy <- case (lamResultTy, piReqTy') of @@ -1470,10 +1458,8 @@ checkInstanceBody className params methods = do ListE methodTys'' <- applySubst (scBs'@@>(SubstVal<$>superclassDicts)) methodTys' methodsChecked <- mapM (checkMethodDef className methodTys'') methods let (idxs, methods') = unzip $ sortOn fst $ methodsChecked - forM_ (repeated idxs) \i -> - throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i) - forM_ ([0..(length methodTys'' - 1)] `listDiff` idxs) \i -> - throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i) + forM_ (repeated idxs) \i -> throw $ DuplicateMethod $ pprint (methodNames!!i) + forM_ ([0..(length methodTys''-1)] `listDiff` idxs) \i -> throw $ MissingMethod $ pprint (methodNames!!i) return $ InstanceBody superclassDicts methods' superclassDictTys :: Nest CBinder o o' -> InfererM i o [CType o] @@ -1488,7 +1474,7 @@ checkMethodDef className methodTys (WithSrcE _ m) = do MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do ClassBinding classDef <- lookupEnv className - throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint (getSourceName classDef) + throw $ NotAMethod (pprint sourceName) (pprint $ getSourceName classDef) (i,) <$> toAtom <$> Lam <$> checkULam rhs (methodTys !! i) checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) @@ -1525,7 +1511,7 @@ getCaseAltIndex (WithSrcB _ pat) = case pat of UPatCon ~(InternalName _ _ conName) _ -> do (_, con) <- renameM conName >>= lookupDataCon return con - _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" + _ -> throw IllFormedCasePattern checkCasePat :: Emits o @@ -1539,15 +1525,13 @@ checkCasePat (WithSrcB _ pat) scrutineeTy cont = case pat of params <- inferParams scrutineeTy dataDefName ADTCons cons <- instantiateTyConDef tyConDef params DataConDef _ _ repTy idxs <- return $ cons !! con - when (length idxs /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxs) - ++ " got " ++ show (nestLength ps) + when (length idxs /= nestLength ps) $ throw $ PatternArityErr (length idxs) (nestLength ps) withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do buildBlock do args <- forM idxs \projs -> do emitToVar =<< applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) bindLetPats ps args $ cont - _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" + _ -> throw IllFormedCasePattern inferParams :: Emits o => CType o -> TyConName o -> InfererM i o (TyConParams o) inferParams ty dataDefName = do @@ -1582,13 +1566,13 @@ bindLetPat (WithSrcB _ pat) v cont = case pat of let n = nestLength ps case getType v of TyCon (ProdType ts) | length ts == n -> return () - ty -> throw TypeErr $ "Expected a product type but got: " ++ pprint ty + ty -> throw $ PatTypeErr "product type" (pprint ty) xs <- forM (iota n) \i -> proj i (toAtom v) >>= emitInline bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do case getType v of TyCon (DepPairTy _) -> return () - ty -> throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty + ty -> throw $ PatTypeErr "dependent pair" (pprint ty) -- XXX: we're careful here to reduce the projection because of the dependent -- types. We do the same in the `UPatCon` case. x1 <- reduceProj 0 (toAtom v) >>= emitInline @@ -1601,18 +1585,17 @@ bindLetPat (WithSrcB _ pat) v cont = case pat of TyConDef _ _ _ cons <- lookupTyCon dataDefName case cons of ADTCons [DataConDef _ _ _ idxss] -> do - when (length idxss /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxss) - ++ " got " ++ show (nestLength ps) + when (length idxss /= nestLength ps) $ + throw $ PatternArityErr (length idxss) (nestLength ps) void $ inferParams (getType $ toAtom v) dataDefName xs <- forM idxss \idxs -> applyProjectionsReduced idxs (toAtom v) >>= emitInline bindLetPats ps xs cont - _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" + _ -> throw SumTypeCantFail UPatTable ps -> do let n = fromIntegral (nestLength ps) :: Word32 case getType v of TyCon (TabPi (TabPiType _ (_:>FinConst n') _)) | n == n' -> return () - ty -> throw TypeErr $ "Expected a Fin " ++ show n ++ " table type but got: " ++ pprint ty + ty -> throw $ PatTypeErr ("Fin " ++ show n ++ " table") (pprint ty) xs <- forM [0 .. n - 1] \i -> do emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont @@ -1628,14 +1611,14 @@ checkUType t = do checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) checkUParam k uty = withReducibleEmissions msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty - where msg = "Can't reduce type expression: " ++ pprint uty + where msg = CantReduceType $ pprint uty inferTabCon :: forall i o. Emits o => [UExpr i] -> InfererM i o (CAtom o) inferTabCon xs = do let n = fromIntegral (length xs) :: Word32 let finTy = FinConst n elemTy <- case xs of - [] -> throw TypeErr "Can't infer type of empty table" + [] -> throw InferEmptyTable x:_ -> getType <$> bottomUp x ixTy <- asIxType finTy let tabTy = ixTy ==> elemTy @@ -1697,8 +1680,7 @@ applyConstraint = \case -- any inference variables in r2's explicit effects because we don't know -- how they line up with r1's. So this is just about figuring out r2's tail. r2 <- zonk r2' - let msg = "Allowed effects: " ++ pprint r1 ++ - "\nRequested effects: " ++ pprint r2 + let msg = DisallowedEffects (pprint r1) (pprint r2) case checkExtends r1 r2 of Success () -> return () Failure _ -> searchFailureAsTypeErr msg do @@ -1713,18 +1695,14 @@ constrainEq :: ToAtom e CoreIR => e o -> e o -> SolverM i o () constrainEq t1 t2 = do t1' <- zonk $ toAtom t1 t2' <- zonk $ toAtom t2 - msg <- liftEnvReaderM $ do + msg <- liftEnvReaderM do ab <- renameForPrinting $ PairE t1' t2' return $ canonicalizeForPrinting ab \(Abs infVars (PairE t1Pretty t2Pretty)) -> - "Expected: " ++ pprint t1Pretty - ++ "\n Actual: " ++ pprint t2Pretty - ++ (case infVars of - Empty -> "" - _ -> "\n(Solving for: " ++ pprint (nestToList pprint infVars) ++ ")") + UnificationFailure (pprint t1Pretty) (pprint t2Pretty) (nestToList pprint infVars) void $ searchFailureAsTypeErr msg $ unify t1' t2' -searchFailureAsTypeErr :: String -> SolverM i n a -> SolverM i n a -searchFailureAsTypeErr msg cont = cont <|> throw TypeErr msg +searchFailureAsTypeErr :: ToErr e => e -> SolverM i n a -> SolverM i n a +searchFailureAsTypeErr msg cont = cont <|> throw msg {-# INLINE searchFailureAsTypeErr #-} class AlphaEqE e => Unifiable (e::E) where @@ -1964,13 +1942,13 @@ extendSolution :: CAtomVar n -> CAtom n -> SolverM i n () extendSolution (AtomVar v _) t = isUnificationName v >>= \case True -> do - when (v `isFreeIn` t) $ throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) + when (v `isFreeIn` t) $ throw $ OccursCheckFailure (pprint v) (pprint t) -- When we unify under a pi binder we replace its occurrences with a -- skolem variable. We don't want to unify with terms containing these -- variables because that would mean inferring dependence, which is a can -- of worms. forM_ (freeAtomVarsList t) \fv -> - whenM (isSkolemName fv) $ throw TypeErr $ "Can't unify with skolem vars" + whenM (isSkolemName fv) $ throw CantUnifySkolem addConstraint v t False -> empty @@ -2090,11 +2068,11 @@ emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do trySynthTerm :: CType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) trySynthTerm ty reqMethodAccess = do hasInferenceVars ty >>= \case - True -> throw TypeErr $ "Can't synthesize a dictionary for a type with inference vars: " ++ pprint ty + True -> throw $ CantSynthInfVars $ pprint ty False -> withVoidSubst do synthTy <- liftExcept $ typeAsSynthType ty synthTerm synthTy reqMethodAccess - <|> throw TypeErr ("Couldn't synthesize a class dictionary for: " ++ pprint ty) + <|> (throw $ CantSynthDict $ pprint ty) {-# SCC trySynthTerm #-} hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool @@ -2139,7 +2117,7 @@ typeAsSynthType = \case TyCon (DictTy dictTy) -> return $ SynthDictType dictTy TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> return $ SynthPiType (expls, Abs bs d) - ty -> Failure $ Err TypeErr $ "Can't synthesize terms of type: " ++ pprint ty + ty -> Failure $ toErr $ NotASynthType $ pprint ty {-# SCC typeAsSynthType #-} getSuperclassClosure :: EnvReader m => Givens n -> [SynthAtom n] -> m n (Givens n) @@ -2254,7 +2232,7 @@ addInstanceSynthCandidate className maybeBuiltin instanceName = do instantiateSynthArgs :: DictType n -> SynthPiType n -> InfererM i n [CAtom n] instantiateSynthArgs target (expls, synthPiTy) = do - liftM fromListE $ withReducibleEmissions "dict args" do + liftM fromListE $ withReducibleEmissions CantReduceDict do bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do return [TypeConstraint (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] ListE <$> inferMixedArgs "dict" expls bsConstrained emptyMixedArgs @@ -2263,10 +2241,9 @@ emptyMixedArgs :: MixedArgs (CAtom n) emptyMixedArgs = ([], []) typeErrAsSearchFailure :: InfererM i n a -> InfererM i n a -typeErrAsSearchFailure cont = cont `catchErr` \err@(Err errTy _) -> do - case errTy of - TypeErr -> empty - _ -> throwErr err +typeErrAsSearchFailure cont = cont `catchErr` \case + TypeErr _ -> empty + e -> throwErr e synthDictForData :: forall i n. DictType n -> InfererM i n (SynthAtom n) synthDictForData dictTy@(DataDictType ty) = case ty of @@ -2352,7 +2329,7 @@ checkFFIFunTypeM _ = error "expected at least one argument" checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType checkScalar (BaseTy ty) = return ty -checkScalar ty = throw TypeErr $ pprint ty +checkScalar ty = throw $ FFIArgTyNotScalar $ pprint ty checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType] checkScalarOrPairType (PairTy a b) = do @@ -2360,7 +2337,7 @@ checkScalarOrPairType (PairTy a b) = do tys2 <- checkScalarOrPairType b return $ tys1 ++ tys2 checkScalarOrPairType (BaseTy ty) = return [ty] -checkScalarOrPairType ty = throw TypeErr $ pprint ty +checkScalarOrPairType ty = throw $ FFIResultTyErr $ pprint ty -- === instances === diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 4d3b6dc8e..111027ff0 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -24,6 +24,7 @@ import qualified Text.Megaparsec.Char.Lexer as L import Text.Megaparsec.Debug import Err +import PPrint import Types.Primitives import Types.Source import Util (toSnocList) @@ -43,7 +44,7 @@ type Parser = StateT ParseCtx (Parsec Void Text) parseit :: Text -> Parser a -> Except a parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of - Left e -> throw ParseErr $ errorBundlePretty e + Left e -> throw $ MiscParseErr $ errorBundlePretty e Right x -> return x mustParseit :: Text -> Parser a -> a diff --git a/src/lib/Name.hs b/src/lib/Name.hs index fd23def5e..4f025384c 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -2757,12 +2757,7 @@ pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' liftHoistExcept :: Fallible m => HoistExcept a -> m a liftHoistExcept (HoistSuccess x) = return x -liftHoistExcept (HoistFailure vs) = throw EscapedNameErr (pprint vs) - -liftHoistExcept' :: Fallible m => String -> HoistExcept a -> m a -liftHoistExcept' _ (HoistSuccess x) = return x -liftHoistExcept' msg (HoistFailure vs) = - throw EscapedNameErr $ (pprint vs) ++ "\n" ++ msg +liftHoistExcept (HoistFailure vs) = throw $ EscapedNameErr $ map pprint vs ignoreHoistFailure :: HasCallStack => HoistExcept a -> a ignoreHoistFailure (HoistSuccess x) = x @@ -2864,7 +2859,7 @@ partitionBinders bs assignBinder = go bs where RightB b2 -> withSubscopeDistinct bs2 case exchangeBs (PairB b2 bs1) of HoistSuccess (PairB bs1' b2') -> return $ PairB bs1' (Nest b2' bs2) - HoistFailure vs -> throw EscapedNameErr $ (pprint vs) + HoistFailure vs -> throw $ EscapedNameErr $ map pprint vs -- NameBinder has no free vars, so there's no risk associated with hoisting. -- The scope is completely distinct, so their exchange doesn't create any accidental diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index a6e3b4b52..acbff02c7 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -8,6 +8,7 @@ module QueryType (module QueryType, module QueryTypePure, toAtomVar) where import Control.Category ((>>>)) import Control.Monad +import Control.Applicative import Data.List (elemIndex) import Data.Maybe (fromJust) import Data.Functor ((<&>)) @@ -23,14 +24,14 @@ import Err import Name hiding (withFreshM) import Subst import Util -import PPrint () +import PPrint import QueryTypePure import CheapReduction sourceNameType :: (EnvReader m, Fallible1 m) => SourceName -> m n (Type CoreIR n) sourceNameType v = do lookupSourceMap v >>= \case - Nothing -> throw UnboundVarErr $ pprint v + Nothing -> throw $ UnboundVarErr $ pprint v Just uvar -> getUVarType uvar -- === Exposed helpers for querying types and effects === @@ -347,35 +348,29 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where isData :: EnvReader m => Type CoreIR n -> m n Bool isData ty = do - result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty + result <- liftEnvReaderT $ withSubstReaderT $ go ty case result of - Success () -> return True - Failure _ -> return False - -checkDataLike :: Type CoreIR i -> SubstReaderT Name (EnvReaderT Except) i o () -checkDataLike ty = case ty of - StuckTy _ _ -> notData - TyCon con -> case con of - TabPi (TabPiType _ b eltTy) -> do - renameBinders b \_ -> - checkDataLike eltTy - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l - renameBinders b \_ -> checkDataLike r - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType =<< renameM nt - dropSubst $ recur ty' - BaseType _ -> return () - ProdType as -> mapM_ recur as - SumType cs -> mapM_ recur cs - RefType _ _ -> return () - HeapType -> return () - TypeKind -> notData - DictTy _ -> notData - Pi _ -> notData + Just () -> return True + Nothing -> return False where - recur = checkDataLike - notData = throw TypeErr $ pprint ty + go :: Type CoreIR i -> SubstReaderT Name (EnvReaderT Maybe) i o () + go = \case + StuckTy _ _ -> notData + TyCon con -> case con of + TabPi (TabPiType _ b eltTy) -> renameBinders b \_ -> go eltTy + DepPairTy (DepPairType _ b@(_:>l) r) -> go l >> renameBinders b \_ -> go r + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType =<< renameM nt + dropSubst $ go ty' + BaseType _ -> return () + ProdType as -> mapM_ go as + SumType cs -> mapM_ go cs + RefType _ _ -> return () + HeapType -> return () + TypeKind -> notData + DictTy _ -> notData + Pi _ -> notData + where notData = empty checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () checkExtends allowed (EffectRow effs effTail) = do @@ -384,6 +379,6 @@ checkExtends allowed (EffectRow effs effTail) = do EffectRowTail _ -> assertEq allowedEffTail effTail "" NoTail -> return () forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ - throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ - "\nAllowed: " ++ pprint allowed + throwInternal $ "Unexpected effect: " ++ pprint eff ++ + "\nAllowed: " ++ pprint allowed {-# INLINE checkExtends #-} diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index 3cbc25c86..e4015c6a1 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -30,7 +30,7 @@ import GHC.Generics import Err import Paths_dex (getDataFileName) -import PPrint () +import PPrint import Types.Source import Util (unsnoc, foldJusts) diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 20730e332..102019098 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -72,7 +72,7 @@ checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO () checkedCallFunPtr fd argsPtr resultPtr fPtr = do let (CInt fd') = fdFD fd exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throw RuntimeErr "" + unless (exitCode == 0) $ throw RuntimeErr withPipeToLogger :: PassLogger -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 15fa1c86b..dd12d67a5 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -16,6 +16,7 @@ import IRVariants import MTL1 import Name import CheapReduction +import PPrint import Types.Core import Types.Source import Types.Primitives diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 789ccb59c..f44aa1099 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -14,7 +14,6 @@ import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader import Data.Maybe -import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder import CheapReduction @@ -26,6 +25,7 @@ import IRVariants import Linearize import Name import Subst +import PPrint import QueryType import RuntimePrint import Transpose diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs index 19ca2f8ca..7e2436200 100644 --- a/src/lib/SourceIdTraversal.hs +++ b/src/lib/SourceIdTraversal.hs @@ -11,6 +11,7 @@ import Data.Functor ((<&>)) import Types.Source import Types.Primitives +import Err getGroupTree :: SourceBlock' -> GroupTree getGroupTree b = mkGroupTree False rootSrcId $ runTreeM $ visit b diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index c6b68d82d..1fcaa73d0 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -10,7 +10,6 @@ module SourceRename ( renameSourceNamesTopUDecl, uDeclErrSourceMap , renameSourceNamesUExpr ) where import Prelude hiding (id, (.)) -import Data.List (sort) import Control.Category import Control.Monad.Except hiding (Except) import qualified Data.Set as S @@ -19,7 +18,7 @@ import qualified Data.Map.Strict as M import Err import Name import Core (EnvReader (..), withEnv, lookupSourceMapPure) -import PPrint () +import PPrint import IRVariants import Types.Source import Types.Primitives @@ -107,45 +106,40 @@ lookupSourceName :: Renamer m => SourceName -> m n (UVar n) lookupSourceName v = do sm <- askSourceMap case lookupSourceMapPure sm v of - [] -> throw UnboundVarErr $ pprint v + [] -> throw $ UnboundVarErr $ pprint v LocalVar v' : _ -> return v' [ModuleVar _ maybeV] -> case maybeV of Just v' -> return v' - Nothing -> throw VarDefErr $ pprint v - vs -> throw AmbiguousVarErr $ ambiguousVarErrMsg v vs - -ambiguousVarErrMsg :: SourceName -> [SourceNameDef n] -> String -ambiguousVarErrMsg v defs = - -- we sort the lines to make the result a bit more deterministic for quine tests - pprint v ++ " is defined:\n" ++ unlines (sort $ map defsPretty defs) - where - defsPretty :: SourceNameDef n -> String - defsPretty (ModuleVar mname _) = case mname of - Main -> "in this file" - Prelude -> "in the prelude" - OrdinaryModule mname' -> "in " ++ pprint mname' - defsPretty (LocalVar _) = - error "shouldn't be possible because module vars can't shadow local ones" + Nothing -> throw $ VarDefErr $ pprint v + vs -> throw $ AmbiguousVarErr (pprint v) (map wherePretty vs) + where + wherePretty :: SourceNameDef n -> String + wherePretty (ModuleVar mname _) = case mname of + Main -> "in this file" + Prelude -> "in the prelude" + OrdinaryModule mname' -> "in " ++ pprint mname' + wherePretty (LocalVar _) = + error "shouldn't be possible because module vars can't shadow local ones" instance SourceRenamableE (SourceNameOr (Name (AtomNameC CoreIR))) where sourceRenameE (SourceName pos sourceName) = do lookupSourceName sourceName >>= \case UAtomVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not an ordinary variable: " ++ pprint sourceName + _ -> throw $ NotAnOrdinaryVar $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name DataConNameC)) where sourceRenameE (SourceName pos sourceName) = do lookupSourceName sourceName >>= \case UDataConVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not a data constructor: " ++ pprint sourceName + _ -> throw $ NotADataCon $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name ClassNameC)) where sourceRenameE (SourceName pos sourceName) = do lookupSourceName sourceName >>= \case UClassVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not a class name: " ++ pprint sourceName + _ -> throw $ NotAClassName $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrInternalName c) where @@ -310,8 +304,7 @@ sourceRenameUBinder' asUVar ubinder cont = case ubinder of SourceMap sm <- askSourceMap mayShadow <- askMayShadow let shadows = M.member b sm - when (not mayShadow && shadows) $ - throw RepeatedVarErr $ pprint b + when (not mayShadow && shadows) $ throw (RepeatedVarErr $ pprint b) withFreshM (getNameHint b) \freshName -> do Distinct <- getDistinct extendSourceMap b (asUVar $ binderName freshName) $ @@ -367,7 +360,7 @@ instance SourceRenamableE UMethodDef' where sourceRenameE (UMethodDef ~(SourceName pos v) expr) = do lookupSourceName v >>= \case UMethodVar v' -> UMethodDef (InternalName pos v v') <$> sourceRenameE expr - _ -> throw TypeErr $ "not a method name: " ++ pprint v + _ -> throw $ NotAMethodName $ pprint v instance SourceRenamableB b => SourceRenamableB (Nest b) where sourceRenameB (Nest b bs) cont = @@ -394,7 +387,7 @@ instance SourceRenamablePat (UBinder' (AtomNameC CoreIR)) where sourceRenamePat sibs ubinder cont = do newSibs <- case ubinder of UBindSource b -> do - when (S.member b sibs) $ throw RepeatedPatVarErr $ pprint b + when (S.member b sibs) $ throw $ RepeatedPatVarErr $ pprint b return $ S.singleton b UIgnore -> return mempty UBind _ _ -> error "Shouldn't be source-renaming internal names" diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 69dc417dd..47b132c37 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -70,6 +70,7 @@ import Serialize (takePtrSnapshot, restorePtrSnapshot) import Simplify import SourceRename import SourceIdTraversal +import PPrint import Types.Core import Types.Imp import Types.Primitives @@ -262,7 +263,7 @@ evalSourceBlock' mname block = case sbContents block of DeclareForeign fname (WithSrc _ dexName) cTy -> do ty <- evalUType =<< parseExpr cTy asFFIFunType ty >>= \case - Nothing -> throw TypeErr + Nothing -> throw $ MiscMiscErr "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available @@ -274,15 +275,15 @@ evalSourceBlock' mname block = case sbContents block of DeclareCustomLinearization fname zeros g -> do expr <- parseExpr g lookupSourceMap (withoutSrc fname) >>= \case - Nothing -> throw UnboundVarErr $ pprint fname + Nothing -> throw $ UnboundVarErr $ pprint fname Just (UAtomVar fname') -> do lookupCustomRules fname' >>= \case Nothing -> return () - Just _ -> throw TypeErr + Just _ -> throw $ MiscMiscErr $ pprint fname ++ " already has a custom linearization" lookupAtomName fname' >>= \case NoinlineFun _ _ -> return () - _ -> throw TypeErr "Custom linearizations only apply to @noinline functions" + _ -> throw $ MiscMiscErr "Custom linearizations only apply to @noinline functions" -- We do some special casing to avoid instantiating polymorphic functions. impl <- case expr of WithSrcE _ (UVar _) -> @@ -295,14 +296,13 @@ evalSourceBlock' mname block = case sbContents block of liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case Failure _ -> do let implTy = getType impl - throw TypeErr $ unlines + throw $ MiscMiscErr $ unlines [ "Expected the custom linearization to have type:" , "" , pprint linFunTy , "" , "but it has type:" , "" , pprint implTy] Success () -> return () updateTopEnv $ AddCustomRule fname' $ CustomLinearize nimplicit nexplicit zeros impl - Just _ -> throw TypeErr - $ "Custom linearization can only be defined for functions" - UnParseable _ s -> throw ParseErr s + Just _ -> throw $ MiscMiscErr $ "Custom linearization can only be defined for functions" + UnParseable _ s -> throw $ MiscParseErr s Misc m -> case m of GetNameType v -> do ty <- sourceNameType (withoutSrc v) @@ -327,11 +327,11 @@ runEnvQuery query = do DumpSubst -> logTop $ TextOut $ pprint $ env InternalNameInfo name -> case lookupSubstFragRaw (fromRecSubst $ envDefs $ topEnv env) name of - Nothing -> throw UnboundVarErr $ pprint name + Nothing -> throw $ UnboundVarErr $ pprint name Just binding -> logTop $ TextOut $ pprint binding SourceNameInfo name -> do lookupSourceMap name >>= \case - Nothing -> throw UnboundVarErr $ pprint name + Nothing -> throw $ UnboundVarErr $ pprint name Just uvar -> do logTop $ TextOut $ pprint uvar info <- case uvar of @@ -400,7 +400,7 @@ evalPartiallyParsedUModuleCached md@(UModulePartialParse name deps source) = do directDeps <- forM deps \dep -> do lookupLoadedModule dep >>= \case Just depVal -> return depVal - Nothing -> throw CompilerErr $ pprint dep ++ " isn't loaded" + Nothing -> throwInternal $ pprint dep ++ " isn't loaded" let req = (fHash source, directDeps) case M.lookup name cache of Just (cachedReq, result) | cachedReq == req -> return result @@ -434,7 +434,7 @@ evalUModule (UModule name _ blocks) = do importModule :: (Mut n, TopBuilder m, Fallible1 m) => ModuleSourceName -> m n () importModule name = do lookupLoadedModule name >>= \case - Nothing -> throw ModuleImportErr $ "Couldn't import " ++ pprint name + Nothing -> throw $ ModuleImportErr $ pprint name Just name' -> do Module _ _ transImports' _ _ <- lookupModule name' let importStatus = ImportStatus (S.singleton name') @@ -693,13 +693,7 @@ loadModuleSource config moduleName = do fsPaths <- liftIO $ traverse resolveBuiltinPath $ libPaths config liftIO (findFile fsPaths fname) >>= \case Just fpath -> return fpath - Nothing -> throw ModuleImportErr $ unlines - [ "Couldn't find a source file for module " ++ - (case moduleName of - OrdinaryModule n -> pprint n; Prelude -> "prelude"; Main -> error "") - , "Hint: Consider extending --lib-path?" - ] - + Nothing -> throw $ CantFindModuleSource $ pprint moduleName resolveBuiltinPath = \case LibBuiltinPath -> liftIO $ getDataFileName "lib" LibDirectory dir -> return dir @@ -838,14 +832,14 @@ getLinearizationType zeros = \case Just tty -> case zeros of InstantiateZeros -> return tty SymbolicZeros -> symbolicTangentTy tty - Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint t + Nothing -> throw $ MiscMiscErr $ "No tangent type for: " ++ pprint t resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt - Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint resultTy' + Nothing -> throw $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) return (numIs, numEs, toType $ Pi fullTy) - _ -> throw TypeErr $ "Can't define a custom linearization for implicit or impure functions" + _ -> throw $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" where getNumImplicits :: Fallible m => [Explicitness] -> m (Int, Int) getNumImplicits = \case @@ -856,4 +850,4 @@ getLinearizationType zeros = \case Inferred _ _ -> return (ni + 1, ne) Explicit -> case ni of 0 -> return (0, ne + 1) - _ -> throw TypeErr "All implicit args must precede implicit args" + _ -> throw $ MiscMiscErr "All implicit args must precede implicit args" diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index e35305bc6..3e361d0d3 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -13,10 +13,10 @@ import GHC.Stack import Builder import Core -import Err import Imp import IRVariants import Name +import PPrint import Subst import QueryType import Types.Core diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index b43b81a4c..87a6e676a 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -20,7 +20,7 @@ module Types.Source where -import Data.Aeson (ToJSON, ToJSONKey) +import Data.Aeson (ToJSON) import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M @@ -61,13 +61,6 @@ newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr -- === Source Info === --- XXX: 0 is reserved for the root The IDs are generated from left to right in --- parsing order, so IDs for lexemes are guaranteed to be sorted correctly. -newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) - -rootSrcId :: SrcId -rootSrcId = SrcId 0 - -- This is just for syntax highlighting. It won't be needed if we have -- a separate lexing pass where we have a complete lossless data type for -- lexemes. @@ -959,14 +952,10 @@ deriving instance Show (UEffectRow n) deriving instance Eq (UEffectRow n) deriving instance Ord (UEffectRow n) -instance ToJSON SrcId -deriving instance ToJSONKey SrcId instance ToJSON LexemeType -- === Pretty instances === - - instance Pretty CSBlock where pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls pretty (ExprBlock g) = pArg g diff --git a/src/lib/Types/Top.hs b/src/lib/Types/Top.hs index fba64b0e1..b67fe357f 100644 --- a/src/lib/Types/Top.hs +++ b/src/lib/Types/Top.hs @@ -71,15 +71,6 @@ data AtomBinding (r::IR) (n::S) where deriving instance IRRep r => Show (AtomBinding r n) deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) --- name of function, name of arg -type InferenceArgDesc = (String, String) -data InfVarDesc = - ImplicitArgInfVar InferenceArgDesc - | AnnotationInfVar String -- name of binder - | TypeInstantiationInfVar String -- name of type - | MiscInfVar - deriving (Show, Generic, Eq, Ord) - data SolverBinding (n::S) = InfVarBound (CType n) | SkolemBound (CType n) @@ -1023,7 +1014,6 @@ instance Pretty (SpecializationSpec n) where pretty (AppSpecialization f (Abs bs (ListE args))) = "Specialization" <+> pretty f <+> pretty bs <+> pretty args -instance Hashable InfVarDesc instance Hashable a => Hashable (EvalStatus a) instance Store (SolverBinding n) @@ -1039,7 +1029,6 @@ instance Store (TopFunDef n) instance Color c => Store (Binding c n) instance Store (ModuleEnv n) instance Store (SerializedEnv n) -instance Store InfVarDesc instance Store (AtomRules n) instance Store (LinearizationSpec n) instance Store (SpecializedDictDef n) diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 90e289df9..daa606fc6 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -117,7 +117,7 @@ liftTopVectorizeM vectorByteWidth action = do Success (a, (LiftE errs)) -> return $ (a, errs) throwVectErr :: Fallible m => String -> m a -throwVectErr msg = throw MiscErr msg +throwVectErr msg = throwInternal msg askVectorByteWidth :: TopVectorizeM i o Word32 askVectorByteWidth = TopVectorizeM $ liftSubstReaderT $ lift11 (fromLiftE <$> ask) @@ -581,7 +581,7 @@ promoteTypeByStability ty = \case Varying -> getVectorType ty ProdStability stabs -> case ty of TyCon (ProdType elts) -> TyCon <$> ProdType <$> zipWithZ promoteTypeByStability elts stabs - _ -> throw ZipErr "Type and stability" + _ -> throwInternal "Zip error" -- === computing byte widths === From d10e03cc91e4578cefe75e74e572fd4e2f95ab6c Mon Sep 17 00:00:00 2001 From: Dougal Date: Mon, 4 Dec 2023 21:23:20 -0500 Subject: [PATCH 37/41] Start adding SrcIds to user-facing errors --- src/lib/AbstractSyntax.hs | 58 ++--- src/lib/Err.hs | 9 +- src/lib/Export.hs | 12 +- src/lib/Inference.hs | 467 ++++++++++++++++++++------------------ src/lib/Lexing.hs | 2 +- src/lib/Name.hs | 13 +- src/lib/QueryType.hs | 5 - src/lib/Runtime.hs | 2 +- src/lib/SourceRename.hs | 137 +++++------ src/lib/TopLevel.hs | 37 +-- src/lib/Types/Source.hs | 8 + 11 files changed, 382 insertions(+), 368 deletions(-) diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index 6a4b25840..e7a396ce6 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -139,7 +139,7 @@ decl ann (WithSrcs sid _ d) = WithSrcB sid <$> case d of CLet binder rhs -> do (p, ty) <- patOptAnn binder ULet ann p ty <$> asExpr <$> block rhs - CBind _ _ -> throw TopLevelArrowBinder + CBind _ _ -> throw sid TopLevelArrowBinder CDefDecl def -> do (name, lam) <- aDef def return $ ULet ann (fromSourceNameW name) Nothing (WithSrcE sid (ULam lam)) @@ -199,7 +199,7 @@ withTrailingConstraints g cont = case g of Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs <- withTrailingConstraints lhs cont s <- case b of UBindSource s -> return s - UIgnore -> throw CantConstrainAnonBinders + UIgnore -> throw sid CantConstrainAnonBinders UBind _ _ -> error "Shouldn't have internal names until renaming pass" c' <- expr c return $ UnaryNest (UAnnBinder expl (WithSrcB sid b) ann (cs ++ [c'])) @@ -261,7 +261,7 @@ uBinder :: GroupW -> SyntaxM (UBinder c VoidS VoidS) uBinder (WithSrcs sid _ b) = case b of CLeaf (CIdentifier name) -> return $ fromSourceNameW $ WithSrc sid name CLeaf CHole -> return $ WithSrcB sid UIgnore - _ -> throw UnexpectedBinder + _ -> throw sid UnexpectedBinder -- Type annotation with an optional binder pattern tyOptPat :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) @@ -300,7 +300,7 @@ pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of CLeaf (CIdentifier name) -> return $ UPatBinder $ fromSourceNameW $ WithSrc sid name CJuxtapose True lhs rhs -> do case lhs of - WithSrcs _ _ (CJuxtapose True _ _) -> throw OnlyUnaryWithoutParens + WithSrcs lhsId _ (CJuxtapose True _ _) -> throw lhsId OnlyUnaryWithoutParens _ -> return () name <- identifier "pattern constructor name" lhs arg <- pat rhs @@ -312,11 +312,11 @@ pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of gs' <- mapM pat gs return $ UPatCon (fromSourceNameW name) (toNest gs') _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - _ -> throw IllegalPattern + _ -> throw sid IllegalPattern tyOptBinder :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) tyOptBinder expl (WithSrcs sid sids grp) = case grp of - CBin (WithSrc _ Pipe) _ _ -> throw UnexpectedConstraint + CBin (WithSrc _ Pipe) _ rhs -> throw (getSrcId rhs) UnexpectedConstraint CBin (WithSrc _ Colon) name ty -> do b <- uBinder name ann <- UAnn <$> expr ty @@ -340,7 +340,7 @@ binderReqTy expl (WithSrcs _ _ (CBin (WithSrc _ Colon) name ty)) = do b <- uBinder name ann <- UAnn <$> expr ty return $ UAnnBinder expl b ann [] -binderReqTy _ _ = throw ExpectedAnnBinder +binderReqTy _ g = throw (getSrcId g) ExpectedAnnBinder argList :: [GroupW] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) argList gs = partitionEithers <$> mapM singleArg gs @@ -354,7 +354,7 @@ singleArg = \case identifier :: String -> GroupW -> SyntaxM SourceNameW identifier ctx (WithSrcs sid _ g) = case g of CLeaf (CIdentifier name) -> return $ WithSrc sid name - _ -> throw $ ExpectedIdentifier ctx + _ -> throw sid $ ExpectedIdentifier ctx aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS) aEffects (WithSrcs _ _ (effs, optEffTail)) = do @@ -364,7 +364,7 @@ aEffects (WithSrcs _ _ (effs, optEffTail)) = do return $ UEffectRow (S.fromList lhs) rhs effect :: GroupW -> SyntaxM (UEffect VoidS) -effect (WithSrcs _ _ grp) = case grp of +effect (WithSrcs grpSid _ grp) = case grp of CParens [g] -> effect g CJuxtapose True (Identifier "Read" ) (WithSrcs sid _ (CLeaf (CIdentifier h))) -> return $ URWSEffect Reader $ fromSourceNameW (WithSrc sid h) @@ -374,18 +374,18 @@ effect (WithSrcs _ _ grp) = case grp of return $ URWSEffect State $ fromSourceNameW (WithSrc sid h) CLeaf (CIdentifier "Except") -> return UExceptionEffect CLeaf (CIdentifier "IO" ) -> return UIOEffect - _ -> throw UnexpectedEffectForm + _ -> throw grpSid UnexpectedEffectForm aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS)) aMethod (WithSrcs _ _ CPass) = return Nothing -aMethod (WithSrcs src _ d) = Just . WithSrcE src <$> case d of +aMethod (WithSrcs sid _ d) = Just . WithSrcE sid <$> case d of CDefDecl def -> do - (WithSrc sid name, lam) <- aDef def - return $ UMethodDef (SourceName sid name) lam - CLet (WithSrcs sid _ (CLeaf (CIdentifier name))) rhs -> do + (WithSrc nameSid name, lam) <- aDef def + return $ UMethodDef (SourceName nameSid name) lam + CLet (WithSrcs lhsSid _ (CLeaf (CIdentifier name))) rhs -> do rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs - return $ UMethodDef (fromSourceNameW (WithSrc sid name)) rhs' - _ -> throw UnexpectedMethodDef + return $ UMethodDef (fromSourceNameW (WithSrc lhsSid name)) rhs' + _ -> throw sid UnexpectedMethodDef asExpr :: UBlock VoidS -> UExpr VoidS asExpr (WithSrcE src b) = case b of @@ -400,9 +400,9 @@ block (IndentedBlock sid decls) = do blockDecls :: [CSDeclW] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) blockDecls [] = error "shouldn't have empty list of decls" -blockDecls [WithSrcs _ _ d] = case d of +blockDecls [WithSrcs sid _ d] = case d of CExpr g -> (Empty,) <$> expr g - _ -> throw BlockWithoutFinalExpr + _ -> throw sid BlockWithoutFinalExpr blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs @@ -427,7 +427,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of -- Table constructors here. Other uses of square brackets -- should be detected upstream, before calling expr. CBrackets gs -> UTabCon <$> mapM expr gs - CGivens _ -> throw UnexpectedGivenClause + CGivens _ -> throw sid UnexpectedGivenClause CArrow lhs effs rhs -> do case lhs of WithSrcs _ _ (CParens gs) -> do @@ -435,7 +435,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy - _ -> throw ArgsShouldHaveParens + WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens CDo b -> UDo <$> block b CJuxtapose hasSpace lhs rhs -> case hasSpace of True -> extendAppRight <$> expr lhs <*> expr rhs @@ -454,26 +454,26 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of Pipe -> extendAppLeft <$> expr lhs <*> expr rhs Dot -> do lhs' <- expr lhs - WithSrcs src _ rhs' <- return rhs + WithSrcs rhsSid _ rhs' <- return rhs name <- case rhs' of CLeaf (CIdentifier name) -> return $ FieldName name CLeaf (CNat i ) -> return $ FieldNum $ fromIntegral i - _ -> throw BadField - return $ UFieldAccess lhs' (WithSrc src name) + _ -> throw rhsSid BadField + return $ UFieldAccess lhs' (WithSrc rhsSid name) DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs EvalBinOp s -> evalOp s DepAmpersand -> do lhs' <- tyOptPat lhs UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs DepComma -> UDepPair <$> (expr lhs) <*> expr rhs - CSEqual -> throw BadEqualSign - Colon -> throw BadColon + CSEqual -> throw opSid BadEqualSign + Colon -> throw opSid BadColon ImplicitArrow -> case lhs of WithSrcs _ _ (CParens gs) -> do bs <- aPiBinders gs resultTy <- expr rhs return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy - _ -> throw ArgsShouldHaveParens + WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens FatArrow -> do lhs' <- tyOptPat lhs UTabPi . (UTabPiExpr lhs') <$> expr rhs @@ -483,7 +483,7 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of lhs' <- expr lhs rhs' <- expr rhs return $ explicitApp f [lhs', rhs'] - CPrefix (WithSrc _ name) g -> do + CPrefix (WithSrc prefixSid name) g -> do case name of "+" -> (withoutSrc <$> expr g) <&> \case UNatLit i -> UIntLit (fromIntegral i) @@ -494,8 +494,8 @@ expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of WithSrcE _ (UNatLit i) -> UIntLit (-(fromIntegral i)) WithSrcE _ (UIntLit i) -> UIntLit (-i) WithSrcE _ (UFloatLit i) -> UFloatLit (-i) - e -> unaryApp (mkUVar sid "neg") e - _ -> throw $ BadPrefix $ pprint name + e -> unaryApp (mkUVar prefixSid "neg") e + _ -> throw prefixSid $ BadPrefix $ pprint name CLambda params body -> do params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 649525f39..8a7037d22 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -14,7 +14,7 @@ module Err ( catchIOExcept, liftExcept, liftExceptAlt, ignoreExcept, getCurrentCallStack, printCurrentCallStack, ExceptT (..), rootSrcId, SrcId (..), assertEq, throwInternal, - InferenceArgDesc, InfVarDesc (..)) where + InferenceArgDesc, InfVarDesc (..), HasSrcId (..)) where import Control.Exception hiding (throw) import Control.Applicative @@ -43,6 +43,9 @@ newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) rootSrcId :: SrcId rootSrcId = SrcId 0 +class HasSrcId a where + getSrcId :: a -> SrcId + -- === core errro type === data Err = @@ -464,8 +467,8 @@ instance Fallible HardFailM where -- === convenience layer === -throw :: (ToErr e, Fallible m) => e -> m a -throw e = throwErr $ toErr e +throw :: (ToErr e, Fallible m) => SrcId -> e -> m a +throw _ e = throwErr $ toErr e {-# INLINE throw #-} getCurrentCallStack :: () -> Maybe [String] diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 854595720..67e356f6c 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -48,11 +48,11 @@ prepareFunctionForExport :: (Mut n, Topper m) prepareFunctionForExport cc f = do naryPi <- case getType f of TyCon (Pi piTy) -> return piTy - _ -> throw $ MiscMiscErr "Only first-order functions can be exported" + _ -> throw rootSrcId $ MiscMiscErr "Only first-order functions can be exported" sig <- liftExportSigM $ corePiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throw rootSrcId $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (toAtom <$> xs) fSimp <- simplifyTopFunction $ coreLamToTopLam f' @@ -68,7 +68,7 @@ prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throw rootSrcId $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s fImp <- compileTopLevelFun cc f nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -105,7 +105,7 @@ corePiToExportSig :: CallingConvention corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw $ MiscMiscErr "Only pure functions can be exported" + _ -> throw rootSrcId $ MiscMiscErr "Only pure functions can be exported" goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention @@ -113,7 +113,7 @@ simpPiToExportSig :: CallingConvention simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw $ MiscMiscErr "Only pure functions can be exported" + _ -> throw rootSrcId $ MiscMiscErr "Only pure functions can be exported" bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy @@ -164,7 +164,7 @@ toExportType ty = case ty of Nothing -> unsupported Just ety -> return ety _ -> unsupported - where unsupported = throw $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty + where unsupported = throw rootSrcId $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty {-# INLINE toExportType #-} parseTabTy :: IRRep r => Type r i -> ExportSigM r i o (Maybe (ExportType o)) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 67227f3dc..cc10b1718 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -47,6 +47,9 @@ import Types.Top import qualified Types.OpNames as P import Util hiding (group) +sidtodo :: SrcId +sidtodo = rootSrcId + -- === Top-level interface === checkTopUType :: (Fallible1 m, TopLogger m, EnvReader m) => UType n -> m n (CType n) @@ -193,6 +196,9 @@ type Zonkable e = (HasNamesE e, SubstE AtomSubstVal e) liftSolverM :: SolverM i o a -> InfererM i o a liftSolverM cont = fst <$> runDiffStateT1 emptySolverSubst cont +solverFail :: SolverM i o a +solverFail = empty + zonk :: Zonkable e => e n -> SolverM i n (e n) zonk e = do s <- getDiffState @@ -279,7 +285,7 @@ withFreshUnificationVar desc k cont = do ans <- toAtomVar v >>= cont soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case Just soln -> return soln - Nothing -> throw $ AmbiguousInferenceVar (pprint v) (pprint k) desc + Nothing -> throw sidtodo $ AmbiguousInferenceVar (pprint v) (pprint k) desc return (ans, soln) {-# INLINE withFreshUnificationVar #-} @@ -322,11 +328,11 @@ withFreshDictVarNoEmits dictTy synthIt cont = diffStateT1 \s -> do {-# INLINE withFreshDictVarNoEmits #-} withDict - :: (Zonkable e, Emits o) => Kind CoreIR o + :: (Zonkable e, Emits o) => SrcId -> CType o -> (forall o'. (Emits o', DExt o o') => CAtom o' -> SolverM i o' (e o')) -> SolverM i o (e o) -withDict dictTy cont = withFreshDictVar dictTy - (\dictTy' -> lift11 $ trySynthTerm dictTy' Full) +withDict sid dictTy cont = withFreshDictVar dictTy + (\dictTy' -> lift11 $ trySynthTerm sid dictTy' Full) cont {-# INLINE withDict#-} @@ -353,14 +359,14 @@ emitTypeInfo sid ty = do withReducibleEmissions :: (HasNamesE e, SubstE AtomSubstVal e, ToErr err) - => err + => SrcId -> err -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) -> InfererM i o (e o) -withReducibleEmissions msg cont = do +withReducibleEmissions sid msg cont = do withDecls <- buildScoped cont reduceWithDecls withDecls >>= \case Just t -> return t - _ -> throw msg + _ -> throw sid msg {-# INLINE withReducibleEmissions #-} -- === actual inference pass === @@ -391,13 +397,13 @@ topDown :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) topDown ty uexpr = topDownPartial (typeAsPartialType ty) uexpr topDownPartial :: Emits o => PartialType o -> UExpr i -> InfererM i o (CAtom o) -topDownPartial partialTy exprWithSrc@(WithSrcE _ expr) = +topDownPartial partialTy exprWithSrc@(WithSrcE sid expr) = case partialTy of PartialType partialPiTy -> case expr of - ULam lam -> toAtom <$> Lam <$> checkULamPartial partialPiTy lam + ULam lam -> toAtom <$> Lam <$> checkULamPartial partialPiTy sid lam _ -> toAtom <$> Lam <$> etaExpandPartialPi partialPiTy \resultTy explicitArgs -> do expr' <- bottomUpExplicit exprWithSrc - dropSubst $ checkOrInferApp expr' explicitArgs [] resultTy + dropSubst $ checkOrInferApp sid sid expr' explicitArgs [] resultTy FullType ty -> topDownExplicit ty exprWithSrc -- Creates a lambda for all args and returns (via CPA) the explicit args @@ -418,26 +424,26 @@ etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do -- Doesn't introduce implicit pi binders or dependent pairs topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) -topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of +topDownExplicit reqTy exprWithSrc@(WithSrcE sid expr) = case expr of ULam lamExpr -> case reqTy of - TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam lamExpr piTy - _ -> throw $ UnexpectedTerm "lambda" (pprint reqTy) - UFor dir uFor -> case reqTy of + TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam sid lamExpr piTy + _ -> throw sid $ UnexpectedTerm "lambda" (pprint reqTy) + UFor dir uFor@(UForExpr b _) -> case reqTy of TyCon (TabPi tabPiTy) -> do - lam@(UnaryLamExpr b' _) <- checkUForExpr uFor tabPiTy - ixTy <- asIxType $ binderType b' + lam@(UnaryLamExpr b' _) <- checkUForExpr sid uFor tabPiTy + ixTy <- asIxType (getSrcId b) $ binderType b' emitHof $ For dir ixTy lam - _ -> throw $ UnexpectedTerm "`for` expression" (pprint reqTy) + _ -> throw sid $ UnexpectedTerm "`for` expression" (pprint reqTy) UApp f posArgs namedArgs -> do f' <- bottomUpExplicit f - checkOrInferApp f' posArgs namedArgs (Check reqTy) + checkOrInferApp sid (getSrcId f) f' posArgs namedArgs (Check reqTy) UDepPair lhs rhs -> case reqTy of TyCon (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do lhs' <- checkSigmaDependent lhs (FullType lhsTy) rhsTy <- instantiate ty [lhs'] rhs' <- topDown rhsTy rhs return $ toAtom $ DepPair lhs' rhs' ty - _ -> throw $ UnexpectedTerm "dependent pair" (pprint reqTy) + _ -> throw sid $ UnexpectedTerm "dependent pair" (pprint reqTy) UCase scrut alts -> do scrut' <- bottomUp scrut let scrutTy = getType scrut' @@ -446,16 +452,16 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of UDo block -> withBlockDecls block \result -> topDownExplicit (sink reqTy) result UTabCon xs -> do case reqTy of - TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy xs - _ -> throw $ UnexpectedTerm "table constructor" (pprint reqTy) - UNatLit x -> fromNatLit x reqTy - UIntLit x -> fromIntLit x reqTy + TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy sid xs + _ -> throw sid $ UnexpectedTerm "table constructor" (pprint reqTy) + UNatLit x -> fromNatLit sid x reqTy + UIntLit x -> fromIntLit sid x reqTy UPrim UTuple xs -> case reqTy of TyKind -> toAtom . ProdType <$> mapM checkUType xs TyCon (ProdType reqTys) -> do - when (length reqTys /= length xs) $ throw $ TupleLengthMismatch (length reqTys) (length xs) + when (length reqTys /= length xs) $ throw sid $ TupleLengthMismatch (length reqTys) (length xs) toAtom <$> ProdCon <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x - _ -> throw $ UnexpectedTerm "tuple" (pprint reqTy) + _ -> throw sid $ UnexpectedTerm "tuple" (pprint reqTy) UFieldAccess _ _ -> infer UVar _ -> infer UTypeAnn _ _ -> infer @@ -466,15 +472,15 @@ topDownExplicit reqTy exprWithSrc@(WithSrcE _ expr) = case expr of UPi _ -> infer UTabPi _ -> infer UDepPairTy _ -> infer - UHole -> throw InferHoleErr + UHole -> throw sid InferHoleErr where infer :: InfererM i o (CAtom o) infer = do sigmaAtom <- maybeInterpretPunsAsTyCons (Check reqTy) =<< bottomUpExplicit exprWithSrc - instantiateSigma (Check reqTy) sigmaAtom + instantiateSigma sid (Check reqTy) sigmaAtom bottomUp :: Emits o => UExpr i -> InfererM i o (CAtom o) -bottomUp expr = bottomUpExplicit expr >>= instantiateSigma Infer +bottomUp expr = bottomUpExplicit expr >>= instantiateSigma (getSrcId expr) Infer -- Doesn't instantiate implicit args bottomUpExplicit :: Emits o => UExpr i -> InfererM i o (SigmaAtom o) @@ -488,7 +494,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of UFieldAccess x (WithSrc _ field) -> do x' <- bottomUp x ty <- return $ getType x' - fields <- getFieldDefs ty + fields <- getFieldDefs sid ty case M.lookup field fields of Just def -> case def of FieldProj i -> SigmaAtom Nothing <$> projectField i x' @@ -496,18 +502,18 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of method' <- toAtomVar method resultTy <- partialAppType (getType method') (params ++ [x']) return $ SigmaPartialApp resultTy (toAtom method') (params ++ [x']) - Nothing -> throw $ CantFindField (pprint field) (pprint ty) (map pprint $ M.keys fields) + Nothing -> throw sid $ CantFindField (pprint field) (pprint ty) (map pprint $ M.keys fields) ULam lamExpr -> SigmaAtom Nothing <$> toAtom <$> inferULam lamExpr - UFor dir uFor -> do + UFor dir uFor@(UForExpr b _) -> do lam@(UnaryLamExpr b' _) <- inferUForExpr uFor - ixTy <- asIxType $ binderType b' + ixTy <- asIxType (getSrcId b) $ binderType b' liftM (SigmaAtom Nothing) $ emitHof $ For dir ixTy lam UApp f posArgs namedArgs -> do f' <- bottomUpExplicit f - SigmaAtom Nothing <$> checkOrInferApp f' posArgs namedArgs Infer + SigmaAtom Nothing <$> checkOrInferApp sid (getSrcId f) f' posArgs namedArgs Infer UTabApp tab args -> do tab' <- bottomUp tab - SigmaAtom Nothing <$> inferTabApp tab' args + SigmaAtom Nothing <$> inferTabApp (getSrcId tab) tab' args UPi (UPiExpr bs appExpl effs ty) -> do -- TODO: check explicitness constraints withUBinders bs \(ZipB expls bs') -> do @@ -517,26 +523,26 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of UTabPi (UTabPiExpr b ty) -> do Abs b' ty' <- withUBinder b \(WithAttrB _ b') -> liftM (Abs b') $ checkUType ty - d <- getIxDict $ binderType b' + d <- getIxDict (getSrcId b) $ binderType b' let piTy = TabPiType d b' ty' return $ SigmaAtom Nothing $ toAtom $ TabPi piTy UDepPairTy (UDepPairType expl b rhs) -> do withUBinder b \(WithAttrB _ b') -> do rhs' <- checkUType rhs return $ SigmaAtom Nothing $ toAtom $ DepPairTy $ DepPairType expl b' rhs' - UDepPair _ _ -> throw InferDepPairErr + UDepPair _ _ -> throw sid InferDepPairErr UCase scrut (alt:alts) -> do scrut' <- bottomUp scrut let scrutTy = getType scrut' alt'@(IndexedAlt _ altAbs) <- checkCaseAlt Infer scrutTy alt Abs b ty <- liftEnvReaderM $ refreshAbs altAbs \b body -> do return $ Abs b (getType body) - resultTy <- liftHoistExcept $ hoist b ty + resultTy <- liftHoistExcept sid $ hoist b ty alts' <- mapM (checkCaseAlt (Check resultTy) scrutTy) alts SigmaAtom Nothing <$> buildSortedCase scrut' (alt':alts') resultTy - UCase _ [] -> throw InferEmptyCaseEff + UCase _ [] -> throw sid InferEmptyCaseEff UDo block -> withBlockDecls block \result -> bottomUpExplicit result - UTabCon xs -> liftM (SigmaAtom Nothing) $ inferTabCon xs + UTabCon xs -> liftM (SigmaAtom Nothing) $ inferTabCon sid xs UTypeAnn val ty -> do ty' <- checkUType ty liftM (SigmaAtom Nothing) $ topDown ty' val @@ -550,7 +556,7 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of UPrim UExplicitApply (f:xs) -> do f' <- bottomUpExplicit f xs' <- mapM bottomUp xs - SigmaAtom Nothing <$> applySigmaAtom f' xs' + SigmaAtom Nothing <$> applySigmaAtom sid f' xs' UPrim UProjNewtype [x] -> do x' <- bottomUp x >>= unwrapNewtype return $ SigmaAtom Nothing x' @@ -562,56 +568,56 @@ bottomUpExplicit (WithSrcE sid expr) = case expr of _ -> return $ toAtom v x' -> return x' liftM (SigmaAtom Nothing) $ matchPrimApp prim xs' - UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit l NatTy - UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit l (BaseTy $ Scalar Int32Type) + UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit sid l NatTy + UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit sid l (BaseTy $ Scalar Int32Type) UFloatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Float32Lit $ realToFrac x - UHole -> throw InferHoleErr + UHole -> throw sid InferHoleErr -expectEq :: (PrettyE e, AlphaEqE e) => e o -> e o -> InfererM i o () -expectEq reqTy actualTy = alphaEq reqTy actualTy >>= \case +expectEq :: (PrettyE e, AlphaEqE e) => SrcId -> e o -> e o -> InfererM i o () +expectEq sid reqTy actualTy = alphaEq reqTy actualTy >>= \case True -> return () - False -> throw $ TypeMismatch (pprint reqTy) (pprint actualTy) + False -> throw sid $ TypeMismatch (pprint reqTy) (pprint actualTy) {-# INLINE expectEq #-} -fromIntLit :: Emits o => Int -> CType o -> InfererM i o (CAtom o) -fromIntLit x ty = do +fromIntLit :: Emits o => SrcId -> Int -> CType o -> InfererM i o (CAtom o) +fromIntLit sid x ty = do let litVal = Con $ Lit $ Int64Lit $ fromIntegral x - applyFromLiteralMethod ty "from_integer" litVal + applyFromLiteralMethod sid ty "from_integer" litVal -fromNatLit :: Emits o => Word64 -> CType o -> InfererM i o (CAtom o) -fromNatLit x ty = do +fromNatLit :: Emits o => SrcId -> Word64 -> CType o -> InfererM i o (CAtom o) +fromNatLit sid x ty = do let litVal = Con $ Lit $ Word64Lit $ fromIntegral x - applyFromLiteralMethod ty "from_unsigned_integer" litVal - -matchReq :: Ext o o' => RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') -matchReq (Check reqTy) x = do - reqTy' <- sinkM reqTy - return x <* expectEq reqTy' (getType x) -matchReq Infer x = return x -{-# INLINE matchReq #-} + applyFromLiteralMethod sid ty "from_unsigned_integer" litVal -instantiateSigma :: Emits o => RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) -instantiateSigma reqTy sigmaAtom = case sigmaAtom of +instantiateSigma :: Emits o => SrcId -> RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) +instantiateSigma sid reqTy sigmaAtom = case sigmaAtom of SigmaUVar _ _ _ -> case getType sigmaAtom of TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy))) -> do bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do case reqTy of Infer -> return [] - Check reqTy' -> return [TypeConstraint (sink reqTy') resultTy'] + Check reqTy' -> return [TypeConstraint sidtodo (sink reqTy') resultTy'] args <- inferMixedArgs @UExpr fDesc expls bsConstrained ([], []) - applySigmaAtom sigmaAtom args + applySigmaAtom sidtodo sigmaAtom args _ -> fallback _ -> fallback where - fallback = forceSigmaAtom sigmaAtom >>= matchReq reqTy + fallback = forceSigmaAtom sigmaAtom >>= matchReq sid reqTy fDesc = getSourceName sigmaAtom +matchReq :: Ext o o' => SrcId -> RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') +matchReq sid (Check reqTy) x = do + reqTy' <- sinkM reqTy + return x <* expectEq sid reqTy' (getType x) +matchReq _ Infer x = return x +{-# INLINE matchReq #-} + forceSigmaAtom :: Emits o => SigmaAtom o -> InfererM i o (CAtom o) forceSigmaAtom sigmaAtom = case sigmaAtom of SigmaAtom _ x -> return x SigmaUVar _ _ v -> case v of UAtomVar v' -> inlineTypeAliases v' - _ -> applySigmaAtom sigmaAtom [] + _ -> applySigmaAtom sidtodo sigmaAtom [] SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? withBlockDecls @@ -649,14 +655,14 @@ considerInlineAnn PlainLet (TyCon (Pi (CorePiType _ _ _ (EffTy Pure TyKind)))) = considerInlineAnn ann _ = ann applyFromLiteralMethod - :: Emits n => CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) -applyFromLiteralMethod resultTy methodName litVal = + :: Emits n => SrcId -> CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) +applyFromLiteralMethod sid resultTy methodName litVal = lookupSourceMap methodName >>= \case Nothing -> error $ "prelude function not found: " ++ pprint methodName Just ~(UMethodVar methodName') -> do MethodBinding className _ <- lookupEnv methodName' dictTy <- toType <$> dictType className [toAtom resultTy] - Just d <- toMaybeDict <$> trySynthTerm dictTy Full + Just d <- toMaybeDict <$> trySynthTerm sid dictTy Full emit =<< mkApplyMethod d 0 [litVal] -- atom that requires instantiation to become a rho type @@ -687,8 +693,8 @@ data FieldDef (n::S) = | FieldDotMethod (CAtomName n) (TyConParams n) deriving (Show, Generic) -getFieldDefs :: CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) -getFieldDefs ty = case ty of +getFieldDefs :: SrcId -> CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) +getFieldDefs sid ty = case ty of StuckTy _ _ -> noFields TyCon con -> case con of NewtypeTyCon (UserADTType _ tyName params) -> do @@ -704,7 +710,7 @@ getFieldDefs ty = case ty of RefType _ valTy -> case valTy of RefTy _ _ -> noFields _ -> do - valFields <- getFieldDefs valTy + valFields <- getFieldDefs sid valTy return $ M.filter isProj valFields where isProj = \case FieldProj _ -> True @@ -712,7 +718,7 @@ getFieldDefs ty = case ty of ProdType ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) TabPi _ -> noFields _ -> noFields - where noFields = throw $ NoFields $ pprint ty + where noFields = throw sid $ NoFields $ pprint ty projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of @@ -751,40 +757,40 @@ checkCAtom :: CAtom i -> PartialType o -> InfererM i o (CAtom o) checkCAtom arg argTy = do arg' <- renameM arg case argTy of - FullType argTy' -> expectEq argTy' (getType arg') + FullType argTy' -> expectEq sidtodo argTy' (getType arg') PartialType _ -> return () -- TODO? return arg' checkOrInferApp :: forall i o arg . (Emits o, ExplicitArg arg) - => SigmaAtom o -> [arg i] -> [(SourceName, arg i)] + => SrcId -> SrcId -> SigmaAtom o -> [arg i] -> [(SourceName, arg i)] -> RequiredTy o -> InfererM i o (CAtom o) -checkOrInferApp f' posArgs namedArgs reqTy = do +checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of TyCon (Pi piTy@(CorePiType appExpl expls _ _)) -> case appExpl of ExplicitApp -> do - checkExplicitArity expls posArgs - bsConstrained <- buildAppConstraints reqTy piTy + checkExplicitArity appSrcId expls posArgs + bsConstrained <- buildAppConstraints appSrcId reqTy piTy args <- inferMixedArgs fDesc expls bsConstrained (posArgs, namedArgs) - applySigmaAtom f args + applySigmaAtom appSrcId f args ImplicitApp -> error "should already have handled this case" - ty -> throw $ EliminationErr "function type" (pprint ty) + ty -> throw funSrcId $ EliminationErr "function type" (pprint ty) where fDesc :: SourceName fDesc = getSourceName f' -buildAppConstraints :: RequiredTy n -> CorePiType n -> InfererM i n (ConstrainedBinders n) -buildAppConstraints reqTy (CorePiType _ _ bs effTy) = do +buildAppConstraints :: SrcId -> RequiredTy n -> CorePiType n -> InfererM i n (ConstrainedBinders n) +buildAppConstraints appSrcId reqTy (CorePiType _ _ bs effTy) = do effsAllowed <- infEffects <$> getInfState buildConstraints (Abs bs effTy) \_ (EffTy effs resultTy) -> do resultTyConstraints <- return case reqTy of Infer -> [] - Check reqTy' -> [TypeConstraint (sink reqTy') resultTy] + Check reqTy' -> [TypeConstraint appSrcId (sink reqTy') resultTy] EffectRow _ t <- return effs effConstraints <- case t of NoTail -> return [] - EffectRowTail _ -> return [EffectConstraint (sink effsAllowed) effs] + EffectRowTail _ -> return [EffectConstraint appSrcId (sink effsAllowed) effs] return $ resultTyConstraints ++ effConstraints maybeInterpretPunsAsTyCons :: RequiredTy n -> SigmaAtom n -> InfererM i n (SigmaAtom n) @@ -802,12 +808,12 @@ inlineTypeAliases v = do LetBound (DeclBinding InlineLet (Atom e)) -> return e _ -> toAtom <$> toAtomVar v -applySigmaAtom :: Emits o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) -applySigmaAtom (SigmaAtom _ f) args = emitWithEffects =<< mkApp f args -applySigmaAtom (SigmaUVar _ _ f) args = case f of +applySigmaAtom :: Emits o => SrcId -> SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) +applySigmaAtom appSrcId (SigmaAtom _ f) args = emitWithEffects appSrcId =<< mkApp f args +applySigmaAtom appSrcId (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do f'' <- inlineTypeAliases f' - emitWithEffects =<< mkApp f'' args + emitWithEffects appSrcId =<< mkApp f'' args UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls @@ -836,9 +842,9 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args - emitWithEffects =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' -applySigmaAtom (SigmaPartialApp _ f prevArgs) args = - emitWithEffects =<< mkApp f (prevArgs ++ args) + emitWithEffects appSrcId =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' +applySigmaAtom appSrcId (SigmaPartialApp _ f prevArgs) args = + emitWithEffects appSrcId =<< mkApp f (prevArgs ++ args) splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do @@ -878,22 +884,22 @@ applyDataCon tc conIx topArgs = do where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty -emitWithEffects :: Emits o => CExpr o -> InfererM i o (CAtom o) -emitWithEffects expr = do - addEffects $ getEffects expr +emitWithEffects :: Emits o => SrcId -> CExpr o -> InfererM i o (CAtom o) +emitWithEffects sid expr = do + addEffects sid $ getEffects expr emit expr -checkExplicitArity :: [Explicitness] -> [a] -> InfererM i o () -checkExplicitArity expls args = do +checkExplicitArity :: SrcId -> [Explicitness] -> [a] -> InfererM i o () +checkExplicitArity sid expls args = do let arity = length [() | Explicit <- expls] let numArgs = length args - when (numArgs /= arity) $ throw $ ArityErr arity numArgs + when (numArgs /= arity) $ throw sid $ ArityErr arity numArgs type MixedArgs arg = ([arg], [(SourceName, arg)]) -- positional args, named args data Constraint (n::S) = - TypeConstraint (CType n) (CType n) + TypeConstraint SrcId (CType n) (CType n) -- permitted effects (no inference vars), proposed effects - | EffectConstraint (EffectRow CoreIR n) (EffectRow CoreIR n) + | EffectConstraint SrcId (EffectRow CoreIR n) (EffectRow CoreIR n) type Constraints = ListE Constraint type ConstrainedBinders n = ([IsDependent], Abs (Nest CBinder) Constraints n) @@ -944,8 +950,8 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs eagerlyApplyConstraints bs (ListE cs) = ListE <$> forMFilter cs \c -> do case hoist bs c of HoistSuccess c' -> case c' of - TypeConstraint _ _ -> applyConstraint c' >> return Nothing - EffectConstraint _ (EffectRow specificEffs _) -> + TypeConstraint _ _ _ -> applyConstraint c' >> return Nothing + EffectConstraint _ _ (EffectRow specificEffs _) -> hasInferenceVars specificEffs >>= \case False -> applyConstraint c' >> return Nothing -- we delay applying the constraint in this case because we might @@ -980,7 +986,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs withDistinct $ cont arg' args Nothing -> case infMech of Unify -> withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) args - Synth _ -> withDict argTy \d -> cont d args + Synth _ -> withDict sidtodo argTy \d -> cont d args checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) checkOrInferExplicitArg isDependent arg argTy = do @@ -989,7 +995,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs True -> checkExplicitDependentArg arg partialTy False -> checkExplicitNonDependentArg arg partialTy Nothing -> inferExplicitArg arg - constrainEq argTy (getType arg') + constrainEq sidtodo argTy (getType arg') return arg' lookupNamedArg :: MixedArgs x -> Maybe SourceName -> Maybe x @@ -1016,10 +1022,11 @@ checkNamedArgValidity expls offeredNames = do Inferred v _ -> v let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames - when (not $ null duplicates) $ throw $ RepeatedOptionalArgs $ map pprint duplicates + -- here and below we should be able to get a per-name src id + when (not $ null duplicates) $ throw sidtodo $ RepeatedOptionalArgs $ map pprint duplicates let unrecognizedNames = filter (not . (`elem` acceptedNames)) offeredNames when (not $ null unrecognizedNames) do - throw $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) + throw sidtodo $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do @@ -1101,15 +1108,15 @@ pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body))) -- === n-ary applications === -inferTabApp :: Emits o => CAtom o -> [UExpr i] -> InfererM i o (CAtom o) -inferTabApp tab args = do +inferTabApp :: Emits o => SrcId -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) +inferTabApp tabSrcId tab args = do tabTy <- return $ getType tab - args' <- inferNaryTabAppArgs tabTy args + args' <- inferNaryTabAppArgs tabSrcId tabTy args naryTabApp tab args' -inferNaryTabAppArgs :: Emits o => CType o -> [UExpr i] -> InfererM i o [CAtom o] -inferNaryTabAppArgs _ [] = return [] -inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of +inferNaryTabAppArgs :: Emits o => SrcId -> CType o -> [UExpr i] -> InfererM i o [CAtom o] +inferNaryTabAppArgs _ _ [] = return [] +inferNaryTabAppArgs tabSrcId tabTy (arg:rest) = case tabTy of TyCon (TabPi (TabPiType _ b resultTy)) -> do let ixTy = binderType b let isDependent = binderName b `isFreeIn` resultTy @@ -1117,12 +1124,12 @@ inferNaryTabAppArgs tabTy (arg:rest) = case tabTy of then checkSigmaDependent arg (FullType ixTy) else topDown ixTy arg resultTy' <- applySubst (b @> SubstVal arg') resultTy - rest' <- inferNaryTabAppArgs resultTy' rest + rest' <- inferNaryTabAppArgs tabSrcId resultTy' rest return $ arg':rest' - _ -> throw $ EliminationErr "table type" (pprint tabTy) + _ -> throw tabSrcId $ EliminationErr "table type" (pprint tabTy) checkSigmaDependent :: UExpr i -> PartialType o -> InfererM i o (CAtom o) -checkSigmaDependent e ty = withReducibleEmissions CantReduceDependentArg $ +checkSigmaDependent e ty = withReducibleEmissions (getSrcId e) CantReduceDependentArg $ topDownPartial (sink ty) e -- === sorting case alternatives === @@ -1271,10 +1278,11 @@ inferClassDef className methodNames paramBs methodTys = do checkUType m >>= \case TyCon (Pi t) -> return t t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) - PairB paramBs'' superclassBs <- partitionBinders (zipAttrs roleExpls paramBs') $ + PairB paramBs'' superclassBs <- partitionBinders rootSrcId (zipAttrs roleExpls paramBs') $ \b@(WithAttrB (_, expl) b') -> case expl of Explicit -> return $ LeftB b - Inferred _ Unify -> throw InterfacesNoImplicitParams + -- TODO: Add a proper SrcId here. We'll need to plumb it through from the original UBinders + Inferred _ Unify -> throw rootSrcId InterfacesNoImplicitParams Inferred _ (Synth _) -> return $ RightB b' let (roleExpls', paramBs''') = unzipAttrs paramBs'' builtinName <- case className of @@ -1287,7 +1295,7 @@ inferClassDef className methodNames paramBs methodTys = do withUBinder :: UAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a withUBinder (UAnnBinder expl b ann cs) cont = do - ty <- inferAnn ann cs + ty <- inferAnn (getSrcId b) ann cs withFreshBinderInf (getNameHint b) expl ty \b' -> extendSubst (b@>binderName b') $ cont (WithAttrB expl b') @@ -1307,7 +1315,7 @@ inferUBinders Empty cont = withDistinct $ Abs Empty <$> cont [] inferUBinders (Nest (UAnnBinder expl b ann cs) bs) cont = do -- TODO: factor out the common part of each case (requires an annotated -- `where` clause because of the rank-2 type) - ty <- inferAnn ann cs + ty <- inferAnn (getSrcId b) ann cs withFreshBinderInf (getNameHint b) expl ty \b' -> do extendSubst (b@>binderName b') do Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs) @@ -1336,32 +1344,32 @@ withRoleUBinders bs cont = do False -> return DataParam {-# INLINE inferRole #-} -inferAnn :: UAnn i -> [UConstraint i] -> InfererM i o (CType o) -inferAnn ann cs = case ann of +inferAnn :: SrcId -> UAnn i -> [UConstraint i] -> InfererM i o (CType o) +inferAnn binderSrcId ann cs = case ann of UAnn ty -> checkUType ty UNoAnn -> case cs of - WithSrcE _ (UVar ~(InternalName _ _ v)):_ -> do + WithSrcE sid (UVar ~(InternalName _ _ v)):_ -> do renameM v >>= getUVarType >>= \case TyCon (Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _)) -> return ty - ty -> throw $ NotAUnaryConstraint $ pprint ty - _ -> throw AnnotationRequired + ty -> throw sid $ NotAUnaryConstraint $ pprint ty + _ -> throw binderSrcId AnnotationRequired -checkULamPartial :: PartialPiType o -> ULamExpr i -> InfererM i o (CoreLamExpr o) -checkULamPartial partialPiTy lamExpr = do +checkULamPartial :: PartialPiType o -> SrcId -> ULamExpr i -> InfererM i o (CoreLamExpr o) +checkULamPartial partialPiTy sid lamExpr = do PartialPiType piAppExpl expls piBs piEffs piReqTy <- return partialPiTy ULamExpr lamBs lamAppExpl lamEffs lamResultTy body <- return lamExpr - checkExplicitArity expls (nestToList (const ()) lamBs) - when (piAppExpl /= lamAppExpl) $ throw $ WrongArrowErr (pprint piAppExpl) (pprint lamAppExpl) + checkExplicitArity sid expls (nestToList (const ()) lamBs) + when (piAppExpl /= lamAppExpl) $ throw sid $ WrongArrowErr (pprint piAppExpl) (pprint lamAppExpl) checkLamBinders expls piBs lamBs \lamBs' -> do PairE piEffs' piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) (PairE piEffs piReqTy) resultTy <- case (lamResultTy, piReqTy') of (Nothing, Infer ) -> return Infer (Just t , Infer ) -> Check <$> checkUType t (Nothing, Check t) -> Check <$> return t - (Just t , Check t') -> checkUType t >>= expectEq t' >> return (Check t') + (Just t , Check t') -> checkUType t >>= expectEq (getSrcId t) t' >> return (Check t') forM_ lamEffs \lamEffs' -> do lamEffs'' <- checkUEffRow lamEffs' - expectEq (Eff piEffs') (Eff lamEffs'') + expectEq sid (Eff piEffs') (Eff lamEffs'') -- TODO: add source annotations to lambda effects too body' <- withAllowedEffects piEffs' do buildBlock $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result resultTy' <- case resultTy of @@ -1383,7 +1391,7 @@ checkULamPartial partialPiTy lamExpr = do Explicit -> case lamBs of Nest (UAnnBinder _ lamB lamAnn _) lamBsRest -> do case lamAnn of - UAnn lamAnn' -> checkUType lamAnn' >>= expectEq piAnn + UAnn lamAnn' -> checkUType lamAnn' >>= expectEq (getSrcId lamAnn') piAnn UNoAnn -> return () withFreshBinderInf (getNameHint lamB) Explicit piAnn \b -> do Abs piBs' UnitE <- applyRename (piB@>binderName b) (EmptyAbs piBs) @@ -1398,13 +1406,13 @@ inferUForExpr (UForExpr b body) = do body' <- buildBlock $ withBlockDecls body \result -> bottomUp result return $ LamExpr (UnaryNest b') body' -checkUForExpr :: Emits o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) -checkUForExpr (UForExpr bFor body) (TabPiType _ bPi resultTy) = do +checkUForExpr :: Emits o => SrcId -> UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) +checkUForExpr sid (UForExpr bFor body) (TabPiType _ bPi resultTy) = do let uLamExpr = ULamExpr (UnaryNest bFor) ExplicitApp Nothing Nothing body effsAllowed <- infEffects <$> getInfState partialPi <- liftEnvReaderM $ refreshAbs (Abs bPi resultTy) \bPi' resultTy' -> do return $ PartialPiType ExplicitApp [Explicit] (UnaryNest bPi') (sink effsAllowed) (Check resultTy') - CoreLamExpr _ lamExpr <- checkULamPartial partialPi uLamExpr + CoreLamExpr _ lamExpr <- checkULamPartial partialPi sid uLamExpr return lamExpr inferULam :: ULamExpr i -> InfererM i o (CoreLamExpr o) @@ -1421,8 +1429,8 @@ inferULam (ULamExpr bs appExpl effs resultTy body) = do return $ PairE effTy body' return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body') -checkULam :: ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) -checkULam ulam piTy = checkULamPartial (piAsPartialPi piTy) ulam +checkULam :: SrcId -> ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) +checkULam sid ulam piTy = checkULamPartial (piAsPartialPi piTy) sid ulam piAsPartialPi :: CorePiType n -> PartialPiType n piAsPartialPi (CorePiType appExpl expls bs (EffTy effs ty)) = @@ -1451,48 +1459,52 @@ checkInstanceBody :: ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do + -- instances are top-level so it's ok to have imprecise root srcIds here + let sid = rootSrcId ClassDef _ _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys superclassTys <- superclassDictTys scBs' - superclassDicts <- mapM (flip trySynthTerm Full) superclassTys + superclassDicts <- mapM (flip (trySynthTerm sid) Full) superclassTys ListE methodTys'' <- applySubst (scBs'@@>(SubstVal<$>superclassDicts)) methodTys' methodsChecked <- mapM (checkMethodDef className methodTys'') methods let (idxs, methods') = unzip $ sortOn fst $ methodsChecked - forM_ (repeated idxs) \i -> throw $ DuplicateMethod $ pprint (methodNames!!i) - forM_ ([0..(length methodTys''-1)] `listDiff` idxs) \i -> throw $ MissingMethod $ pprint (methodNames!!i) + forM_ (repeated idxs) \i -> + throw sid $ DuplicateMethod $ pprint (methodNames!!i) + forM_ ([0..(length methodTys''-1)] `listDiff` idxs) \i -> + throw sid $ MissingMethod $ pprint (methodNames!!i) return $ InstanceBody superclassDicts methods' superclassDictTys :: Nest CBinder o o' -> InfererM i o [CType o] superclassDictTys Empty = return [] superclassDictTys (Nest b bs) = do - Abs bs' UnitE <- liftHoistExcept $ hoist b $ Abs bs UnitE + Abs bs' UnitE <- liftHoistExcept sidtodo $ hoist b $ Abs bs UnitE (binderType b:) <$> superclassDictTys bs' checkMethodDef :: ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) -checkMethodDef className methodTys (WithSrcE _ m) = do +checkMethodDef className methodTys (WithSrcE sid m) = do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do ClassBinding classDef <- lookupEnv className - throw $ NotAMethod (pprint sourceName) (pprint $ getSourceName classDef) - (i,) <$> toAtom <$> Lam <$> checkULam rhs (methodTys !! i) + throw sid $ NotAMethod (pprint sourceName) (pprint $ getSourceName classDef) + (i,) <$> toAtom <$> Lam <$> checkULam sid rhs (methodTys !! i) checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) checkUEffRow (UEffectRow effs t) = do effs' <- liftM eSetFromList $ mapM checkUEff $ toList effs t' <- case t of Nothing -> return NoTail - Just (SourceOrInternalName ~(InternalName _ _ v)) -> do + Just (SourceOrInternalName ~(InternalName sid _ v)) -> do v' <- toAtomVar =<< renameM v - expectEq EffKind (getType v') + expectEq sid EffKind (getType v') return $ EffectRowTail v' return $ EffectRow effs' t' checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) checkUEff eff = case eff of - URWSEffect rws (SourceOrInternalName ~(InternalName _ _ region)) -> do + URWSEffect rws (SourceOrInternalName ~(InternalName sid _ region)) -> do region' <- renameM region >>= toAtomVar - expectEq (TyCon HeapType) (getType region') + expectEq sid (TyCon HeapType) (getType region') return $ RWSEffect rws (toAtom region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect @@ -1507,31 +1519,31 @@ checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do return $ IndexedAlt idx alt getCaseAltIndex :: UPat i i' -> InfererM i o CaseAltIndex -getCaseAltIndex (WithSrcB _ pat) = case pat of +getCaseAltIndex (WithSrcB sid pat) = case pat of UPatCon ~(InternalName _ _ conName) _ -> do (_, con) <- renameM conName >>= lookupDataCon return con - _ -> throw IllFormedCasePattern + _ -> throw sid IllFormedCasePattern checkCasePat :: Emits o => UPat i i' -> CType o -> (forall o'. (Emits o', Ext o o') => InfererM i' o' (CAtom o')) -> InfererM i o (Alt CoreIR o) -checkCasePat (WithSrcB _ pat) scrutineeTy cont = case pat of +checkCasePat (WithSrcB sid pat) scrutineeTy cont = case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon tyConDef <- lookupTyCon dataDefName params <- inferParams scrutineeTy dataDefName ADTCons cons <- instantiateTyConDef tyConDef params DataConDef _ _ repTy idxs <- return $ cons !! con - when (length idxs /= nestLength ps) $ throw $ PatternArityErr (length idxs) (nestLength ps) + when (length idxs /= nestLength ps) $ throw sid $ PatternArityErr (length idxs) (nestLength ps) withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do buildBlock do args <- forM idxs \projs -> do emitToVar =<< applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) bindLetPats ps args $ cont - _ -> throw IllFormedCasePattern + _ -> throw sid IllFormedCasePattern inferParams :: Emits o => CType o -> TyConName o -> InfererM i o (TyConParams o) inferParams ty dataDefName = do @@ -1542,7 +1554,7 @@ inferParams ty dataDefName = do expl -> expl paramBsAbs <- buildConstraints (Abs paramBs UnitE) \params _ -> do let ty' = toType $ UserADTType sourceName (sink dataDefName) $ TyConParams paramExpls params - return [TypeConstraint (sink ty) ty'] + return [TypeConstraint sidtodo (sink ty) ty'] args <- inferMixedArgs sourceName inferenceExpls paramBsAbs emptyMixedArgs return $ TyConParams paramExpls args @@ -1560,19 +1572,19 @@ bindLetPat => UPat i i' -> CAtomVar o -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) -> InfererM i o (e o) -bindLetPat (WithSrcB _ pat) v cont = case pat of +bindLetPat (WithSrcB sid pat) v cont = case pat of UPatBinder b -> getDistinct >>= \Distinct -> extendSubst (b @> atomVarName v) cont UPatProd ps -> do let n = nestLength ps case getType v of TyCon (ProdType ts) | length ts == n -> return () - ty -> throw $ PatTypeErr "product type" (pprint ty) + ty -> throw sid $ PatTypeErr "product type" (pprint ty) xs <- forM (iota n) \i -> proj i (toAtom v) >>= emitInline bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do case getType v of TyCon (DepPairTy _) -> return () - ty -> throw $ PatTypeErr "dependent pair" (pprint ty) + ty -> throw sid $ PatTypeErr "dependent pair" (pprint ty) -- XXX: we're careful here to reduce the projection because of the dependent -- types. We do the same in the `UPatCon` case. x1 <- reduceProj 0 (toAtom v) >>= emitInline @@ -1586,16 +1598,16 @@ bindLetPat (WithSrcB _ pat) v cont = case pat of case cons of ADTCons [DataConDef _ _ _ idxss] -> do when (length idxss /= nestLength ps) $ - throw $ PatternArityErr (length idxss) (nestLength ps) + throw sid $ PatternArityErr (length idxss) (nestLength ps) void $ inferParams (getType $ toAtom v) dataDefName xs <- forM idxss \idxs -> applyProjectionsReduced idxs (toAtom v) >>= emitInline bindLetPats ps xs cont - _ -> throw SumTypeCantFail + _ -> throw sid SumTypeCantFail UPatTable ps -> do let n = fromIntegral (nestLength ps) :: Word32 case getType v of TyCon (TabPi (TabPiType _ (_:>FinConst n') _)) | n == n' -> return () - ty -> throw $ PatTypeErr ("Fin " ++ show n ++ " table") (pprint ty) + ty -> throw sid $ PatTypeErr ("Fin " ++ show n ++ " table") (pprint ty) xs <- forM [0 .. n - 1] \i -> do emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont @@ -1610,28 +1622,28 @@ checkUType t = do checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) checkUParam k uty = - withReducibleEmissions msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty + withReducibleEmissions (getSrcId uty) msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty where msg = CantReduceType $ pprint uty -inferTabCon :: forall i o. Emits o => [UExpr i] -> InfererM i o (CAtom o) -inferTabCon xs = do +inferTabCon :: forall i o. Emits o => SrcId -> [UExpr i] -> InfererM i o (CAtom o) +inferTabCon sid xs = do let n = fromIntegral (length xs) :: Word32 let finTy = FinConst n elemTy <- case xs of - [] -> throw InferEmptyTable + [] -> throw sid InferEmptyTable x:_ -> getType <$> bottomUp x - ixTy <- asIxType finTy + ixTy <- asIxType sid finTy let tabTy = ixTy ==> elemTy xs' <- forM xs \x -> topDown elemTy x let dTy = toType $ DataDictType elemTy - Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full + Just dataDict <- toMaybeDict <$> trySynthTerm sid dTy Full emit $ TabCon (Just $ WhenIRE dataDict) tabTy xs' -checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> [UExpr i] -> InfererM i o (CAtom o) -checkTabCon tabTy@(TabPiType _ b elemTy) xs = do +checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> SrcId -> [UExpr i] -> InfererM i o (CAtom o) +checkTabCon tabTy@(TabPiType _ b elemTy) sid xs = do let n = fromIntegral (length xs) :: Word32 let finTy = FinConst n - expectEq (binderType b) finTy + expectEq sid (binderType b) finTy xs' <- forM (enumerate xs) \(i, x) -> do let i' = toAtom (NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) :: CAtom o elemTy' <- applySubst (b@>SubstVal i') elemTy @@ -1643,22 +1655,22 @@ checkTabCon tabTy@(TabPiType _ b elemTy) xs = do elemTy' <- applyRename (b@>binderName b') elemTy let dTy = toType $ DataDictType elemTy' return $ toType $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) - Just dataDict <- toMaybeDict <$> trySynthTerm dTy Full + Just dataDict <- toMaybeDict <$> trySynthTerm sid dTy Full emit $ TabCon (Just $ WhenIRE dataDict) (TyCon (TabPi tabTy)) xs' -addEffects :: EffectRow CoreIR o -> InfererM i o () -addEffects Pure = return () -addEffects eff = do +addEffects :: SrcId -> EffectRow CoreIR o -> InfererM i o () +addEffects _ Pure = return () +addEffects sid eff = do effsAllowed <- infEffects <$> getInfState case checkExtends effsAllowed eff of Success () -> return () - Failure _ -> expectEq (Eff effsAllowed) (Eff eff) + Failure _ -> expectEq sid (Eff effsAllowed) (Eff eff) -getIxDict :: CType o -> InfererM i o (IxDict CoreIR o) -getIxDict t = fromJust <$> toMaybeDict <$> trySynthTerm (toType $ IxDictType t) Full +getIxDict :: SrcId -> CType o -> InfererM i o (IxDict CoreIR o) +getIxDict sid t = fromJust <$> toMaybeDict <$> trySynthTerm sid (toType $ IxDictType t) Full -asIxType :: CType o -> InfererM i o (IxType CoreIR o) -asIxType ty = IxType ty <$> getIxDict ty +asIxType :: SrcId -> CType o -> InfererM i o (IxType CoreIR o) +asIxType sid ty = IxType ty <$> getIxDict sid ty -- === Solver === @@ -1674,8 +1686,8 @@ lookupSolverSubst (SolverSubst m) name = applyConstraint :: Constraint o -> SolverM i o () applyConstraint = \case - TypeConstraint t1 t2 -> constrainEq t1 t2 - EffectConstraint r1 r2' -> do + TypeConstraint sid t1 t2 -> constrainEq sid t1 t2 + EffectConstraint sid r1 r2' -> do -- r1 shouldn't have inference variables. And we can't infer anything about -- any inference variables in r2's explicit effects because we don't know -- how they line up with r1's. So this is just about figuring out r2's tail. @@ -1683,7 +1695,7 @@ applyConstraint = \case let msg = DisallowedEffects (pprint r1) (pprint r2) case checkExtends r1 r2 of Success () -> return () - Failure _ -> searchFailureAsTypeErr msg do + Failure _ -> searchFailureAsTypeErr sid msg do EffectRow effs1 t1 <- return r1 EffectRow effs2 (EffectRowTail v2) <- return r2 guard =<< isUnificationName (atomVarName v2) @@ -1691,18 +1703,18 @@ applyConstraint = \case let extras1 = effs1 `eSetDifference` effs2 extendSolution v2 (toAtom $ EffectRow extras1 t1) -constrainEq :: ToAtom e CoreIR => e o -> e o -> SolverM i o () -constrainEq t1 t2 = do +constrainEq :: ToAtom e CoreIR => SrcId -> e o -> e o -> SolverM i o () +constrainEq sid t1 t2 = do t1' <- zonk $ toAtom t1 t2' <- zonk $ toAtom t2 msg <- liftEnvReaderM do ab <- renameForPrinting $ PairE t1' t2' return $ canonicalizeForPrinting ab \(Abs infVars (PairE t1Pretty t2Pretty)) -> UnificationFailure (pprint t1Pretty) (pprint t2Pretty) (nestToList pprint infVars) - void $ searchFailureAsTypeErr msg $ unify t1' t2' + void $ searchFailureAsTypeErr sid msg $ unify t1' t2' -searchFailureAsTypeErr :: ToErr e => e -> SolverM i n a -> SolverM i n a -searchFailureAsTypeErr msg cont = cont <|> throw msg +searchFailureAsTypeErr :: ToErr e => SrcId -> e -> SolverM i n a -> SolverM i n a +searchFailureAsTypeErr sid msg cont = cont <|> throw sid msg {-# INLINE searchFailureAsTypeErr #-} class AlphaEqE e => Unifiable (e::E) where @@ -1933,22 +1945,25 @@ withFreshSkolemName ty cont = diffStateT1 \s -> do (ans, diff) <- runDiffStateT1 (sink s) do v <- toAtomVar $ binderName b ans <- cont v >>= zonk - liftHoistExcept $ hoist b ans - diff' <- liftHoistExcept $ hoist b diff - return (ans, diff') + case hoist b ans of + HoistSuccess ans' -> return ans' + HoistFailure _ -> empty + case hoist b diff of + HoistSuccess diff' -> return (ans, diff') + HoistFailure _ -> empty {-# INLINE withFreshSkolemName #-} extendSolution :: CAtomVar n -> CAtom n -> SolverM i n () extendSolution (AtomVar v _) t = isUnificationName v >>= \case True -> do - when (v `isFreeIn` t) $ throw $ OccursCheckFailure (pprint v) (pprint t) + when (v `isFreeIn` t) solverFail -- occurs check -- When we unify under a pi binder we replace its occurrences with a -- skolem variable. We don't want to unify with terms containing these -- variables because that would mean inferring dependence, which is a can -- of worms. forM_ (freeAtomVarsList t) \fv -> - whenM (isSkolemName fv) $ throw CantUnifySkolem + whenM (isSkolemName fv) solverFail -- can't unify with skolems addConstraint v t False -> empty @@ -2019,7 +2034,9 @@ generalizeDictRec targetTy (DictCon dict) = case dict of InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do d <- mkInstanceDict (sink instanceName) args' - constrainEq (sink $ toAtom targetTy) (toAtom $ getType d) + -- We use rootSrcId here because we only call this after type inference so + -- precise source info isn't needed. + constrainEq rootSrcId (sink $ toAtom targetTy) (toAtom $ getType d) return d IxFin _ -> do TyCon (DictTy (IxDictType (TyCon (NewtypeTyCon (Fin n))))) <- return targetTy @@ -2065,14 +2082,14 @@ emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do emitBinding (getNameHint className) $ InstanceBinding instanceDef ty -- main entrypoint to dictionary synthesizer -trySynthTerm :: CType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) -trySynthTerm ty reqMethodAccess = do +trySynthTerm :: SrcId -> CType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) +trySynthTerm sid ty reqMethodAccess = do hasInferenceVars ty >>= \case - True -> throw $ CantSynthInfVars $ pprint ty + True -> throw sid $ CantSynthInfVars $ pprint ty False -> withVoidSubst do synthTy <- liftExcept $ typeAsSynthType ty - synthTerm synthTy reqMethodAccess - <|> (throw $ CantSynthDict $ pprint ty) + synthTerm sid synthTy reqMethodAccess + <|> (throw sid $ CantSynthDict $ pprint ty) {-# SCC trySynthTerm #-} hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool @@ -2159,11 +2176,11 @@ getSuperclassClosurePure env givens newGivens = forM (enumerate superclasses) \(i, _) -> do reduceSuperclassProj i $ fromJust (toMaybeDict synthExpr) -synthTerm :: SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) -synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of +synthTerm :: SrcId -> SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) +synthTerm sid targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of SynthPiType (expls, ab) -> do ab' <- withFreshBindersInf expls ab \bs' targetTy' -> do - Abs bs' <$> synthTerm (SynthDictType targetTy') reqMethodAccess + Abs bs' <$> synthTerm sid (SynthDictType targetTy') reqMethodAccess Abs bs' synthExpr <- return ab' let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) let lamExpr = LamExpr bs' (Atom synthExpr) @@ -2171,10 +2188,10 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of SynthDictType dictTy -> case dictTy of IxDictType (TyCon (NewtypeTyCon (Fin n))) -> return $ toAtom $ IxFin n DataDictType t -> do - void (synthDictForData dictTy <|> synthDictFromGiven dictTy) + void (synthDictForData sid dictTy <|> synthDictFromGiven sid dictTy) return $ toAtom $ DataData t _ -> do - dict <- synthDictFromInstance dictTy <|> synthDictFromGiven dictTy + dict <- synthDictFromInstance sid dictTy <|> synthDictFromGiven sid dictTy case dict of Con (DictConAtom (InstanceDict _ instanceName _)) -> do isReqMethodAccessAllowed <- reqMethodAccess `isMethodAccessAllowedBy` instanceName @@ -2194,8 +2211,8 @@ isMethodAccessAllowedBy access instanceName = do Full -> return $ numClassMethods == numInstanceMethods Partial numReqMethods -> return $ numReqMethods <= numInstanceMethods -synthDictFromGiven :: DictType n -> InfererM i n (SynthAtom n) -synthDictFromGiven targetTy = do +synthDictFromGiven :: SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictFromGiven sid targetTy = do givens <- ((HM.elems . fromGivens) <$> getGivens) asum $ givens <&> \given -> do case getSynthType given of @@ -2203,15 +2220,15 @@ synthDictFromGiven targetTy = do guard =<< alphaEq targetTy givenDictTy return given SynthPiType givenPiTy -> typeErrAsSearchFailure do - args <- instantiateSynthArgs targetTy givenPiTy + args <- instantiateSynthArgs sid targetTy givenPiTy reduceInstantiateGiven given args -synthDictFromInstance :: DictType n -> InfererM i n (SynthAtom n) -synthDictFromInstance targetTy = do +synthDictFromInstance :: SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictFromInstance sid targetTy = do instances <- getInstanceDicts targetTy asum $ instances <&> \candidate -> typeErrAsSearchFailure do CorePiType _ expls bs (EffTy _ (TyCon (DictTy candidateTy))) <- lookupInstanceTy candidate - args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) + args <- instantiateSynthArgs sid targetTy (expls, Abs bs candidateTy) return $ toAtom $ InstanceDict (toType targetTy) candidate args getInstanceDicts :: EnvReader m => DictType n -> m n [InstanceName n] @@ -2230,11 +2247,11 @@ addInstanceSynthCandidate className maybeBuiltin instanceName = do Just Data -> mempty emitLocalModuleEnv $ mempty {envSynthCandidates = sc} -instantiateSynthArgs :: DictType n -> SynthPiType n -> InfererM i n [CAtom n] -instantiateSynthArgs target (expls, synthPiTy) = do - liftM fromListE $ withReducibleEmissions CantReduceDict do +instantiateSynthArgs :: SrcId -> DictType n -> SynthPiType n -> InfererM i n [CAtom n] +instantiateSynthArgs sid target (expls, synthPiTy) = do + liftM fromListE $ withReducibleEmissions sid CantReduceDict do bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do - return [TypeConstraint (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] + return [TypeConstraint sid (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] ListE <$> inferMixedArgs "dict" expls bsConstrained emptyMixedArgs emptyMixedArgs :: MixedArgs (CAtom n) @@ -2245,11 +2262,11 @@ typeErrAsSearchFailure cont = cont `catchErr` \case TypeErr _ -> empty e -> throwErr e -synthDictForData :: forall i n. DictType n -> InfererM i n (SynthAtom n) -synthDictForData dictTy@(DataDictType ty) = case ty of +synthDictForData :: forall i n. SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictForData sid dictTy@(DataDictType ty) = case ty of -- TODO Deduplicate vs CheckType.checkDataLike -- The "Stuck" case is different - StuckTy _ _ -> synthDictFromGiven dictTy + StuckTy _ _ -> synthDictFromGiven sid dictTy TyCon con -> case con of TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success DepPairTy (DepPairType _ b@(_:>l) r) -> do @@ -2264,16 +2281,16 @@ synthDictForData dictTy@(DataDictType ty) = case ty of HeapType -> success _ -> notData where - recur ty' = synthDictForData $ DataDictType ty' + recur ty' = synthDictForData sid $ DataDictType ty' recurBinder :: Abs CBinder CType n -> InfererM i n (SynthAtom n) recurBinder (Abs b body) = withFreshBinderInf noHint Explicit (binderType b) \b' -> do body' <- applyRename (b@>binderName b') body - ans <- synthDictForData $ DataDictType (toType body') + ans <- synthDictForData sid $ DataDictType (toType body') return $ ignoreHoistFailure $ hoist b' ans notData = empty success = return $ toAtom $ DataData ty -synthDictForData dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy +synthDictForData _ dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy instance GenericE Givens where type RepE Givens = HashMapE (EKey SynthType) SynthAtom @@ -2329,7 +2346,7 @@ checkFFIFunTypeM _ = error "expected at least one argument" checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType checkScalar (BaseTy ty) = return ty -checkScalar ty = throw $ FFIArgTyNotScalar $ pprint ty +checkScalar ty = throw rootSrcId $ FFIArgTyNotScalar $ pprint ty checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType] checkScalarOrPairType (PairTy a b) = do @@ -2337,7 +2354,7 @@ checkScalarOrPairType (PairTy a b) = do tys2 <- checkScalarOrPairType b return $ tys1 ++ tys2 checkScalarOrPairType (BaseTy ty) = return [ty] -checkScalarOrPairType ty = throw $ FFIResultTyErr $ pprint ty +checkScalarOrPairType ty = throw rootSrcId $ FFIResultTyErr $ pprint ty -- === instances === @@ -2417,14 +2434,16 @@ instance RenameE SynthType instance SubstE AtomSubstVal SynthType instance GenericE Constraint where - type RepE Constraint = EitherE - (PairE CType CType) - (PairE (EffectRow CoreIR) (EffectRow CoreIR)) - fromE (TypeConstraint t1 t2) = LeftE (PairE t1 t2) - fromE (EffectConstraint e1 e2) = RightE (PairE e1 e2) + type RepE Constraint = PairE + (LiftE SrcId) + (EitherE + (PairE CType CType) + (PairE (EffectRow CoreIR) (EffectRow CoreIR))) + fromE (TypeConstraint sid t1 t2) = LiftE sid `PairE` LeftE (PairE t1 t2) + fromE (EffectConstraint sid e1 e2) = LiftE sid `PairE` RightE (PairE e1 e2) {-# INLINE fromE #-} - toE (LeftE (PairE t1 t2)) = TypeConstraint t1 t2 - toE (RightE (PairE e1 e2)) = EffectConstraint e1 e2 + toE (LiftE sid `PairE` LeftE (PairE t1 t2)) = TypeConstraint sid t1 t2 + toE (LiftE sid `PairE` RightE (PairE e1 e2)) = EffectConstraint sid e1 e2 {-# INLINE toE #-} instance SinkableE Constraint diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 111027ff0..e8ad7c7f3 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -44,7 +44,7 @@ type Parser = StateT ParseCtx (Parsec Void Text) parseit :: Text -> Parser a -> Except a parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of - Left e -> throw $ MiscParseErr $ errorBundlePretty e + Left e -> throw rootSrcId $ MiscParseErr $ errorBundlePretty e Right x -> return x mustParseit :: Text -> Parser a -> a diff --git a/src/lib/Name.hs b/src/lib/Name.hs index 4f025384c..9aa9402ca 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -2755,9 +2755,9 @@ canonicalizeForPrinting e cont = do pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' -liftHoistExcept :: Fallible m => HoistExcept a -> m a -liftHoistExcept (HoistSuccess x) = return x -liftHoistExcept (HoistFailure vs) = throw $ EscapedNameErr $ map pprint vs +liftHoistExcept :: Fallible m => SrcId -> HoistExcept a -> m a +liftHoistExcept _ (HoistSuccess x) = return x +liftHoistExcept sid (HoistFailure vs) = throw sid $ EscapedNameErr $ map pprint vs ignoreHoistFailure :: HasCallStack => HoistExcept a -> a ignoreHoistFailure (HoistSuccess x) = x @@ -2845,10 +2845,11 @@ exchangeBs (PairB b1 b2) = partitionBinders :: forall b b1 b2 m n l - . (SinkableB b2, HoistableB b1, BindsNames b2, Fallible m, Distinct l) => Nest b n l + . (SinkableB b2, HoistableB b1, BindsNames b2, Fallible m, Distinct l) + => SrcId -> Nest b n l -> (forall n' l'. b n' l' -> m (EitherB b1 b2 n' l')) -> m (PairB (Nest b1) (Nest b2) n l) -partitionBinders bs assignBinder = go bs where +partitionBinders sid bs assignBinder = go bs where go :: Distinct l' => Nest b n' l' -> m (PairB (Nest b1) (Nest b2) n' l') go = \case Empty -> return $ PairB Empty Empty @@ -2859,7 +2860,7 @@ partitionBinders bs assignBinder = go bs where RightB b2 -> withSubscopeDistinct bs2 case exchangeBs (PairB b2 bs1) of HoistSuccess (PairB bs1' b2') -> return $ PairB bs1' (Nest b2' bs2) - HoistFailure vs -> throw $ EscapedNameErr $ map pprint vs + HoistFailure vs -> throw sid $ EscapedNameErr $ map pprint vs -- NameBinder has no free vars, so there's no risk associated with hoisting. -- The scope is completely distinct, so their exchange doesn't create any accidental diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index acbff02c7..6229a1e8d 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -28,11 +28,6 @@ import PPrint import QueryTypePure import CheapReduction -sourceNameType :: (EnvReader m, Fallible1 m) => SourceName -> m n (Type CoreIR n) -sourceNameType v = do - lookupSourceMap v >>= \case - Nothing -> throw $ UnboundVarErr $ pprint v - Just uvar -> getUVarType uvar -- === Exposed helpers for querying types and effects === diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 102019098..011604d8e 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -72,7 +72,7 @@ checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO () checkedCallFunPtr fd argsPtr resultPtr fPtr = do let (CInt fd') = fdFD fd exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throw RuntimeErr + unless (exitCode == 0) $ throw rootSrcId RuntimeErr withPipeToLogger :: PassLogger -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 1fcaa73d0..d5420dab7 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -98,20 +98,20 @@ class SourceRenamableB (b :: B) where -> m o a instance SourceRenamableE (SourceNameOr UVar) where - sourceRenameE (SourceName pos sourceName) = - InternalName pos sourceName <$> lookupSourceName sourceName + sourceRenameE (SourceName sid sourceName) = + InternalName sid sourceName <$> lookupSourceName sid sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" -lookupSourceName :: Renamer m => SourceName -> m n (UVar n) -lookupSourceName v = do +lookupSourceName :: Renamer m => SrcId -> SourceName -> m n (UVar n) +lookupSourceName sid v = do sm <- askSourceMap case lookupSourceMapPure sm v of - [] -> throw $ UnboundVarErr $ pprint v + [] -> throw sid $ UnboundVarErr $ pprint v LocalVar v' : _ -> return v' [ModuleVar _ maybeV] -> case maybeV of Just v' -> return v' - Nothing -> throw $ VarDefErr $ pprint v - vs -> throw $ AmbiguousVarErr (pprint v) (map wherePretty vs) + Nothing -> throw sid $ VarDefErr $ pprint v + vs -> throw sid $ AmbiguousVarErr (pprint v) (map wherePretty vs) where wherePretty :: SourceNameDef n -> String wherePretty (ModuleVar mname _) = case mname of @@ -122,24 +122,24 @@ lookupSourceName v = do error "shouldn't be possible because module vars can't shadow local ones" instance SourceRenamableE (SourceNameOr (Name (AtomNameC CoreIR))) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UAtomVar v -> return $ InternalName pos sourceName v - _ -> throw $ NotAnOrdinaryVar $ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UAtomVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotAnOrdinaryVar $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name DataConNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UDataConVar v -> return $ InternalName pos sourceName v - _ -> throw $ NotADataCon $ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UDataConVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotADataCon $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name ClassNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UClassVar v -> return $ InternalName pos sourceName v - _ -> throw $ NotAClassName $ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UClassVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotAClassName $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrInternalName c) where @@ -148,8 +148,8 @@ instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrI instance (SourceRenamableE e, SourceRenamableB b) => SourceRenamableE (Abs b e) where sourceRenameE (Abs b e) = sourceRenameB b \b' -> Abs b' <$> sourceRenameE e -instance SourceRenamableB (UBinder' (AtomNameC CoreIR)) where - sourceRenameB b cont = sourceRenameUBinder' UAtomVar b cont +instance SourceRenamableB (UBinder (AtomNameC CoreIR)) where + sourceRenameB b cont = sourceRenameUBinder UAtomVar b cont instance SourceRenamableE UAnn where sourceRenameE UNoAnn = return UNoAnn @@ -161,8 +161,8 @@ instance SourceRenamableB UAnnBinder where cs' <- mapM sourceRenameE cs sourceRenameB b \b' -> cont $ UAnnBinder expl b' ann' cs' -instance SourceRenamableE UExpr' where - sourceRenameE expr = setMayShadow True case expr of +instance SourceRenamableE UExpr where + sourceRenameE (WithSrcE sid expr) = liftM (WithSrcE sid) $ setMayShadow True case expr of UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam @@ -211,12 +211,6 @@ instance SourceRenamableE UEffect where sourceRenameE UExceptionEffect = return UExceptionEffect sourceRenameE UIOEffect = return UIOEffect -instance SourceRenamableE a => SourceRenamableE (WithSrcE a) where - sourceRenameE (WithSrcE pos e) = WithSrcE pos <$> sourceRenameE e - -instance SourceRenamableB a => SourceRenamableB (WithSrcB a) where - sourceRenameB (WithSrcB pos b) cont = sourceRenameB b \b' -> cont $ WithSrcB pos b' - instance SourceRenamableB UTopDecl where sourceRenameB decl cont = case decl of ULocalDecl d -> sourceRenameB d \d' -> cont $ ULocalDecl d' @@ -244,15 +238,17 @@ instance SourceRenamableB UTopDecl where sourceRenameB instanceName \instanceName' -> cont $ UInstance className' conditions' params' methodDefs' instanceName' expl -instance SourceRenamableB UDecl' where - sourceRenameB decl cont = case decl of +instance SourceRenamableB UDecl where + sourceRenameB (WithSrcB sid decl) cont = case decl of ULet ann pat ty expr -> do expr' <- sourceRenameE expr ty' <- mapM sourceRenameE ty sourceRenameB pat \pat' -> - cont $ ULet ann pat' ty' expr' - UExprDecl e -> cont =<< (UExprDecl <$> sourceRenameE e) - UPass -> cont UPass + cont $ WithSrcB sid $ ULet ann pat' ty' expr' + UExprDecl e -> do + e' <- UExprDecl <$> sourceRenameE e + cont $ WithSrcB sid e' + UPass -> cont $ WithSrcB sid UPass instance SourceRenamableE ULamExpr where sourceRenameE (ULamExpr args expl effs resultTy body) = @@ -262,11 +258,11 @@ instance SourceRenamableE ULamExpr where <*> mapM sourceRenameE resultTy <*> sourceRenameE body -instance SourceRenamableE UBlock' where - sourceRenameE (UBlock decls result) = +instance SourceRenamableE UBlock where + sourceRenameE (WithSrcE sid (UBlock decls result)) = sourceRenameB decls \decls' -> do result' <- sourceRenameE result - return $ UBlock decls' result' + return $ WithSrcE sid $ UBlock decls' result' instance SourceRenamableB UnitB where sourceRenameB UnitB cont = cont UnitB @@ -294,31 +290,24 @@ sourceRenameUBinderNest asUVar (Nest b bs) cont = sourceRenameUBinderNest asUVar bs \bs' -> cont $ Nest b' bs' -sourceRenameUBinder' :: (Color c, Distinct o, Renamer m) - => (forall l. Name c l -> UVar l) - -> UBinder' c i i' - -> (forall o'. DExt o o' => UBinder' c o o' -> m o' a) - -> m o a -sourceRenameUBinder' asUVar ubinder cont = case ubinder of +sourceRenameUBinder + :: (Color c, Distinct o, Renamer m) + => (forall l. Name c l -> UVar l) + -> UBinder c i i' + -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) + -> m o a +sourceRenameUBinder asUVar (WithSrcB sid ubinder) cont = case ubinder of UBindSource b -> do SourceMap sm <- askSourceMap mayShadow <- askMayShadow let shadows = M.member b sm - when (not mayShadow && shadows) $ throw (RepeatedVarErr $ pprint b) + when (not mayShadow && shadows) $ throw sid $ RepeatedVarErr $ pprint b withFreshM (getNameHint b) \freshName -> do Distinct <- getDistinct extendSourceMap b (asUVar $ binderName freshName) $ - cont $ UBind b freshName + cont $ WithSrcB sid $ UBind b freshName UBind _ _ -> error "Shouldn't be source-renaming internal names" - UIgnore -> cont $ UIgnore - -sourceRenameUBinder :: (Color c, Distinct o, Renamer m) - => (forall l. Name c l -> UVar l) - -> UBinder c i i' - -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) - -> m o a -sourceRenameUBinder asUVar (WithSrcB sid ubinder) cont = - sourceRenameUBinder' asUVar ubinder \ubinder' -> cont (WithSrcB sid ubinder') + UIgnore -> cont $ WithSrcB sid $ UIgnore instance SourceRenamableE UDataDef where sourceRenameE (UDataDef tyConName paramBs dataCons) = do @@ -356,11 +345,11 @@ instance SourceRenamableE e => SourceRenamableE (ListE e) where instance SourceRenamableE UnitE where sourceRenameE UnitE = return UnitE -instance SourceRenamableE UMethodDef' where - sourceRenameE (UMethodDef ~(SourceName pos v) expr) = do - lookupSourceName v >>= \case - UMethodVar v' -> UMethodDef (InternalName pos v v') <$> sourceRenameE expr - _ -> throw $ NotAMethodName $ pprint v +instance SourceRenamableE UMethodDef where + sourceRenameE (WithSrcE sid ((UMethodDef ~(SourceName vSid v) expr))) = WithSrcE sid <$> do + lookupSourceName vSid v >>= \case + UMethodVar v' -> UMethodDef (InternalName vSid v v') <$> sourceRenameE expr + _ -> throw vSid $ NotAMethodName $ pprint v instance SourceRenamableB b => SourceRenamableB (Nest b) where sourceRenameB (Nest b bs) cont = @@ -383,32 +372,33 @@ class SourceRenamablePat (pat::B) where -> (forall o'. DExt o o' => SiblingSet -> pat o o' -> m o' a) -> m o a -instance SourceRenamablePat (UBinder' (AtomNameC CoreIR)) where - sourceRenamePat sibs ubinder cont = do +instance SourceRenamablePat (UBinder (AtomNameC CoreIR)) where + sourceRenamePat sibs (WithSrcB sid ubinder) cont = do newSibs <- case ubinder of UBindSource b -> do - when (S.member b sibs) $ throw $ RepeatedPatVarErr $ pprint b + when (S.member b sibs) $ throw sid $ RepeatedPatVarErr $ pprint b return $ S.singleton b UIgnore -> return mempty UBind _ _ -> error "Shouldn't be source-renaming internal names" - sourceRenameB ubinder \ubinder' -> + sourceRenameB (WithSrcB sid ubinder) \ubinder' -> cont (sibs <> newSibs) ubinder' -instance SourceRenamablePat UPat' where - sourceRenamePat sibs pat cont = case pat of - UPatBinder b -> sourceRenamePat sibs b \sibs' b' -> cont sibs' $ UPatBinder b' +instance SourceRenamablePat UPat where + sourceRenamePat sibs (WithSrcB sid pat) cont = case pat of + UPatBinder b -> sourceRenamePat sibs b \sibs' b' -> + cont sibs' $ WithSrcB sid $ UPatBinder b' UPatCon con bs -> do -- TODO Deduplicate this against the code for sourceRenameE of -- the SourceName case of SourceNameOr con' <- sourceRenameE con sourceRenamePat sibs bs \sibs' bs' -> - cont sibs' $ UPatCon con' bs' + cont sibs' $ WithSrcB sid $ UPatCon con' bs' UPatDepPair (PairB p1 p2) -> sourceRenamePat sibs p1 \sibs' p1' -> sourceRenamePat sibs' p2 \sibs'' p2' -> - cont sibs'' $ UPatDepPair $ PairB p1' p2' - UPatProd bs -> sourceRenamePat sibs bs \sibs' bs' -> cont sibs' $ UPatProd bs' - UPatTable ps -> sourceRenamePat sibs ps \sibs' ps' -> cont sibs' $ UPatTable ps' + cont sibs'' $ WithSrcB sid $ UPatDepPair $ PairB p1' p2' + UPatProd bs -> sourceRenamePat sibs bs \sibs' bs' -> cont sibs' $ WithSrcB sid $ UPatProd bs' + UPatTable ps -> sourceRenamePat sibs ps \sibs' ps' -> cont sibs' $ WithSrcB sid $ UPatTable ps' instance SourceRenamablePat UnitB where sourceRenamePat sibs UnitB cont = cont sibs UnitB @@ -429,11 +419,6 @@ instance (SourceRenamablePat p1, SourceRenamablePat p2) sourceRenamePat sibs p \sibs' p' -> cont sibs' $ RightB p' -instance SourceRenamablePat p => SourceRenamablePat (WithSrcB p) where - sourceRenamePat sibs (WithSrcB pos pat) cont = do - sourceRenamePat sibs pat \sibs' pat' -> - cont sibs' $ WithSrcB pos pat' - instance SourceRenamablePat p => SourceRenamablePat (Nest p) where sourceRenamePat sibs (Nest b bs) cont = sourceRenamePat sibs b \sibs' b' -> @@ -441,7 +426,7 @@ instance SourceRenamablePat p => SourceRenamablePat (Nest p) where cont sibs'' $ Nest b' bs' sourceRenamePat sibs Empty cont = cont sibs Empty -instance SourceRenamableB UPat' where +instance SourceRenamableB UPat where sourceRenameB pat cont = sourceRenamePat mempty pat \_ pat' -> cont pat' diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 47b132c37..5f0f84eae 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -263,7 +263,7 @@ evalSourceBlock' mname block = case sbContents block of DeclareForeign fname (WithSrc _ dexName) cTy -> do ty <- evalUType =<< parseExpr cTy asFFIFunType ty >>= \case - Nothing -> throw $ MiscMiscErr + Nothing -> throw rootSrcId $ MiscMiscErr "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available @@ -275,15 +275,15 @@ evalSourceBlock' mname block = case sbContents block of DeclareCustomLinearization fname zeros g -> do expr <- parseExpr g lookupSourceMap (withoutSrc fname) >>= \case - Nothing -> throw $ UnboundVarErr $ pprint fname + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint fname Just (UAtomVar fname') -> do lookupCustomRules fname' >>= \case Nothing -> return () - Just _ -> throw $ MiscMiscErr + Just _ -> throw rootSrcId $ MiscMiscErr $ pprint fname ++ " already has a custom linearization" lookupAtomName fname' >>= \case NoinlineFun _ _ -> return () - _ -> throw $ MiscMiscErr "Custom linearizations only apply to @noinline functions" + _ -> throw rootSrcId $ MiscMiscErr "Custom linearizations only apply to @noinline functions" -- We do some special casing to avoid instantiating polymorphic functions. impl <- case expr of WithSrcE _ (UVar _) -> @@ -296,17 +296,20 @@ evalSourceBlock' mname block = case sbContents block of liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case Failure _ -> do let implTy = getType impl - throw $ MiscMiscErr $ unlines + throw rootSrcId $ MiscMiscErr $ unlines [ "Expected the custom linearization to have type:" , "" , pprint linFunTy , "" , "but it has type:" , "" , pprint implTy] Success () -> return () updateTopEnv $ AddCustomRule fname' $ CustomLinearize nimplicit nexplicit zeros impl - Just _ -> throw $ MiscMiscErr $ "Custom linearization can only be defined for functions" - UnParseable _ s -> throw $ MiscParseErr s + Just _ -> throw rootSrcId $ MiscMiscErr $ "Custom linearization can only be defined for functions" + UnParseable _ s -> throw rootSrcId $ MiscParseErr s Misc m -> case m of GetNameType v -> do - ty <- sourceNameType (withoutSrc v) - logTop $ TextOut $ pprintCanonicalized ty + lookupSourceMap (withoutSrc v) >>= \case + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint v + Just uvar -> do + ty <- getUVarType uvar + logTop $ TextOut $ pprintCanonicalized ty ImportModule moduleName -> importModule moduleName QueryEnv query -> void $ runEnvQuery query $> UnitE ProseBlock _ -> return () @@ -327,11 +330,11 @@ runEnvQuery query = do DumpSubst -> logTop $ TextOut $ pprint $ env InternalNameInfo name -> case lookupSubstFragRaw (fromRecSubst $ envDefs $ topEnv env) name of - Nothing -> throw $ UnboundVarErr $ pprint name + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint name Just binding -> logTop $ TextOut $ pprint binding SourceNameInfo name -> do lookupSourceMap name >>= \case - Nothing -> throw $ UnboundVarErr $ pprint name + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint name Just uvar -> do logTop $ TextOut $ pprint uvar info <- case uvar of @@ -434,7 +437,7 @@ evalUModule (UModule name _ blocks) = do importModule :: (Mut n, TopBuilder m, Fallible1 m) => ModuleSourceName -> m n () importModule name = do lookupLoadedModule name >>= \case - Nothing -> throw $ ModuleImportErr $ pprint name + Nothing -> throw rootSrcId $ ModuleImportErr $ pprint name Just name' -> do Module _ _ transImports' _ _ <- lookupModule name' let importStatus = ImportStatus (S.singleton name') @@ -693,7 +696,7 @@ loadModuleSource config moduleName = do fsPaths <- liftIO $ traverse resolveBuiltinPath $ libPaths config liftIO (findFile fsPaths fname) >>= \case Just fpath -> return fpath - Nothing -> throw $ CantFindModuleSource $ pprint moduleName + Nothing -> throw rootSrcId $ CantFindModuleSource $ pprint moduleName resolveBuiltinPath = \case LibBuiltinPath -> liftIO $ getDataFileName "lib" LibDirectory dir -> return dir @@ -832,14 +835,14 @@ getLinearizationType zeros = \case Just tty -> case zeros of InstantiateZeros -> return tty SymbolicZeros -> symbolicTangentTy tty - Nothing -> throw $ MiscMiscErr $ "No tangent type for: " ++ pprint t + Nothing -> throw rootSrcId $ MiscMiscErr $ "No tangent type for: " ++ pprint t resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt - Nothing -> throw $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' + Nothing -> throw rootSrcId $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) return (numIs, numEs, toType $ Pi fullTy) - _ -> throw $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" + _ -> throw rootSrcId $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" where getNumImplicits :: Fallible m => [Explicitness] -> m (Int, Int) getNumImplicits = \case @@ -850,4 +853,4 @@ getLinearizationType zeros = \case Inferred _ _ -> return (ni + 1, ne) Explicit -> case ni of 0 -> return (0, ne + 1) - _ -> throw $ MiscMiscErr "All implicit args must precede implicit args" + _ -> throw rootSrcId $ MiscMiscErr "All implicit args must precede implicit args" diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 87a6e676a..b20a0e09a 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -480,6 +480,14 @@ data WithSrcE (a::E) (n::S) = WithSrcE SrcId (a n) data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcId (binder n l) deriving (Show, Generic) +instance HasSrcId (WithSrc a ) where getSrcId (WithSrc sid _ ) = sid +instance HasSrcId (WithSrcs a ) where getSrcId (WithSrcs sid _ _) = sid +instance HasSrcId (WithSrcE e n ) where getSrcId (WithSrcE sid _ ) = sid +instance HasSrcId (WithSrcB b n l) where getSrcId (WithSrcB sid _ ) = sid + +instance HasSrcId (UAnnBinder n l) where + getSrcId (UAnnBinder _ b _ _) = getSrcId b + class HasSrcPos withSrc a | withSrc -> a where srcPos :: withSrc -> SrcId withoutSrc :: withSrc -> a From e551ed0a7e20fac8496a5bbfaa0e7c18617a5089 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 5 Dec 2023 10:38:56 -0500 Subject: [PATCH 38/41] Fix bug in applying highlighting updates --- static/index.js | 32 ++++++++++++++++++-------------- static/style.css | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/static/index.js b/static/index.js index 4a9044db1..3be1aea57 100644 --- a/static/index.js +++ b/static/index.js @@ -162,31 +162,36 @@ function spansBetween(l, r) { return spans } function setCellStatus(cell, status) { - cell.className = "class" + cell.className = "cell" cell.classList.add(getStatusClass(status)) } - -function setCellContents(cellId, cell, contents) { +function addChild(cell, className, innerHTML) { + let child = document.createElement("div") + child.innerHTML = innerHTML + child.className = className + cell.appendChild(child) +} +function initializeCellContents(cellId, cell, contents) { let [source, status, result] = contents; let lineNum = source["rsbLine"]; let sourceText = source["rsbHtml"]; - let lineNumDiv = document.createElement("div"); - lineNumDiv.innerHTML = lineNum.toString(); - lineNumDiv.className = "line-num"; - cell.innerHTML = "" - cell.appendChild(lineNumDiv) + highlightMap[cellId] = {}; + hoverInfoMap[cellId] = {}; + addChild(cell, "line-num" , lineNum.toString()) + addChild(cell, "code-block" , sourceText) + addChild(cell, "cell-results", "") setCellStatus(cell, status) - cell.innerHTML += sourceText renderLaTeX(cell) extendCellResult(cellId, cell, result) } function extendCellResult(cellId, cell, result) { let resultText = result["rrHtml"] if (resultText !== "") { - cell.innerHTML += resultText + let bodyDiv = cell.querySelector(".cell-results") + bodyDiv.innerHTML += resultText } - highlightMap[cellId] = result["rrHighlightMap"] - hoverInfoMap[cellId] = result["rrHoverInfoMap"] + Object.assign(highlightMap[cellId], result["rrHighlightMap"]) + Object.assign(hoverInfoMap[cellId], result["rrHoverInfoMap"]) } function updateCellContents(cellId, cell, contents) { let [statusUpdate, result] = contents; @@ -198,7 +203,6 @@ function processUpdate(msg) { let cellUpdates = msg["nodeMapUpdate"]["mapUpdates"]; let numDropped = msg["orderedNodesUpdate"]["numDropped"]; let newTail = msg["orderedNodesUpdate"]["newTail"]; - // drop_dead_cells for (i = 0; i < numDropped; i++) { body.lastElementChild.remove();} @@ -210,7 +214,7 @@ function processUpdate(msg) { if (tag == "Create" || tag == "Replace") { let cell = document.createElement("div"); cells[cellId] = cell; - setCellContents(cellId, cell, contents) + initializeCellContents(cellId, cell, contents) } else if (tag == "Update") { let cell = cells[cellId]; updateCellContents(cellId, cell, contents); diff --git a/static/style.css b/static/style.css index e9f78a6d9..3d8de7c67 100644 --- a/static/style.css +++ b/static/style.css @@ -33,7 +33,7 @@ body { .code-block { } -.code-block, .err-block, .result-block { +.code-block, .cell-results, .err-block, .result-block { margin: 0em 0em 0em 4em; padding: 0em 0em 0em 2em; display: block; From 534be19547572e6fcfda86838d03dc6f71c0709a Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 5 Dec 2023 17:01:16 -0500 Subject: [PATCH 39/41] Highlight error source locations --- src/lib/Actor.hs | 34 ++++++++++++++++-- src/lib/Err.hs | 29 +++++++--------- src/lib/Export.hs | 12 +++---- src/lib/IncState.hs | 8 ++++- src/lib/Inference.hs | 12 +++---- src/lib/Lexing.hs | 2 +- src/lib/Live/Eval.hs | 2 +- src/lib/Live/Web.hs | 24 ++++++++----- src/lib/RenderHtml.hs | 72 +++++++++++++++++++++++++++----------- src/lib/Runtime.hs | 2 +- src/lib/TopLevel.hs | 24 ++++++------- static/index.js | 81 +++++++++++++++++++++++-------------------- static/style.css | 5 ++- 13 files changed, 195 insertions(+), 112 deletions(-) diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index 1da61268e..59ff089a5 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -10,12 +10,12 @@ module Actor ( ActorM, Actor (..), launchActor, send, selfMailbox, messageLoop, sliceMailbox, SubscribeMsg (..), IncServer, IncServerT, FileWatcher, StateServer, flushDiffs, handleSubscribeMsg, subscribe, subscribeIO, sendSync, - runIncServerT, launchFileWatcher, Mailbox + runIncServerT, launchFileWatcher, Mailbox, launchIncFunctionEvaluator ) where import Control.Concurrent import Control.Monad -import Control.Monad.State.Strict hiding (get) +import Control.Monad.State.Strict import Control.Monad.Reader import qualified Data.ByteString as BS import Data.IORef @@ -162,6 +162,36 @@ runIncServerT s cont = do ref <- newRef $ IncServerState [] mempty s runReaderT (runIncServerT' cont) ref +-- === Incremental function server === + +-- If you just need something that computes a function incrementally and doesn't +-- need to maintain any other state then this will do. + +data IncFunctionEvaluatorMsg da b db = + Subscribe_IFEM (SubscribeMsg b db) + | Update_IFEM da + deriving (Show) + +launchIncFunctionEvaluator + :: (IncState b db, Show da, MonadIO m) + => StateServer a da + -> (a -> (b,s)) + -> (b -> s -> da -> (db, s)) + -> m (StateServer b db) +launchIncFunctionEvaluator server fInit fUpdate = + sliceMailbox Subscribe_IFEM <$> launchActor do + x0 <- subscribe Update_IFEM server + let (y0, s0) = fInit x0 + flip evalStateT s0 $ runIncServerT y0 $ messageLoop \case + Subscribe_IFEM msg -> handleSubscribeMsg msg + Update_IFEM dx -> do + y <- getl It + s <- lift get + let (dy, s') = fUpdate y s dx + lift $ put s' + update dy + flushDiffs + -- === Refs === -- Just a wrapper around IORef lifted to `MonadIO` diff --git a/src/lib/Err.hs b/src/lib/Err.hs index 8a7037d22..7405cfbf9 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -52,11 +52,11 @@ data Err = SearchFailure String -- used as the identity for `Alternative` instances and for MonadFail. | InternalErr String | ParseErr ParseErr - | SyntaxErr SyntaxErr - | NameErr NameErr - | TypeErr TypeErr + | SyntaxErr SrcId SyntaxErr + | NameErr SrcId NameErr + | TypeErr SrcId TypeErr | RuntimeErr - | MiscErr MiscErr + | MiscErr MiscErr deriving (Show, Eq) type MsgStr = String @@ -161,14 +161,11 @@ data InfVarDesc = -- === ToErr class === class ToErr a where - toErr :: a -> Err + toErr :: SrcId -> a -> Err -instance ToErr Err where toErr = id -instance ToErr ParseErr where toErr = ParseErr instance ToErr SyntaxErr where toErr = SyntaxErr instance ToErr NameErr where toErr = NameErr instance ToErr TypeErr where toErr = TypeErr -instance ToErr MiscErr where toErr = MiscErr -- === Error messages === @@ -180,12 +177,12 @@ instance PrintableErr Err where SearchFailure s -> "Internal search failure: " ++ s InternalErr s -> "Internal compiler error: " ++ s ++ "\n" ++ "Please report this at github.com/google-research/dex-lang/issues\n" - ParseErr e -> "Parse error: " ++ printErr e - SyntaxErr e -> "Syntax error: " ++ printErr e - NameErr e -> "Name error: " ++ printErr e - TypeErr e -> "Type error: " ++ printErr e - MiscErr e -> "Error: " ++ printErr e - RuntimeErr -> "Runtime error" + ParseErr e -> "Parse error: " ++ printErr e + SyntaxErr _ e -> "Syntax error: " ++ printErr e + NameErr _ e -> "Name error: " ++ printErr e + TypeErr _ e -> "Type error: " ++ printErr e + MiscErr e -> "Error: " ++ printErr e + RuntimeErr -> "Runtime error" instance PrintableErr ParseErr where printErr = \case @@ -257,7 +254,7 @@ instance PrintableErr TypeErr where PatternArityErr n1 n2 -> "unexpected number of pattern binders. Expected " ++ show n1 ++ " but got " ++ show n2 SumTypeCantFail -> "sum type constructor in can't-fail pattern" PatTypeErr patTy rhsTy -> "pattern is for a " ++ patTy ++ "but we're matching against a " ++ rhsTy - EliminationErr expected ty -> "expected a " ++ expected ++ ". Got a: " ++ ty + EliminationErr expected ty -> "expected a " ++ expected ++ ". Got: " ++ ty IllFormedCasePattern -> "case patterns must start with a data constructor or variant pattern" NotAMethod method className -> "unexpected method: " ++ method ++ " is not a method of " ++ className DuplicateMethod method -> "duplicate method: " ++ method @@ -468,7 +465,7 @@ instance Fallible HardFailM where -- === convenience layer === throw :: (ToErr e, Fallible m) => SrcId -> e -> m a -throw _ e = throwErr $ toErr e +throw sid e = throwErr $ toErr sid e {-# INLINE throw #-} getCurrentCallStack :: () -> Maybe [String] diff --git a/src/lib/Export.hs b/src/lib/Export.hs index 67e356f6c..466afbb5d 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -48,11 +48,11 @@ prepareFunctionForExport :: (Mut n, Topper m) prepareFunctionForExport cc f = do naryPi <- case getType f of TyCon (Pi piTy) -> return piTy - _ -> throw rootSrcId $ MiscMiscErr "Only first-order functions can be exported" + _ -> throwErr $ MiscErr $ MiscMiscErr "Only first-order functions can be exported" sig <- liftExportSigM $ corePiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw rootSrcId $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throwErr $ MiscErr $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (toAtom <$> xs) fSimp <- simplifyTopFunction $ coreLamToTopLam f' @@ -68,7 +68,7 @@ prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw rootSrcId $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throwErr $ MiscErr $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s fImp <- compileTopLevelFun cc f nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -105,7 +105,7 @@ corePiToExportSig :: CallingConvention corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw rootSrcId $ MiscMiscErr "Only pure functions can be exported" + _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention @@ -113,7 +113,7 @@ simpPiToExportSig :: CallingConvention simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw rootSrcId $ MiscMiscErr "Only pure functions can be exported" + _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy @@ -164,7 +164,7 @@ toExportType ty = case ty of Nothing -> unsupported Just ety -> return ety _ -> unsupported - where unsupported = throw rootSrcId $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty + where unsupported = throwErr $ MiscErr $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty {-# INLINE toExportType #-} parseTabTy :: IRRep r => Type r i -> ExportSigM r i o (Maybe (ExportType o)) diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs index b5eb01b24..e825bf1e3 100644 --- a/src/lib/IncState.hs +++ b/src/lib/IncState.hs @@ -9,7 +9,7 @@ module IncState ( IncState (..), MapEltUpdate (..), MapUpdate (..), Overwrite (..), TailUpdate (..), Unchanging (..), Overwritable (..), - mapUpdateMapWithKey) where + mapUpdateMapWithKey, MonoidState (..)) where import Data.Aeson (ToJSON, ToJSONKey) import qualified Data.Map.Strict as M @@ -122,6 +122,12 @@ instance IncState (Overwritable a) (Overwrite a) where NoChange -> s OverwriteWith s' -> Overwritable s' +-- Case when the diff and the state are the same +newtype MonoidState a = MonoidState a + +instance Monoid a => IncState (MonoidState a) a where + applyDiff (MonoidState d) d' = MonoidState $ d <> d' + -- Trivial diff that works for any type - just replace the old value with a completely new one. newtype Unchanging a = Unchanging { fromUnchanging :: a } deriving (Show, Eq, Ord) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index cc10b1718..8f6e24fb9 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -2087,7 +2087,7 @@ trySynthTerm sid ty reqMethodAccess = do hasInferenceVars ty >>= \case True -> throw sid $ CantSynthInfVars $ pprint ty False -> withVoidSubst do - synthTy <- liftExcept $ typeAsSynthType ty + synthTy <- liftExcept $ typeAsSynthType sid ty synthTerm sid synthTy reqMethodAccess <|> (throw sid $ CantSynthDict $ pprint ty) {-# SCC trySynthTerm #-} @@ -2126,15 +2126,15 @@ extendGivens newGivens cont = do {-# INLINE extendGivens #-} getSynthType :: SynthAtom n -> SynthType n -getSynthType x = ignoreExcept $ typeAsSynthType (getType x) +getSynthType x = ignoreExcept $ typeAsSynthType rootSrcId (getType x) {-# INLINE getSynthType #-} -typeAsSynthType :: CType n -> Except (SynthType n) -typeAsSynthType = \case +typeAsSynthType :: SrcId -> CType n -> Except (SynthType n) +typeAsSynthType sid = \case TyCon (DictTy dictTy) -> return $ SynthDictType dictTy TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> return $ SynthPiType (expls, Abs bs d) - ty -> Failure $ toErr $ NotASynthType $ pprint ty + ty -> Failure $ toErr sid $ NotASynthType $ pprint ty {-# SCC typeAsSynthType #-} getSuperclassClosure :: EnvReader m => Givens n -> [SynthAtom n] -> m n (Givens n) @@ -2259,7 +2259,7 @@ emptyMixedArgs = ([], []) typeErrAsSearchFailure :: InfererM i n a -> InfererM i n a typeErrAsSearchFailure cont = cont `catchErr` \case - TypeErr _ -> empty + TypeErr _ _ -> empty e -> throwErr e synthDictForData :: forall i n. SrcId -> DictType n -> InfererM i n (SynthAtom n) diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index e8ad7c7f3..462902825 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -44,7 +44,7 @@ type Parser = StateT ParseCtx (Parsec Void Text) parseit :: Text -> Parser a -> Except a parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of - Left e -> throw rootSrcId $ MiscParseErr $ errorBundlePretty e + Left e -> throwErr $ ParseErr $ MiscParseErr $ errorBundlePretty e Right x -> return x mustParseit :: Text -> Parser a -> a diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 89fcd3e3b..97f99761c 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -7,7 +7,7 @@ {-# LANGUAGE UndecidableInstances #-} module Live.Eval ( - watchAndEvalFile, EvalServer, EvalUpdate, CellsUpdate, fmapCellsUpdate, + watchAndEvalFile, EvalServer, EvalUpdate, CellsState, CellsUpdate, fmapCellsUpdate, NodeList (..), NodeListUpdate (..), subscribeIO, nodeListAsUpdate) where import Control.Concurrent diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index 4e23d805c..0f5739a8e 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -22,16 +22,18 @@ import qualified Data.ByteString as BS import Live.Eval import RenderHtml +import IncState +import Actor import TopLevel import Types.Source runWeb :: FilePath -> EvalConfig -> TopStateEx -> IO () runWeb fname opts env = do - resultsChan <- watchAndEvalFile fname opts env + resultsChan <- watchAndEvalFile fname opts env >>= renderResults putStrLn "Streaming output to http://localhost:8000/" run 8000 $ serveResults resultsChan -serveResults :: EvalServer -> Application +serveResults :: RenderedResultsServer -> Application serveResults resultsSubscribe request respond = do print (pathInfo request) case pathInfo request of @@ -50,14 +52,15 @@ serveResults resultsSubscribe request respond = do -- fname <- getDataFileName dataFname respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing -resultStream :: EvalServer -> StreamingBody +type RenderedResultsServer = StateServer (MonoidState RenderedResults) RenderedResults +type RenderedResults = CellsUpdate RenderedSourceBlock RenderedOutputs + +resultStream :: RenderedResultsServer -> StreamingBody resultStream resultsServer write flush = do sendUpdate ("start"::String) - (initResult, resultsChan) <- subscribeIO resultsServer - sendUpdate $ renderEvalUpdate $ nodeListAsUpdate initResult - forever do - nextUpdate <- readChan resultsChan - sendUpdate $ renderEvalUpdate nextUpdate + (MonoidState initResult, resultsChan) <- subscribeIO resultsServer + sendUpdate initResult + forever $ readChan resultsChan >>= sendUpdate where sendUpdate :: ToJSON a => a -> IO () sendUpdate x = write (fromByteString $ encodePacket x) >> flush @@ -66,6 +69,11 @@ encodePacket :: ToJSON a => a -> BS.ByteString encodePacket = toStrict . wrap . encode where wrap s = "data:" <> s <> "\n\n" +renderResults :: EvalServer -> IO RenderedResultsServer +renderResults evalServer = launchIncFunctionEvaluator evalServer + (\x -> (MonoidState $ renderEvalUpdate $ nodeListAsUpdate x, ())) + (\_ () dx -> (renderEvalUpdate dx, ())) + renderEvalUpdate :: CellsUpdate SourceBlock Outputs -> CellsUpdate RenderedSourceBlock RenderedOutputs renderEvalUpdate cellsUpdate = fmapCellsUpdate cellsUpdate (\k b -> renderSourceBlock k b) diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index e4015c6a1..da8e8a28c 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -19,6 +19,7 @@ import Data.Aeson (ToJSON) import qualified Data.Map.Strict as M import Control.Monad.State.Strict import Control.Monad.Writer.Strict +import Data.Foldable (fold) import Data.Functor ((<&>)) import Data.Maybe (fromJust) import Data.String (fromString) @@ -29,10 +30,11 @@ import System.IO.Unsafe import GHC.Generics import Err +import IncState import Paths_dex (getDataFileName) import PPrint import Types.Source -import Util (unsnoc, foldJusts) +import Util (unsnoc) -- === rendering results === @@ -50,15 +52,22 @@ data RenderedSourceBlock = RenderedSourceBlock data RenderedOutputs = RenderedOutputs { rrHtml :: String + , rrLexemeSpans :: SpanMap , rrHighlightMap :: HighlightMap - , rrHoverInfoMap :: HoverInfoMap } + , rrHoverInfoMap :: HoverInfoMap + , rrErrorSrcIds :: [SrcId] } deriving (Generic) renderOutputs :: Outputs -> RenderedOutputs -renderOutputs r = RenderedOutputs - { rrHtml = pprintHtml r +renderOutputs (Outputs outputs) = fold $ map renderOutput outputs + +renderOutput :: Output -> RenderedOutputs +renderOutput r = RenderedOutputs + { rrHtml = pprintHtml r + , rrLexemeSpans = computeSpanMap r , rrHighlightMap = computeHighlights r - , rrHoverInfoMap = computeHoverInfo r } + , rrHoverInfoMap = computeHoverInfo r + , rrErrorSrcIds = computeErrSrcIds r} renderSourceBlock :: BlockId -> SourceBlock -> RenderedSourceBlock renderSourceBlock n b = RenderedSourceBlock @@ -83,38 +92,60 @@ instance ToMarkup Output where instance ToJSON RenderedOutputs instance ToJSON RenderedSourceBlock +instance Semigroup RenderedOutputs where + RenderedOutputs x1 y1 z1 w1 v1 <> RenderedOutputs x2 y2 z2 w2 v2 = + RenderedOutputs (x1<>x2) (y1<>y2) (z1<>z2) (w1<>w2) (v1<>v2) + +instance Monoid RenderedOutputs where + mempty = RenderedOutputs mempty mempty mempty mempty mempty + -- === textual information on hover === type HoverInfo = String newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) -computeHoverInfo :: Outputs -> HoverInfoMap -computeHoverInfo (Outputs outputs) = do - let typeInfo = foldJusts outputs \case - SourceInfo (SITypeInfo m) -> Just m - _ -> Nothing - HoverInfoMap $ fromTypeInfo typeInfo +computeHoverInfo :: Output -> HoverInfoMap +computeHoverInfo (SourceInfo (SITypeInfo m)) = HoverInfoMap $ fromTypeInfo m +computeHoverInfo _ = mempty -- === highlighting on hover === -newtype FocusMap = FocusMap (M.Map LexemeId SrcId) deriving (ToJSON, Semigroup, Monoid) +newtype SpanMap = SpanMap (M.Map SrcId LexemeSpan) deriving (ToJSON, Semigroup, Monoid) newtype HighlightMap = HighlightMap (M.Map SrcId Highlights) deriving (ToJSON, Semigroup, Monoid) -type Highlights = [(HighlightType, LexemeSpan)] +type Highlights = [(HighlightType, SrcId)] data HighlightType = HighlightGroup | HighlightLeaf deriving Generic instance ToJSON HighlightType -computeHighlights :: Outputs -> HighlightMap -computeHighlights (Outputs outputs) = do - execWriter $ mapM go $ foldJusts outputs \case - SourceInfo (SIGroupTree t) -> Just t - _ -> Nothing - where +computeErrSrcIds :: Output -> [SrcId] +computeErrSrcIds (Error err) = case err of + SearchFailure _ -> [] + InternalErr _ -> [] + ParseErr _ -> [] + SyntaxErr sid _ -> [sid] + NameErr sid _ -> [sid] + TypeErr sid _ -> [sid] + RuntimeErr -> [] + MiscErr _ -> [] +computeErrSrcIds _ = [] + +computeSpanMap :: Output -> SpanMap +computeSpanMap (SourceInfo (SIGroupTree (OverwriteWith tree))) = + execWriter $ go tree where + go :: GroupTree -> Writer SpanMap () + go t = do + tell $ SpanMap $ M.singleton (gtSrcId t) (gtSpan t) + mapM_ go $ gtChildren t +computeSpanMap _ = mempty + +computeHighlights :: Output -> HighlightMap +computeHighlights (SourceInfo (SIGroupTree (OverwriteWith tree))) = + execWriter $ go tree where go :: GroupTree -> Writer HighlightMap () go t = do let children = gtChildren t let highlights = children <&> \child -> - (getHighlightType (gtIsAtomicLexeme child), gtSpan child) + (getHighlightType (gtIsAtomicLexeme child), gtSrcId child) forM_ children \child-> do tell $ HighlightMap $ M.singleton (gtSrcId child) highlights go child @@ -122,6 +153,7 @@ computeHighlights (Outputs outputs) = do getHighlightType :: Bool -> HighlightType getHighlightType True = HighlightLeaf getHighlightType False = HighlightGroup +computeHighlights _ = mempty -- ----------------- diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 011604d8e..885088c21 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -72,7 +72,7 @@ checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO () checkedCallFunPtr fd argsPtr resultPtr fPtr = do let (CInt fd') = fdFD fd exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throw rootSrcId RuntimeErr + unless (exitCode == 0) $ throwErr RuntimeErr withPipeToLogger :: PassLogger -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 5f0f84eae..d1932bbf8 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -263,7 +263,7 @@ evalSourceBlock' mname block = case sbContents block of DeclareForeign fname (WithSrc _ dexName) cTy -> do ty <- evalUType =<< parseExpr cTy asFFIFunType ty >>= \case - Nothing -> throw rootSrcId $ MiscMiscErr + Nothing -> throwErr $ MiscErr $ MiscMiscErr "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available @@ -279,11 +279,11 @@ evalSourceBlock' mname block = case sbContents block of Just (UAtomVar fname') -> do lookupCustomRules fname' >>= \case Nothing -> return () - Just _ -> throw rootSrcId $ MiscMiscErr + Just _ -> throwErr $ MiscErr $ MiscMiscErr $ pprint fname ++ " already has a custom linearization" lookupAtomName fname' >>= \case NoinlineFun _ _ -> return () - _ -> throw rootSrcId $ MiscMiscErr "Custom linearizations only apply to @noinline functions" + _ -> throwErr $ MiscErr $ MiscMiscErr "Custom linearizations only apply to @noinline functions" -- We do some special casing to avoid instantiating polymorphic functions. impl <- case expr of WithSrcE _ (UVar _) -> @@ -296,13 +296,13 @@ evalSourceBlock' mname block = case sbContents block of liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case Failure _ -> do let implTy = getType impl - throw rootSrcId $ MiscMiscErr $ unlines + throwErr $ MiscErr $ MiscMiscErr $ unlines [ "Expected the custom linearization to have type:" , "" , pprint linFunTy , "" , "but it has type:" , "" , pprint implTy] Success () -> return () updateTopEnv $ AddCustomRule fname' $ CustomLinearize nimplicit nexplicit zeros impl - Just _ -> throw rootSrcId $ MiscMiscErr $ "Custom linearization can only be defined for functions" - UnParseable _ s -> throw rootSrcId $ MiscParseErr s + Just _ -> throwErr $ MiscErr $ MiscMiscErr $ "Custom linearization can only be defined for functions" + UnParseable _ s -> throwErr $ ParseErr $ MiscParseErr s Misc m -> case m of GetNameType v -> do lookupSourceMap (withoutSrc v) >>= \case @@ -437,7 +437,7 @@ evalUModule (UModule name _ blocks) = do importModule :: (Mut n, TopBuilder m, Fallible1 m) => ModuleSourceName -> m n () importModule name = do lookupLoadedModule name >>= \case - Nothing -> throw rootSrcId $ ModuleImportErr $ pprint name + Nothing -> throwErr $ MiscErr $ ModuleImportErr $ pprint name Just name' -> do Module _ _ transImports' _ _ <- lookupModule name' let importStatus = ImportStatus (S.singleton name') @@ -696,7 +696,7 @@ loadModuleSource config moduleName = do fsPaths <- liftIO $ traverse resolveBuiltinPath $ libPaths config liftIO (findFile fsPaths fname) >>= \case Just fpath -> return fpath - Nothing -> throw rootSrcId $ CantFindModuleSource $ pprint moduleName + Nothing -> throwErr $ MiscErr $ CantFindModuleSource $ pprint moduleName resolveBuiltinPath = \case LibBuiltinPath -> liftIO $ getDataFileName "lib" LibDirectory dir -> return dir @@ -835,14 +835,14 @@ getLinearizationType zeros = \case Just tty -> case zeros of InstantiateZeros -> return tty SymbolicZeros -> symbolicTangentTy tty - Nothing -> throw rootSrcId $ MiscMiscErr $ "No tangent type for: " ++ pprint t + Nothing -> throwErr $ MiscErr $ MiscMiscErr $ "No tangent type for: " ++ pprint t resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt - Nothing -> throw rootSrcId $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' + Nothing -> throwErr $ MiscErr $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) return (numIs, numEs, toType $ Pi fullTy) - _ -> throw rootSrcId $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" + _ -> throwErr $ MiscErr $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" where getNumImplicits :: Fallible m => [Explicitness] -> m (Int, Int) getNumImplicits = \case @@ -853,4 +853,4 @@ getLinearizationType zeros = \case Inferred _ _ -> return (ni + 1, ne) Explicit -> case ni of 0 -> return (0, ne + 1) - _ -> throw rootSrcId $ MiscMiscErr "All implicit args must precede implicit args" + _ -> throwErr $ MiscErr $ MiscMiscErr "All implicit args must precede implicit args" diff --git a/static/index.js b/static/index.js index 3be1aea57..6aa93849d 100644 --- a/static/index.js +++ b/static/index.js @@ -29,18 +29,10 @@ function renderLaTeX(root) { ); } -/** - * HTML rendering mode. - * Static rendering is used for static HTML pages. - * Dynamic rendering is used for dynamic HTML pages via `dex web`. - * - * @enum {string} - */ var RENDER_MODE = Object.freeze({ STATIC: "static", DYNAMIC: "dynamic", }) - var body = document.getElementById("main-output"); var hoverInfoDiv = document.getElementById("hover-info"); @@ -48,8 +40,8 @@ var hoverInfoDiv = document.getElementById("hover-info"); var cells = {} var frozenHover = false; var curHighlights = []; // HTML elements currently highlighted -var focusMap = {} var highlightMap = {} +var spanMap = {} var hoverInfoMap = {} function removeHover() { @@ -77,16 +69,25 @@ function applyHoverInfo(cellId, srcId) { hoverInfoDiv.innerHTML = hoverInfo } } +function getSpan(cellId, srcId) { + return lookupSrcMap(spanMap, cellId, srcId) +} function applyHoverHighlights(cellId, srcId) { let highlights = lookupSrcMap(highlightMap, cellId, srcId) if (highlights == null) return highlights.map(function (highlight) { - let [highlightType, [l, r]] = highlight - let spans = spansBetween(selectSpan(cellId, l), selectSpan(cellId, r)); + let [highlightType, highlightSrcId] = highlight let highlightClass = getHighlightClass(highlightType) + addClass(cellId, highlightSrcId, highlightClass)}) +} +function addClass(cellId, srcId, className) { + let span = getSpan(cellId, srcId) + if (span !== undefined) { + let [l, r] = span + let spans = spansBetween(selectSpan(cellId, l), selectSpan(cellId, r)); spans.map(function (span) { - span.classList.add(highlightClass) - curHighlights.push(span)})}) + span.classList.add(className) + curHighlights.push(span)})} } function toggleFrozenHover() { if (frozenHover) { @@ -103,30 +104,6 @@ function attachHovertip(cellId, srcId) { span.addEventListener("mouseout" , function (event) { event.stopPropagation() removeHover()})} - -/** - * Renders the webpage. - * @param {RENDER_MODE} renderMode The render mode, either static or dynamic. - */ -function render(renderMode) { - if (renderMode == RENDER_MODE.STATIC) { - // For static pages, simply call rendering functions once. - renderLaTeX(document); - } else { - // For dynamic pages (via `dex web`), listen to update events. - var source = new EventSource("/getnext"); - source.onmessage = function(event) { - var msg = JSON.parse(event.data); - if (msg == "start") { - body.innerHTML = "" - body.addEventListener("click", function (event) { - event.stopPropagation() - toggleFrozenHover()}) - cells = {} - return - } else { - processUpdate(msg)}};} -} function selectSpan(cellId, srcId) { return cells[cellId].querySelector("#span_".concat(cellId, "_", srcId)) } @@ -177,6 +154,7 @@ function initializeCellContents(cellId, cell, contents) { let sourceText = source["rsbHtml"]; highlightMap[cellId] = {}; hoverInfoMap[cellId] = {}; + spanMap[cellId] = {}; addChild(cell, "line-num" , lineNum.toString()) addChild(cell, "code-block" , sourceText) addChild(cell, "cell-results", "") @@ -192,6 +170,11 @@ function extendCellResult(cellId, cell, result) { } Object.assign(highlightMap[cellId], result["rrHighlightMap"]) Object.assign(hoverInfoMap[cellId], result["rrHoverInfoMap"]) + Object.assign(spanMap[cellId] , result["rrLexemeSpans"]) + + let errSrcIds = result["rrErrorSrcIds"] + errSrcIds.map(function (srcId) { + addClass(cellId, srcId, "err-span")}) } function updateCellContents(cellId, cell, contents) { let [statusUpdate, result] = contents; @@ -238,3 +221,27 @@ function processUpdate(msg) { let lexemeList = source["rsbLexemeList"]; lexemeList.map(function (lexemeId) {attachHovertip(cellId, lexemeId.toString())})}}); } + +/** + * Renders the webpage. + * @param {RENDER_MODE} renderMode The render mode, either static or dynamic. + */ +function render(renderMode) { + if (renderMode == RENDER_MODE.STATIC) { + // For static pages, simply call rendering functions once. + renderLaTeX(document); + } else { + // For dynamic pages (via `dex web`), listen to update events. + var source = new EventSource("/getnext"); + source.onmessage = function(event) { + var msg = JSON.parse(event.data); + if (msg == "start") { + body.innerHTML = "" + body.addEventListener("click", function (event) { + event.stopPropagation() + toggleFrozenHover()}) + cells = {} + return + } else { + processUpdate(msg)}};} +} diff --git a/static/style.css b/static/style.css index 3d8de7c67..450f70bfd 100644 --- a/static/style.css +++ b/static/style.css @@ -40,7 +40,10 @@ body { font-family: monospace; white-space: pre; } - +.err-span { + text-decoration: red wavy underline; + text-decoration-skip-ink: none; +} code { background-color: #F0F0F0; } From 2c9e55780908a16a25212fa1d8c6dcbb995e8b59 Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 5 Dec 2023 21:32:15 -0500 Subject: [PATCH 40/41] Use # instead of -- for comments --- lib/prelude.dx | 242 +++++++++++++++++++++++----------------------- misc/dex.el | 5 +- src/lib/Lexing.hs | 6 +- 3 files changed, 125 insertions(+), 128 deletions(-) diff --git a/lib/prelude.dx b/lib/prelude.dx index be288660a..8c62e15b5 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -321,7 +321,7 @@ def size(n:Type|Ix) -> Nat = size'(n=n) def Fin(n:Nat) -> Type = %Fin(n) --- version of subtraction on Nats that clamps at zero +# version of subtraction on Nats that clamps at zero def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y @@ -333,7 +333,7 @@ def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = y' = nat_to_rep y rep_to_nat %isub(x', y') --- TODO: need to a way to indicate constructor as private +# TODO: need to a way to indicate constructor as private struct RangeFrom(i:q) given (q:Type) = val : Nat struct RangeFromExc(i:q) given (q:Type) = val : Nat @@ -377,16 +377,16 @@ instance Add(n=>a) given (a|Add, n|Ix) instance Sub(n=>a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => RangeFrom i => a) given (a|Add, n|Ix) -- Upper triangular tables +instance Add((i:n) => RangeFrom i => a) given (a|Add, n|Ix) # Upper triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => RangeFrom i => a) given (a|Sub, n|Ix) -- Upper triangular tables +instance Sub((i:n) => RangeFrom i => a) given (a|Sub, n|Ix) # Upper triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => RangeTo i => a) given (a|Add, n|Ix) -- Lower triangular tables +instance Add((i:n) => RangeTo i => a) given (a|Add, n|Ix) # Lower triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => RangeTo i => a) given (a|Sub, n|Ix) -- Lower triangular tables +instance Sub((i:n) => RangeTo i => a) given (a|Sub, n|Ix) # Lower triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] instance Add((i:n) => RangeToExc i => a) given (a|Add, n|Ix) @@ -516,8 +516,8 @@ def not(x:Bool) -> Bool = '## More Boolean operations TODO: move these with the others? --- Can't use `%select` because it lowers to `ISelect`, which requires --- `a` to be a `BaseTy`. +# Can't use `%select` because it lowers to `ISelect`, which requires +# `a` to be a `BaseTy`. def select(p:Bool, x:a, y:a) -> a given (a:Type) = case p of True -> x @@ -570,14 +570,14 @@ instance Ix(Either(a, b)) given (a|Ix, b|Ix) def unsafe_from_ordinal(o) = as = nat_to_rep $ size a o' = nat_to_rep o - -- TODO: Reshuffle the prelude to be able to use (<) here + # TODO: Reshuffle the prelude to be able to use (<) here case w8_to_b $ %ilt(o', as) of True -> Left $ unsafe_from_ordinal(n=a, o) - -- TODO: Reshuffle the prelude to be able to use `diff_nat` here + # TODO: Reshuffle the prelude to be able to use `diff_nat` here False -> Right $ unsafe_from_ordinal(n=b, rep_to_nat (%isub(o', as))) '## Subtraction on Nats --- TODO: think more about the right API here +# TODO: think more about the right API here def unsafe_i_to_n(x:Int) -> Nat = rep_to_nat $ internal_cast x @@ -846,7 +846,7 @@ instance Ord(Nat) def (>)(x, y) = nat_to_rep x > nat_to_rep y def (<)(x, y) = nat_to_rep x < nat_to_rep y --- TODO: we want Eq and Ord for all index sets, not just `Fin n` +# TODO: we want Eq and Ord for all index sets, not just `Fin n` instance Eq(Fin n) given (n:Nat) def (==)(x, y) = ordinal x == ordinal y @@ -886,10 +886,10 @@ instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) first_ix = unsafe_from_ordinal 0 --- The below instance is valid, but causes "multiple candidate dictionaries" --- errors if both Left and Right are NonEmpty. --- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] --- first_ix = unsafe_from_ordinal _ 0 +# The below instance is valid, but causes 'multiple candidate dictionaries' +# errors if both Left and Right are NonEmpty. +# instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] +# first_ix = unsafe_from_ordinal _ 0 instance NonEmpty(Maybe a) given (a|Ix) first_ix = unsafe_from_ordinal 0 @@ -955,7 +955,7 @@ interface Subset(subset:Type, superset:Type) project' : (superset) -> Maybe subset unsafe_project' : (superset) -> subset --- wrappers with more helpful implicit arg names +# wrappers with more helpful implicit arg names def inject(x:from) -> to given (to:Type, from:Type) (Subset(from, to)) = inject'(x) def project(x:from) -> Maybe to given (to:Type, from:Type) (Subset(to, from)) = project'(x) def unsafe_project(x:from) -> to given (to:Type, from:Type) (Subset(to, from)) = unsafe_project'(x) @@ -1053,14 +1053,14 @@ interface Floating(a:Type) def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) --- Todo: better numerics for very large and small values. --- Using %exp here to avoid circular definition problems. +# Todo: better numerics for very large and small values. +# Using %exp here to avoid circular definition problems. def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) ,%fadd(%exp(x), %exp(%fsub(0.0,x)))) --- Todo: unify this with float32 functions. +# Todo: unify this with float32 functions. def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) @@ -1150,9 +1150,9 @@ instance Storable(Nat) instance Storable(Ptr a) given (a:Type) def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) - def storage_size() = 8 -- TODO: something more portable? + def storage_size() = 8 # TODO: something more portable? --- TODO: Storable instances for other types +# TODO: Storable instances for other types def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n @@ -1164,7 +1164,7 @@ def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) Ptr(%ptrOffset(ptr.val, i')) --- TODO: consider making a Storable instance for tables instead +# TODO: consider making a Storable instance for tables instead def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = for_ i:n. store(ptr +>> ordinal i, tab[i]) @@ -1173,8 +1173,8 @@ def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = i' = ordinal i store(dest +>> i', load $ src +>> i') --- TODO: generalize these brackets to allow other effects --- TODO: make sure that freeing happens even if there are run-time errors +# TODO: generalize these brackets to allow other effects +# TODO: make sure that freeing happens even if there are run-time errors def with_alloc( a|Storable, n:Nat, @@ -1202,7 +1202,7 @@ pi : Float = 3.141592653589793 def id(x:a) -> a given (a:Type) = x def dup(x:a) -> (a, a) given (a:Type) = (x, x) --- map, flipped so that the function goes last +# map, flipped so that the function goes last def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a:Type, b:Type, n|Ix, eff:Effects) = for i. f xs[i] @@ -1266,10 +1266,10 @@ def scan( def fold(init:c, xs:n=>a, body:(n, a, c)-> c) -> c given (a:Type, n|Ix, c|Data) = snd $ scan(init, xs) \i x carry. ((), body(i, x, carry)) --- `combine` should be a commutative and associative, and form a --- commutative monoid with `identity` +# `combine` should be a commutative and associative, and form a +# commutative monoid with `identity` def reduce(xs:n=>a, identity:a, combine:(a,a)->a) -> a given (a|Data, n|Ix) = - -- TODO: implement with the accumulator effect + # TODO: implement with the accumulator effect fold(identity, xs) \i x c. combine(c, x) def fsum(xs:n=>Float) -> Float given (n|Ix) = @@ -1309,7 +1309,7 @@ def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = '### AD operations --- TODO: add vector space constraints +# TODO: add vector space constraints def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a:Type, b:Type) = %linearize(\x:a. f x, x) @@ -1327,8 +1327,8 @@ def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0) def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0) --- XXX: Watch out when editing this data type! We depend on its structure --- deep inside the compiler (mostly in linearization and during rule registration). +# XXX: Watch out when editing this data type! We depend on its structure +# deep inside the compiler (mostly in linearization and during rule registration). data SymbolicTangent(a:Type) = ZeroTangent SomeTangent(a) @@ -1433,13 +1433,13 @@ named-instance ListMonoid (a|Data) -> Monoid(List a) mempty = mempty def (<>)(x, y) = x <> y --- TODO Eliminate or reimplement this operation, since it costs O(n) --- where n is the length of the list held in the reference. +# TODO Eliminate or reimplement this operation, since it costs O(n) +# where n is the length of the list held in the reference. def append(list: Ref(h, List a), x:a) -> {Accum h} () given (a|Data, h:Heap) (AccumMonoid(h, List a)) = list += to_list [x] --- TODO: replace `slice` with this? +# TODO: replace `slice` with this? def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a:Type) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). @@ -1452,13 +1452,13 @@ String : Type = List Char def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String = AsList(rep_to_nat n, table_from_ptr ptr) --- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint +# TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint(c:Char) -> Int = w8_to_i c struct CString = ptr : RawPtr --- TODO: check the string contains no nulls +# TODO: check the string contains no nulls def with_c_string( s:String, action: (CString) -> {IO} a @@ -1563,15 +1563,15 @@ def copysign(a:Float, b:Float) -> Float = True -> (-a) False -> 0.0 --- Todo: use IEEE floating-point builtins. +# Todo: use IEEE floating-point builtins. infinity = 1.0 / 0.0 nan = 0.0 / 0.0 --- Todo: use IEEE floating-point builtins. +# Todo: use IEEE floating-point builtins. def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity) def isnan(x:Float) -> Bool = not (x >= x && x <= x) --- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. +# Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y) '## File system operations @@ -1633,7 +1633,7 @@ data IterResult(a|Data) = Continue Done(a) --- A little iteration combinator +# A little iteration combinator def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff:Effects) = result = yield_state (Nothing::Maybe a) \resultRef. i <- with_state (0::Nat) @@ -1693,14 +1693,14 @@ def from_ordinal(i:Nat) -> n given (n|Ix) = True -> unsafe_from_ordinal i False -> error $ from_ordinal_error(i, size n) --- TODO: should this be called `from_ordinal`? +# TODO: should this be called `from_ordinal`? def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing --- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy --- TODO: safe (runtime-checked) and unsafe versions +# TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy +# TODO: safe (runtime-checked) and unsafe versions def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = case size from == size to of True -> unsafe_cast_table xs @@ -1726,19 +1726,19 @@ Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/j '### Key functions --- TODO: newtype +# TODO: newtype Key = Word64 @noinline def threefry_2x32(k:Word64, count:Word64) -> Word64 = - -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins + # Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins rotations1 : Fin 4 => Int32 = [13, 15, 26, 6] rotations2 : Fin 4 => Int32 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k - -- TODO: add a fromHex - k2 = k0 .^. k1 .^. (n_to_w32 466688986) -- 0x1BD11BDA + # TODO: add a fromHex + k2 = k0 .^. k1 .^. (n_to_w32 466688986) # 0x1BD11BDA x = low_word count y = high_word count @@ -1776,8 +1776,8 @@ These functions generate samples taken from, different distributions. Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library. def rand(k:Key) -> Float = - exponent_bits : Word32 = 1065353216 -- 1065353216 = 127 << 23 - mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 + exponent_bits : Word32 = 1065353216 # 1065353216 = 127 << 23 + mantissa_bits = (high_word k .&. 8388607) # 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0 @@ -1789,13 +1789,13 @@ def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given ( def randn(k:Key) -> Float = [k1, k2] = split_key(n=2, k) - -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 - -- (in float32) once every few million draws, but never rounds to 1. + # rand is uniform between 0 and 1, but implemented such that it rounds to 0 + # (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) --- TODO: Make this better... +# TODO: Make this better... def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647 def randn_vec(k:Key) -> n=>Float given (n|Ix) = @@ -1916,7 +1916,7 @@ def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(xs, id) def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(xs, id) '### argmin/argmax --- TODO: put in same section as `searchsorted` +# TODO: put in same section as `searchsorted` def argscan(xs:n=>a, comp:(a,a)->Bool) -> n given (a|Ord, n|Ix) = AccumTy : Type = (n, a) @@ -1937,18 +1937,18 @@ def lexical_order( compareElements:(n,n)->Bool, compareLengths: (Nat,Nat)->Bool ) -> Bool given (n|Ord) = - -- Orders Lists according to the order of their elements, - -- in the same way a dictionary does. - -- For example, this lets us sort Strings. - -- - -- More precisely, it returns True iff compareElements xs.i ys.i is true - -- at the first location they differ. - -- - -- This function operates serially and short-circuits - -- at the first difference. One could also write this - -- function as a parallel reduction, but it would be - -- wasteful in the case where there is an early difference, - -- because we can't short circuit. + # Orders Lists according to the order of their elements, + # in the same way a dictionary does. + # For example, this lets us sort Strings. + # + # More precisely, it returns True iff compareElements xs.i ys.i is true + # at the first location they differ. + # + # This function operates serially and short-circuits + # at the first difference. One could also write this + # function as a parallel reduction, but it would be + # wasteful in the case where there is an early difference, + # because we can't short circuit. AsList(nx, xs) = xList AsList(ny, ys) = yList iter \i. @@ -1978,10 +1978,10 @@ TODO: these should be with the other Elementary/Special Functions ### atan/atan2 def atan_inner(x:Float) -> Float = - -- From "Computing accurate Horner form approximations to - -- special functions in finite precision arithmetic" - -- https://arxiv.org/abs/1508.03211 - -- Only accurate in the range [-1, 1] + # From "Computing accurate Horner form approximations to + # special functions in finite precision arithmetic" + # https://arxiv.org/abs/1508.03211 + # Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 @@ -1996,13 +1996,13 @@ def atan_inner(x:Float) -> Float = def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) = - select(x < y, (x, y), (y, x)) -- get both with one comparison. + select(x < y, (x, y), (y, x)) # get both with one comparison. def atan2(y:Float, x:Float) -> Float = - -- Based off of the Tensorflow implementation at - -- github.com/tensorflow/mlir-hlo/blob/master/lib/ - -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 - -- With a fix to the nan propagation. + # Based off of the Tensorflow implementation at + # github.com/tensorflow/mlir-hlo/blob/master/lib/ + # Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 + # With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max(abs_x, abs_y) @@ -2012,9 +2012,9 @@ def atan2(y:Float, x:Float) -> Float = t = select(x < 0.0, pi, 0.0) a = select(y == 0.0, t, a) t = select(x < 0.0, 3.0 * pi / 4.0, pi / 4.0) - a = select(isinf x && isinf y, t, a) -- Handle infinite inputs. + a = select(isinf x && isinf y, t, a) # Handle infinite inputs. a = copysign(a, y) - select(either_is_nan(x, y), nan, a) -- Propagate NaNs. + select(either_is_nan(x, y), nan, a) # Propagate NaNs. def atan(x:Float) -> Float = atan2(x, 1.0) @@ -2045,22 +2045,22 @@ def is_odd(x:Nat) -> Bool = rem(x, 2) == 1 def is_even(x:Nat) -> Bool = rem(x, 2) == 0 def is_power_of_2(x:Nat) -> Bool = - -- A fast trick based on bitwise AND. - -- This works on integer types larger than 8 bits. - -- Note: The bitwise and operator (.&.) - -- is only defined for Byte, which is why - -- we use %and here. TODO: Make (.&.) polymorphic. + # A fast trick based on bitwise AND. + # This works on integer types larger than 8 bits. + # Note: The bitwise and operator (.&.) + # is only defined for Byte, which is why + # we use %and here. TODO: Make (.&.) polymorphic. x' = nat_to_rep x if x' == 0 then False else %and(x', (%isub(x', 1::NatRep))) == 0 --- This computes the integer part of the binary logarithm of the input. --- TODO: natlog2 0 should do something other than underflow the answer. --- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new --- code path in ImpToLLVM, because it's the first LLVM intrinsic --- we have with a fixed-point argument. --- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic +# This computes the integer part of the binary logarithm of the input. +# TODO: natlog2 0 should do something other than underflow the answer. +# TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new +# code path in ImpToLLVM, because it's the first LLVM intrinsic +# we have with a fixed-point argument. +# https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic def natlog2(x:Nat) -> Nat = tmp = yield_state (0::Nat) \ans. cmp <- run_state (1::Nat) @@ -2072,7 +2072,7 @@ def natlog2(x:Nat) -> Nat = True else False - unsafe_nat_diff(tmp, 1) -- TODO: something less horrible + unsafe_nat_diff(tmp, 1) # TODO: something less horrible def nextpow2(x:Nat) -> Nat = case is_power_of_2 x of @@ -2085,9 +2085,9 @@ def general_integer_power( power:Nat ) -> a given (a|Data) = iters : Nat = if power == 0 then 0 else 1 + natlog2 power - -- Implements exponentiation by squaring. - -- This could be nicer if there were a way to explicitly - -- specify which typelcass instance to use for Mul. + # Implements exponentiation by squaring. + # This could be nicer if there were a way to explicitly + # specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base @@ -2105,8 +2105,8 @@ def from_just(x:Maybe a) -> a given (a:Type) = case x of Just(x') -> x' def any_sat(xs:n=>a, f:(a)->Bool) -> Bool given (a:Type, n|Ix) = any(each xs f) def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a:Type) = - -- is it possible to implement this safely? (i.e. without using partial - -- functions) + # is it possible to implement this safely? (i.e. without using partial + # functions) case any_sat(xs, is_nothing) of True -> Nothing False -> Just $ each xs from_just @@ -2121,8 +2121,8 @@ def list_length(l:List a) -> Nat given (a:Type) = AsList(n, _) = l n --- This is for efficiency (rather than using `<>` repeatedly) --- TODO: we want this for any monoid but this implementation won't work. +# This is for efficiency (rather than using `<>` repeatedly) +# TODO: we want this for any monoid but this implementation won't work. def concat(lists:n=>(List a)) -> List a given (a:Type, n|Ix) = totalSize = sum for i:n. list_length lists[i] to_list $ with_state (0::Nat) \listIdx. @@ -2155,8 +2155,8 @@ def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = case res_inds[unsafe_from_ordinal $ ordinal i] of Just(j) -> case xs[j] of Just(x) -> x - Nothing -> todo -- Impossible - Nothing -> todo -- Impossible + Nothing -> todo # Impossible + Nothing -> todo # Impossible def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = cat_maybes $ for i:n. if condition xs[i] then Just xs[i] else Nothing @@ -2164,7 +2164,7 @@ def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = cat_maybes $ for i:n. if condition xs[i] then Just i else Nothing --- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead +# TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead def prev_ix(i:n) -> Maybe n given (n|Ix) = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing @@ -2185,7 +2185,7 @@ def lines(source:String) -> List String = '## Probability --- cdf should include 0.0 but not 1.0 +# cdf should include 0.0 but not 1.0 def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) = r = rand key from_just $ left_fence $ search_sorted(cdf, r) @@ -2199,8 +2199,8 @@ def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key) --- batch variant to share the work of forming the cumsum --- (alternatively we could rely on hoisting of loop constants) +# batch variant to share the work of forming the cumsum +# (alternatively we could rely on hoisting of loop constants) def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs for i:m. categorical_from_cdf(cdf, ixkey(key, i)) @@ -2223,11 +2223,11 @@ def softmax(x: n=>Float) -> n=>Float given (n|Ix) = TODO: Move this somewhere else def evalpoly(coeffs:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = - -- Evaluate a polynomial at x. Same as Numpy's polyval. + # Evaluate a polynomial at x. Same as Numpy's polyval. fold zero coeffs \i coeff c. coeff + x .* c '## Exception effect --- TODO: move `error` and `todo` to here. +# TODO: move `error` and `todo` to here. def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a:Type, eff:Effects)= f' : (() -> {Except|eff} a) = \. f() @@ -2277,8 +2277,8 @@ def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = (next_k, next_base) instance Ix(a=>b) given (a|Ix, b|Ix) - -- 0@a is the least significant digit, - -- while (size a - 1)@a is the most significant digit. + # 0@a is the least significant digit, + # while (size a - 1)@a is the most significant digit. def size'() = size b `intpow` size a def ordinal(i) = reversed_digits_to_int i def unsafe_from_ordinal(i) = int_to_reversed_digits i @@ -2401,12 +2401,12 @@ def check_env(name:String) -> {IO} Bool = '## Testing Helpers --- -- Reliably causes a segfault if pointers aren't initialized to zero. --- -- TODO: add this test when we cache modules --- justSomeDataToTestCaching = toList for i:(Fin 100). --- if ordinal i == 0 --- then Left (toList [1,2,3]) --- else Right 1 +# # Reliably causes a segfault if pointers aren't initialized to zero. +# # TODO: add this test when we cache modules +# justSomeDataToTestCaching = toList for i:(Fin 100). +# if ordinal i == 0 +# then Left (toList [1,2,3]) +# else Right 1 '### TestMode @@ -2416,7 +2416,7 @@ def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" '### More Stream IO def fread(stream:Stream ReadMode) -> {IO} String = - -- TODO: allow reading longer files! + # TODO: allow reading longer files! n : Nat = 4096 ptr <- with_alloc(Char, n) stack <- with_stack Char @@ -2508,9 +2508,9 @@ def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j:n. s[j] .* def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = for i k. fsum for j:m. x[i,j] * y[j,k] --- A `FullTileIx` type represents `tile_ix`th full tile (of size --- `tile_size`) iterating over the index set `n`. --- This type is only well formed when tile_ix * tile_size < size n. +# A `FullTileIx` type represents `tile_ix`th full tile (of size +# `tile_size`) iterating over the index set `n`. +# This type is only well formed when tile_ix * tile_size < size n. struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) = unwrap : Fin tile_size @@ -2524,9 +2524,9 @@ instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat def project'(i) = todo def unsafe_project'(i) = todo --- A `CodaIx` type represents the last few elements of the index set `n`, --- as might be left over after iterating by tiles. --- This type is only well formed when size n == coda_offset + coda_size +# A `CodaIx` type represents the last few elements of the index set `n`, +# as might be left over after iterating by tiles. +# This type is only well formed when size n == coda_offset + coda_size struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) = unwrap : Fin coda_size @@ -2558,7 +2558,7 @@ def tiled_matmul( x: l=>m=>Float, y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = - -- Tile sizes picked for axch's laptop + # Tile sizes picked for axch's laptop l_tile_size : Nat = 32 n_tile_size : Nat = 128 m_tile_size : Nat = 8 @@ -2574,7 +2574,7 @@ def tiled_matmul( n_ix = inject(to=n, n_offset) result!l_ix!n_ix += x[l_ix][m_ix] * y[m_ix][n_ix] --- matmul. Better symbol to use? `@`? +# matmul. Better symbol to use? `@`? def (**)( x: l=>m=>Float, y: m=>n=>Float diff --git a/misc/dex.el b/misc/dex.el index bc56b2934..de892c45e 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -5,7 +5,7 @@ ;; https://developers.google.com/open-source/licenses/bsd (setq dex-highlights - `(("--\\([^o].*$\\|$\\)" . font-lock-comment-face) + `(("#.*$" . font-lock-comment-face) ("^> .*$" . font-lock-comment-face) ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) @@ -18,7 +18,6 @@ "\\bwith\\b\\|\\bself\\b\\|" "\\bimport\\b\\|\\bforeign\\b\\|\\bsatisfying\\b") . font-lock-keyword-face) - ("--o" . font-lock-variable-name-face) ("[-.,!;$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) ("\\b[[:upper:]][[:alnum:]]*\\b" . font-lock-type-face) ("^@[[:alnum:]]*\\b" . font-lock-keyword-face) @@ -36,7 +35,7 @@ (define-derived-mode dex-mode fundamental-mode "dex" (setq font-lock-defaults '(dex-highlights)) - (setq-local comment-start "--") + (setq-local comment-start "#") (setq-local comment-end "") (setq-local syntax-propertize-function (syntax-propertize-rules (".>\\( +\\)" (1 ".")))) diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 462902825..ec916f749 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -167,7 +167,7 @@ knownSymStrs :: HS.HashSet String knownSymStrs = HS.fromList [ ".", ":", "::", "!", "=", "-", "+", "||", "&&" , "$", "&>", "|", ",", ",>", "<-", "+=", ":=" - , "->", "->>", "=>", "?->", "?=>", "--o", "--", "<<<", ">>>" + , "->", "->>", "=>", "?->", "?=>", "<<<", ">>>" , "..", "<..", "..<", "..<", "<..<", "?", "#", "##", "#?", "#&", "#|", "@"] sym :: Text -> Lexer () @@ -228,9 +228,7 @@ sc = (skipSome s >> recordWhitespace) <|> return () where s = hidden space <|> hidden lineComment lineComment :: Parser () -lineComment = do - try $ string "--" >> notFollowedBy (void (char 'o')) - void (takeWhileP (Just "char") (/= '\n')) +lineComment = string "#" >> void (takeWhileP (Just "char") (/= '\n')) outputLines :: Parser () outputLines = void $ many (symbol ">" >> takeWhileP Nothing (/= '\n') >> ((eol >> return ()) <|> eof)) From 3129592023f6776d9a9f0cb44a8257c5a74c6faa Mon Sep 17 00:00:00 2001 From: Dougal Date: Tue, 5 Dec 2023 22:51:55 -0500 Subject: [PATCH 41/41] More source IDs during inference --- src/lib/Inference.hs | 81 ++++++++++++++++++++++---------------------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index 8f6e24fb9..50c936f2c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -47,9 +47,6 @@ import Types.Top import qualified Types.OpNames as P import Util hiding (group) -sidtodo :: SrcId -sidtodo = rootSrcId - -- === Top-level interface === checkTopUType :: (Fallible1 m, TopLogger m, EnvReader m) => UType n -> m n (CType n) @@ -277,26 +274,26 @@ withInferenceVar hint binding cont = diffStateT1 \s -> do {-# INLINE withInferenceVar #-} withFreshUnificationVar - :: (Zonkable e, Emits o) => InfVarDesc -> Kind CoreIR o + :: (Zonkable e, Emits o) => SrcId -> InfVarDesc -> Kind CoreIR o -> (forall o'. (Emits o', DExt o o') => CAtomVar o' -> SolverM i o' (e o')) -> SolverM i o (e o) -withFreshUnificationVar desc k cont = do +withFreshUnificationVar sid desc k cont = do withInferenceVar "_unif_" (InfVarBound k) \v -> do ans <- toAtomVar v >>= cont soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case Just soln -> return soln - Nothing -> throw sidtodo $ AmbiguousInferenceVar (pprint v) (pprint k) desc + Nothing -> throw sid $ AmbiguousInferenceVar (pprint v) (pprint k) desc return (ans, soln) {-# INLINE withFreshUnificationVar #-} withFreshUnificationVarNoEmits - :: (Zonkable e) => InfVarDesc -> Kind CoreIR o + :: (Zonkable e) => SrcId -> InfVarDesc -> Kind CoreIR o -> (forall o'. (DExt o o') => CAtomVar o' -> SolverM i o' (e o')) -> SolverM i o (e o) -withFreshUnificationVarNoEmits desc k cont = diffStateT1 \s -> do +withFreshUnificationVarNoEmits sid desc k cont = diffStateT1 \s -> do Abs Empty resultAndDiff <- buildScoped do liftM toPairE $ runDiffStateT1 (sink s) $ - withFreshUnificationVar desc (sink k) cont + withFreshUnificationVar sid desc (sink k) cont return $ fromPairE resultAndDiff withFreshDictVar @@ -596,13 +593,13 @@ instantiateSigma sid reqTy sigmaAtom = case sigmaAtom of bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do case reqTy of Infer -> return [] - Check reqTy' -> return [TypeConstraint sidtodo (sink reqTy') resultTy'] - args <- inferMixedArgs @UExpr fDesc expls bsConstrained ([], []) - applySigmaAtom sidtodo sigmaAtom args + Check reqTy' -> return [TypeConstraint sid (sink reqTy') resultTy'] + args <- inferMixedArgs @UExpr sid fDesc expls bsConstrained ([], []) + applySigmaAtom sid sigmaAtom args _ -> fallback _ -> fallback where - fallback = forceSigmaAtom sigmaAtom >>= matchReq sid reqTy + fallback = forceSigmaAtom sid sigmaAtom >>= matchReq sid reqTy fDesc = getSourceName sigmaAtom matchReq :: Ext o o' => SrcId -> RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') @@ -612,12 +609,12 @@ matchReq sid (Check reqTy) x = do matchReq _ Infer x = return x {-# INLINE matchReq #-} -forceSigmaAtom :: Emits o => SigmaAtom o -> InfererM i o (CAtom o) -forceSigmaAtom sigmaAtom = case sigmaAtom of +forceSigmaAtom :: Emits o => SrcId -> SigmaAtom o -> InfererM i o (CAtom o) +forceSigmaAtom sid sigmaAtom = case sigmaAtom of SigmaAtom _ x -> return x SigmaUVar _ _ v -> case v of UAtomVar v' -> inlineTypeAliases v' - _ -> applySigmaAtom sidtodo sigmaAtom [] + _ -> applySigmaAtom sid sigmaAtom [] SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? withBlockDecls @@ -738,6 +735,7 @@ class PrettyE e => ExplicitArg (e::E) where checkExplicitDependentArg :: e i -> PartialType o -> InfererM i o (CAtom o) inferExplicitArg :: Emits o => e i -> InfererM i o (CAtom o) isHole :: e n -> Bool + explicitArgSrcId :: e n -> SrcId instance ExplicitArg UExpr where checkExplicitDependentArg arg argTy = checkSigmaDependent arg argTy @@ -746,18 +744,20 @@ instance ExplicitArg UExpr where isHole = \case WithSrcE _ UHole -> True _ -> False + explicitArgSrcId = getSrcId instance ExplicitArg CAtom where checkExplicitDependentArg = checkCAtom checkExplicitNonDependentArg = checkCAtom inferExplicitArg arg = renameM arg isHole _ = False + explicitArgSrcId _ = rootSrcId checkCAtom :: CAtom i -> PartialType o -> InfererM i o (CAtom o) checkCAtom arg argTy = do arg' <- renameM arg case argTy of - FullType argTy' -> expectEq sidtodo argTy' (getType arg') + FullType argTy' -> expectEq rootSrcId argTy' (getType arg') PartialType _ -> return () -- TODO? return arg' @@ -772,7 +772,7 @@ checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = do ExplicitApp -> do checkExplicitArity appSrcId expls posArgs bsConstrained <- buildAppConstraints appSrcId reqTy piTy - args <- inferMixedArgs fDesc expls bsConstrained (posArgs, namedArgs) + args <- inferMixedArgs appSrcId fDesc expls bsConstrained (posArgs, namedArgs) applySigmaAtom appSrcId f args ImplicitApp -> error "should already have handled this case" ty -> throw funSrcId $ EliminationErr "function type" (pprint ty) @@ -922,12 +922,12 @@ buildConstraints ab cont = liftEnvReaderM do -- TODO: check that there are no extra named args provided inferMixedArgs :: forall arg i o . (Emits o, ExplicitArg arg) - => SourceName + => SrcId -> SourceName -> [Explicitness] -> ConstrainedBinders o -> MixedArgs (arg i) -> InfererM i o [CAtom o] -inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgsTop) = do - checkNamedArgValidity explsTop (map fst namedArgsTop) +inferMixedArgs appSrcId fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgsTop) = do + checkNamedArgValidity appSrcId explsTop (map fst namedArgsTop) liftSolverM $ fromListE <$> go explsTop dependenceTop bsAbs argsTop where go :: Emits oo @@ -973,7 +973,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs if isHole arg then do let desc = (pprint fSourceName, "_") - withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> + withFreshUnificationVar appSrcId (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) (argsRest, namedArgs) else do arg' <- checkOrInferExplicitArg isDependent arg argTy @@ -985,8 +985,8 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs arg' <- checkOrInferExplicitArg isDependent arg argTy withDistinct $ cont arg' args Nothing -> case infMech of - Unify -> withFreshUnificationVar (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) args - Synth _ -> withDict sidtodo argTy \d -> cont d args + Unify -> withFreshUnificationVar appSrcId (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) args + Synth _ -> withDict appSrcId argTy \d -> cont d args checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) checkOrInferExplicitArg isDependent arg argTy = do @@ -995,7 +995,7 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs True -> checkExplicitDependentArg arg partialTy False -> checkExplicitNonDependentArg arg partialTy Nothing -> inferExplicitArg arg - constrainEq sidtodo argTy (getType arg') + constrainEq (explicitArgSrcId arg) argTy (getType arg') return arg' lookupNamedArg :: MixedArgs x -> Maybe SourceName -> Maybe x @@ -1015,18 +1015,17 @@ inferMixedArgs fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgs True -> return Nothing False -> return $ Just x -checkNamedArgValidity :: Fallible m => [Explicitness] -> [SourceName] -> m () -checkNamedArgValidity expls offeredNames = do +checkNamedArgValidity :: Fallible m => SrcId -> [Explicitness] -> [SourceName] -> m () +checkNamedArgValidity sid expls offeredNames = do let explToMaybeName = \case Explicit -> Nothing Inferred v _ -> v let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames - -- here and below we should be able to get a per-name src id - when (not $ null duplicates) $ throw sidtodo $ RepeatedOptionalArgs $ map pprint duplicates + when (not $ null duplicates) $ throw sid $ RepeatedOptionalArgs $ map pprint duplicates let unrecognizedNames = filter (not . (`elem` acceptedNames)) offeredNames when (not $ null unrecognizedNames) do - throw sidtodo $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) + throw sid $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do @@ -1477,7 +1476,7 @@ checkInstanceBody className params methods = do superclassDictTys :: Nest CBinder o o' -> InfererM i o [CType o] superclassDictTys Empty = return [] superclassDictTys (Nest b bs) = do - Abs bs' UnitE <- liftHoistExcept sidtodo $ hoist b $ Abs bs UnitE + Abs bs' UnitE <- liftHoistExcept rootSrcId $ hoist b $ Abs bs UnitE (binderType b:) <$> superclassDictTys bs' checkMethodDef :: ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) @@ -1534,7 +1533,7 @@ checkCasePat (WithSrcB sid pat) scrutineeTy cont = case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon tyConDef <- lookupTyCon dataDefName - params <- inferParams scrutineeTy dataDefName + params <- inferParams sid scrutineeTy dataDefName ADTCons cons <- instantiateTyConDef tyConDef params DataConDef _ _ repTy idxs <- return $ cons !! con when (length idxs /= nestLength ps) $ throw sid $ PatternArityErr (length idxs) (nestLength ps) @@ -1545,8 +1544,8 @@ checkCasePat (WithSrcB sid pat) scrutineeTy cont = case pat of bindLetPats ps args $ cont _ -> throw sid IllFormedCasePattern -inferParams :: Emits o => CType o -> TyConName o -> InfererM i o (TyConParams o) -inferParams ty dataDefName = do +inferParams :: Emits o => SrcId -> CType o -> TyConName o -> InfererM i o (TyConParams o) +inferParams sid ty dataDefName = do TyConDef sourceName roleExpls paramBs _ <- lookupTyCon dataDefName let paramExpls = snd <$> roleExpls let inferenceExpls = paramExpls <&> \case @@ -1554,8 +1553,8 @@ inferParams ty dataDefName = do expl -> expl paramBsAbs <- buildConstraints (Abs paramBs UnitE) \params _ -> do let ty' = toType $ UserADTType sourceName (sink dataDefName) $ TyConParams paramExpls params - return [TypeConstraint sidtodo (sink ty) ty'] - args <- inferMixedArgs sourceName inferenceExpls paramBsAbs emptyMixedArgs + return [TypeConstraint sid (sink ty) ty'] + args <- inferMixedArgs sid sourceName inferenceExpls paramBsAbs emptyMixedArgs return $ TyConParams paramExpls args bindLetPats @@ -1599,7 +1598,7 @@ bindLetPat (WithSrcB sid pat) v cont = case pat of ADTCons [DataConDef _ _ _ idxss] -> do when (length idxss /= nestLength ps) $ throw sid $ PatternArityErr (length idxss) (nestLength ps) - void $ inferParams (getType $ toAtom v) dataDefName + void $ inferParams sid (getType $ toAtom v) dataDefName xs <- forM idxss \idxs -> applyProjectionsReduced idxs (toAtom v) >>= emitInline bindLetPats ps xs cont _ -> throw sid SumTypeCantFail @@ -1887,7 +1886,7 @@ withFreshEff => (forall o'. DExt o o' => EffectRow CoreIR o' -> SolverM i o' (e o')) -> SolverM i o (e o) withFreshEff cont = - withFreshUnificationVarNoEmits MiscInfVar EffKind \v -> do + withFreshUnificationVarNoEmits rootSrcId MiscInfVar EffKind \v -> do cont $ EffectRow mempty $ EffectRowTail v {-# INLINE withFreshEff #-} @@ -2069,12 +2068,12 @@ generalizeInstanceArg role ty arg cont = case role of -- that it's valid to implement `generalizeDict` by synthesizing an entirely -- fresh dictionary, and if we were to do that, we would infer this type -- parameter exactly as we do here, using inference. - TypeParam -> withFreshUnificationVarNoEmits MiscInfVar TyKind \v -> cont $ toAtom v + TypeParam -> withFreshUnificationVarNoEmits rootSrcId MiscInfVar TyKind \v -> cont $ toAtom v DictParam -> withFreshDictVarNoEmits ty ( \ty' -> case toMaybeDict (sink arg) of Just d -> liftM toAtom $ lift11 $ generalizeDictRec ty' d _ -> error "not a dict") cont - DataParam -> withFreshUnificationVarNoEmits MiscInfVar ty \v -> cont $ toAtom v + DataParam -> withFreshUnificationVarNoEmits rootSrcId MiscInfVar ty \v -> cont $ toAtom v emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do @@ -2252,7 +2251,7 @@ instantiateSynthArgs sid target (expls, synthPiTy) = do liftM fromListE $ withReducibleEmissions sid CantReduceDict do bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do return [TypeConstraint sid (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] - ListE <$> inferMixedArgs "dict" expls bsConstrained emptyMixedArgs + ListE <$> inferMixedArgs sid "dict" expls bsConstrained emptyMixedArgs emptyMixedArgs :: MixedArgs (CAtom n) emptyMixedArgs = ([], [])