Skip to content

Commit

Permalink
feat: async modes for environment access
Browse files Browse the repository at this point in the history
  • Loading branch information
Kha committed Jan 31, 2025
1 parent 6865329 commit d171691
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 24 deletions.
98 changes: 77 additions & 21 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ def AsyncConsts.add (aconsts : AsyncConsts) (aconst : AsyncConst) : AsyncConsts
def AsyncConsts.find? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
aconsts.map.find? declName

/-- Checks whether the name of any constant in the collection is a prefix of `declName`. -/
def AsyncConsts.hasPrefix (aconsts : AsyncConsts) (declName : Name) : Bool :=
/-- Finds the constant in the collection that is a prefix of `declName`, if any. -/
def AsyncConsts.findPrefix? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
-- as macro scopes are a strict suffix,
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes) |>.isSome
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes)

/--
Elaboration-specific extension of `Kernel.Environment` that adds tracking of asynchronously
Expand Down Expand Up @@ -463,6 +463,9 @@ private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → K
private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment) : Environment :=
{ env with checked := .pure newChecked, checkedWithoutAsync := newChecked }

def asyncMayContain (env : Environment) (declName : Name) : Bool :=
env.asyncCtx?.all (·.mayContain declName)

@[extern "lean_elab_add_decl"]
private opaque addDeclCheck (env : Environment) (maxHeartbeats : USize) (decl : @& Declaration)
(cancelTk? : @& Option IO.CancelToken) : Except Kernel.Exception Environment
Expand Down Expand Up @@ -515,7 +518,7 @@ def addExtraName (env : Environment) (name : Name) : Environment :=

/-- Find base case: name did not match any asynchronous declaration. -/
private def findNoAsync (env : Environment) (n : Name) : Option ConstantInfo := do
if env.asyncConsts.hasPrefix n then
if let some _ := env.asyncConsts.findPrefix? n then
-- Constant generated in a different environment branch: wait for final kernel environment. Rare
-- case when only proofs are elaborated asynchronously as they are rarely inspected. Could be
-- optimized in the future by having the elaboration thread publish an (incremental?) map of
Expand Down Expand Up @@ -756,13 +759,38 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=

end ConstantInfo

/--
Async access mode for environment extensions used in `EnvironmentExtension.get/set/modifyState`.
Depending on their specific uses, extensions may opt out of the strict `sync` access mode in order
to avoid blocking parallel elaboration and/or to optimize accesses. The access mode is set at
environment extension registration time but can be overriden at `EnvironmentExtension.getState` in
order to weaken it for specific accesses.
In all modes, the state stored into the `.olean` file for persistent environment extensions is the
result of `getState` called on the main environment branch at the end of the file, i.e. it
encompasses all modifications for all modes but `local`.
-/
inductive EnvExtension.AsyncMode where
/--
Default access mode, writing and reading the extension state to/from the `checked` environment
branch. This mode ensures the observed state is identical independently of whether or how parallel
elaboration is used but `getState` will block on all prior environment branches by waiting for
`checked`. `setState` and `modifyState` do not block.
-/
| sync
| async
| local
| mainOnly
deriving Inhabited

/--
Environment extension, can only be generated by `registerEnvExtension` that allocates a unique index
for this extension into each environment's extension state's array.
-/
structure EnvExtension (σ : Type) where private mk ::
idx : Nat
mkInitial : IO σ
asyncMode : EnvExtension.AsyncMode
deriving Inhabited

namespace EnvExtension
Expand Down Expand Up @@ -817,23 +845,36 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
let exts ← envExtensionsRef.get
exts.mapM fun ext => ext.mkInitial

-- TODO: store extension state in `checked`

def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.setStateImpl env.checkedWithoutAsync.extensions s }

def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
match ext.asyncMode with
| .mainOnly =>
if let some asyncCtx := env.asyncCtx? then
let _ : Inhabited Environment := ⟨env⟩
panic! s!"Environment.modifyState: environment extension is marked as `mainOnly` but used in \
async context '{asyncCtx.declPrefix}'"
else
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| .local =>
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| _ =>
env.modifyCheckedAsync fun env =>
{ env with extensions := unsafe ext.modifyStateImpl env.extensions f }

def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
inline <| modifyState ext env fun _ => s

-- `unsafe` fails to infer `Nonempty` here
unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ :=
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (asyncMode := ext.asyncMode) : σ :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
ext.getStateImpl env.checkedWithoutAsync.extensions
match asyncMode with
| .sync => ext.getStateImpl env.checked.get.extensions
| _ => ext.getStateImpl env.checkedWithoutAsync.extensions

