From c7f2ecb77ea1ae8d46571d15b523c136d00fa3b1 Mon Sep 17 00:00:00 2001
From: Giacomo Cavalieri <giacomo.cavalieri@icloud.com>
Date: Thu, 20 Jun 2024 17:24:15 +0200
Subject: [PATCH] fix-3250

---
 CHANGELOG.md                                  |  4 +
 compiler-core/src/erlang.rs                   | 57 +++++++++-----
 compiler-core/src/erlang/pattern.rs           | 76 ++++++++++---------
 compiler-core/src/erlang/tests/patterns.rs    | 46 +++++++++++
 ...ring_prefix_as_pattern_with_assertion.snap | 23 ++++++
 ...s__string_prefix_as_pattern_with_list.snap | 18 +++++
 ...fix_as_pattern_with_multiple_subjects.snap | 18 +++++
 ...tern_with_multiple_subjects_and_guard.snap | 18 +++++
 ...__string_prefix_assignment_with_guard.snap |  2 +-
 9 files changed, 205 insertions(+), 57 deletions(-)
 create mode 100644 compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_assertion.snap
 create mode 100644 compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_list.snap
 create mode 100644 compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects.snap
 create mode 100644 compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects_and_guard.snap

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9692f9d847b..cfc2edcae18 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -154,6 +154,10 @@
   they referenced a non-existent type.
   ([Gears](https://github.com/gearsdatapacks))
 
+- Fixed a bug where the compiler would generate invalid Erlang when pattern
+  matching on strings with an `as` pattern.
+  ([Giacomo Cavalieri](https://github.com/giacomocavalieri))
+
 ## v1.2.1 - 2024-05-30
 
 ### Bug Fixes
diff --git a/compiler-core/src/erlang.rs b/compiler-core/src/erlang.rs
index 8df27bdb4d0..09783b6373f 100644
--- a/compiler-core/src/erlang.rs
+++ b/compiler-core/src/erlang.rs
@@ -24,7 +24,7 @@ use ecow::EcoString;
 use heck::ToSnakeCase;
 use im::HashSet;
 use itertools::Itertools;
-use pattern::{pattern, requires_guard};
+use pattern::pattern;
 use regex::{Captures, Regex};
 use std::sync::OnceLock;
 use std::{collections::HashMap, ops::Deref, str::FromStr, sync::Arc};
@@ -871,10 +871,17 @@ fn let_assert<'a>(value: &'a TypedExpr, pat: &'a TypedPattern, env: &mut Env<'a>
         let definition = docvec![var.clone(), " = ", body, ",", line()];
         (var, definition)
     };
-    let check_pattern = pattern::to_doc_discarding_all(pat, &mut vars, env);
-    let assign_pattern = pattern::to_doc(pat, &mut vars, env);
+
+    let mut guards = vec![];
+    let check_pattern = pattern::to_doc_discarding_all(pat, &mut vars, env, &mut guards);
+    let clause_guard = optional_clause_guard(None, guards, env);
+
+    // We don't take the guards from the assign pattern or we would end up with
+    // all the same guards repeated twice!
+    let assign_pattern = pattern::to_doc(pat, &mut vars, env, &mut vec![]);
     let clauses = docvec![
         check_pattern.clone(),
+        clause_guard,
         " -> ",
         subject_var.clone(),
         ";",
@@ -908,7 +915,8 @@ fn let_assert<'a>(value: &'a TypedExpr, pat: &'a TypedPattern, env: &mut Env<'a>
 
 fn let_<'a>(value: &'a TypedExpr, pat: &'a TypedPattern, env: &mut Env<'a>) -> Document<'a> {
     let body = maybe_block_expr(value, env).group();
-    pattern(pat, env).append(" = ").append(body)
+    let mut guards = vec![];
+    pattern(pat, env, &mut guards).append(" = ").append(body)
 }
 
 fn float<'a>(value: &str) -> Document<'a> {
@@ -1089,17 +1097,21 @@ fn clause<'a>(clause: &'a TypedClause, env: &mut Env<'a>) -> Document<'a> {
         std::iter::once(pat)
             .chain(alternative_patterns)
             .map(|patterns| {
+                let mut additional_guards = vec![];
                 env.erl_function_scope_vars = initial_erlang_vars.clone();
 
                 let patterns_doc = if patterns.len() == 1 {
                     let p = patterns.first().expect("Single pattern clause printing");
-                    pattern(p, env)
+                    pattern(p, env, &mut additional_guards)
                 } else {
-                    tuple(patterns.iter().map(|p| pattern(p, env)))
+                    tuple(
+                        patterns
+                            .iter()
+                            .map(|p| pattern(p, env, &mut additional_guards)),
+                    )
                 };
 
-                let new_guard = !patterns.iter().any(requires_guard);
-                let guard = optional_clause_guard(guard.as_ref(), new_guard, env);
+                let guard = optional_clause_guard(guard.as_ref(), additional_guards, env);
                 if then_doc.is_none() {
                     then_doc = Some(clause_consequence(then, env));
                     end_erlang_vars = env.erl_function_scope_vars.clone();
@@ -1127,20 +1139,25 @@ fn clause_consequence<'a>(consequence: &'a TypedExpr, env: &mut Env<'a>) -> Docu
 
 fn optional_clause_guard<'a>(
     guard: Option<&'a TypedClauseGuard>,
-    new: bool,
+    additional_guards: Vec<Document<'a>>,
     env: &mut Env<'a>,
 ) -> Document<'a> {
-    guard
-        .map(|guard| {
-            if new {
-                " when ".to_doc().append(bare_clause_guard(guard, env))
-            } else {
-                " andalso "
-                    .to_doc()
-                    .append(bare_clause_guard(guard, env).surround("(", ")"))
-            }
-        })
-        .unwrap_or_else(nil)
+    let guard_doc = guard.map(|guard| bare_clause_guard(guard, env));
+
+    let guards_count = guard_doc.iter().len() + additional_guards.len();
+    let guards_docs = additional_guards.into_iter().chain(guard_doc).map(|guard| {
+        if guards_count > 1 {
+            guard.surround("(", ")")
+        } else {
+            guard
+        }
+    });
+    let doc = join(guards_docs, " andalso ".to_doc());
+    if doc.is_empty() {
+        doc
+    } else {
+        " when ".to_doc().append(doc)
+    }
 }
 
 fn bare_clause_guard<'a>(guard: &'a TypedClauseGuard, env: &mut Env<'a>) -> Document<'a> {
diff --git a/compiler-core/src/erlang/pattern.rs b/compiler-core/src/erlang/pattern.rs
index c637e78353f..b1dd2cc0294 100644
--- a/compiler-core/src/erlang/pattern.rs
+++ b/compiler-core/src/erlang/pattern.rs
@@ -2,19 +2,13 @@ use crate::analyse::Inferred;
 
 use super::*;
 
-pub(super) fn pattern<'a>(p: &'a TypedPattern, env: &mut Env<'a>) -> Document<'a> {
+pub(super) fn pattern<'a>(
+    p: &'a TypedPattern,
+    env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
+) -> Document<'a> {
     let mut vars = vec![];
-    to_doc(p, &mut vars, env)
-}
-
-pub(super) fn requires_guard(p: &TypedPattern) -> bool {
-    match p {
-        Pattern::StringPrefix {
-            left_side_assignment: Some(_),
-            ..
-        } => true,
-        _ => false,
-    }
+    to_doc(p, &mut vars, env, guards)
 }
 
 fn print<'a>(
@@ -22,22 +16,28 @@ fn print<'a>(
     vars: &mut Vec<&'a str>,
     define_variables: bool,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
     match p {
         Pattern::Assign {
             name, pattern: p, ..
         } if define_variables => {
             vars.push(name);
-            print(p, vars, define_variables, env)
+            print(p, vars, define_variables, env, guards)
                 .append(" = ")
                 .append(env.next_local_var_name(name))
         }
 
-        Pattern::Assign { pattern: p, .. } => print(p, vars, define_variables, env),
+        Pattern::Assign { pattern: p, .. } => print(p, vars, define_variables, env, guards),
 
-        Pattern::List { elements, tail, .. } => {
-            pattern_list(elements, tail.as_deref(), vars, define_variables, env)
-        }
+        Pattern::List { elements, tail, .. } => pattern_list(
+            elements,
+            tail.as_deref(),
+            vars,
+            define_variables,
+            env,
+            guards,
+        ),
 
         Pattern::Discard { .. } => "_".to_doc(),
 
@@ -73,7 +73,7 @@ fn print<'a>(
             arguments: args,
             constructor: Inferred::Known(PatternConstructor { name, .. }),
             ..
-        } => tag_tuple_pattern(name, args, vars, define_variables, env),
+        } => tag_tuple_pattern(name, args, vars, define_variables, env, guards),
 
         Pattern::Constructor {
             constructor: Inferred::Unknown,
@@ -82,16 +82,18 @@ fn print<'a>(
             panic!("Erlang generation performed with uninferred pattern constructor")
         }
 
-        Pattern::Tuple { elems, .. } => {
-            tuple(elems.iter().map(|p| print(p, vars, define_variables, env)))
-        }
-
-        Pattern::BitArray { segments, .. } => bit_array(
-            segments
+        Pattern::Tuple { elems, .. } => tuple(
+            elems
                 .iter()
-                .map(|s| pattern_segment(&s.value, &s.options, vars, define_variables, env)),
+                .map(|p| print(p, vars, define_variables, env, guards)),
         ),
 
+        Pattern::BitArray { segments, .. } => {
+            bit_array(segments.iter().map(|s| {
+                pattern_segment(&s.value, &s.options, vars, define_variables, env, guards)
+            }))
+        }
+
         Pattern::StringPrefix {
             left_side_string,
             right_side_assignment,
@@ -117,6 +119,7 @@ fn print<'a>(
                     //   bytes, then use a guard clause to verify the content.
                     //
                     let name = env.next_local_var_name(left_name);
+                    guards.push(docvec![name.clone(), " =:= ", string(left_side_string)]);
                     docvec![
                         "<<",
                         name.clone(),
@@ -126,10 +129,6 @@ fn print<'a>(
                         ", ",
                         right,
                         "/binary>>",
-                        " when ",
-                        name,
-                        " =:= ",
-                        string(left_side_string)
                     ]
                 }
                 None => docvec![
@@ -151,16 +150,18 @@ pub(super) fn to_doc<'a>(
     p: &'a TypedPattern,
     vars: &mut Vec<&'a str>,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
-    print(p, vars, true, env)
+    print(p, vars, true, env, guards)
 }
 
 pub(super) fn to_doc_discarding_all<'a>(
     p: &'a TypedPattern,
     vars: &mut Vec<&'a str>,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
-    print(p, vars, false, env)
+    print(p, vars, false, env, guards)
 }
 
 fn tag_tuple_pattern<'a>(
@@ -169,6 +170,7 @@ fn tag_tuple_pattern<'a>(
     vars: &mut Vec<&'a str>,
     define_variables: bool,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
     if args.is_empty() {
         atom_string(name.to_snake_case())
@@ -176,7 +178,7 @@ fn tag_tuple_pattern<'a>(
         tuple(
             [atom_string(name.to_snake_case())].into_iter().chain(
                 args.iter()
-                    .map(|p| print(&p.value, vars, define_variables, env)),
+                    .map(|p| print(&p.value, vars, define_variables, env, guards)),
             ),
         )
     }
@@ -188,6 +190,7 @@ fn pattern_segment<'a>(
     vars: &mut Vec<&'a str>,
     define_variables: bool,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
     let document = match value {
         // Skip the normal <<value/utf8>> surrounds
@@ -197,7 +200,7 @@ fn pattern_segment<'a>(
         Pattern::Discard { .. }
         | Pattern::Variable { .. }
         | Pattern::Int { .. }
-        | Pattern::Float { .. } => print(value, vars, define_variables, env),
+        | Pattern::Float { .. } => print(value, vars, define_variables, env, guards),
 
         // No other pattern variants are allowed in pattern bit array segments
         _ => panic!("Pattern segment match not recognised"),
@@ -206,7 +209,7 @@ fn pattern_segment<'a>(
     let size = |value: &'a TypedPattern, env: &mut Env<'a>| {
         Some(
             ":".to_doc()
-                .append(print(value, vars, define_variables, env)),
+                .append(print(value, vars, define_variables, env, guards)),
         )
     };
 
@@ -221,13 +224,14 @@ fn pattern_list<'a>(
     vars: &mut Vec<&'a str>,
     define_variables: bool,
     env: &mut Env<'a>,
+    guards: &mut Vec<Document<'a>>,
 ) -> Document<'a> {
     let elements = join(
         elements
             .iter()
-            .map(|e| print(e, vars, define_variables, env)),
+            .map(|e| print(e, vars, define_variables, env, guards)),
         break_(",", ", "),
     );
-    let tail = tail.map(|tail| print(tail, vars, define_variables, env));
+    let tail = tail.map(|tail| print(tail, vars, define_variables, env, guards));
     list(elements, tail)
 }
diff --git a/compiler-core/src/erlang/tests/patterns.rs b/compiler-core/src/erlang/tests/patterns.rs
index fc7ea2e815e..af7a15ecbdb 100644
--- a/compiler-core/src/erlang/tests/patterns.rs
+++ b/compiler-core/src/erlang/tests/patterns.rs
@@ -76,3 +76,49 @@ fn pattern_as() {
 }"
     );
 }
+
+#[test]
+fn string_prefix_as_pattern_with_multiple_subjects() {
+    assert_erl!(
+        "pub fn a(x) {
+  case x, x {
+    _, \"a\" as a <> _  -> a
+    _, _ -> \"a\"
+  }
+}"
+    );
+}
+
+#[test]
+fn string_prefix_as_pattern_with_multiple_subjects_and_guard() {
+    assert_erl!(
+        "pub fn a(x) {
+  case x, x {
+    _, \"a\" as a <> rest if rest == \"a\" -> a
+    _, _ -> \"a\"
+  }
+}"
+    );
+}
+
+#[test]
+fn string_prefix_as_pattern_with_list() {
+    assert_erl!(
+        "pub fn a(x) {
+  case x {
+    [\"a\" as a <> _, \"b\" as b <> _] -> a <> b
+    _ -> \"\"
+  }
+}"
+    );
+}
+
+#[test]
+fn string_prefix_as_pattern_with_assertion() {
+    assert_erl!(
+        "pub fn a(x) {
+  let assert \"a\" as a <> rest = \"wibble\"
+  a
+}"
+    );
+}
diff --git a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_assertion.snap b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_assertion.snap
new file mode 100644
index 00000000000..d38a9f29ba1
--- /dev/null
+++ b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_assertion.snap
@@ -0,0 +1,23 @@
+---
+source: compiler-core/src/erlang/tests/patterns.rs
+expression: "pub fn a(x) {\n  let assert \"a\" as a <> rest = \"wibble\"\n  a\n}"
+---
+-module(my@mod).
+-compile([no_auto_import, nowarn_unused_vars, nowarn_unused_function, nowarn_nomatch]).
+
+-export([a/1]).
+
+-spec a(any()) -> binary().
+a(X) ->
+    _assert_subject = <<"wibble"/utf8>>,
+    <<A@1:1/binary, Rest/binary>> = case _assert_subject of
+        <<A:1/binary, _/binary>> when A =:= <<"a"/utf8>> -> _assert_subject;
+        _assert_fail ->
+            erlang:error(#{gleam_error => let_assert,
+                        message => <<"Assertion pattern match failed"/utf8>>,
+                        value => _assert_fail,
+                        module => <<"my/mod"/utf8>>,
+                        function => <<"a"/utf8>>,
+                        line => 2})
+    end,
+    A@1.
diff --git a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_list.snap b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_list.snap
new file mode 100644
index 00000000000..7c05bab83cd
--- /dev/null
+++ b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_list.snap
@@ -0,0 +1,18 @@
+---
+source: compiler-core/src/erlang/tests/patterns.rs
+expression: "pub fn a(x) {\n  case x {\n    [\"a\" as a <> _, \"b\" as b <> _] -> a <> b\n    _ -> \"\"\n  }\n}"
+---
+-module(my@mod).
+-compile([no_auto_import, nowarn_unused_vars, nowarn_unused_function, nowarn_nomatch]).
+
+-export([a/1]).
+
+-spec a(list(binary())) -> binary().
+a(X) ->
+    case X of
+        [<<A:1/binary, _/binary>>, <<B:1/binary, _/binary>>] when (A =:= <<"a"/utf8>>) andalso (B =:= <<"b"/utf8>>) ->
+            <<A/binary, B/binary>>;
+
+        _ ->
+            <<""/utf8>>
+    end.
diff --git a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects.snap b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects.snap
new file mode 100644
index 00000000000..c9ce5985978
--- /dev/null
+++ b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects.snap
@@ -0,0 +1,18 @@
+---
+source: compiler-core/src/erlang/tests/patterns.rs
+expression: "pub fn a(x) {\n  case x, x {\n    _, \"a\" as a <> _  -> a\n    _, _ -> \"a\"\n  }\n}"
+---
+-module(my@mod).
+-compile([no_auto_import, nowarn_unused_vars, nowarn_unused_function, nowarn_nomatch]).
+
+-export([a/1]).
+
+-spec a(binary()) -> binary().
+a(X) ->
+    case {X, X} of
+        {_, <<A:1/binary, _/binary>>} when A =:= <<"a"/utf8>> ->
+            A;
+
+        {_, _} ->
+            <<"a"/utf8>>
+    end.
diff --git a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects_and_guard.snap b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects_and_guard.snap
new file mode 100644
index 00000000000..2e7f966ea34
--- /dev/null
+++ b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__patterns__string_prefix_as_pattern_with_multiple_subjects_and_guard.snap
@@ -0,0 +1,18 @@
+---
+source: compiler-core/src/erlang/tests/patterns.rs
+expression: "pub fn a(x) {\n  case x, x {\n    _, \"a\" as a <> rest if rest == \"a\" -> a\n    _, _ -> \"a\"\n  }\n}"
+---
+-module(my@mod).
+-compile([no_auto_import, nowarn_unused_vars, nowarn_unused_function, nowarn_nomatch]).
+
+-export([a/1]).
+
+-spec a(binary()) -> binary().
+a(X) ->
+    case {X, X} of
+        {_, <<A:1/binary, Rest/binary>>} when (A =:= <<"a"/utf8>>) andalso (Rest =:= <<"a"/utf8>>) ->
+            A;
+
+        {_, _} ->
+            <<"a"/utf8>>
+    end.
diff --git a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__strings__string_prefix_assignment_with_guard.snap b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__strings__string_prefix_assignment_with_guard.snap
index cff3316d2a0..5be11ccce79 100644
--- a/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__strings__string_prefix_assignment_with_guard.snap
+++ b/compiler-core/src/erlang/tests/snapshots/gleam_core__erlang__tests__strings__string_prefix_assignment_with_guard.snap
@@ -10,7 +10,7 @@ expression: "\npub fn go(x) {\n  case x {\n    \"Hello, \" as greeting <> name i
 -spec go(binary()) -> binary().
 go(X) ->
     case X of
-        <<Greeting:7/binary, Name/binary>> when Greeting =:= <<"Hello, "/utf8>> andalso (Name =:= <<"Dude"/utf8>>) ->
+        <<Greeting:7/binary, Name/binary>> when (Greeting =:= <<"Hello, "/utf8>>) andalso (Name =:= <<"Dude"/utf8>>) ->
             <<Greeting/binary, "Mate"/utf8>>;
 
         <<Greeting@1:7/binary, Name@1/binary>> when Greeting@1 =:= <<"Hello, "/utf8>> ->