Skip to content

Commit 82f0ad2

Browse files
committed
fix: bring elaborator in line with kernel for primitive projections
The kernel supports primitive projections for all inductive types with one construtor. The elaborator was assuming primitive projections only work for "structure-likes", non-recursive inductive types with no indices. Enables numeric projection notation for general one-constructor inductives. Extracted from #5783.
1 parent 66dbad9 commit 82f0ad2

File tree

8 files changed

+134
-40
lines changed

8 files changed

+134
-40
lines changed

src/Lean/Compiler/LCNF/InferType.lean

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,12 @@ mutual
168168
/- TODO: after we erase universe variables, we can just extract a better type using just `structName` and `idx`. -/
169169
return erasedExpr
170170
else
171-
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
172-
let n := structVal.numParams
173-
let structParams := structType.getAppArgs
174-
if n != structParams.size then
171+
matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal =>
172+
let structTypeArgs := structType.getAppArgs
173+
if structVal.numParams + structVal.numIndices != structTypeArgs.size then
175174
failed ()
176175
else do
177-
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structParams)
176+
let mut ctorType ← inferAppType (mkAppN (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams])
178177
for _ in [:idx] do
179178
match ctorType with
180179
| .forallE _ _ body _ =>

src/Lean/Elab/App.lean

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,19 +1188,19 @@ private def resolveLValAux (e : Expr) (eType : Expr) (lval : LVal) : TermElabM L
11881188
if idx == 0 then
11891189
throwError "invalid projection, index must be greater than 0"
11901190
let env ← getEnv
1191-
unless isStructureLike env structName do
1192-
throwLValError e eType "invalid projection, structure expected"
1193-
let numFields := getStructureLikeNumFields env structName
1194-
if idx - 1 < numFields then
1195-
if isStructure env structName then
1196-
let fieldNames := getStructureFields env structName
1197-
return LValResolution.projFn structName structName fieldNames[idx - 1]!
1191+
let failK _ := throwLValError e eType "invalid projection, structure expected"
1192+
matchConstStructure eType.getAppFn failK fun _ _ ctorVal => do
1193+
let numFields := ctorVal.numFields
1194+
if idx - 1 < numFields then
1195+
if isStructure env structName then
1196+
let fieldNames := getStructureFields env structName
1197+
return LValResolution.projFn structName structName fieldNames[idx - 1]!
1198+
else
1199+
/- `structName` was declared using `inductive` command.
1200+
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
1201+
return LValResolution.projIdx structName (idx - 1)
11981202
else
1199-
/- `structName` was declared using `inductive` command.
1200-
So, we don't projection functions for it. Thus, we use `Expr.proj` -/
1201-
return LValResolution.projIdx structName (idx - 1)
1202-
else
1203-
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
1203+
throwLValError e eType m!"invalid projection, structure has only {numFields} field(s)"
12041204
| some structName, LVal.fieldName _ fieldName _ _ =>
12051205
let env ← getEnv
12061206
let searchEnv : Unit → TermElabM LValResolution := fun _ => do

src/Lean/Meta/ExprDefEq.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,9 @@ private def isDefEqProj : Expr → Expr → MetaM Bool
19571957
where
19581958
/-- If `structName` is a structure with a single field and `(?m ...).1 =?= v`, then solve constraint as `?m ... =?= ⟨v⟩` -/
19591959
isDefEqSingleton (structName : Name) (s : Expr) (v : Expr) : MetaM Bool := do
1960+
let some ctorVal := getStructureLikeCtor? (← getEnv) structName | return false
1961+
if ctorVal.numFields != 1 then
1962+
return false -- It is not a structure with a single field.
19601963
if isClass (← getEnv) structName then
19611964
/-
19621965
We disable this feature if `structName` is a class. See issue #2011.
@@ -1975,9 +1978,6 @@ where
19751978
assign `?m`.
19761979
-/
19771980
return false
1978-
let ctorVal := getStructureCtor (← getEnv) structName
1979-
if ctorVal.numFields != 1 then
1980-
return false -- It is not a structure with a single field.
19811981
let sType ← whnf (← inferType s)
19821982
let sTypeFn := sType.getAppFn
19831983
if !sTypeFn.isConstOf structName then
@@ -2013,7 +2013,7 @@ private def isDefEqApp (t s : Expr) : MetaM Bool := do
20132013
/-- Return `true` if the type of the given expression is an inductive datatype with a single constructor with no fields. -/
20142014
private def isDefEqUnitLike (t : Expr) (s : Expr) : MetaM Bool := do
20152015
let tType ← whnf (← inferType t)
2016-
matchConstStruct tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do
2016+
matchConstStructureLike tType.getAppFn (fun _ => return false) fun _ _ ctorVal => do
20172017
if ctorVal.numFields != 0 then
20182018
return false
20192019
else if (← useEtaStruct ctorVal.induct) then