@[implemented_by getStateUnsafe]
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
(asyncMode := ext.asyncMode) : σ

end EnvExtension

Expand All @@ -844,15 +885,13 @@ end EnvExtension
Note that by default, extension state is *not* stored in .olean files and will not propagate across `import`s.
For that, you need to register a persistent environment extension. -/
def registerEnvExtension {σ : Type} (mkInitial : IO σ) : IO (EnvExtension σ) := do
def registerEnvExtension {σ : Type} (mkInitial : IO σ)
(asyncMode : EnvExtension.AsyncMode := .mainOnly) : IO (EnvExtension σ) := do
unless (← initializing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← EnvExtension.envExtensionsRef.get
let idx := exts.size
let ext : EnvExtension σ := {
idx := idx,
mkInitial := mkInitial,
}
let ext : EnvExtension σ := { idx, mkInitial, asyncMode }
-- safety: `EnvExtensionState` is opaque, so we can upcast to it
EnvExtension.envExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
pure ext
Expand Down Expand Up @@ -953,7 +992,8 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
namespace PersistentEnvExtension

def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
(ext.toEnvExtension.getState env).importedEntries.get! m
-- `importedEntries` is identical on all environment branches, so `local` is sufficient
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries.get! m

def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
Expand All @@ -972,6 +1012,19 @@ def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (f : σ → σ) : Environment :=
ext.toEnvExtension.modifyState env fun ps => { ps with state := f (ps.state) }

-- `unsafe` fails to infer `Nonempty` here
private unsafe def findStateAsyncUnsafe {α β σ : Type} [Inhabited σ]
(ext : PersistentEnvExtension α β σ) (env : Environment) (declName : Name) : σ :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
if let some { exts? := some exts, .. } := env.asyncConsts.findPrefix? declName then
ext.toEnvExtension.getStateImpl exts.get |>.state
else
ext.toEnvExtension.getStateImpl env.checkedWithoutAsync.extensions |>.state

@[implemented_by findStateAsyncUnsafe]
opaque findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (declName : Name) : σ

end PersistentEnvExtension

builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
Expand All @@ -983,11 +1036,12 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
addEntryFn : σ → β → σ
exportEntriesFn : σ → Array α
statsFn : σ → Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly

unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
let pExts ← persistentEnvExtensionsRef.get
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
let ext ← registerEnvExtension do
let ext ← registerEnvExtension (asyncMode := descr.asyncMode) do
let initial ← descr.mkInitial
let s : PersistentEnvExtensionState α σ := {
importedEntries := #[],
Expand Down Expand Up @@ -1019,6 +1073,7 @@ structure SimplePersistentEnvExtensionDescr (α σ : Type) where
addEntryFn : σ → α → σ
addImportedFn : Array (Array α) → σ
toArrayFn : List α → Array α := fun es => es.toArray
asyncMode : EnvExtension.AsyncMode := .mainOnly

def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : SimplePersistentEnvExtensionDescr α σ) : IO (SimplePersistentEnvExtension α σ) :=
registerPersistentEnvExtension {
Expand All @@ -1029,6 +1084,7 @@ def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr :
| (entries, s) => (e::entries, descr.addEntryFn s e),
exportEntriesFn := fun s => descr.toArrayFn s.1.reverse,
statsFn := fun s => format "number of local entries: " ++ format s.1.length
asyncMode := descr.asyncMode
}

namespace SimplePersistentEnvExtension
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/LazyDiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ def findImportMatches
let ngen ← getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
let dummy : IO.Ref (Option (LazyDiscrTree α)) ← IO.mkRef none
let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv)
let _ : Inhabited (IO.Ref (Option (LazyDiscrTree α))) := ⟨← IO.mkRef none
let ref := ext.getState (←getEnv)
let importTree ← (←ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do
let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Parser/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,8 @@ builtin_initialize categoryParserFnRef : IO.Ref CategoryParserFn ← IO.mkRef fu
builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ← registerEnvExtension $ categoryParserFnRef.get

def categoryParserFn (catName : Name) : ParserFn := fun ctx s =>
categoryParserFnExtension.getState ctx.env catName ctx s
let fn := categoryParserFnExtension.getState ctx.env
fn catName ctx s

def categoryParser (catName : Name) (prec : Nat) : Parser where
fn := adaptCacheableContextFn ({ · with prec }) (withCacheFn catName (categoryParserFn catName))
Expand Down

0 comments on commit d171691

Please sign in to comment.