Skip to content

Commit

Permalink
allow join with rank differences > 2
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Jun 3, 2024
1 parent f3004c1 commit 8e2e260
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 121 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ This version is not yet released. If you are reading this on the website, then t
- Add the experimental [`astar`](https://uiua.org/docs/astar) modifier, which performs the A* pathfinding algorithm
- [`keep ▽`](https://uiua.org/docs/keep) will now cycle counts if the counts array is shorter than the counted array
- [`keep ▽`](https://uiua.org/docs/keep) now works with non-integer scalar counts to scale an array
- [`join ⊂`](https://uiua.org/docs/join) will rank differences greater than 1 can now extend the smaller array

## 0.11.0 - 2024-06-02
You can find the release announcement [here](https://uiua.org/blog/uiua-0.11.0).
Expand Down
159 changes: 94 additions & 65 deletions src/algorithm/dyadic/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,84 +86,91 @@ impl<T: ArrayValue> Array<T> {

impl Value {
/// `join` the array with another
pub fn join(self, other: Self, env: &Uiua) -> UiuaResult<Self> {
self.join_impl(other, env)
///
/// `allow_ext` allows extending one of the arrays if they have different shapes
pub fn join(self, other: Self, allow_ext: bool, env: &Uiua) -> UiuaResult<Self> {
self.join_impl(other, allow_ext, env)
}
/// `join` the array with another
///
/// # Panics
/// Panics if the arrays have incompatible shapes
pub fn join_infallible(self, other: Self) -> Self {
self.join_impl(other, &()).unwrap()
pub fn join_infallible(self, other: Self, allow_ext: bool) -> Self {
self.join_impl(other, allow_ext, &()).unwrap()
}
pub(crate) fn join_impl<C: FillContext>(self, other: Self, ctx: &C) -> Result<Self, C::Error> {
fn join_impl<C: FillContext>(self, other: Self, ext: bool, ctx: &C) -> Result<Self, C::Error> {
Ok(match (self, other) {
(Value::Num(a), Value::Num(b)) => a.join_impl(b, ctx)?.into(),
(Value::Num(a), Value::Num(b)) => a.join_impl(b, ext, ctx)?.into(),
(Value::Byte(a), Value::Byte(b)) => op2_bytes_retry_fill::<_, C>(
a,
b,
ctx,
|a, b| Ok(a.join_impl(b, ctx)?.into()),
|a, b| Ok(a.join_impl(b, ctx)?.into()),
|a, b| Ok(a.join_impl(b, ext, ctx)?.into()),
|a, b| Ok(a.join_impl(b, ext, ctx)?.into()),
)?,
(Value::Complex(a), Value::Complex(b)) => a.join_impl(b, ctx)?.into(),
(Value::Char(a), Value::Char(b)) => a.join_impl(b, ctx)?.into(),
(Value::Byte(a), Value::Num(b)) => a.convert().join_impl(b, ctx)?.into(),
(Value::Num(a), Value::Byte(b)) => a.join_impl(b.convert(), ctx)?.into(),
(Value::Complex(a), Value::Num(b)) => a.join_impl(b.convert(), ctx)?.into(),
(Value::Num(a), Value::Complex(b)) => a.convert().join_impl(b, ctx)?.into(),
(Value::Complex(a), Value::Byte(b)) => a.join_impl(b.convert(), ctx)?.into(),
(Value::Byte(a), Value::Complex(b)) => a.convert().join_impl(b, ctx)?.into(),
(Value::Complex(a), Value::Complex(b)) => a.join_impl(b, ext, ctx)?.into(),
(Value::Char(a), Value::Char(b)) => a.join_impl(b, ext, ctx)?.into(),
(Value::Byte(a), Value::Num(b)) => a.convert().join_impl(b, ext, ctx)?.into(),
(Value::Num(a), Value::Byte(b)) => a.join_impl(b.convert(), ext, ctx)?.into(),
(Value::Complex(a), Value::Num(b)) => a.join_impl(b.convert(), ext, ctx)?.into(),
(Value::Num(a), Value::Complex(b)) => a.convert().join_impl(b, ext, ctx)?.into(),
(Value::Complex(a), Value::Byte(b)) => a.join_impl(b.convert(), ext, ctx)?.into(),
(Value::Byte(a), Value::Complex(b)) => a.convert().join_impl(b, ext, ctx)?.into(),
(a, b) => a.bin_coerce_to_boxes(
b,
ctx,
|a, b, env| Ok(a.join_impl(b, env)?.into()),
|a, b, env| Ok(a.join_impl(b, ext, env)?.into()),
|a, b| format!("Cannot join {a} array and {b} array"),
)?,
})
}
pub(crate) fn append<C: FillContext>(&mut self, other: Self, ctx: &C) -> Result<(), C::Error> {
pub(crate) fn append<C: FillContext>(
&mut self,
other: Self,
ext: bool,
ctx: &C,
) -> Result<(), C::Error> {
match (&mut *self, other) {
(Value::Num(a), Value::Num(b)) => a.append(b, ctx)?,
(Value::Num(a), Value::Num(b)) => a.append(b, ext, ctx)?,
(Value::Byte(a), Value::Byte(b)) => {
*self = op2_bytes_retry_fill::<_, C>(
a.clone(),
b,
ctx,
|mut a, b| {
a.append(b, ctx)?;
a.append(b, ext, ctx)?;
Ok(a.into())
},
|mut a, b| {
a.append(b, ctx)?;
a.append(b, ext, ctx)?;
Ok(a.into())
},
)?;
}
(Value::Complex(a), Value::Complex(b)) => a.append(b, ctx)?,
(Value::Char(a), Value::Char(b)) => a.append(b, ctx)?,
(Value::Complex(a), Value::Complex(b)) => a.append(b, ext, ctx)?,
(Value::Char(a), Value::Char(b)) => a.append(b, ext, ctx)?,
(Value::Byte(a), Value::Num(b)) => {
let mut a = a.convert_ref();
a.append(b, ctx)?;
a.append(b, ext, ctx)?;
*self = a.into();
}
(Value::Num(a), Value::Byte(b)) => a.append(b.convert(), ctx)?,
(Value::Complex(a), Value::Num(b)) => a.append(b.convert(), ctx)?,
(Value::Num(a), Value::Byte(b)) => a.append(b.convert(), ext, ctx)?,
(Value::Complex(a), Value::Num(b)) => a.append(b.convert(), ext, ctx)?,
(Value::Num(a), Value::Complex(b)) => {
let mut a = a.convert_ref();
a.append(b, ctx)?;
a.append(b, ext, ctx)?;
*self = a.into();
}
(Value::Complex(a), Value::Byte(b)) => a.append(b.convert(), ctx)?,
(Value::Complex(a), Value::Byte(b)) => a.append(b.convert(), ext, ctx)?,
(Value::Byte(a), Value::Complex(b)) => {
let mut a = a.convert_ref();
a.append(b, ctx)?;
a.append(b, ext, ctx)?;
*self = a.into();
}
(a, b) => a.bin_coerce_to_boxes_mut(
b,
ctx,
|a, b, env| a.append(b, env),
|a, b, env| a.append(b, ext, env),
|a, b| format!("Cannot add {b} row to {a} array"),
)?,
}
Expand Down Expand Up @@ -232,17 +239,22 @@ impl Value {

impl<T: ArrayValue> Array<T> {
/// `join` the array with another
pub fn join(self, other: Self, env: &Uiua) -> UiuaResult<Self> {
self.join_impl(other, env)
pub fn join(self, other: Self, allow_ext: bool, env: &Uiua) -> UiuaResult<Self> {
self.join_impl(other, allow_ext, env)
}
/// `join` the array with another
///
/// # Panics
/// Panics if the arrays have incompatible shapes
pub fn join_infallible(self, other: Self) -> Self {
self.join_impl(other, &()).unwrap()
pub fn join_infallible(self, other: Self, allow_ext: bool) -> Self {
self.join_impl(other, allow_ext, &()).unwrap()
}
fn join_impl<C: FillContext>(mut self, mut other: Self, ctx: &C) -> Result<Self, C::Error> {
fn join_impl<C: FillContext>(
mut self,
mut other: Self,
allow_ext: bool,
ctx: &C,
) -> Result<Self, C::Error> {
crate::profile_function!();
let res = match self.rank().cmp(&other.rank()) {
Ordering::Less => {
Expand All @@ -262,19 +274,25 @@ impl<T: ArrayValue> Array<T> {
target_shape
}
Err(e) => {
if other.rank() - self.rank() > 1 {
return Err(C::fill_error(ctx.error(format!(
"Cannot join rank {} array with rank {} array{e}",
self.rank(),
other.rank()
))));
}
if self.shape() != other.shape()[1..] {
return Err(C::fill_error(ctx.error(format!(
"Cannot join arrays of shapes {} and {}{e}",
self.shape(),
other.shape()
))));
if allow_ext && other.shape.ends_with(&self.shape) {
for &b_dim in other.shape[1..other.rank() - self.rank()].iter().rev() {
self.reshape_scalar_integer(b_dim);
}
} else {
if other.rank() - self.rank() > 1 {
return Err(C::fill_error(ctx.error(format!(
"Cannot join rank {} array with rank {} array{e}",
self.rank(),
other.rank()
))));
}
if self.shape() != other.shape()[1..] {
return Err(C::fill_error(ctx.error(format!(
"Cannot join arrays of shapes {} and {}{e}",
self.shape(),
other.shape()
))));
}
}
other.shape
}
Expand All @@ -291,7 +309,7 @@ impl<T: ArrayValue> Array<T> {
if other.shape() == 0 {
return Ok(self);
}
self.append(other, ctx)?;
self.append(other, allow_ext, ctx)?;
self
}
Ordering::Equal => {
Expand Down Expand Up @@ -358,7 +376,12 @@ impl<T: ArrayValue> Array<T> {
res.validate_shape();
Ok(res)
}
fn append<C: FillContext>(&mut self, mut other: Self, ctx: &C) -> Result<(), C::Error> {
fn append<C: FillContext>(
&mut self,
mut other: Self,
allow_ext: bool,
ctx: &C,
) -> Result<(), C::Error> {
self.combine_meta(other.meta());
let target_shape = match ctx.scalar_fill::<T>() {
Ok(fill) => {
Expand All @@ -372,19 +395,25 @@ impl<T: ArrayValue> Array<T> {
target_shape
}
Err(e) => {
if self.rank() <= other.rank() || self.rank() - other.rank() > 1 {
return Err(C::fill_error(ctx.error(format!(
"Cannot add rank {} row to rank {} array{e}",
other.rank(),
self.rank()
))));
}
if &self.shape()[1..] != other.shape() {
return Err(C::fill_error(ctx.error(format!(
"Cannot add shape {} row to array with shape {} rows{e}",
other.shape(),
FormatShape(&self.shape()[1..]),
))));
if allow_ext && self.shape.ends_with(&other.shape) {
for &a_dim in self.shape[1..self.rank() - other.rank()].iter().rev() {
other.reshape_scalar_integer(a_dim);
}
} else {
if self.rank() <= other.rank() || self.rank() - other.rank() > 1 {
return Err(C::fill_error(ctx.error(format!(
"Cannot add rank {} row to rank {} array{e}",
other.rank(),
self.rank()
))));
}
if &self.shape()[1..] != other.shape() {
return Err(C::fill_error(ctx.error(format!(
"Cannot add shape {} row to array with shape {} rows{e}",
other.shape(),
FormatShape(&self.shape()[1..]),
))));
}
}
take(&mut self.shape)
}
Expand Down Expand Up @@ -688,7 +717,7 @@ impl Value {
value.reserve_min(total_elements);
value.couple_impl(row, ctx)?;
for row in row_values {
value.append(row, ctx)?;
value.append(row, false, ctx)?;
}
} else {
value.shape_mut().insert(0, 1);
Expand Down Expand Up @@ -736,7 +765,7 @@ impl<T: ArrayValue> Array<T> {
arr.data.reserve_min(total_elements);
arr.couple_impl(row, ctx)?;
for row in row_values {
arr.append(row, ctx)?;
arr.append(row, false, ctx)?;
}
} else {
arr.shape.insert(0, 1);
Expand Down
Loading

0 comments on commit 8e2e260

Please sign in to comment.