diff --git a/site/primitives.json b/site/primitives.json index 79c74a77c..76b572e5b 100644 --- a/site/primitives.json +++ b/site/primitives.json @@ -441,7 +441,7 @@ "csv": { "args": 1, "outputs": 1, - "class": "Misc", + "class": "Encoding", "description": "Encode an array into a CSV string" }, "deal": { @@ -685,7 +685,7 @@ "json": { "args": 1, "outputs": 1, - "class": "Misc", + "class": "Encoding", "description": "Encode an array into a JSON string" }, "keep": { @@ -1171,7 +1171,7 @@ "utf": { "args": 1, "outputs": 1, - "class": "Misc", + "class": "Encoding", "description": "Convert a string to UTF-8 bytes" }, "wait": { @@ -1197,7 +1197,7 @@ "xlsx": { "args": 1, "outputs": 1, - "class": "Misc", + "class": "Encoding", "description": "Encode an array into XLSX bytes" } } \ No newline at end of file diff --git a/src/algorithm/dyadic/mod.rs b/src/algorithm/dyadic/mod.rs index bebbd37c2..9dedc433a 100644 --- a/src/algorithm/dyadic/mod.rs +++ b/src/algorithm/dyadic/mod.rs @@ -519,84 +519,20 @@ impl Array { /// `keep` this array with some counts pub fn list_keep(mut self, counts: &[usize], env: &Uiua) -> UiuaResult { self.take_map_keys(); - let mut amount = Cow::Borrowed(counts); - match amount.len().cmp(&self.row_count()) { - Ordering::Equal => {} - Ordering::Less => match env.num_array_fill() { - Ok(fill) => { - if let Some(n) = fill.data.iter().find(|&&n| n < 0.0 || n.fract() != 0.0) { - return Err(env.error(format!( - "Fill value for keep must be an array of \ - non-negative integers, but one of the \ - values is {n}" - ))); - } - match fill.rank() { - 0 => { - let fill = fill.data[0] as usize; - let mut new_amount = amount.to_vec(); - new_amount.extend(repeat(fill).take(self.row_count() - amount.len())); - amount = new_amount.into(); - } - 1 => { - let mut new_amount = amount.to_vec(); - new_amount.extend( - (fill.data.iter().map(|&n| n as usize).cycle()) - .take(self.row_count() - amount.len()), - ); - amount = new_amount.into(); - } - _ => { - return Err(env.error(format!( - "Fill value for keep must be a scalar or a 1D array, \ - but it has shape {}", - fill.shape - ))); - } - } - } - Err(e) => { - return Err(env.error(format!( - "Cannot keep array with shape {} with array of shape {}{e}", - self.shape(), - FormatShape(&[amount.len()]) - ))); - } - }, - Ordering::Greater => { - return Err(env.error(match env.num_array_fill() { - Ok(_) => { - format!( - "Cannot keep array with shape {} with array of shape {}. \ - A fill value is available, but keep can only been filled \ - if there are fewer counts than rows.", - self.shape(), - FormatShape(&[amount.len()]) - ) - } - Err(e) => { - format!( - "Cannot keep array with shape {} with array of shape {}{e}", - self.shape(), - FormatShape(&[amount.len()]) - ) - } - })) - } - } + let counts = pad_keep_counts(counts, self.row_count(), env)?; if self.rank() == 0 { - if amount.len() != 1 { + if counts.len() != 1 { return Err(env.error("Scalar array can only be kept with a single number")); } - let mut new_data = EcoVec::with_capacity(amount[0]); - for _ in 0..amount[0] { + let mut new_data = EcoVec::with_capacity(counts[0]); + for _ in 0..counts[0] { new_data.push(self.data[0].clone()); } self = new_data.into(); } else { let mut all_bools = true; let mut true_count = 0; - for &n in amount.iter() { + for &n in counts.iter() { match n { 0 => {} 1 => true_count += 1, @@ -611,7 +547,7 @@ impl Array { let new_flat_len = true_count * row_len; let mut new_data = CowSlice::with_capacity(new_flat_len); if row_len > 0 { - for (b, r) in amount.iter().zip(self.data.chunks_exact(row_len)) { + for (b, r) in counts.iter().zip(self.data.chunks_exact(row_len)) { if *b == 1 { new_data.extend_from_slice(r); } @@ -623,14 +559,14 @@ impl Array { let mut new_data = CowSlice::new(); let mut new_len = 0; if row_len > 0 { - for (n, r) in amount.iter().zip(self.data.chunks_exact(row_len)) { + for (n, r) in counts.iter().zip(self.data.chunks_exact(row_len)) { new_len += *n; for _ in 0..*n { new_data.extend_from_slice(r); } } } else { - new_len = amount.iter().sum(); + new_len = counts.iter().sum(); } self.data = new_data; self.shape[0] = new_len; @@ -640,6 +576,7 @@ impl Array { Ok(self) } fn undo_keep(self, counts: &[usize], into: Self, env: &Uiua) -> UiuaResult { + let counts = pad_keep_counts(counts, into.row_count(), env)?; if counts.iter().any(|&n| n > 1) { return Err(env.error("Cannot invert keep with non-boolean counts")); } @@ -670,6 +607,79 @@ impl Array { } } +fn pad_keep_counts<'a>( + counts: &'a [usize], + len: usize, + env: &Uiua, +) -> UiuaResult> { + let mut amount = Cow::Borrowed(counts); + match amount.len().cmp(&len) { + Ordering::Equal => {} + Ordering::Less => match env.num_array_fill() { + Ok(fill) => { + if let Some(n) = fill.data.iter().find(|&&n| n < 0.0 || n.fract() != 0.0) { + return Err(env.error(format!( + "Fill value for keep must be an array of \ + non-negative integers, but one of the \ + values is {n}" + ))); + } + match fill.rank() { + 0 => { + let fill = fill.data[0] as usize; + let mut new_amount = amount.to_vec(); + new_amount.extend(repeat(fill).take(len - amount.len())); + amount = new_amount.into(); + } + 1 => { + let mut new_amount = amount.to_vec(); + new_amount.extend( + (fill.data.iter().map(|&n| n as usize).cycle()) + .take(len - amount.len()), + ); + amount = new_amount.into(); + } + _ => { + return Err(env.error(format!( + "Fill value for keep must be a scalar or a 1D array, \ + but it has shape {}", + fill.shape + ))); + } + } + } + Err(e) => { + return Err(env.error(format!( + "Cannot keep array with shape {} with array of shape {}{e}", + len, + FormatShape(&[amount.len()]) + ))) + } + }, + Ordering::Greater => { + return Err(env.error(match env.num_array_fill() { + Ok(_) => { + format!( + "Cannot keep array with shape {} with array of shape {}. \ + A fill value is available, but keep can only been filled \ + if there are fewer counts than rows.", + len, + FormatShape(&[amount.len()]) + ) + } + Err(e) => { + format!( + "Cannot keep array with shape {} with array of shape {}{e}", + len, + FormatShape(&[amount.len()]) + ) + } + })) + } + } + Ok(amount) +} + impl Value { /// Use this value to `rotate` another pub fn rotate(&self, rotated: Self, env: &Uiua) -> UiuaResult { diff --git a/tests/under.ua b/tests/under.ua index 41d8677a3..c0bddc75c 100644 --- a/tests/under.ua +++ b/tests/under.ua @@ -36,6 +36,7 @@ # Keep ⍤⟜≍: [10 2 30 4 50] ⍜▽(×10) ◿2. [1 2 3 4 5] ⍤⟜≍: [10 2 30 4 50] ⍜(▽◿2.|×10) [1 2 3 4 5] +⍤⟜≍: [0 9 2 7 4 5 6 3 8 1] ⍜⬚0_1▽⇌ [] ⇡10 # Both uncouple ⍤⟜≍: [2_1 3_4] [⍜∩°⊟: 1_2 3_4]