Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: async modes for environment access #6852

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 147 additions & 20 deletions src/Lean/Environment.lean
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,10 @@ def AsyncConsts.add (aconsts : AsyncConsts) (aconst : AsyncConst) : AsyncConsts
def AsyncConsts.find? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
aconsts.map.find? declName

/-- Checks whether the name of any constant in the collection is a prefix of `declName`. -/
def AsyncConsts.hasPrefix (aconsts : AsyncConsts) (declName : Name) : Bool :=
/-- Finds the constant in the collection that is a prefix of `declName`, if any. -/
def AsyncConsts.findPrefix? (aconsts : AsyncConsts) (declName : Name) : Option AsyncConst :=
-- as macro scopes are a strict suffix,
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes) |>.isSome
aconsts.normalizedTrie.findLongestPrefix? (privateToUserName declName.eraseMacroScopes)

/--
Elaboration-specific extension of `Kernel.Environment` that adds tracking of asynchronously
Expand Down Expand Up @@ -463,6 +463,18 @@ private def modifyCheckedAsync (env : Environment) (f : Kernel.Environment → K
private def setCheckedSync (env : Environment) (newChecked : Kernel.Environment) : Environment :=
{ env with checked := .pure newChecked, checkedWithoutAsync := newChecked }

/--
Checks whether the given declaration name may potentially added, or have been added, to the current
environment branch, which is the case either if this is the main branch or if the declaration name
is a suffix (modulo privacy and hygiene information) of the top-level declaration name for which
this branch was created.

This function should always be checked before modifying an `AsyncMode.async` environment extension
to ensure `findStateAsync` will be able to find the modification from other branches.
-/
def asyncMayContain (env : Environment) (declName : Name) : Bool :=
env.asyncCtx?.all (·.mayContain declName)

@[extern "lean_elab_add_decl"]
private opaque addDeclCheck (env : Environment) (maxHeartbeats : USize) (decl : @& Declaration)
(cancelTk? : @& Option IO.CancelToken) : Except Kernel.Exception Environment
Expand Down Expand Up @@ -515,7 +527,7 @@ def addExtraName (env : Environment) (name : Name) : Environment :=

/-- Find base case: name did not match any asynchronous declaration. -/
private def findNoAsync (env : Environment) (n : Name) : Option ConstantInfo := do
if env.asyncConsts.hasPrefix n then
if let some _ := env.asyncConsts.findPrefix? n then
-- Constant generated in a different environment branch: wait for final kernel environment. Rare
-- case when only proofs are elaborated asynchronously as they are rarely inspected. Could be
-- optimized in the future by having the elaboration thread publish an (incremental?) map of
Expand Down Expand Up @@ -756,13 +768,76 @@ def instantiateValueLevelParams! (c : ConstantInfo) (ls : List Level) : Expr :=

end ConstantInfo

/--
Async access mode for environment extensions used in `EnvironmentExtension.get/set/modifyState`.
Depending on their specific uses, extensions may opt out of the strict `sync` access mode in order
to avoid blocking parallel elaboration and/or to optimize accesses. The access mode is set at
environment extension registration time but can be overriden at `EnvironmentExtension.getState` in
order to weaken it for specific accesses.

In all modes, the state stored into the `.olean` file for persistent environment extensions is the
result of `getState` called on the main environment branch at the end of the file, i.e. it
encompasses all modifications for all modes but `local`.
-/
inductive EnvExtension.AsyncMode where
/--
Default access mode, writing and reading the extension state to/from the full `checked`
environment. This mode ensures the observed state is identical independently of whether or how
parallel elaboration is used but `getState` will block on all prior environment branches by
waiting for `checked`. `setState` and `modifyState` do not block.

While a safe default, any extension that reasonably could be used in parallel elaboration contexts
should opt for a weaker mode to avoid blocking unless there is no way to access the correct state
without waiting for all prior environment branches, in which case its data management should be
restructured if at all possible.
-/
| sync
/--
Accesses only the state of the current environment branch. Modifications on other branches are not
visible and are ultimately discarded except for the main branch. Provides the fastest accessors,
will never block.

This mode is particularly suitable for extensions where state does not escape from lexical scopes
even without parallelism, e.g. `ScopedEnvExtension`s when setting local entries.
-/
| local
/--
Like `local` but panics when trying to modify the state on anything but the main environment
branch. For extensions that fulfill this requirement, all modes functionally coincide but this
is the safest and most efficient choice in that case, preventing accidental misuse.

This mode is suitable for extensions that are modified only at the command elaboration level
before any environment forks in the command, and in particular for extensions that are modified
only at the very beginning of the file.
-/
| mainOnly
/--
Accumulates modifications in the `checked` environment like `sync`, but `getState` will panic
instead of blocking. Instead `findStateAsync` should be used, which will access the state of the
environment branch corresponding to the passed declaration name, if any, or otherwise the state
of the current branch. In other words, at most one environment branch will be blocked on instead
of all prior branches. The local state can still be accessed by calling `getState` with mode
`local` explicitly.

This mode is suitable for extensions with map-like state where the key uniquely identifies the
top-level declaration where it could have been set, e.g. because the key on modification is always
the surrounding declaration's name. Any calls to `modifyState`/`setState` should assert
`asyncMayContain` with that key to ensure state is never accidentally stored in a branch where it
cannot be found by `findStateAsync`. In particular, this mode is closest to how the environment's
own constant map works which asserts the same predicate on modification and provides `findAsync?`
for block-avoiding access.
-/
| async
deriving Inhabited

/--
Environment extension, can only be generated by `registerEnvExtension` that allocates a unique index
for this extension into each environment's extension state's array.
-/
structure EnvExtension (σ : Type) where private mk ::
idx : Nat
mkInitial : IO σ
asyncMode : EnvExtension.AsyncMode
deriving Inhabited

namespace EnvExtension
Expand Down Expand Up @@ -817,23 +892,55 @@ def mkInitialExtStates : IO (Array EnvExtensionState) := do
let exts ← envExtensionsRef.get
exts.mapM fun ext => ext.mkInitial

-- TODO: store extension state in `checked`

def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.setStateImpl env.checkedWithoutAsync.extensions s }
/--
Applies the given function to the extension state. See `AsyncMode` for details on how modifications
from different environment branches are reconciled.

