diff --git a/pgdog/src/backend/pool/cluster.rs b/pgdog/src/backend/pool/cluster.rs index 4b022bbd..9d5413e1 100644 --- a/pgdog/src/backend/pool/cluster.rs +++ b/pgdog/src/backend/pool/cluster.rs @@ -15,7 +15,7 @@ use tracing::{error, info}; use crate::{ backend::{ databases::{databases, User as DatabaseUser}, - replication::{ReplicationConfig, ShardedColumn, ShardedSchemas}, + replication::{ReplicationConfig, ShardedSchemas}, Schema, ShardedTables, }, config::{ @@ -369,11 +369,6 @@ impl Cluster { self.pub_sub_channel_size > 0 } - /// Find sharded column position, if the table and columns match the configuration. - pub fn sharded_column(&self, table: &str, columns: &[&str]) -> Option { - self.sharded_tables.sharded_column(table, columns) - } - /// A cluster is read_only if zero shards have a primary. pub fn read_only(&self) -> bool { for shard in &self.shards { diff --git a/pgdog/src/backend/pool/connection/lazy.rs b/pgdog/src/backend/pool/connection/lazy.rs new file mode 100644 index 00000000..8c156fa3 --- /dev/null +++ b/pgdog/src/backend/pool/connection/lazy.rs @@ -0,0 +1,3 @@ +//! Lazy connection guard. +//! +//! Handles server synchronization and lazy connection creation. diff --git a/pgdog/src/backend/replication/sharded_tables.rs b/pgdog/src/backend/replication/sharded_tables.rs index 3cf54541..03de2b14 100644 --- a/pgdog/src/backend/replication/sharded_tables.rs +++ b/pgdog/src/backend/replication/sharded_tables.rs @@ -3,7 +3,7 @@ use pgdog_config::OmnishardedTable; use crate::{ config::{DataType, ShardedTable}, - frontend::router::sharding::Mapping, + frontend::router::{parser::Column, sharding::Mapping}, net::messages::Vector, }; use std::{ @@ -103,6 +103,39 @@ impl ShardedTables { .find(|t| t.name.as_deref() == Some(name)) } + /// Determine if the column is sharded and return its data type, + /// as declared in the schema. + pub fn get_table(&self, column: Column<'_>) -> Option<&ShardedTable> { + // Only fully-qualified columns can be matched. + let table = if let Some(table) = column.table() { + table + } else { + return None; + }; + + for candidate in &self.inner.tables { + if let Some(table_name) = candidate.name.as_ref() { + if !table.name_match(table_name) { + continue; + } + } + + if let Some(schema_name) = candidate.schema.as_ref() { + if let Some(schema) = table.schema() { + if schema.name != schema_name { + continue; + } + } + } + + if column.name == candidate.column { + return Some(candidate); + } + } + + None + } + /// Find out which column (if any) is sharded in the given table. pub fn sharded_column(&self, table: &str, columns: &[&str]) -> Option { let with_names = self diff --git a/pgdog/src/frontend/router/parser/column.rs b/pgdog/src/frontend/router/parser/column.rs index cc1d5138..60566065 100644 --- a/pgdog/src/frontend/router/parser/column.rs +++ b/pgdog/src/frontend/router/parser/column.rs @@ -6,7 +6,7 @@ use pg_query::{ }; use std::fmt::{Display, Formatter, Result as FmtResult}; -use super::Table; +use super::{Error, Table}; use crate::util::escape_identifier; /// Column name extracted from a query. @@ -44,9 +44,7 @@ impl<'a> Column<'a> { pub fn to_owned(&self) -> OwnedColumn { OwnedColumn::from(*self) } -} -impl<'a> Column<'a> { pub fn from_string(string: &'a Node) -> Result { match &string.node { Some(NodeEnum::String(protobuf::String { sval })) => Ok(Self { @@ -57,6 +55,14 @@ impl<'a> Column<'a> { _ => Err(()), } } + + /// Fully-qualify this column with a table. + pub fn qualify(&mut self, table: Table<'a>) { + if self.table.is_none() { + self.table = Some(table.name); + self.schema = table.schema; + } + } } impl<'a> Display for Column<'a> { @@ -114,7 +120,7 @@ impl<'a> From<&'a OwnedColumn> for Column<'a> { } impl<'a> TryFrom<&'a Node> for Column<'a> { - type Error = (); + type Error = Error; fn try_from(value: &'a Node) -> Result { Column::try_from(&value.node) @@ -122,7 +128,7 @@ impl<'a> TryFrom<&'a Node> for Column<'a> { } impl<'a> TryFrom<&'a Option> for Column<'a> { - type Error = (); + type Error = Error; fn try_from(value: &'a Option) -> Result { fn from_node(node: &Node) -> Option<&str> { @@ -133,12 +139,15 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { } } - fn from_slice<'a>(nodes: &'a [Node]) -> Result, ()> { + fn from_slice<'a>(nodes: &'a [Node]) -> Result, Error> { match nodes.len() { 3 => { let schema = nodes.first().and_then(from_node); let table = nodes.get(1).and_then(from_node); - let name = nodes.get(2).and_then(from_node).ok_or(())?; + let name = nodes + .get(2) + .and_then(from_node) + .ok_or(Error::ColumnDecode)?; Ok(Column { schema, @@ -149,7 +158,10 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { 2 => { let table = nodes.first().and_then(from_node); - let name = nodes.get(1).and_then(from_node).ok_or(())?; + let name = nodes + .get(1) + .and_then(from_node) + .ok_or(Error::ColumnDecode)?; Ok(Column { schema: None, @@ -159,7 +171,10 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { } 1 => { - let name = nodes.first().and_then(from_node).ok_or(())?; + let name = nodes + .first() + .and_then(from_node) + .ok_or(Error::ColumnDecode)?; Ok(Column { name, @@ -167,7 +182,7 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { }) } - _ => Err(()), + _ => Err(Error::ColumnDecode), } } @@ -186,26 +201,26 @@ impl<'a> TryFrom<&'a Option> for Column<'a> { if let Some(ref node) = list.arg { Ok(Column::try_from(&node.node)?) } else { - Err(()) + Err(Error::ColumnDecode) } } else { - Err(()) + Err(Error::ColumnDecode) } } - _ => Err(()), + _ => Err(Error::ColumnDecode), } } } impl<'a> TryFrom<&Option<&'a Node>> for Column<'a> { - type Error = (); + type Error = Error; fn try_from(value: &Option<&'a Node>) -> Result { if let Some(value) = value { (*value).try_into() } else { - Err(()) + Err(Error::ColumnDecode) } } } @@ -224,7 +239,7 @@ impl<'a> From<&'a str> for Column<'a> { mod test { use pg_query::{parse, NodeEnum}; - use super::Column; + use super::{Column, Error}; #[test] fn test_column() { @@ -236,7 +251,7 @@ mod test { .cols .iter() .map(Column::try_from) - .collect::, ()>>() + .collect::, Error>>() .unwrap(); assert_eq!( columns, diff --git a/pgdog/src/frontend/router/parser/error.rs b/pgdog/src/frontend/router/parser/error.rs index 2d360f83..caee9223 100644 --- a/pgdog/src/frontend/router/parser/error.rs +++ b/pgdog/src/frontend/router/parser/error.rs @@ -99,4 +99,16 @@ pub enum Error { #[error("prepared statement \"{0}\" doesn't exist")] PreparedStatementDoesntExist(String), + + #[error("column decode error")] + ColumnDecode, + + #[error("table decode error")] + TableDecode, + + #[error("parameter ${0} not in bind")] + BindParameterMissing(i32), + + #[error("statement is not a SELECT")] + NotASelect, } diff --git a/pgdog/src/frontend/router/parser/insert.rs b/pgdog/src/frontend/router/parser/insert.rs index be04e0fb..bba332aa 100644 --- a/pgdog/src/frontend/router/parser/insert.rs +++ b/pgdog/src/frontend/router/parser/insert.rs @@ -52,7 +52,7 @@ impl<'a> Insert<'a> { .cols .iter() .map(Column::try_from) - .collect::>, ()>>() + .collect::>, Error>>() .ok() .unwrap_or(vec![]) } diff --git a/pgdog/src/frontend/router/parser/mod.rs b/pgdog/src/frontend/router/parser/mod.rs index aa8402e4..3c3ccad8 100644 --- a/pgdog/src/frontend/router/parser/mod.rs +++ b/pgdog/src/frontend/router/parser/mod.rs @@ -28,6 +28,7 @@ pub mod rewrite_plan; pub mod route; pub mod schema; pub mod sequence; +pub mod statement; pub mod table; pub mod tuple; pub mod value; @@ -59,6 +60,7 @@ pub use rewrite_plan::{HelperKind, HelperMapping, QueryRewriter, RewriteOutput, pub use route::{Route, Shard}; pub use schema::Schema; pub use sequence::{OwnedSequence, Sequence}; +pub use statement::StatementParser; pub use table::{OwnedTable, Table}; pub use tuple::Tuple; pub use value::Value; diff --git a/pgdog/src/frontend/router/parser/query/delete.rs b/pgdog/src/frontend/router/parser/query/delete.rs index 33a4fab4..387e74ab 100644 --- a/pgdog/src/frontend/router/parser/query/delete.rs +++ b/pgdog/src/frontend/router/parser/query/delete.rs @@ -1,6 +1,4 @@ -use crate::frontend::router::parser::where_clause::TablesSource; - -use super::shared::ConvergeAlgorithm; +use super::StatementParser; use super::*; impl QueryParser { @@ -9,47 +7,32 @@ impl QueryParser { stmt: &DeleteStmt, context: &QueryParserContext, ) -> Result { - let table = stmt.relation.as_ref().map(Table::from); - - if let Some(table) = table { - // Schema-based sharding. - if let Some(schema) = context.sharding_schema.schemas.get(table.schema()) { - let shard: Shard = schema.shard().into(); - + let shard = StatementParser::from_delete( + stmt, + context.router_context.bind, + &context.sharding_schema, + self.recorder_mut(), + ) + .shard()?; + + let shard = match shard { + Some(shard) => { if let Some(recorder) = self.recorder_mut() { recorder.record_entry( Some(shard.clone()), - format!("DELETE matched schema {}", schema.name()), + "DELETE matched WHERE clause for sharding key", ); } - - return Ok(Command::Query(Route::write(shard))); + shard } - - let source = TablesSource::from(table); - let where_clause = WhereClause::new(&source, &stmt.where_clause); - - if let Some(where_clause) = where_clause { - let shards = Self::where_clause( - &context.sharding_schema, - &where_clause, - context.router_context.bind, - &mut self.explain_recorder, - )?; - let shard = Self::converge(&shards, ConvergeAlgorithm::default()); + None => { if let Some(recorder) = self.recorder_mut() { - recorder.record_entry( - Some(shard.clone()), - "DELETE matched WHERE clause for sharding key", - ); + recorder.record_entry(None, "DELETE fell back to broadcast"); } - return Ok(Command::Query(Route::write(shard))); + Shard::default() } - } + }; - if let Some(recorder) = self.recorder_mut() { - recorder.record_entry(None, "DELETE fell back to broadcast"); - } - Ok(Command::Query(Route::write(None))) + Ok(Command::Query(Route::write(shard))) } } diff --git a/pgdog/src/frontend/router/parser/query/select.rs b/pgdog/src/frontend/router/parser/query/select.rs index bf2fdd0a..d4b27818 100644 --- a/pgdog/src/frontend/router/parser/query/select.rs +++ b/pgdog/src/frontend/router/parser/query/select.rs @@ -39,27 +39,29 @@ impl QueryParser { )); } + let mut shards = HashSet::new(); + + let shard = StatementParser::from_select( + stmt, + context.router_context.bind, + &context.sharding_schema, + self.recorder_mut(), + ) + .shard()?; + if let Some(shard) = shard { + shards.insert(shard); + } + // `SELECT NOW()`, `SELECT 1`, etc. - if stmt.from_clause.is_empty() { + if shards.is_empty() && stmt.from_clause.is_empty() { return Ok(Command::Query( Route::read(Some(round_robin::next() % context.shards)).set_write(writes), )); } let order_by = Self::select_sort(&stmt.sort_clause, context.router_context.bind); - let mut shards = HashSet::new(); let from_clause = TablesSource::from(FromClause::new(&stmt.from_clause)); - let where_clause = WhereClause::new(&from_clause, &stmt.where_clause); - - if let Some(ref where_clause) = where_clause { - shards = Self::where_clause( - &context.sharding_schema, - where_clause, - context.router_context.bind, - &mut self.explain_recorder, - )?; - } // Schema-based sharding. let mut schema_sharder = SchemaSharder::default(); diff --git a/pgdog/src/frontend/router/parser/query/test.rs b/pgdog/src/frontend/router/parser/query/test.rs index 3a009834..05d99aa9 100644 --- a/pgdog/src/frontend/router/parser/query/test.rs +++ b/pgdog/src/frontend/router/parser/query/test.rs @@ -968,3 +968,38 @@ fn test_set_comments() { ); assert!(matches!(command, Command::Set { .. })); } + +#[test] +fn test_subqueries() { + println!( + "{:#?}", + pg_query::parse(r#" + SELECT + count(*) AS "count" + FROM + ( + SELECT "companies".* FROM + ( + SELECT "companies".* + FROM "companies" + INNER JOIN "organizations_relevant_companies" ON + ("organizations_relevant_companies"."company_id" = "companies"."id") + WHERE + ( + ("organizations_relevant_companies"."org_id" = 1) + AND NOT + ( + EXISTS + ( + SELECT * FROM "hidden_globals" + WHERE + ( + ("hidden_globals"."org_id" = 1) + AND ("hidden_globals"."global_company_id" = "organizations_relevant_companies"."company_id") + ) + ) + ) + ) + ) AS "companies" OFFSET 0) AS "t1" LIMIT 1"#).unwrap() + ); +} diff --git a/pgdog/src/frontend/router/parser/statement.rs b/pgdog/src/frontend/router/parser/statement.rs new file mode 100644 index 00000000..3df1f3a0 --- /dev/null +++ b/pgdog/src/frontend/router/parser/statement.rs @@ -0,0 +1,1591 @@ +use std::collections::{HashMap, HashSet}; + +use pg_query::{ + protobuf::{ + AExprKind, BoolExprType, DeleteStmt, InsertStmt, RangeVar, RawStmt, SelectStmt, UpdateStmt, + }, + Node, NodeEnum, +}; + +use super::{ + super::sharding::Value as ShardingValue, explain_trace::ExplainRecorder, Column, Error, Table, + Value, +}; +use crate::{ + backend::ShardingSchema, + frontend::router::{parser::Shard, sharding::ContextBuilder}, + net::Bind, +}; + +/// Context for searching a SELECT statement, tracking table aliases. +#[derive(Debug, Default, Clone)] +struct SearchContext<'a> { + /// Maps alias -> full Table (including schema) + aliases: HashMap<&'a str, Table<'a>>, + /// The primary table from the FROM clause (if simple) + table: Option>, +} + +impl<'a> SearchContext<'a> { + /// Build context from a FROM clause, extracting table aliases. + fn from_from_clause(nodes: &'a [Node]) -> Self { + let mut ctx = Self::default(); + ctx.extract_aliases(nodes); + + // Try to get the primary table for simple queries + if nodes.len() == 1 { + if let Some(table) = nodes.first().and_then(|n| Table::try_from(n).ok()) { + ctx.table = Some(table); + } + } + + ctx + } + + fn extract_aliases(&mut self, nodes: &'a [Node]) { + for node in nodes { + self.extract_alias_from_node(node); + } + } + + fn extract_alias_from_node(&mut self, node: &'a Node) { + match &node.node { + Some(NodeEnum::RangeVar(range_var)) => { + if let Some(ref alias) = range_var.alias { + let table = Table::from(range_var); + self.aliases.insert(alias.aliasname.as_str(), table); + } + } + Some(NodeEnum::JoinExpr(join)) => { + if let Some(ref larg) = join.larg { + self.extract_alias_from_node(larg); + } + if let Some(ref rarg) = join.rarg { + self.extract_alias_from_node(rarg); + } + } + Some(NodeEnum::RangeSubselect(subselect)) => { + if let Some(ref alias) = subselect.alias { + // For subselects, we don't have a real table name + // but we record the alias anyway for future use + self.aliases.insert( + alias.aliasname.as_str(), + Table { + name: alias.aliasname.as_str(), + schema: None, + alias: None, + }, + ); + } + } + _ => {} + } + } + + /// Resolve a table reference (which may be an alias) to the actual Table. + fn resolve_table(&self, name: &str) -> Option> { + self.aliases.get(name).copied() + } +} + +#[derive(Debug)] +enum SearchResult<'a> { + Column(Column<'a>), + Value(Value<'a>), + Values(Vec>), + Match(Shard), + Matches(Vec), + None, +} + +struct ValueIterator<'a, 'b> { + source: &'b SearchResult<'a>, + pos: usize, +} + +impl<'a, 'b> Iterator for ValueIterator<'a, 'b> { + type Item = &'b Value<'a>; + + fn next(&mut self) -> Option { + let next = match self.source { + SearchResult::Value(ref val) => { + if self.pos == 0 { + Some(val) + } else { + None + } + } + SearchResult::Values(values) => values.get(self.pos), + _ => None, + }; + + self.pos += 1; + + next + } +} + +impl<'a> SearchResult<'a> { + fn is_none(&self) -> bool { + matches!(self, Self::None) + } + + fn is_match(&self) -> bool { + matches!(self, Self::Match(_) | Self::Matches(_)) + } + + fn merge(self, other: Self) -> Self { + match (self, other) { + (Self::Match(first), Self::Match(second)) => Self::Matches(vec![first, second]), + (Self::Match(shard), Self::Matches(mut shards)) + | (Self::Matches(mut shards), Self::Match(shard)) => Self::Matches({ + shards.push(shard); + shards + }), + (Self::None, other) | (other, Self::None) => other, + _ => Self::None, + } + } + + fn iter<'b>(&'b self) -> ValueIterator<'a, 'b> { + ValueIterator { + source: self, + pos: 0, + } + } +} + +enum Statement<'a> { + Select(&'a SelectStmt), + Update(&'a UpdateStmt), + Delete(&'a DeleteStmt), + Insert(&'a InsertStmt), +} + +pub struct StatementParser<'a, 'b, 'c> { + stmt: Statement<'a>, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, +} + +impl<'a, 'b, 'c> StatementParser<'a, 'b, 'c> { + fn new( + stmt: Statement<'a>, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Self { + Self { + stmt, + bind, + schema, + recorder, + } + } + + pub fn from_select( + stmt: &'a SelectStmt, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Self { + Self::new(Statement::Select(stmt), bind, schema, recorder) + } + + pub fn from_update( + stmt: &'a UpdateStmt, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Self { + Self::new(Statement::Update(stmt), bind, schema, recorder) + } + + pub fn from_delete( + stmt: &'a DeleteStmt, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Self { + Self::new(Statement::Delete(stmt), bind, schema, recorder) + } + + pub fn from_insert( + stmt: &'a InsertStmt, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Self { + Self::new(Statement::Insert(stmt), bind, schema, recorder) + } + + /// Record a sharding key match. + fn record_sharding_key(&mut self, shard: &Shard, column: Column<'_>, value: &Value<'_>) { + if let Some(recorder) = self.recorder.as_mut() { + let col_str = if let Some(table) = column.table { + format!("{}.{}", table, column.name) + } else { + column.name.to_string() + }; + let description = match value { + Value::Placeholder(pos) => { + format!("matched sharding key {} using parameter ${}", col_str, pos) + } + _ => format!("matched sharding key {} using constant", col_str), + }; + recorder.record_entry(Some(shard.clone()), description); + } + } + + pub fn from_raw( + raw: &'a RawStmt, + bind: Option<&'b Bind>, + schema: &'b ShardingSchema, + recorder: Option<&'c mut ExplainRecorder>, + ) -> Result { + match raw.stmt.as_ref().and_then(|n| n.node.as_ref()) { + Some(NodeEnum::SelectStmt(stmt)) => Ok(Self::from_select(stmt, bind, schema, recorder)), + Some(NodeEnum::UpdateStmt(stmt)) => Ok(Self::from_update(stmt, bind, schema, recorder)), + Some(NodeEnum::DeleteStmt(stmt)) => Ok(Self::from_delete(stmt, bind, schema, recorder)), + Some(NodeEnum::InsertStmt(stmt)) => Ok(Self::from_insert(stmt, bind, schema, recorder)), + _ => Err(Error::NotASelect), + } + } + + pub fn shard(&mut self) -> Result, Error> { + match self.stmt { + Statement::Select(stmt) => self.shard_select(stmt), + Statement::Update(stmt) => self.shard_update(stmt), + Statement::Delete(stmt) => self.shard_delete(stmt), + Statement::Insert(stmt) => self.shard_insert(stmt), + } + } + + fn shard_select(&mut self, stmt: &'a SelectStmt) -> Result, Error> { + let ctx = SearchContext::from_from_clause(&stmt.from_clause); + let result = self.search_select_stmt(stmt, &ctx)?; + + match result { + SearchResult::Match(shard) => Ok(Some(shard)), + SearchResult::Matches(shards) => Ok(Self::converge(&shards)), + _ => Ok(None), + } + } + + fn shard_update(&mut self, stmt: &'a UpdateStmt) -> Result, Error> { + let ctx = self.context_from_relation(&stmt.relation); + let result = self.search_update_stmt(stmt, &ctx)?; + + match result { + SearchResult::Match(shard) => Ok(Some(shard)), + SearchResult::Matches(shards) => Ok(Self::converge(&shards)), + _ => Ok(None), + } + } + + fn shard_delete(&mut self, stmt: &'a DeleteStmt) -> Result, Error> { + let ctx = self.context_from_relation(&stmt.relation); + let result = self.search_delete_stmt(stmt, &ctx)?; + + match result { + SearchResult::Match(shard) => Ok(Some(shard)), + SearchResult::Matches(shards) => Ok(Self::converge(&shards)), + _ => Ok(None), + } + } + + fn shard_insert(&mut self, stmt: &'a InsertStmt) -> Result, Error> { + let ctx = self.context_from_relation(&stmt.relation); + let result = self.search_insert_stmt(stmt, &ctx)?; + + match result { + SearchResult::Match(shard) => Ok(Some(shard)), + SearchResult::Matches(shards) => Ok(Self::converge(&shards)), + _ => Ok(None), + } + } + + fn context_from_relation(&self, relation: &'a Option) -> SearchContext<'a> { + let mut ctx = SearchContext::default(); + if let Some(ref range_var) = relation { + let table = Table::from(range_var); + ctx.table = Some(table); + if let Some(ref alias) = range_var.alias { + ctx.aliases.insert(alias.aliasname.as_str(), table); + } + } + ctx + } + + fn converge(shards: &[Shard]) -> Option { + let shards: HashSet = shards.into_iter().cloned().collect(); + match shards.len() { + 0 => None, + 1 => shards.into_iter().next(), + _ => { + let mut multi = vec![]; + for shard in shards.into_iter() { + match shard { + Shard::All => return Some(Shard::All), + Shard::Direct(direct) => multi.push(direct), + Shard::Multi(many) => multi.extend(many), + } + } + Some(Shard::Multi(multi)) + } + } + } + + fn compute_shard( + &mut self, + column: Column<'a>, + value: Value<'a>, + ) -> Result, Error> { + if let Some(table) = self.schema.tables().get_table(column) { + let context = ContextBuilder::new(table); + let shard = match value { + Value::Placeholder(pos) => { + let param = self + .bind + .map(|bind| bind.parameter(pos as usize - 1)) + .transpose()? + .flatten(); + // Expect params to be accurate. + let param = if let Some(param) = param { + param + } else { + return Ok(None); + }; + let value = ShardingValue::from_param(¶m, table.data_type)?; + Some( + context + .value(value) + .shards(self.schema.shards) + .build()? + .apply()?, + ) + } + + Value::String(val) => Some( + context + .data(val) + .shards(self.schema.shards) + .build()? + .apply()?, + ), + + Value::Integer(val) => Some( + context + .data(val) + .shards(self.schema.shards) + .build()? + .apply()?, + ), + Value::Null => None, + _ => None, + }; + + Ok(shard) + } else { + Ok(None) + } + } + + fn select_search( + &mut self, + node: &'a Node, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + match node.node { + // Value types - these are leaf nodes representing actual values + Some(NodeEnum::AConst(_)) + | Some(NodeEnum::ParamRef(_)) + | Some(NodeEnum::FuncCall(_)) => { + if let Ok(value) = Value::try_from(&node.node) { + return Ok(SearchResult::Value(value)); + } + Ok(SearchResult::None) + } + + Some(NodeEnum::TypeCast(ref cast)) => { + if let Some(ref arg) = cast.arg { + return self.select_search(arg, ctx); + } + Ok(SearchResult::None) + } + + Some(NodeEnum::SelectStmt(ref stmt)) => { + // Build context with aliases from the FROM clause + let ctx = SearchContext::from_from_clause(&stmt.from_clause); + self.search_select_stmt(stmt, &ctx) + } + + Some(NodeEnum::RangeSubselect(ref subselect)) => { + if let Some(ref node) = subselect.subquery { + return self.select_search(node, ctx); + } else { + Ok(SearchResult::None) + } + } + + Some(NodeEnum::ColumnRef(_)) => { + let mut column = Column::try_from(&node.node)?; + + // If column has no table, qualify with context table + if column.table().is_none() { + if let Some(ref table) = ctx.table { + column.qualify(*table); + } + } + + Ok(SearchResult::Column(column)) + } + + Some(NodeEnum::AExpr(ref expr)) => { + let kind = expr.kind(); + let mut supported = false; + + if matches!( + kind, + AExprKind::AexprOp | AExprKind::AexprIn | AExprKind::AexprOpAny + ) { + supported = expr + .name + .first() + .map(|node| match node.node { + Some(NodeEnum::String(ref string)) => string.sval.as_str(), + _ => "", + }) + .unwrap_or_default() + == "="; + } + + if !supported { + return Ok(SearchResult::None); + } + + let is_any = matches!(kind, AExprKind::AexprOpAny); + + let mut results = vec![]; + + if let Some(ref left) = expr.lexpr { + results.push(self.select_search(left, ctx)?); + } + + if let Some(ref right) = expr.rexpr { + results.push(self.select_search(right, ctx)?); + } + + if results.len() != 2 { + Ok(SearchResult::None) + } else { + let right = results.pop().unwrap(); + let left = results.pop().unwrap(); + + // If either side is already a match (from subquery), return it + if right.is_match() { + return Ok(right); + } + if left.is_match() { + return Ok(left); + } + + match (right, left) { + (SearchResult::Column(column), values) + | (values, SearchResult::Column(column)) => { + // For ANY expressions with sharding columns, we can't reliably + // parse array literals or parameters, so route to all shards. + if is_any { + if matches!(values, SearchResult::Value(_)) { + if self.schema.tables().get_table(column).is_some() { + return Ok(SearchResult::Match(Shard::All)); + } + } + } + + let mut shards = HashSet::new(); + for value in values.iter() { + if let Some(shard) = + self.compute_shard_with_ctx(column, value.clone(), ctx)? + { + shards.insert(shard); + } + } + + match shards.len() { + 0 => Ok(SearchResult::None), + 1 => Ok(SearchResult::Match(shards.into_iter().next().unwrap())), + _ => Ok(SearchResult::Matches(shards.into_iter().collect())), + } + } + _ => Ok(SearchResult::None), + } + } + } + + Some(NodeEnum::List(ref list)) => { + let mut values = vec![]; + + for value in &list.items { + if let Ok(value) = Value::try_from(&value.node) { + values.push(value); + } + } + + Ok(SearchResult::Values(values)) + } + + Some(NodeEnum::WithClause(ref with_clause)) => { + for cte in &with_clause.ctes { + let result = self.select_search(cte, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + Ok(SearchResult::None) + } + + Some(NodeEnum::JoinExpr(ref join)) => { + let mut results = vec![]; + + if let Some(ref left) = join.larg { + results.push(self.select_search(left, ctx)?); + } + if let Some(ref right) = join.rarg { + results.push(self.select_search(right, ctx)?); + } + + results.retain(|result| result.is_match()); + + let result = results + .into_iter() + .fold(SearchResult::None, |acc, x| acc.merge(x)); + + Ok(result) + } + + Some(NodeEnum::BoolExpr(ref expr)) => { + // Only AND expressions can determine a shard. + // OR expressions could route to multiple shards. + if expr.boolop() != BoolExprType::AndExpr { + return Ok(SearchResult::None); + } + + for arg in &expr.args { + let result = self.select_search(arg, ctx)?; + if result.is_match() { + return Ok(result); + } + } + + Ok(SearchResult::None) + } + + Some(NodeEnum::SubLink(ref sublink)) => { + if let Some(ref subselect) = sublink.subselect { + return self.select_search(subselect, ctx); + } + Ok(SearchResult::None) + } + + Some(NodeEnum::CommonTableExpr(ref cte)) => { + if let Some(ref ctequery) = cte.ctequery { + return self.select_search(ctequery, ctx); + } + Ok(SearchResult::None) + } + + _ => Ok(SearchResult::None), + } + } + + /// Search a SELECT statement with its own context. + fn search_select_stmt( + &mut self, + stmt: &'a SelectStmt, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + // Handle UNION/INTERSECT/EXCEPT (set operations) + // These have larg and rarg instead of a regular SELECT structure + if let Some(ref larg) = stmt.larg { + let larg_ctx = SearchContext::from_from_clause(&larg.from_clause); + let result = self.search_select_stmt(larg, &larg_ctx)?; + if !result.is_none() { + return Ok(result); + } + } + if let Some(ref rarg) = stmt.rarg { + let rarg_ctx = SearchContext::from_from_clause(&rarg.from_clause); + let result = self.search_select_stmt(rarg, &rarg_ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + let result = self.select_search(cte, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + } + + // Search WHERE clause + if let Some(ref where_clause) = stmt.where_clause { + let result = self.select_search(where_clause, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + for from_ in &stmt.from_clause { + let result = self.select_search(from_, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + Ok(SearchResult::None) + } + + /// Compute shard with alias resolution from context. + fn compute_shard_with_ctx( + &mut self, + column: Column<'a>, + value: Value<'a>, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + // Resolve table alias if present + let resolved_column = if let Some(table_ref) = column.table() { + if let Some(resolved) = ctx.resolve_table(table_ref.name) { + Column { + name: column.name, + table: Some(resolved.name), + schema: resolved.schema, + } + } else { + column + } + } else { + column + }; + + let shard = self.compute_shard(resolved_column, value.clone())?; + if let Some(ref shard) = shard { + self.record_sharding_key(shard, resolved_column, &value); + } + Ok(shard) + } + + /// Search an UPDATE statement for sharding keys. + fn search_update_stmt( + &mut self, + stmt: &'a UpdateStmt, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + // Handle CTEs (WITH clause) + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + let result = self.select_search(cte, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + } + + // Search WHERE clause + if let Some(ref where_clause) = stmt.where_clause { + let result = self.select_search(where_clause, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + // Search FROM clause (UPDATE ... FROM ...) + for from_ in &stmt.from_clause { + let result = self.select_search(from_, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + Ok(SearchResult::None) + } + + /// Search a DELETE statement for sharding keys. + fn search_delete_stmt( + &mut self, + stmt: &'a DeleteStmt, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + // Handle CTEs (WITH clause) + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + let result = self.select_search(cte, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + } + + // Search WHERE clause + if let Some(ref where_clause) = stmt.where_clause { + let result = self.select_search(where_clause, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + // Search USING clause (DELETE ... USING ...) + for using_ in &stmt.using_clause { + let result = self.select_search(using_, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + + Ok(SearchResult::None) + } + + /// Search an INSERT statement for sharding keys. + fn search_insert_stmt( + &mut self, + stmt: &'a InsertStmt, + ctx: &SearchContext<'a>, + ) -> Result, Error> { + // Handle CTEs (WITH clause) + if let Some(ref with_clause) = stmt.with_clause { + for cte in &with_clause.ctes { + let result = self.select_search(cte, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + } + + // Get the column names from INSERT INTO table (col1, col2, ...) + let columns: Vec<&str> = stmt + .cols + .iter() + .filter_map(|node| match &node.node { + Some(NodeEnum::ResTarget(target)) => Some(target.name.as_str()), + _ => None, + }) + .collect(); + + // The select_stmt field contains either VALUES or a SELECT subquery + if let Some(ref select_node) = stmt.select_stmt { + if let Some(NodeEnum::SelectStmt(ref select_stmt)) = select_node.node { + // Check if this is VALUES (has values_lists) - need special handling + // to match column positions with sharding keys + if !select_stmt.values_lists.is_empty() { + for values_list in &select_stmt.values_lists { + if let Some(NodeEnum::List(ref list)) = values_list.node { + for (pos, value_node) in list.items.iter().enumerate() { + // Check if this position corresponds to a sharding key column + if let Some(column_name) = columns.get(pos) { + let column = Column { + name: column_name, + table: ctx.table.map(|t| t.name), + schema: ctx.table.and_then(|t| t.schema), + }; + + if self.schema.tables().get_table(column).is_some() { + // Try to extract the value directly + if let Ok(value) = Value::try_from(&value_node.node) { + if let Some(shard) = + self.compute_shard_with_ctx(column, value, ctx)? + { + return Ok(SearchResult::Match(shard)); + } + } + } + } + + // Search subqueries in values recursively + let result = self.select_search(value_node, ctx)?; + if result.is_match() { + return Ok(result); + } + } + } + } + } + + // Handle INSERT ... SELECT by recursively searching the SelectStmt + let result = self.select_search(select_node, ctx)?; + if !result.is_none() { + return Ok(result); + } + } + } + + Ok(SearchResult::None) + } +} + +#[cfg(test)] +mod test { + use pgdog_config::{FlexibleType, Mapping, ShardedMapping, ShardedMappingKind, ShardedTable}; + + use crate::backend::ShardedTables; + use crate::net::messages::{Bind, Parameter}; + + use super::*; + + fn run_test(stmt: &str, bind: Option<&Bind>) -> Result, Error> { + let schema = ShardingSchema { + shards: 3, + tables: ShardedTables::new( + vec![ + ShardedTable { + column: "id".into(), + name: Some("sharded".into()), + ..Default::default() + }, + ShardedTable { + column: "sharded_id".into(), + ..Default::default() + }, + ShardedTable { + column: "list_id".into(), + mapping: Mapping::new(&[ShardedMapping { + kind: ShardedMappingKind::List, + values: vec![FlexibleType::Integer(1), FlexibleType::Integer(2)] + .into_iter() + .collect(), + ..Default::default() + }]), + ..Default::default() + }, + // Schema-qualified sharded table with different column name + ShardedTable { + column: "tenant_id".into(), + name: Some("schema_sharded".into()), + schema: Some("myschema".into()), + ..Default::default() + }, + ], + vec![], + ), + ..Default::default() + }; + let raw = pg_query::parse(stmt) + .unwrap() + .protobuf + .stmts + .first() + .cloned() + .unwrap(); + let mut parser = StatementParser::from_raw(&raw, bind, &schema, None)?; + parser.shard() + } + + #[test] + fn test_simple_select() { + let result = run_test("SELECT * FROM sharded WHERE id = 1", None); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_select_with_and() { + let result = run_test("SELECT * FROM sharded WHERE id = 1 AND name = 'foo'", None).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_or_returns_none() { + // OR expressions can't determine a single shard + let result = run_test("SELECT * FROM sharded WHERE id = 1 OR id = 2", None).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_with_subquery() { + let result = run_test( + "SELECT * FROM sharded WHERE id IN (SELECT sharded_id FROM other WHERE sharded_id = 1)", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_cte() { + let result = run_test( + "WITH cte AS (SELECT * FROM sharded WHERE id = 1) SELECT * FROM cte", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_join() { + let result = run_test( + "SELECT * FROM sharded s JOIN other o ON s.id = o.sharded_id WHERE s.id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_type_cast() { + let result = run_test("SELECT * FROM sharded WHERE id = '1'::int", None).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_from_subquery() { + let result = run_test( + "SELECT * FROM (SELECT * FROM sharded WHERE id = 1) AS sub", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_nested_cte() { + let result = run_test( + "WITH cte1 AS (SELECT * FROM sharded WHERE id = 1), \ + cte2 AS (SELECT * FROM cte1) \ + SELECT * FROM cte2", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_no_where_returns_none() { + let result = run_test("SELECT * FROM sharded", None).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_with_in_list() { + let result = run_test("SELECT * FROM sharded WHERE id IN (1, 2, 3)", None).unwrap(); + // IN with multiple values should return a shard match + assert!(result.is_some()); + } + + #[test] + fn test_select_with_not_equals_returns_none() { + // != operator is not supported for sharding + let result = run_test("SELECT * FROM sharded WHERE id != 1", None).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_with_greater_than_returns_none() { + // > operator is not supported for sharding + let result = run_test("SELECT * FROM sharded WHERE id > 1", None).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_with_complex_and() { + let result = run_test( + "SELECT * FROM sharded WHERE id = 1 AND status = 'active' AND created_at > now()", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_left_join() { + let result = run_test( + "SELECT * FROM sharded s LEFT JOIN other o ON s.id = o.sharded_id WHERE s.id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_multiple_joins() { + let result = run_test( + "SELECT * FROM sharded s \ + JOIN other o ON s.id = o.sharded_id \ + JOIN third t ON o.id = t.other_id \ + WHERE s.id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_exists_subquery() { + let result = run_test( + "SELECT * FROM sharded WHERE EXISTS (SELECT 1 FROM other WHERE sharded_id = 1)", + None, + ) + .unwrap(); + // EXISTS subquery should find the shard condition inside + assert!(result.is_some()); + } + + #[test] + fn test_select_with_scalar_subquery() { + // Scalar subquery where shard is determined by the subquery's WHERE clause + let result = run_test( + "SELECT * FROM sharded WHERE id = (SELECT sharded_id FROM other WHERE sharded_id = 1 LIMIT 1)", + None, + ) + .unwrap(); + // The subquery's sharded_id = 1 should be found + assert!(result.is_some()); + } + + #[test] + fn test_select_with_recursive_cte() { + // Recursive CTEs have UNION - we look at the base case + let result = run_test( + "WITH RECURSIVE cte AS ( \ + SELECT * FROM sharded WHERE id = 1 \ + UNION ALL \ + SELECT s.* FROM sharded s JOIN cte c ON s.parent_id = c.id \ + ) SELECT * FROM cte", + None, + ) + .unwrap(); + // The base case has id = 1 + assert!(result.is_some()); + } + + #[test] + fn test_select_with_union() { + let result = run_test( + "SELECT * FROM sharded WHERE id = 1 UNION SELECT * FROM sharded WHERE id = 2", + None, + ) + .unwrap(); + // UNION queries should find at least one shard + assert!(result.is_some()); + } + + #[test] + fn test_select_with_nested_subselects() { + let result = run_test( + "SELECT * FROM sharded WHERE id IN ( \ + SELECT * FROM other WHERE x IN ( \ + SELECT y FROM third WHERE sharded_id = 1 \ + ) \ + )", + None, + ) + .unwrap(); + // The innermost subquery has sharded_id = 1 + assert!(result.is_some()); + } + + // Bound parameter tests + + #[test] + fn test_bound_simple_select() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("SELECT * FROM sharded WHERE id = $1", Some(&bind)).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_and() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE id = $1 AND name = 'foo'", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_or_returns_none() { + let bind = Bind::new_params("", &[Parameter::new(b"1"), Parameter::new(b"2")]); + let result = run_test( + "SELECT * FROM sharded WHERE id = $1 OR id = $2", + Some(&bind), + ) + .unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_bound_select_with_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE id IN (SELECT sharded_id FROM other WHERE sharded_id = $1)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_cte() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH cte AS (SELECT * FROM sharded WHERE id = $1) SELECT * FROM cte", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_join() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded s JOIN other o ON s.id = o.sharded_id WHERE s.id = $1", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_type_cast() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("SELECT * FROM sharded WHERE id = $1::int", Some(&bind)).unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_from_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM (SELECT * FROM sharded WHERE id = $1) AS sub", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_nested_cte() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH cte1 AS (SELECT * FROM sharded WHERE id = $1), \ + cte2 AS (SELECT * FROM cte1) \ + SELECT * FROM cte2", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_in_list() { + let bind = Bind::new_params( + "", + &[ + Parameter::new(b"1"), + Parameter::new(b"2"), + Parameter::new(b"3"), + ], + ); + let result = run_test( + "SELECT * FROM sharded WHERE id IN ($1, $2, $3)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_any_array() { + // ANY($1) with an array parameter - $1 is a single array value like '{1,2,3}' + // Array parameters route to all shards since we can't reliably parse them + let bind = Bind::new_params("", &[Parameter::new(b"{1,2,3}")]); + let result = run_test("SELECT * FROM sharded WHERE id = ANY($1)", Some(&bind)).unwrap(); + assert_eq!(result, Some(Shard::All)); + } + + #[test] + fn test_bound_select_with_not_equals_returns_none() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("SELECT * FROM sharded WHERE id != $1", Some(&bind)).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_bound_select_with_greater_than_returns_none() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("SELECT * FROM sharded WHERE id > $1", Some(&bind)).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_bound_select_with_complex_and() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE id = $1 AND status = 'active' AND created_at > now()", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_left_join() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded s LEFT JOIN other o ON s.id = o.sharded_id WHERE s.id = $1", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_multiple_joins() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded s \ + JOIN other o ON s.id = o.sharded_id \ + JOIN third t ON o.id = t.other_id \ + WHERE s.id = $1", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_exists_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE EXISTS (SELECT 1 FROM other WHERE sharded_id = $1)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_scalar_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE id = (SELECT sharded_id FROM other WHERE sharded_id = $1 LIMIT 1)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_recursive_cte() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH RECURSIVE cte AS ( \ + SELECT * FROM sharded WHERE id = $1 \ + UNION ALL \ + SELECT s.* FROM sharded s JOIN cte c ON s.parent_id = c.id \ + ) SELECT * FROM cte", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_union() { + let bind = Bind::new_params("", &[Parameter::new(b"1"), Parameter::new(b"2")]); + let result = run_test( + "SELECT * FROM sharded WHERE id = $1 UNION SELECT * FROM sharded WHERE id = $2", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_nested_subselects() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM sharded WHERE id IN ( \ + SELECT * FROM other WHERE x IN ( \ + SELECT y FROM third WHERE sharded_id = $1 \ + ) \ + )", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_cte_and_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH cte AS (SELECT * FROM sharded WHERE id = $1) \ + SELECT * FROM cte WHERE id IN (SELECT sharded_id FROM other)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_multiple_ctes_and_subquery() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH cte1 AS (SELECT * FROM sharded WHERE id = $1), \ + cte2 AS (SELECT * FROM other WHERE sharded_id IN (SELECT id FROM cte1)) \ + SELECT * FROM cte2", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_cte_subquery_and_join() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "WITH cte AS (SELECT * FROM sharded WHERE id = $1) \ + SELECT c.*, o.* FROM cte c \ + JOIN other o ON c.id = o.sharded_id \ + WHERE o.x IN (SELECT y FROM third)", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + // Schema-qualified table tests + + #[test] + fn test_select_with_schema_qualified_table() { + let result = run_test( + "SELECT * FROM myschema.schema_sharded WHERE tenant_id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_schema_qualified_alias() { + let result = run_test( + "SELECT * FROM myschema.schema_sharded s WHERE s.tenant_id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_bound_select_with_schema_qualified_alias() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "SELECT * FROM myschema.schema_sharded s WHERE s.tenant_id = $1", + Some(&bind), + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_with_schema_qualified_join() { + let result = run_test( + "SELECT * FROM myschema.schema_sharded s \ + JOIN other o ON s.id = o.sharded_id WHERE s.tenant_id = 1", + None, + ) + .unwrap(); + assert!(result.is_some()); + } + + #[test] + fn test_select_wrong_schema_returns_none() { + let result = run_test( + "SELECT * FROM otherschema.schema_sharded WHERE tenant_id = 1", + None, + ) + .unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_wrong_schema_alias_returns_none() { + let result = run_test( + "SELECT * FROM otherschema.schema_sharded s WHERE s.tenant_id = 1", + None, + ) + .unwrap(); + assert!(result.is_none()); + } + + #[test] + fn test_select_with_any_array_literal() { + let result = run_test("SELECT * FROM sharded WHERE id = ANY('{1, 2, 3}')", None).unwrap(); + // ANY with array literal routes to all shards + assert_eq!(result, Some(Shard::All)); + } + + // UPDATE statement tests + + #[test] + fn test_simple_update() { + let result = run_test("UPDATE sharded SET name = 'foo' WHERE id = 1", None); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_update_with_bound_param() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("UPDATE sharded SET name = 'foo' WHERE id = $1", Some(&bind)); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_update_with_and() { + let result = run_test( + "UPDATE sharded SET name = 'foo' WHERE id = 1 AND status = 'active'", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_update_no_where_returns_none() { + let result = run_test("UPDATE sharded SET name = 'foo'", None); + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_update_with_subquery() { + let result = run_test( + "UPDATE sharded SET name = 'foo' WHERE id IN (SELECT sharded_id FROM other WHERE sharded_id = 1)", + None, + ); + assert!(result.unwrap().is_some()); + } + + // DELETE statement tests + + #[test] + fn test_simple_delete() { + let result = run_test("DELETE FROM sharded WHERE id = 1", None); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_delete_with_bound_param() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test("DELETE FROM sharded WHERE id = $1", Some(&bind)); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_delete_with_and() { + let result = run_test( + "DELETE FROM sharded WHERE id = 1 AND status = 'active'", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_delete_no_where_returns_none() { + let result = run_test("DELETE FROM sharded", None); + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_delete_with_subquery() { + let result = run_test( + "DELETE FROM sharded WHERE id IN (SELECT sharded_id FROM other WHERE sharded_id = 1)", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_delete_with_cte() { + let result = run_test( + "WITH to_delete AS (SELECT id FROM sharded WHERE id = 1) DELETE FROM sharded WHERE id IN (SELECT id FROM to_delete)", + None, + ); + assert!(result.unwrap().is_some()); + } + + // INSERT statement tests + + #[test] + fn test_simple_insert_with_value() { + let result = run_test("INSERT INTO sharded (id, name) VALUES (1, 'foo')", None); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_with_bound_param() { + let bind = Bind::new_params("", &[Parameter::new(b"1"), Parameter::new(b"foo")]); + let result = run_test( + "INSERT INTO sharded (id, name) VALUES ($1, $2)", + Some(&bind), + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_with_subquery_in_values() { + let result = run_test( + "INSERT INTO sharded (id, name) VALUES ((SELECT sharded_id FROM other WHERE sharded_id = 1), 'foo')", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_with_subquery_in_values_param() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "INSERT INTO sharded (id, name) VALUES ((SELECT sharded_id FROM other WHERE sharded_id = $1), 'foo')", + Some(&bind), + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_select() { + let result = run_test( + "INSERT INTO sharded (id, name) SELECT sharded_id, name FROM other WHERE sharded_id = 1", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_select_with_param() { + let bind = Bind::new_params("", &[Parameter::new(b"1")]); + let result = run_test( + "INSERT INTO sharded (id, name) SELECT sharded_id, name FROM other WHERE sharded_id = $1", + Some(&bind), + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_with_cte() { + let result = run_test( + "WITH src AS (SELECT id, name FROM sharded WHERE id = 1) INSERT INTO sharded (id, name) SELECT id, name FROM src", + None, + ); + assert!(result.unwrap().is_some()); + } + + #[test] + fn test_insert_no_sharding_key_returns_none() { + let result = run_test("INSERT INTO sharded (name) VALUES ('foo')", None); + assert!(result.unwrap().is_none()); + } +} diff --git a/pgdog/src/frontend/router/parser/table.rs b/pgdog/src/frontend/router/parser/table.rs index 56c20c83..cda74e4a 100644 --- a/pgdog/src/frontend/router/parser/table.rs +++ b/pgdog/src/frontend/router/parser/table.rs @@ -5,7 +5,7 @@ use pg_query::{ Node, NodeEnum, }; -use super::Schema; +use super::{Error, Schema}; use crate::util::escape_identifier; /// Table name in a query. @@ -100,7 +100,7 @@ impl<'a> TryFrom<&'a Node> for Table<'a> { } impl<'a> TryFrom<&'a Vec> for Table<'a> { - type Error = (); + type Error = Error; fn try_from(value: &'a Vec) -> Result { match value.len() { @@ -116,7 +116,7 @@ impl<'a> TryFrom<&'a Vec> for Table<'a> { }) }) .flatten() - .ok_or(())?; + .ok_or(Error::TableDecode)?; return table; } @@ -149,7 +149,7 @@ impl<'a> TryFrom<&'a Vec> for Table<'a> { _ => (), } - Err(()) + Err(Error::TableDecode) } } @@ -173,7 +173,8 @@ impl<'a> From<&'a RangeVar> for Table<'a> { } impl<'a> TryFrom<&'a List> for Table<'a> { - type Error = (); + type Error = Error; + fn try_from(value: &'a List) -> Result { fn str_value(list: &List, pos: usize) -> Option<&str> { if let Some(NodeEnum::String(ref schema)) = list.items.get(pos).unwrap().node { @@ -186,7 +187,7 @@ impl<'a> TryFrom<&'a List> for Table<'a> { match value.items.len() { 2 => { let schema = str_value(value, 0); - let name = str_value(value, 1).ok_or(())?; + let name = str_value(value, 1).ok_or(Error::TableDecode)?; Ok(Table { schema, name, @@ -195,7 +196,7 @@ impl<'a> TryFrom<&'a List> for Table<'a> { } 1 => { - let name = str_value(value, 0).ok_or(())?; + let name = str_value(value, 0).ok_or(Error::TableDecode)?; Ok(Table { schema: None, name, @@ -203,7 +204,7 @@ impl<'a> TryFrom<&'a List> for Table<'a> { }) } - _ => Err(()), + _ => Err(Error::TableDecode), } } } diff --git a/pgdog/src/frontend/router/sharding/context_builder.rs b/pgdog/src/frontend/router/sharding/context_builder.rs index 10572ac7..d52edcec 100644 --- a/pgdog/src/frontend/router/sharding/context_builder.rs +++ b/pgdog/src/frontend/router/sharding/context_builder.rs @@ -1,3 +1,8 @@ +//! Context builder for sharding a value. +//! +//! Manages mapping a value (integer, string, etc.) +//! to a shard number, given a sharded mapping in pgdog.toml. +//! use crate::{ backend::ShardingSchema, config::{DataType, Hasher as HasherConfig, ShardedTable}, @@ -6,6 +11,7 @@ use crate::{ use super::{Centroids, Context, Data, Error, Hasher, Lists, Operator, Ranges, Value}; +/// Sharding context builder. #[derive(Debug)] pub struct ContextBuilder<'a> { data_type: DataType, @@ -21,6 +27,8 @@ pub struct ContextBuilder<'a> { } impl<'a> ContextBuilder<'a> { + /// Create new context builder from a sharded table + /// in the config. pub fn new(table: &'a ShardedTable) -> Self { Self { data_type: table.data_type, @@ -42,7 +50,8 @@ impl<'a> ContextBuilder<'a> { } } - /// Infer sharding function from config. + /// Infer sharding function from config, iff + /// only one sharding function is configured. pub fn infer_from_from_and_config( value: &'a str, sharding_schema: &'a ShardingSchema, @@ -141,6 +150,7 @@ impl<'a> ContextBuilder<'a> { } } + /// Set the number of shards in the configuration. pub fn shards(mut self, shards: usize) -> Self { if let Some(centroids) = self.centroids.take() { self.operator = Some(Operator::Centroids { diff --git a/pgdog/src/frontend/router/sharding/value.rs b/pgdog/src/frontend/router/sharding/value.rs index a59bc0a1..9b7c292c 100644 --- a/pgdog/src/frontend/router/sharding/value.rs +++ b/pgdog/src/frontend/router/sharding/value.rs @@ -54,6 +54,8 @@ impl<'a> Value<'a> { } } + /// Convert parameter to value, given the data type + /// and known encoding. pub fn from_param( param: &'a ParameterWithFormat<'a>, data_type: DataType,