src/Lean/Meta/InferType.lean

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,12 @@ private def inferProjType (structName : Name) (idx : Nat) (e : Expr) : MetaM Exp
9999
let structType ← whnf structType
100100
let failed {α} : Unit → MetaM α := fun _ =>
101101
throwError "invalid projection{indentExpr (mkProj structName idx e)} from type {structType}"
102-
matchConstStruct structType.getAppFn failed fun structVal structLvls ctorVal =>
103-
let n := structVal.numParams
104-
let structParams := structType.getAppArgs
105-
if n != structParams.size then
102+
matchConstStructure structType.getAppFn failed fun structVal structLvls ctorVal =>
103+
let structTypeArgs := structType.getAppArgs
104+
if structVal.numParams + structVal.numIndices != structTypeArgs.size then
106105
failed ()
107106
else do
108-
let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structParams
107+
let mut ctorType ← inferAppType (mkConst ctorVal.name structLvls) structTypeArgs[:structVal.numParams]
109108
for i in [:idx] do
110109
ctorType ← whnf ctorType
111110
match ctorType with

src/Lean/Meta/Tactic/Constructor.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def _root_.Lean.MVarId.existsIntro (mvarId : MVarId) (w : Expr) : MetaM MVarId :
3232
mvarId.withContext do
3333
mvarId.checkNotAssigned `exists
3434
let target ← mvarId.getType'
35-
matchConstStruct target.getAppFn
35+
matchConstStructure target.getAppFn
3636
(fun _ => throwTacticEx `exists mvarId "target is not an inductive datatype with one constructor")
3737
fun _ us cval => do
3838
if cval.numFields < 2 then

src/Lean/MonadEnv.lean

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,26 @@ def getConstInfoRec [Monad m] [MonadEnv m] [MonadError m] (constName : Name) : m
118118
| ConstantInfo.recInfo v => pure v
119119
| _ => throwError "'{mkConst constName}' is not a recursor"
120120

