Skip to content

Commit

Permalink
Fix Ctr patterns not being renamed in imports (#725)
Browse files Browse the repository at this point in the history
  • Loading branch information
developedby authored Oct 7, 2024
1 parent 53d5262 commit 4d6f461
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
25 changes: 18 additions & 7 deletions src/imports/book.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{BindMap, ImportsMap, PackageLoader};
use crate::{
diagnostics::{Diagnostics, DiagnosticsConfig},
fun::{
parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Rule, Source, SourceKind, Term,
parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Pattern, Source, SourceKind, Term,
},
imp::{self, Expr, MatchArm, Stmt},
imports::packages::Packages,
Expand Down Expand Up @@ -346,19 +346,30 @@ trait Def {

impl Def for Definition {
fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap) {
fn rename_ctr_patterns(rule: &mut Rule, binds: &BindMap) {
for pat in &mut rule.pats {
for bind in pat.binds_mut().flatten() {
if let Some(alias) = binds.get(bind) {
*bind = alias.clone();
fn rename_ctr_pattern(pat: &mut Pattern, binds: &BindMap) {
for pat in pat.children_mut() {
rename_ctr_pattern(pat, binds);
}
match pat {
Pattern::Ctr(nam, _) => {
if let Some(alias) = binds.get(nam) {
*nam = alias.clone();
}
}
Pattern::Var(Some(nam)) => {
if let Some(alias) = binds.get(nam) {
*nam = alias.clone();
}
}
_ => {}
}
}

for rule in &mut self.rules {
if maybe_constructor {
rename_ctr_patterns(rule, binds);
for pat in &mut rule.pats {
rename_ctr_pattern(pat, binds);
}
}
let bod = std::mem::take(&mut rule.body);
rule.body = bod.fold_uses(binds.iter().rev());
Expand Down
6 changes: 5 additions & 1 deletion tests/golden_tests/import_system/import_type.bend
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from lib/MyOption import (MyOption, MyOption/bind, MyOption/wrap)

unwrap (val : (MyOption u24)) : u24
unwrap (MyOption/Some x) = x
unwrap (MyOption/None) = 0

def main() -> MyOption((u24, u24)):
with MyOption:
a <- MyOption/Some(1)
b <- MyOption/Some(2)
b = unwrap(MyOption/Some(2))
return wrap((a, b))

0 comments on commit 4d6f461

Please sign in to comment.