From d1716914ce5de8c42244bb1367fed04f51249550 Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Wed, 29 Jan 2025 13:39:43 +0100 Subject: [PATCH] feat: async modes for environment access --- src/Lean/Environment.lean | 98 +++++++++++++++++++++++++------- src/Lean/Meta/LazyDiscrTree.lean | 4 +- src/Lean/Parser/Basic.lean | 3 +- 3 files changed, 81 insertions(+), 24 deletions(-) diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index 6350662db662..1436493833d2 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -404,10 +404,10 @@ def AsyncConsts.add (aconsts : AsyncConsts) (aconst : AsyncConst) : AsyncConsts def AsyncConsts.find? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst := aconsts.map.find? declName -/-- Checks whether the name of any constant in the collection is a prefix of `declName`. -/ -def AsyncConsts.hasPrefix (aconsts : AsyncConsts) (declName : Name) : Bool := +/-- Finds the constant in the collection that is a prefix of `declName`, if any. -/ +def AsyncConsts.findPrefix? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst := -- as macro scopes are a strict suffix, - aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes) |>.isSome + aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes) /-- Elaboration-specific extension of `Kernel.Environment` that adds tracking of asynchronously @@ -463,6 +463,9 @@ private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → K private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment) : Environment := { env with checked := .pure newChecked, checkedWithoutAsync := newChecked } +def asyncMayContain (env : Environment) (declName : Name) : Bool := + env.asyncCtx?.all (·.mayContain declName) + @[extern "lean_elab_add_decl"] private opaque addDeclCheck (env : Environment) (maxHeartbeats : USize) (decl : @& Declaration) (cancelTk? : @& Option IO.CancelToken) : Except Kernel.Exception Environment @@ -515,7 +518,7 @@ def addExtraName (env : Environment) (name : Name) : Environment := /-- Find base case: name did not match any asynchronous declaration. -/ private def findNoAsync (env : Environment) (n : Name) : Option ConstantInfo := do - if env.asyncConsts.hasPrefix n then + if let some _ := env.asyncConsts.findPrefix? n then -- Constant generated in a different environment branch: wait for final kernel environment. Rare -- case when only proofs are elaborated asynchronously as they are rarely inspected. Could be -- optimized in the future by having the elaboration thread publish an (incremental?) map of @@ -756,6 +759,30 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr := end ConstantInfo +/-- +Async access mode for environment extensions used in `EnvironmentExtension.get/set/modifyState`. +Depending on their specific uses, extensions may opt out of the strict `sync` access mode in order +to avoid blocking parallel elaboration and/or to optimize accesses. The access mode is set at +environment extension registration time but can be overriden at `EnvironmentExtension.getState` in +order to weaken it for specific accesses. + +In all modes, the state stored into the `.olean` file for persistent environment extensions is the +result of `getState` called on the main environment branch at the end of the file, i.e. it +encompasses all modifications for all modes but `local`. +-/ +inductive EnvExtension.AsyncMode where + /-- + Default access mode, writing and reading the extension state to/from the `checked` environment + branch. This mode ensures the observed state is identical independently of whether or how parallel + elaboration is used but `getState` will block on all prior environment branches by waiting for + `checked`. `setState` and `modifyState` do not block. + -/ + | sync + | async + | local + | mainOnly + deriving Inhabited + /-- Environment extension, can only be generated by `registerEnvExtension` that allocates a unique index for this extension into each environment's extension state's array. @@ -763,6 +790,7 @@ for this extension into each environment's extension state's array. structure EnvExtension (σ : Type) where private mk :: idx : Nat mkInitial : IO σ + asyncMode : EnvExtension.AsyncMode deriving Inhabited namespace EnvExtension @@ -817,23 +845,36 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do let exts ← envExtensionsRef.get exts.mapM fun ext => ext.mkInitial --- TODO: store extension state in `checked` - -def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := - -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ` - { env with checkedWithoutAsync.extensions := unsafe ext.setStateImpl env.checkedWithoutAsync.extensions s } - def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment := -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ` - { env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f } + match ext.asyncMode with + | .mainOnly => + if let some asyncCtx := env.asyncCtx? then + let _ : Inhabited Environment := ⟨env⟩ + panic! s!"Environment.modifyState: environment extension is marked as `mainOnly` but used in \ + async context '{asyncCtx.declPrefix}'" + else + { env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f } + | .local => + { env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f } + | _ => + env.modifyCheckedAsync fun env => + { env with extensions := unsafe ext.modifyStateImpl env.extensions f } + +def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment := + inline <| modifyState ext env fun _ => s -- `unsafe` fails to infer `Nonempty` here -unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ := +private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ) + (env : Environment) (asyncMode := ext.asyncMode) : σ := -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ` - ext.getStateImpl env.checkedWithoutAsync.extensions + match asyncMode with + | .sync => ext.getStateImpl env.checked.get.extensions + | _ => ext.getStateImpl env.checkedWithoutAsync.extensions @[implemented_by getStateUnsafe] -opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ +opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) + (asyncMode := ext.asyncMode) : σ end EnvExtension @@ -844,15 +885,13 @@ end EnvExtension Note that by default, extension state is *not* stored in .olean files and will not propagate across `import`s. For that, you need to register a persistent environment extension. -/ -def registerEnvExtension {σ : Type} (mkInitial : IO σ) : IO (EnvExtension σ) := do +def registerEnvExtension {σ : Type} (mkInitial : IO σ) + (asyncMode : EnvExtension.AsyncMode := .mainOnly) : IO (EnvExtension σ) := do unless (← initializing) do throw (IO.userError "failed to register environment, extensions can only be registered during initialization") let exts ← EnvExtension.envExtensionsRef.get let idx := exts.size - let ext : EnvExtension σ := { - idx := idx, - mkInitial := mkInitial, - } + let ext : EnvExtension σ := { idx, mkInitial, asyncMode } -- safety: `EnvExtensionState` is opaque, so we can upcast to it EnvExtension.envExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext) pure ext @@ -953,7 +992,8 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ) namespace PersistentEnvExtension def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α := - (ext.toEnvExtension.getState env).importedEntries.get! m + -- `importedEntries` is identical on all environment branches, so `local` is sufficient + (ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries.get! m def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment := ext.toEnvExtension.modifyState env fun s => @@ -972,6 +1012,19 @@ def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (f : σ → σ) : Environment := ext.toEnvExtension.modifyState env fun ps => { ps with state := f (ps.state) } +-- `unsafe` fails to infer `Nonempty` here +private unsafe def findStateAsyncUnsafe {α β σ : Type} [Inhabited σ] + (ext : PersistentEnvExtension α β σ) (env : Environment) (declName : Name) : σ := + -- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ` + if let some { exts? := some exts, .. } := env.asyncConsts.findPrefix? declName then + ext.toEnvExtension.getStateImpl exts.get |>.state + else + ext.toEnvExtension.getStateImpl env.checkedWithoutAsync.extensions |>.state + +@[implemented_by findStateAsyncUnsafe] +opaque findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) + (env : Environment) (declName : Name) : σ + end PersistentEnvExtension builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[] @@ -983,11 +1036,12 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where addEntryFn : σ → β → σ exportEntriesFn : σ → Array α statsFn : σ → Format := fun _ => Format.nil + asyncMode : EnvExtension.AsyncMode := .mainOnly unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do let pExts ← persistentEnvExtensionsRef.get if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used") - let ext ← registerEnvExtension do + let ext ← registerEnvExtension (asyncMode := descr.asyncMode) do let initial ← descr.mkInitial let s : PersistentEnvExtensionState α σ := { importedEntries := #[], @@ -1019,6 +1073,7 @@ structure SimplePersistentEnvExtensionDescr (α σ : Type) where addEntryFn : σ → α → σ addImportedFn : Array (Array α) → σ toArrayFn : List α → Array α := fun es => es.toArray + asyncMode : EnvExtension.AsyncMode := .mainOnly def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : SimplePersistentEnvExtensionDescr α σ) : IO (SimplePersistentEnvExtension α σ) := registerPersistentEnvExtension { @@ -1029,6 +1084,7 @@ def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : | (entries, s) => (e::entries, descr.addEntryFn s e), exportEntriesFn := fun s => descr.toArrayFn s.1.reverse, statsFn := fun s => format "number of local entries: " ++ format s.1.length + asyncMode := descr.asyncMode } namespace SimplePersistentEnvExtension diff --git a/src/Lean/Meta/LazyDiscrTree.lean b/src/Lean/Meta/LazyDiscrTree.lean index abb7f612efa9..33c39f7801c7 100644 --- a/src/Lean/Meta/LazyDiscrTree.lean +++ b/src/Lean/Meta/LazyDiscrTree.lean @@ -980,8 +980,8 @@ def findImportMatches let ngen ← getNGen let (cNGen, ngen) := ngen.mkChild setNGen ngen - let dummy : IO.Ref (Option (LazyDiscrTree α)) ← IO.mkRef none - let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv) + let _ : Inhabited (IO.Ref (Option (LazyDiscrTree α))) := ⟨← IO.mkRef none⟩ + let ref := ext.getState (←getEnv) let importTree ← (←ref.get).getDM $ do profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry diff --git a/src/Lean/Parser/Basic.lean b/src/Lean/Parser/Basic.lean index 857b47f40396..a55cf4dcbaa8 100644 --- a/src/Lean/Parser/Basic.lean +++ b/src/Lean/Parser/Basic.lean @@ -1705,7 +1705,8 @@ builtin_initialize categoryParserFnRef : IO.Ref CategoryParserFn ← IO.mkRef fu builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ← registerEnvExtension $ categoryParserFnRef.get def categoryParserFn (catName : Name) : ParserFn := fun ctx s => - categoryParserFnExtension.getState ctx.env catName ctx s + let fn := categoryParserFnExtension.getState ctx.env + fn catName ctx s def categoryParser (catName : Name) (prec : Nat) : Parser where fn := adaptCacheableContextFn ({ · with prec }) (withCacheFn catName (categoryParserFn catName))