Note that in modes `sync` and `async`, `f` will be called twice, on the local and on the `checked`
state.
-/
def modifyState {σ : Type} (ext : EnvExtension σ) (env : Environment) (f : σ → σ) : Environment :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
match ext.asyncMode with
| .mainOnly =>
if let some asyncCtx := env.asyncCtx? then
let _ : Inhabited Environment := ⟨env⟩
panic! s!"Environment.modifyState: environment extension is marked as `mainOnly` but used in \
async context '{asyncCtx.declPrefix}'"
else
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| .local =>
{ env with checkedWithoutAsync.extensions := unsafe ext.modifyStateImpl env.checkedWithoutAsync.extensions f }
| _ =>
env.modifyCheckedAsync fun env =>
{ env with extensions := unsafe ext.modifyStateImpl env.extensions f }

/--
Sets the extension state to the given value. See `AsyncMode` for details on how modifications from
different environment branches are reconciled.
-/
def setState {σ : Type} (ext : EnvExtension σ) (env : Environment) (s : σ) : Environment :=
inline <| modifyState ext env fun _ => s

-- `unsafe` fails to infer `Nonempty` here
unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ :=
private unsafe def getStateUnsafe {σ : Type} [Inhabited σ] (ext : EnvExtension σ)
(env : Environment) (asyncMode := ext.asyncMode) : σ :=
Kha marked this conversation as resolved.
Show resolved Hide resolved
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
ext.getStateImpl env.checkedWithoutAsync.extensions
match asyncMode with
| .sync => ext.getStateImpl env.checked.get.extensions
| .async => panic! "EnvExtension.getState: called on `async` extension, use `findStateAsync` \
instead or pass `(asyncMode := .local)` to explicitly access local state"
| _ => ext.getStateImpl env.checkedWithoutAsync.extensions

/--
Returns the current extension state. See `AsyncMode` for details on how modifications from
different environment branches are reconciled. Panics if the extension is marked as `async`; see its
documentation for more details. Overriding the extension's default `AsyncMode` is usually not
recommended and should be considered only for important optimizations.
-/
@[implemented_by getStateUnsafe]
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment) : σ
opaque getState {σ : Type} [Inhabited σ] (ext : EnvExtension σ) (env : Environment)
(asyncMode := ext.asyncMode) : σ

end EnvExtension

Expand All @@ -844,15 +951,13 @@ end EnvExtension

Note that by default, extension state is *not* stored in .olean files and will not propagate across `import`s.
For that, you need to register a persistent environment extension. -/
def registerEnvExtension {σ : Type} (mkInitial : IO σ) : IO (EnvExtension σ) := do
def registerEnvExtension {σ : Type} (mkInitial : IO σ)
(asyncMode : EnvExtension.AsyncMode := .mainOnly) : IO (EnvExtension σ) := do
unless (← initializing) do
throw (IO.userError "failed to register environment, extensions can only be registered during initialization")
let exts ← EnvExtension.envExtensionsRef.get
let idx := exts.size
let ext : EnvExtension σ := {
idx := idx,
mkInitial := mkInitial,
}
let ext : EnvExtension σ := { idx, mkInitial, asyncMode }
-- safety: `EnvExtensionState` is opaque, so we can upcast to it
EnvExtension.envExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
pure ext
Expand Down Expand Up @@ -953,7 +1058,8 @@ instance {α β σ} [Inhabited σ] : Inhabited (PersistentEnvExtension α β σ)
namespace PersistentEnvExtension

def getModuleEntries {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ) (env : Environment) (m : ModuleIdx) : Array α :=
(ext.toEnvExtension.getState env).importedEntries.get! m
-- `importedEntries` is identical on all environment branches, so `local` is always sufficient
(ext.toEnvExtension.getState (asyncMode := .local) env).importedEntries.get! m

