diff --git a/base/src/main/java/org/aya/normalize/LetReplacer.java b/base/src/main/java/org/aya/normalize/LetReplacer.java index f072f8cdc9..8519ce5d1c 100644 --- a/base/src/main/java/org/aya/normalize/LetReplacer.java +++ b/base/src/main/java/org/aya/normalize/LetReplacer.java @@ -9,6 +9,8 @@ import org.aya.tyck.ctx.LocalLet; import org.jetbrains.annotations.NotNull; +/// This implements [FreeTerm] substitution. The substitution object is represented using a +/// [LocalLet] for convenience -- for the functionality we only need [LocalLet#contains] and [LocalLet#get]. public record LetReplacer(@NotNull LocalLet let) implements UnaryOperator { @Override public Term apply(Term term) { return switch (term) { diff --git a/base/src/main/java/org/aya/tyck/StmtTycker.java b/base/src/main/java/org/aya/tyck/StmtTycker.java index 0cc5138e1b..2bf208ab6b 100644 --- a/base/src/main/java/org/aya/tyck/StmtTycker.java +++ b/base/src/main/java/org/aya/tyck/StmtTycker.java @@ -46,6 +46,19 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +/// Checks definitions. All the dirty telescope manipulation are here. +/// This class glues the type checking of exprs and patterns together. +/// +/// Note that we handle mutual recursions, so we support checking the _signature_ of a decl +/// without checking its body. It is like checking forward declarations, but the forward +/// declarations are just a part of the real decl. This is done in [#checkHeader]. +/// +/// For [PrimDef] and [ConDef], they only have headers. The body of a [DataDef] is the header +/// of all of its [DataCon]s, but just checking the [DataCon]s themselves is not enough: they need to be +/// put together and added to the [DataDef], so we can use them for exhaustiveness checking. +/// +/// For [ClassDef] and [MemberDef], they both only have headers, because they don't allow mutual recursion. +/// The header of a [ClassDef] is just all of its [ClassMember]s. public record StmtTycker( @NotNull SuppressingReporter reporter, @NotNull ModulePath fileModule, @NotNull ShapeFactory shapeFactory, @NotNull PrimFactory primFactory @@ -241,11 +254,9 @@ private void checkMember(@NotNull ClassMember member, @NotNull ExprTycker tycker } } - /** - * Kitsune says kon! - * - * @apiNote invoke this method after loading the telescope of data! - */ + /// Kitsune says kon! Checks the data constructor. + /// + /// @apiNote invoke this method after loading the telescope of data! private void checkKitsune(@NotNull DataCon con, @NotNull ExprTycker tycker) { var ref = con.ref; if (ref.core != null) return; diff --git a/base/src/main/java/org/aya/tyck/pat/ClauseTycker.java b/base/src/main/java/org/aya/tyck/pat/ClauseTycker.java index 20a528dcdd..404085dbab 100644 --- a/base/src/main/java/org/aya/tyck/pat/ClauseTycker.java +++ b/base/src/main/java/org/aya/tyck/pat/ClauseTycker.java @@ -122,20 +122,8 @@ public record Worker( if (clauses.get(i).expr.isEmpty()) continue; var currentClasses = usages.get(i); if (currentClasses.sizeEquals(1)) { - var curLhs = lhs.get(i); - var curCls = currentClasses.get(0); - var lets = new PatBinder().apply(curLhs.freePats(), curCls.term()); - if (lets.let().let().allMatch((_, j) -> j.wellTyped() instanceof FreeTerm)) - continue; - var sibling = Objects.requireNonNull(curLhs.localCtx.parent()).derive(); - var newPatterns = curCls.pat().map(pat -> pat.descentTerm(lets)); - newPatterns.forEach(pat -> pat.consumeBindings(sibling::put)); - curLhs.asSubst.let().replaceAll((_, t) -> t.map(lets)); - var paramSubst = curLhs.paramSubst.map(jdg -> jdg.map(lets)); - lets.let().let().forEach(curLhs.asSubst::put); - lhs.set(i, new LhsResult( - sibling, lets.apply(curLhs.result), curLhs.unpiParamSize, newPatterns, - curLhs.sourcePos, curLhs.body, paramSubst, curLhs.asSubst, curLhs.hasError)); + var newLhs = refinePattern(lhs.get(i), currentClasses.get(0)); + if (newLhs != null) lhs.set(i, newLhs); } } } @@ -152,6 +140,37 @@ public record Worker( return new WorkerResult(wellTyped, hasError); } + /// When we realize (in first-match only) that a clause is only reachable for a single leaf in the case tree, + /// we try to specialize the patterns according to the case tree leaf. For example, + /// ``` + /// f zero = body1 + /// f x = body2 + ///``` + /// The `x` in the second case is only reachable for input `suc y`, + /// and we can realize this by inspecting the result of [PatClassifier#firstMatchDomination]. + /// So, we can replace `x` with `suc y` to help computing the result type. + /// A more realistic motivating example can be found + /// [here](https://twitter.com/zornsllama/status/1465435870861926400). + /// + /// However, we cannot just simply replace the patterns -- the localCtx obtained by checking the patterns, + /// the result type [LhsResult#result], the types in the patterns, and [LhsResult#paramSubst], + /// all of these need to be changed accordingly. + /// This method performs these changes. + private @Nullable LhsResult refinePattern(LhsResult curLhs, PatClass.Seq curCls) { + var lets = new PatBinder().apply(curLhs.freePats(), curCls.term()); + if (lets.let().let().allMatch((_, j) -> j.wellTyped() instanceof FreeTerm)) + return null; + var sibling = Objects.requireNonNull(curLhs.localCtx.parent()).derive(); + var newPatterns = curCls.pat().map(pat -> pat.descentTerm(lets)); + newPatterns.forEach(pat -> pat.consumeBindings(sibling::put)); + curLhs.asSubst.let().replaceAll((_, t) -> t.map(lets)); + var paramSubst = curLhs.paramSubst.map(jdg -> jdg.map(lets)); + lets.let().let().forEach(curLhs.asSubst::put); + return new LhsResult( + sibling, lets.apply(curLhs.result), curLhs.unpiParamSize, newPatterns, + curLhs.sourcePos, curLhs.body, paramSubst, curLhs.asSubst, curLhs.hasError); + } + public @NotNull MutableSeq checkAllLhs() { return parent.checkAllLhs(() -> SignatureIterator.make(telescope, unpi, teleVars, elims), diff --git a/base/src/main/java/org/aya/tyck/pat/PatClassifier.java b/base/src/main/java/org/aya/tyck/pat/PatClassifier.java index 5992e22394..5e2cf4216a 100644 --- a/base/src/main/java/org/aya/tyck/pat/PatClassifier.java +++ b/base/src/main/java/org/aya/tyck/pat/PatClassifier.java @@ -8,7 +8,6 @@ import kala.collection.SeqView; import kala.collection.immutable.ImmutableSeq; import kala.collection.immutable.primitive.ImmutableIntSeq; -import kala.collection.mutable.MutableArrayList; import kala.collection.mutable.MutableList; import kala.collection.mutable.MutableSeq; import kala.control.Result; @@ -38,6 +37,22 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +/// Coverage checking & case tree generation. Part of the code is generalized and moved to [ClassifierUtil], +/// which is reusable. The main subroutine of coverage checking is _splitting_, i.e. look at a list of clauses, +/// group them by those who match the same constructor head, and recurse -- hence the name _classifier_. +/// +/// Note that catch-all patterns will be put into all groups, and literals will have their special classification +/// rules -- if you have very large literals, turning them into constructors and classify with the constructors will +/// be very slow. For pure literal pattern matching, we will only split on literals, and ask for a catch-all. +/// +/// There are 3 variants of this task: +/// +/// * Look at a single pattern and split according to the head. This is [#classify1], and is language-specific +/// (it depends on what type formers does a language have), so cannot be generalized. +/// * Look at a list of patterns and split them monadically. This is [ClassifierUtil#classifyN], +/// which simply calls [#classify1] and flatMap on the results, so it's language-independent, and is generalized. +/// * Look at a pair of patterns and split them. This is [ClassifierUtil#classify2], +/// which is just a special case of [#classifyN], but implemented for convenience of dealing with binary tuples. public record PatClassifier( @NotNull AbstractTycker delegate, @NotNull SourcePos pos ) implements ClassifierUtil, Term, Param, Pat>, Stateful, Problematic { @@ -112,14 +127,12 @@ case DepTypeTerm(var kind, var lT, var rT) when kind == DTKind.Sigma -> { var binds = Indexed.indices(clauses.filter(cl -> cl.pat() instanceof Pat.Bind)); if (clauses.isNotEmpty() && lits.size() + binds.size() == clauses.size()) { // There is only literals and bind patterns, no constructor patterns + // So we do not turn them into constructors, but split the literals directly var classes = ImmutableSeq.from(lits.collect( Collectors.groupingBy(i -> i.pat().repr())).values()) .map(i -> simple(i.getFirst().pat(), Indexed.indices(Seq.wrapJava(i)).concat(binds))); - var ml = MutableArrayList.>create(classes.size() + 1); - ml.appendAll(classes); - ml.append(simple(param.toFreshPat(), binds)); - return ml.toImmutableSeq(); + return classes.appended(simple(param.toFreshPat(), binds)); } var buffer = MutableList.>create(); diff --git a/base/src/main/java/org/aya/unify/TermComparator.java b/base/src/main/java/org/aya/unify/TermComparator.java index ceab7484e8..05e7ad11bf 100644 --- a/base/src/main/java/org/aya/unify/TermComparator.java +++ b/base/src/main/java/org/aya/unify/TermComparator.java @@ -1,7 +1,12 @@ -// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang. +// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang. // Use of this source code is governed by the MIT license that can be found in the LICENSE.md file. package org.aya.unify; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.function.UnaryOperator; + import kala.collection.immutable.ImmutableSeq; import kala.collection.mutable.MutableList; import kala.collection.mutable.MutableStack; @@ -33,11 +38,6 @@ import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.util.Objects; -import java.util.function.Function; -import java.util.function.Supplier; -import java.util.function.UnaryOperator; - public abstract sealed class TermComparator extends AbstractTycker permits Unifier { protected final @NotNull SourcePos pos; protected @NotNull Ordering cmp; @@ -143,15 +143,14 @@ case Pair(MemberCall lMem, MemberCall rMem) -> { // TODO: type info? if (!compare(lMem.of(), rMem.of(), null)) yield null; yield compareMany(lMem.args(), rMem.args(), - lMem.ref().signature().inst(ImmutableSeq.of(lMem.of())).lift(Math.min(lMem.ulift(), rMem.ulift()))); + lMem.ref().signature().inst(ImmutableSeq.of(lMem.of())).lift(Math.min(lMem.ulift(), rMem.ulift()))); } default -> null; }; } - /** - * Compare the arguments of two callable ONLY, this method will NOT try to normalize and then compare (while the old project does). - */ + /// Compare the arguments of two callable ONLY, this method will NOT try to + /// normalize and then compare (while the old version of Aya does). private @Nullable Term compareCallApprox(@NotNull Callable.Tele lhs, @NotNull Callable.Tele rhs) { if (!lhs.ref().equals(rhs.ref())) return null; return compareMany(lhs.args(), rhs.args(), @@ -165,11 +164,9 @@ private R swapped(@NotNull Supplier callback) { return result; } - /** - * Compare two terms with the given {@param type} (if not null) - * - * @return true if they are 'the same' under {@param type}, false otherwise. - */ + /// Compare two terms with the given {@param type} (if not null) + /// + /// @return true if they are 'the same' under {@param type}, false otherwise. public boolean compare(@NotNull Term preLhs, @NotNull Term preRhs, @Nullable Term type) { if (preLhs == preRhs || preLhs instanceof ErrorTerm || preRhs instanceof ErrorTerm) return true; if (checkApproxResult(type, compareApprox(preLhs, preRhs))) return true; @@ -223,12 +220,10 @@ private boolean checkApproxResult(@Nullable Term type, Term approxResult) { } else return false; } - /** - * Compare whnf {@param lhs} and whnf {@param rhs} with {@param type} information - * - * @param type the whnf type. - * @return whether they are 'the same' and their types are {@param type} - */ + /// Compare whnf {@param lhs} and whnf {@param rhs} with {@param type} information + /// + /// @param type the type in whnf. + /// @return whether they are 'the same' and their types are {@param type} private boolean doCompareTyped(@NotNull Term lhs, @NotNull Term rhs, @NotNull Term type) { return switch (whnf(type)) { case LamTerm _, ConCallLike _, TupTerm _ -> Panic.unreachable(); @@ -286,11 +281,9 @@ case DepTypeTerm(_, var lTy, var rTy) -> { }; } - /** - * Compare head-normalized {@param preLhs} and whnfed {@param preRhs} without type information. - * - * @return the head-normalized type of {@param preLhs} and {@param preRhs} if they are 'the same', null otherwise. - */ + /// Compare head-normalized {@param preLhs} and whnfed {@param preRhs} without type information. + /// + /// @return the head-normalized type of {@param preLhs} and {@param preRhs} if they are _the same_, null otherwise. private @Nullable Term compareUntyped(@NotNull Term preLhs, @NotNull Term preRhs) { { var result = compareApprox(preLhs, preRhs); diff --git a/syntax/src/main/java/org/aya/syntax/core/term/Term.java b/syntax/src/main/java/org/aya/syntax/core/term/Term.java index a2efa3aaab..af9fac92e4 100644 --- a/syntax/src/main/java/org/aya/syntax/core/term/Term.java +++ b/syntax/src/main/java/org/aya/syntax/core/term/Term.java @@ -2,6 +2,10 @@ // Use of this source code is governed by the MIT license that can be found in the LICENSE.md file. package org.aya.syntax.core.term; +import java.io.Serializable; +import java.util.function.Consumer; +import java.util.function.UnaryOperator; + import kala.collection.SeqView; import kala.collection.immutable.ImmutableSeq; import kala.function.IndexedFunction; @@ -22,10 +26,8 @@ import org.jetbrains.annotations.ApiStatus; import org.jetbrains.annotations.NotNull; -import java.io.Serializable; -import java.util.function.Consumer; -import java.util.function.UnaryOperator; - +/// The core syntax of Aya. To understand how locally nameless works, see [#bindAllFrom] and [#replaceAllFrom], +/// together with their overrides in [LocalTerm] and [FreeTerm]. public sealed interface Term extends Serializable, AyaDocile permits ClassCastTerm, LocalTerm, Callable, BetaRedex, Formation, StableWHNF, TyckInternal, CoeTerm { @@ -46,27 +48,23 @@ public sealed interface Term extends Serializable, AyaDocile return descent((i, t) -> t.bindAllFrom(vars, fromDepth + i)); } - /** - * Corresponds to abstract operator in [MM 2004]. - * However, abstract is a keyword in Java, so we can't - * use it as a method name. - *
-   * abstract :: Name → Expr → Scope
-   * 
- * - * @apiNote bind preserve the term former unless it's a {@link FreeTerm}. - * @see Closure#apply(Term) - * @see Closure#mkConst - */ + /// Corresponds to _abstract_ operator in \[MM 2004\]. + /// However, `abstract` is a keyword in Java, so we can't + /// use it as a method name. + /// ```haskell + /// abstract :: Name → Expr → Scope + /// ``` + /// + /// @apiNote bind preserve the term former unless it's a [FreeTerm]. + /// @see Closure#apply(Term) + /// @see Closure#mkConst default @NotNull Closure.Locns bind(@NotNull LocalVar var) { return new Closure.Locns(bindAt(var, 0)); } - /** - * Used nontrivially for pattern match expressions, where the clauses are lifted to a global definition, - * so after binding the pattern-introduced variables, we need to bind all the free vars, - * which will be indexed from the bindCount, rather than 0. - */ + /// Used nontrivially for pattern match expressions, where the clauses are lifted to a global definition, + /// so after binding the pattern-introduced variables, we need to bind all the free vars, + /// which will be indexed from the bindCount, rather than 0. default @NotNull Term bindTele(int depth, @NotNull SeqView teleVars) { if (teleVars.isEmpty()) return this; return bindAllFrom(teleVars.reversed().toImmutableSeq(), depth); @@ -86,30 +84,25 @@ public sealed interface Term extends Serializable, AyaDocile return descent((i, t) -> t.replaceAllFrom(from + i, list)); } - /** - * @see #replaceAllFrom(int, ImmutableSeq) - * @see #instTele(SeqView) - */ + /// @see #replaceAllFrom(int, ImmutableSeq) + /// @see #instTele(SeqView) default @NotNull Term instTeleFrom(int from, @NotNull SeqView tele) { return replaceAllFrom(from, tele.reversed().toImmutableSeq()); } - /** - * Corresponds to instantiate operator in [MM 2004]. - * Could be called apply similar to Mini-TT. - */ + /// Corresponds to _instantiate_ operator in \[MM 2004\]. + /// Could be called `apply` similar to Mini-TT, but `apply` is used a lot as method name in Java. @ApiStatus.Internal default @NotNull Term instantiate(Term arg) { return instTeleFrom(0, SeqView.of(arg)); } - /** - * Instantiate in telescope-order. For example:
- * Consider a signature {@code (?2 : Nat) (?1 : Bool) (?0 : True) -> P ?2 ?0 ?1}, - * we can instantiate the result {@code P ?2 ?0 ?1} by some argument {@code [ 114514 , false , tt ] }, - * now it becomes {@code P 114514 tt false}. - * Without this method, we need to reverse the list. - */ + /// Instantiate in telescope-order. For example: + /// + /// Consider a signature `(?2 : Nat) (?1 : Bool) (?0 : True) -> P ?2 ?0 ?1`, + /// we can instantiate the result `P ?2 ?0 ?1` by some argument `[ 114514 , false , tt ]`, + /// now it becomes `P 114514 tt false`. + /// Without this method, we need to reverse the list. default @NotNull Term instTele(@NotNull SeqView tele) { return instTeleFrom(0, tele); } @@ -151,11 +144,9 @@ public sealed interface Term extends Serializable, AyaDocile return this.descent((_, t) -> f.apply(t)); } - /** - * Lift the sort level of this term - * - * @param level level, should be non-negative - */ + /// Lift the sort level of this term + /// + /// @param level level, should be non-negative @ApiStatus.NonExtendable default @NotNull Term elevate(int level) { assert level >= 0 : "level >= 0";