diff --git a/src/executor.rs b/src/executor.rs index c4ab490..fb1abda 100644 --- a/src/executor.rs +++ b/src/executor.rs @@ -9,7 +9,7 @@ use crate::{ }; use self::{ - aggregate_executor::{AggregateExecutor, AggregateTable}, + aggregate_executor::{AggregateExecutor, AggregateTable, AggregateTableValue}, delete_executor::DeleteExecutor, filter_executor::FilterExecutor, insert_executor::InsertExecutor, @@ -107,7 +107,11 @@ impl ExecutorEngine { plan: plan.clone(), child: Box::new(self.create_executor(&plan.child)), executor_context: &self.context, - aggregate_table: AggregateTable::new(), + aggregate_table_value: if plan.group_by.len() == 0 { + AggregateTableValue::Value(vec![vec![]; plan.aggregate_functions.len()]) + } else { + AggregateTableValue::Table(AggregateTable::new()) + }, result: vec![], index: 0, }) diff --git a/src/executor/aggregate_executor.rs b/src/executor/aggregate_executor.rs index 1b37125..a8e487a 100644 --- a/src/executor/aggregate_executor.rs +++ b/src/executor/aggregate_executor.rs @@ -15,7 +15,7 @@ pub struct AggregateExecutor<'a> { pub plan: AggregatePlan, pub child: Box>, pub executor_context: &'a ExecutorContext, - pub aggregate_table: AggregateTable, + pub aggregate_table_value: AggregateTableValue, pub result: Vec>, pub index: usize, } @@ -37,15 +37,17 @@ impl AggregateExecutor<'_> { .ok_or(anyhow::anyhow!("SUM argument error"))? .eval(&vec![&tuple], &vec![&self.plan.child.schema()])?, }; - self.aggregate_table.add( - keys.clone(), - value, - i, - self.plan.aggregate_functions.len(), - ); + match &mut self.aggregate_table_value { + AggregateTableValue::Table(table) => { + table.add(keys.clone(), value, i, self.plan.aggregate_functions.len()); + } + AggregateTableValue::Value(values_list) => { + values_list[i].push(value); + } + } } } - self.result = self.aggregate_table.aggregate( + self.result = self.aggregate_table_value.aggregate( self.plan .aggregate_functions .clone() @@ -68,7 +70,7 @@ impl AggregateExecutor<'_> { pub struct AggregateTable { map: Box>, } -enum AggregateTableValue { +pub enum AggregateTableValue { Table(AggregateTable), Value(Vec>), } @@ -100,110 +102,111 @@ impl AggregateTable { } } } +} +impl AggregateTableValue { fn aggregate(&self, function_names: Vec) -> Result>> { - let mut result = vec![]; - for (key, value) in self.map.iter() { - match value { - AggregateTableValue::Table(table) => { - let mut rows = table.aggregate(function_names.clone())?; + match self { + AggregateTableValue::Table(table) => { + let mut result = vec![]; + for (key, value) in table.map.iter() { + let mut rows = value.aggregate(function_names.clone())?; for row in &mut rows { row.insert(0, key.clone()); } result.append(&mut rows); } - AggregateTableValue::Value(values_list) => { - let mut row = vec![]; - row.push(key.clone()); - for (i, values) in values_list.iter().enumerate() { - match &*function_names[i] { - "COUNT" => { - let mut sum = 0; - for value in values { - if value.is_null_value() { - continue; - } - sum += 1; + Ok(result) + } + AggregateTableValue::Value(values_list) => { + let mut row = vec![]; + for (i, values) in values_list.iter().enumerate() { + match &*function_names[i] { + "COUNT" => { + let mut sum = 0; + for value in values { + if value.is_null_value() { + continue; } - row.push(Value::Integer(IntegerValue(sum))); + sum += 1; } - "SUM" => { - let mut sum = 0; - for value in values { - if value.is_null_value() { - continue; - } - match value.convert_to(&DataType::Integer)? { - Value::Integer(v) => { - sum += v.0; - } - _ => unimplemented!(), + row.push(Value::Integer(IntegerValue(sum))); + } + "SUM" => { + let mut sum = 0; + for value in values { + if value.is_null_value() { + continue; + } + match value.convert_to(&DataType::Integer)? { + Value::Integer(v) => { + sum += v.0; } + _ => unimplemented!(), } - row.push(Value::Integer(IntegerValue(sum))); } - "MAX" => { - let mut max = Value::Integer(IntegerValue(i64::MIN)); - for value in values { - if value.is_null_value() { - continue; - } - match value.convert_to(&DataType::Integer)? { - Value::Integer(v) => { - if value.perform_greater_than(&max)? - == Value::Boolean(BooleanValue(true)) - { - max = Value::Integer(v); - } + row.push(Value::Integer(IntegerValue(sum))); + } + "MAX" => { + let mut max = Value::Integer(IntegerValue(i64::MIN)); + for value in values { + if value.is_null_value() { + continue; + } + match value.convert_to(&DataType::Integer)? { + Value::Integer(v) => { + if value.perform_greater_than(&max)? + == Value::Boolean(BooleanValue(true)) + { + max = Value::Integer(v); } - _ => unimplemented!(), } + _ => unimplemented!(), } - row.push(max); } - "MIN" => { - let mut min = Value::Integer(IntegerValue(i64::MAX)); - for value in values { - if value.is_null_value() { - continue; - } - match value.convert_to(&DataType::Integer)? { - Value::Integer(v) => { - if value.perform_less_than(&min)? - == Value::Boolean(BooleanValue(true)) - { - min = Value::Integer(v); - } + row.push(max); + } + "MIN" => { + let mut min = Value::Integer(IntegerValue(i64::MAX)); + for value in values { + if value.is_null_value() { + continue; + } + match value.convert_to(&DataType::Integer)? { + Value::Integer(v) => { + if value.perform_less_than(&min)? + == Value::Boolean(BooleanValue(true)) + { + min = Value::Integer(v); } - _ => unimplemented!(), } + _ => unimplemented!(), } - row.push(min); } - "AVG" => { - let mut sum = 0; - let mut count = 0; - for value in values { - if value.is_null_value() { - continue; - } - match value.convert_to(&DataType::Integer)? { - Value::Integer(v) => { - sum += v.0; - count += 1; - } - _ => unimplemented!(), + row.push(min); + } + "AVG" => { + let mut sum = 0; + let mut count = 0; + for value in values { + if value.is_null_value() { + continue; + } + match value.convert_to(&DataType::Integer)? { + Value::Integer(v) => { + sum += v.0; + count += 1; } + _ => unimplemented!(), } - row.push(Value::Integer(IntegerValue(sum / count))); } - _ => Err(anyhow::anyhow!("unknown aggregate function error"))?, + row.push(Value::Integer(IntegerValue(sum / count))); } + _ => Err(anyhow::anyhow!("unknown aggregate function error"))?, } - result.push(row); } + Ok(vec![row]) } } - Ok(result) } }