121-
@[inline] def matchConstStruct [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
121+
/--
122+
Matches if `e` is a constant that is an inductive type with one constructor.
123+
Such types can be used with primitive projections.
124+
See also `Lean.matchConstStructLike` for a more restrictive version.
125+
-/
126+
@[inline] def matchConstStructure [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
127+
matchConstInduct e failK fun ival us => do
128+
match ival.ctors with
129+
| [ctor] =>
130+
match (← getConstInfo ctor) with
131+
| ConstantInfo.ctorInfo cval => k ival us cval
132+
| _ => failK ()
133+
| _ => failK ()
134+
135+
/--
136+
Matches if `e` is a constant that is an non-recursive inductive type with no indices and with one constructor.
137+
Such a type satisfies `Lean.isStructureLike`.
138+
See also `Lean.matchConstStructure` for a less restrictive version.
139+
-/
140+
@[inline] def matchConstStructureLike [Monad m] [MonadEnv m] [MonadError m] (e : Expr) (failK : Unit → m α) (k : InductiveVal → List Level → ConstructorVal → m α) : m α :=
122141
matchConstInduct e failK fun ival us => do
123142
if ival.isRec || ival.numIndices != 0 then failK ()
124143
else match ival.ctors with

src/Lean/Structure.lean

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,23 @@ def getStructureInfo? (env : Environment) (structName : Name) : Option Structure
6767
| some modIdx => structureExt.getModuleEntries env modIdx |>.binSearch { structName } StructureInfo.lt
6868
| none => structureExt.getState env |>.map.find? structName
6969

70+
/--
71+
Gets the constructor of an inductive type that has exactly one constructor.
72+
This is meant to be used with types that have had been registered as a structure by `registerStructure`,
73+
but this is not checked.
74+
75+
Warning: these do *not* need to be "structure-likes". A structure-like is non-recursive,
76+
and structure-likes have special kernel support.
77+
-/
7078
def getStructureCtor (env : Environment) (constName : Name) : ConstructorVal :=
7179
match env.find? constName with
72-
| some (.inductInfo { isRec := false, ctors := [ctorName], .. }) =>
80+
| some (.inductInfo { ctors := [ctorName], .. }) =>
7381
match env.find? ctorName with
7482
| some (ConstantInfo.ctorInfo val) => val
7583
| _ => panic! "ill-formed environment"
7684
| _ => panic! "structure expected"
7785

78-
/-- Get direct field names for the given structure. -/
86+
/-- Gets the direct field names for the given structure, including subobject fields. -/
7987
def getStructureFields (env : Environment) (structName : Name) : Array Name :=
8088
if let some info := getStructureInfo? env structName then
8189
info.fieldNames
@@ -88,22 +96,22 @@ def getFieldInfo? (env : Environment) (structName : Name) (fieldName : Name) : O
8896
else
8997
none
9098

91-
/-- If `fieldName` represents the relation to a parent structure `S`, return `S` -/
99+
/-- If `fieldName` represents the relation to a parent structure `S`, returns `S` -/
92100
def isSubobjectField? (env : Environment) (structName : Name) (fieldName : Name) : Option Name :=
93101
if let some fieldInfo := getFieldInfo? env structName fieldName then
94102
fieldInfo.subobject?
95103
else
96104
none
97105

98-
/-- Return immediate parent structures -/
106+
/-- Returns immediate parent structures. -/
99107
def getParentStructures (env : Environment) (structName : Name) : Array Name :=
100108
let fieldNames := getStructureFields env structName;
101109
fieldNames.foldl (init := #[]) fun acc fieldName =>
102110
match isSubobjectField? env structName fieldName with
103111
| some parentStructName => acc.push parentStructName
104112
| none => acc
105113

106-
/-- Return all parent structures -/
114+
/-- Returns all parent structures. -/
107115
partial def getAllParentStructures (env : Environment) (structName : Name) : Array Name :=
108116
visit structName |>.run #[] |>.2
109117
where
@@ -127,7 +135,8 @@ private partial def getStructureFieldsFlattenedAux (env : Environment) (structNa
127135
getStructureFieldsFlattenedAux env parentStructName fullNames includeSubobjectFields
128136
| none => fullNames.push fieldName
129137

130-
/-- Return field names for the given structure, including "flattened" fields from parent
138+
/--
139+
Returns field names for the given structure, including "flattened" fields from parent
131140
structures. To omit `toParent` projections, set `includeSubobjectFields := false`.
132141
133142
For example, given `Bar` such that
@@ -140,11 +149,11 @@ def getStructureFieldsFlattened (env : Environment) (structName : Name) (include
140149
getStructureFieldsFlattenedAux env structName #[] includeSubobjectFields
141150

142151
/--
143-
Return true if `constName` is the name of an inductive datatype
144-
created using the `structure` or `class` commands.
152+
Returns true if `constName` is the name of an inductive datatype
153+
created using the `structure` or `class` commands.
145154
146-
We perform the check by testing whether auxiliary projection functions
147-
have been created. -/
155+
These are inductive types for which structure information has been registered with `registerStructure`.
156+
-/
148157
def isStructure (env : Environment) (constName : Name) : Bool :=
149158
getStructureInfo? env constName |>.isSome
150159

@@ -186,18 +195,30 @@ partial def getPathToBaseStructureAux (env : Environment) (baseStructName : Name
186195
| some projFn => getPathToBaseStructureAux env baseStructName parentStructName (projFn :: path)
187196

188197
/--
189-
If `baseStructName` is an ancestor structure for `structName`, then return a sequence of projection functions
198+
If `baseStructName` is an ancestor structure for `structName`, then returns a sequence of projection functions
190199
to go from `structName` to `baseStructName`.
191200
-/
192201
def getPathToBaseStructure? (env : Environment) (baseStructName : Name) (structName : Name) : Option (List Name) :=
193202
getPathToBaseStructureAux env baseStructName structName []
194203

195-
/-- Return true iff `constName` is the a non-recursive inductive datatype that has only one constructor. -/
204+
/--
205+
Returns true iff `constName` is a non-recursive inductive datatype that has only one constructor and no indices.
206+
207+
Such types have special kernel support. This must be in sync with `is_structure_like`.
208+
-/
196209
def isStructureLike (env : Environment) (constName : Name) : Bool :=
197210
match env.find? constName with
198211
| some (.inductInfo { isRec := false, ctors := [_], numIndices := 0, .. }) => true
199212
| _ => false
200213

214+
def getStructureLikeCtor? (env : Environment) (constName : Name) : Option ConstructorVal :=
215+
match env.find? constName with
216+
| some (.inductInfo { isRec := false, ctors := [ctorName], numIndices := 0, .. }) =>
217+
match env.find? ctorName with
218+
| some (ConstantInfo.ctorInfo val) => val
219+
| _ => panic! "ill-formed environment"
220+
| _ => none
221+
201222
/-- Return number of fields for a structure-like type -/
202223
def getStructureLikeNumFields (env : Environment) (constName : Name) : Nat :=
203224
match env.find? constName with
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/-!
2+
# Tests for numeric projections of inductive types
3+
-/
4+
5+
/-!
6+
Non-recursive, no indices.
7+
-/
8+
inductive I0 where
9+
| mk (x : Nat) (xs : List Nat)
10+
/-- info: fun v => v.1 : I0 → Nat -/
11+
#guard_msgs in #check fun (v : I0) => v.1
12+
/-- info: fun v => v.2 : I0 → List Nat -/
13+
#guard_msgs in #check fun (v : I0) => v.2
14+
15+
/-!
16+
Recursive, no indices.
17+
-/
18+
inductive I1 where
19+
| mk (x : Nat) (xs : I1)
20+
/-- info: fun v => v.1 : I1 → Nat -/
21+
#guard_msgs in #check fun (v : I1) => v.1
22+
/-- info: fun v => v.2 : I1 → I1 -/
23+
#guard_msgs in #check fun (v : I1) => v.2
24+
25+
/-!
26+
Non-recursive, indices.
27+
-/
28+
inductive I2 : Nat → Type where
29+
| mk (x : Nat) (xs : List (Fin x)) : I2 (x + 1)
30+
/-- info: fun v => v.1 : I2 2 → Nat -/
31+
#guard_msgs in #check fun (v : I2 2) => v.1
32+
/-- info: fun v => v.2 : (v : I2 2) → List (Fin v.1) -/
33+
#guard_msgs in #check fun (v : I2 2) => v.2
34+
35+
/-!
36+
Recursive, indices.
37+
-/
38+
inductive I3 : Nat → Type where
39+
| mk (x : Nat) (xs : I3 (x + 1)) : I3 x
40+
/-- info: fun v => v.1 : I3 2 → Nat -/
41+
#guard_msgs in #check fun (v : I3 2) => v.1
42+
/-- info: fun v => v.2 : (v : I3 2) → I3 (v.1 + 1) -/
43+
#guard_msgs in #check fun (v : I3 2) => v.2
44+
45+
46+
/-!
47+
Make sure these can be compiled.
48+
-/
49+
def f0_1 (v : I0) : Nat := v.1
50+
def f0_2 (v : I0) : List Nat := v.2
51+
def f1_1 (v : I1) : Nat := v.1
52+
def f1_2 (v : I1) : I1 := v.2
53+
def f2_1 (v : I2 n) : Nat := v.1
54+
def f2_2 (v : I2 n) : List (Fin (f2_1 v)) := v.2
55+
def f3_1 (v : I3 n) : Nat := v.1
56+
def f3_2 (v : I3 n) : I3 (f3_1 v + 1) := v.2

0 commit comments

Comments
 (0)