def addEntry {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (b : β) : Environment :=
ext.toEnvExtension.modifyState env fun s =>
Expand All @@ -972,6 +1078,24 @@ def setState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : En
def modifyState {α β σ : Type} (ext : PersistentEnvExtension α β σ) (env : Environment) (f : σ → σ) : Environment :=
ext.toEnvExtension.modifyState env fun ps => { ps with state := f (ps.state) }

-- `unsafe` fails to infer `Nonempty` here
private unsafe def findStateAsyncUnsafe {α β σ : Type} [Inhabited σ]
(ext : PersistentEnvExtension α β σ) (env : Environment) (declPrefix : Name) : σ :=
-- safety: `ext`'s constructor is private, so we can assume the entry at `ext.idx` is of type `σ`
if let some { exts? := some exts, .. } := env.asyncConsts.findPrefix? declPrefix then
ext.toEnvExtension.getStateImpl exts.get |>.state
else
ext.toEnvExtension.getStateImpl env.checkedWithoutAsync.extensions |>.state

/--
Returns the final extension state on the environment branch corresponding to the passed declaration
name, if any, or otherwise the state on the current branch. In other words, at most one environment
branch will be blocked on.
-/
@[implemented_by findStateAsyncUnsafe]
opaque findStateAsync {α β σ : Type} [Inhabited σ] (ext : PersistentEnvExtension α β σ)
(env : Environment) (declPrefix : Name) : σ

end PersistentEnvExtension

builtin_initialize persistentEnvExtensionsRef : IO.Ref (Array (PersistentEnvExtension EnvExtensionEntry EnvExtensionEntry EnvExtensionState)) ← IO.mkRef #[]
Expand All @@ -983,11 +1107,12 @@ structure PersistentEnvExtensionDescr (α β σ : Type) where
addEntryFn : σ → β → σ
exportEntriesFn : σ → Array α
statsFn : σ → Format := fun _ => Format.nil
asyncMode : EnvExtension.AsyncMode := .mainOnly

unsafe def registerPersistentEnvExtensionUnsafe {α β σ : Type} [Inhabited σ] (descr : PersistentEnvExtensionDescr α β σ) : IO (PersistentEnvExtension α β σ) := do
let pExts ← persistentEnvExtensionsRef.get
if pExts.any (fun ext => ext.name == descr.name) then throw (IO.userError s!"invalid environment extension, '{descr.name}' has already been used")
let ext ← registerEnvExtension do
let ext ← registerEnvExtension (asyncMode := descr.asyncMode) do
let initial ← descr.mkInitial
let s : PersistentEnvExtensionState α σ := {
importedEntries := #[],
Expand Down Expand Up @@ -1019,6 +1144,7 @@ structure SimplePersistentEnvExtensionDescr (α σ : Type) where
addEntryFn : σ → α → σ
addImportedFn : Array (Array α) → σ
toArrayFn : List α → Array α := fun es => es.toArray
asyncMode : EnvExtension.AsyncMode := .mainOnly

def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr : SimplePersistentEnvExtensionDescr α σ) : IO (SimplePersistentEnvExtension α σ) :=
registerPersistentEnvExtension {
Expand All @@ -1029,6 +1155,7 @@ def registerSimplePersistentEnvExtension {α σ : Type} [Inhabited σ] (descr :
| (entries, s) => (e::entries, descr.addEntryFn s e),
exportEntriesFn := fun s => descr.toArrayFn s.1.reverse,
statsFn := fun s => format "number of local entries: " ++ format s.1.length
asyncMode := descr.asyncMode
}

namespace SimplePersistentEnvExtension
Expand Down
4 changes: 2 additions & 2 deletions src/Lean/Meta/LazyDiscrTree.lean
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,8 @@ def findImportMatches
let ngen ← getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
let dummy : IO.Ref (Option (LazyDiscrTree α)) ← IO.mkRef none
let ref := @EnvExtension.getState _ ⟨dummy⟩ ext (←getEnv)
let _ : Inhabited (IO.Ref (Option (LazyDiscrTree α))) := ⟨← IO.mkRef none
let ref := ext.getState (←getEnv)
let importTree ← (←ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (←getOptions) $ do
let t ← createImportedDiscrTree (createTreeCtx cctx) cNGen (←getEnv) addEntry
Expand Down
3 changes: 2 additions & 1 deletion src/Lean/Parser/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1705,7 +1705,8 @@ builtin_initialize categoryParserFnRef : IO.Ref CategoryParserFn ← IO.mkRef fu
builtin_initialize categoryParserFnExtension : EnvExtension CategoryParserFn ← registerEnvExtension $ categoryParserFnRef.get

def categoryParserFn (catName : Name) : ParserFn := fun ctx s =>
categoryParserFnExtension.getState ctx.env catName ctx s
let fn := categoryParserFnExtension.getState ctx.env
fn catName ctx s

def categoryParser (catName : Name) (prec : Nat) : Parser where
fn := adaptCacheableContextFn ({ · with prec }) (withCacheFn catName (categoryParserFn catName))
Expand Down
Loading