Skip to content

Commit 2b3b220

Browse files
authored
feat: Handle edge case with corr with single row and NaN (#18677)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #18659. ## Rationale for this change Fix an edge case in `corr` and `NaN` <!-- Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed. Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes. --> ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> ## Are these changes tested? <!-- We typically require tests for all PRs in order to: 1. Prevent the code from being accidentally broken by subsequent changes 2. Serve as another way to document the expected behavior of the code If tests are not included in your PR, please explain why (for example, are they covered by existing tests)? --> ## Are there any user-facing changes? <!-- If there are user-facing changes then we may require documentation to be updated before approving the PR. --> <!-- If there are any breaking changes to public APIs, please add the `api change` label. -->
1 parent 1dddf03 commit 2b3b220

File tree

2 files changed

+142
-59
lines changed

2 files changed

+142
-59
lines changed

datafusion/functions-aggregate/src/correlation.rs

Lines changed: 78 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,24 @@ impl Accumulator for CorrelationAccumulator {
196196
}
197197

198198
fn evaluate(&mut self) -> Result<ScalarValue> {
199-
let n = self.covar.get_count();
200-
if n < 2 {
201-
return Ok(ScalarValue::Float64(None));
202-
}
203-
204199
let covar = self.covar.evaluate()?;
205200
let stddev1 = self.stddev1.evaluate()?;
206201
let stddev2 = self.stddev2.evaluate()?;
207202

203+
// First check if we have NaN values by examining the internal state
204+
// This handles the case where both inputs are NaN even with count=1
205+
let mean1 = self.covar.get_mean1();
206+
let mean2 = self.covar.get_mean2();
207+
208+
// If both means are NaN, then both input columns contain only NaN values
209+
if mean1.is_nan() && mean2.is_nan() {
210+
return Ok(ScalarValue::Float64(Some(f64::NAN)));
211+
}
212+
let n = self.covar.get_count();
213+
if mean1.is_nan() || mean2.is_nan() || n < 2 {
214+
return Ok(ScalarValue::Float64(None));
215+
}
216+
208217
if let ScalarValue::Float64(Some(c)) = covar {
209218
if let ScalarValue::Float64(Some(s1)) = stddev1 {
210219
if let ScalarValue::Float64(Some(s2)) = stddev2 {
@@ -402,54 +411,6 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
402411
Ok(())
403412
}
404413

405-
fn merge_batch(
406-
&mut self,
407-
values: &[ArrayRef],
408-
group_indices: &[usize],
409-
opt_filter: Option<&BooleanArray>,
410-
total_num_groups: usize,
411-
) -> Result<()> {
412-
// Resize vectors to accommodate total number of groups
413-
self.count.resize(total_num_groups, 0);
414-
self.sum_x.resize(total_num_groups, 0.0);
415-
self.sum_y.resize(total_num_groups, 0.0);
416-
self.sum_xy.resize(total_num_groups, 0.0);
417-
self.sum_xx.resize(total_num_groups, 0.0);
418-
self.sum_yy.resize(total_num_groups, 0.0);
419-
420-
// Extract arrays from input values
421-
let partial_counts = values[0].as_primitive::<UInt64Type>();
422-
let partial_sum_x = values[1].as_primitive::<Float64Type>();
423-
let partial_sum_y = values[2].as_primitive::<Float64Type>();
424-
let partial_sum_xy = values[3].as_primitive::<Float64Type>();
425-
let partial_sum_xx = values[4].as_primitive::<Float64Type>();
426-
let partial_sum_yy = values[5].as_primitive::<Float64Type>();
427-
428-
assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
429-
430-
accumulate_correlation_states(
431-
group_indices,
432-
(
433-
partial_counts,
434-
partial_sum_x,
435-
partial_sum_y,
436-
partial_sum_xy,
437-
partial_sum_xx,
438-
partial_sum_yy,
439-
),
440-
|group_index, count, values| {
441-
self.count[group_index] += count;
442-
self.sum_x[group_index] += values[0];
443-
self.sum_y[group_index] += values[1];
444-
self.sum_xy[group_index] += values[2];
445-
self.sum_xx[group_index] += values[3];
446-
self.sum_yy[group_index] += values[4];
447-
},
448-
);
449-
450-
Ok(())
451-
}
452-
453414
fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
454415
let n = match emit_to {
455416
EmitTo::All => self.count.len(),
@@ -465,21 +426,31 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
465426
// - Correlation can't be calculated when a group only has 1 record, or when
466427
// the `denominator` state is 0. In these cases, the final aggregation
467428
// result should be `Null` (according to PostgreSQL's behavior).
429+
// - However, if any of the accumulated values contain NaN, the result should
430+
// be NaN regardless of the count (even for single-row groups).
468431
//
469432
for i in 0..n {
470-
if self.count[i] < 2 {
471-
values.push(0.0);
472-
nulls.append_null();
473-
continue;
474-
}
475-
476433
let count = self.count[i];
477434
let sum_x = self.sum_x[i];
478435
let sum_y = self.sum_y[i];
479436
let sum_xy = self.sum_xy[i];
480437
let sum_xx = self.sum_xx[i];
481438
let sum_yy = self.sum_yy[i];
482439

440+
// If BOTH sum_x AND sum_y are NaN, then both input values are NaN → return NaN
441+
// If only ONE of them is NaN, then only one input value is NaN → return NULL
442+
if sum_x.is_nan() && sum_y.is_nan() {
443+
// Both inputs are NaN → return NaN
444+
values.push(f64::NAN);
445+
nulls.append_non_null();
446+
continue;
447+
} else if count < 2 || sum_x.is_nan() || sum_y.is_nan() {
448+
// Only one input is NaN → return NULL
449+
values.push(0.0);
450+
nulls.append_null();
451+
continue;
452+
}
453+
483454
let mean_x = sum_x / count as f64;
484455
let mean_y = sum_y / count as f64;
485456

@@ -515,6 +486,54 @@ impl GroupsAccumulator for CorrelationGroupsAccumulator {
515486
])
516487
}
517488

489+
fn merge_batch(
490+
&mut self,
491+
values: &[ArrayRef],
492+
group_indices: &[usize],
493+
opt_filter: Option<&BooleanArray>,
494+
total_num_groups: usize,
495+
) -> Result<()> {
496+
// Resize vectors to accommodate total number of groups
497+
self.count.resize(total_num_groups, 0);
498+
self.sum_x.resize(total_num_groups, 0.0);
499+
self.sum_y.resize(total_num_groups, 0.0);
500+
self.sum_xy.resize(total_num_groups, 0.0);
501+
self.sum_xx.resize(total_num_groups, 0.0);
502+
self.sum_yy.resize(total_num_groups, 0.0);
503+
504+
// Extract arrays from input values
505+
let partial_counts = values[0].as_primitive::<UInt64Type>();
506+
let partial_sum_x = values[1].as_primitive::<Float64Type>();
507+
let partial_sum_y = values[2].as_primitive::<Float64Type>();
508+
let partial_sum_xy = values[3].as_primitive::<Float64Type>();
509+
let partial_sum_xx = values[4].as_primitive::<Float64Type>();
510+
let partial_sum_yy = values[5].as_primitive::<Float64Type>();
511+
512+
assert!(opt_filter.is_none(), "aggregate filter should be applied in partial stage, there should be no filter in final stage");
513+
514+
accumulate_correlation_states(
515+
group_indices,
516+
(
517+
partial_counts,
518+
partial_sum_x,
519+
partial_sum_y,
520+
partial_sum_xy,
521+
partial_sum_xx,
522+
partial_sum_yy,
523+
),
524+
|group_index, count, values| {
525+
self.count[group_index] += count;
526+
self.sum_x[group_index] += values[0];
527+
self.sum_y[group_index] += values[1];
528+
self.sum_xy[group_index] += values[2];
529+
self.sum_xx[group_index] += values[3];
530+
self.sum_yy[group_index] += values[4];
531+
},
532+
);
533+
534+
Ok(())
535+
}
536+
518537
fn size(&self) -> usize {
519538
size_of_val(&self.count)
520539
+ size_of_val(&self.sum_x)

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,70 @@ from data
607607
----
608608
1
609609

610+
# group correlation_query_with_nans_f32
611+
query IR
612+
select id, corr(f, b)
613+
from values
614+
(1, 1, 'nan'::float),
615+
(2, 'nan'::float, 1),
616+
(3, 'nan'::float, null),
617+
(4, null, 'nan'::float),
618+
(5, 'nan'::float, 'nan'::float),
619+
(5, 1, 1),
620+
(5, 2, 2),
621+
(6, 'nan'::float, 'nan'::float) t(id, f, b)
622+
group by id
623+
order by id
624+
----
625+
1 NULL
626+
2 NULL
627+
3 NULL
628+
4 NULL
629+
5 NaN
630+
6 NaN
631+
632+
# correlation_query_with_nans_f32
633+
query RR
634+
with data as (
635+
select 'nan'::float as f, 'nan'::float as b
636+
)
637+
select corr(f, b), corr('nan'::float, 'nan'::float)
638+
from data
639+
----
640+
NaN NaN
641+
642+
# group correlation_query_with_nans_f64
643+
query IR
644+
select id, corr(f, b)
645+
from values
646+
(1, 1, 'nan'::double),
647+
(2, 'nan'::double, 1),
648+
(3, 'nan'::double, null),
649+
(4, null, 'nan'::float),
650+
(5, 'nan'::double, 'nan'::double),
651+
(5, 1, 1),
652+
(5, 2, 2),
653+
(6, 'nan'::double, 'nan'::double) t(id, f, b)
654+
group by id
655+
order by id
656+
----
657+
1 NULL
658+
2 NULL
659+
3 NULL
660+
4 NULL
661+
5 NaN
662+
6 NaN
663+
664+
# correlation_query_with_nans_f64
665+
query RR
666+
with data as (
667+
select 'nan'::double as f, 'nan'::double as b
668+
)
669+
select corr(f, b), corr('nan'::double, 'nan'::double)
670+
from data
671+
----
672+
NaN NaN
673+
610674
# csv_query_variance_1
611675
query R
612676
SELECT var_pop(c2) FROM aggregate_test_100

0 commit comments

Comments
 (0)