diff --git a/Cargo.toml b/Cargo.toml index 8573e33f..3d7c4b1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,7 +64,7 @@ clap = { version = "4.5", features = ["derive"], optional = tru env_logger = { version = "0.11", optional = true } futures = { version = "0.3", optional = true } log = { version = "0.4", optional = true } -pgwire = { version = "0.19", optional = true } +pgwire = { version = "0.28.0", optional = true } tokio = { version = "1.36", features = ["full"], optional = true } diff --git a/src/bin/server.rs b/src/bin/server.rs index 6cbf773e..dbf0d28a 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -3,19 +3,16 @@ use clap::Parser; use fnck_sql::db::{DBTransaction, DataBaseBuilder, Database, ResultIter}; use fnck_sql::errors::DatabaseError; use fnck_sql::storage::rocksdb::RocksStorage; -use fnck_sql::types::tuple::{Schema, Tuple}; +use fnck_sql::types::tuple::{Schema, SchemaRef, Tuple}; use fnck_sql::types::LogicalType; use futures::stream; use log::{error, info, LevelFilter}; use parking_lot::Mutex; use pgwire::api::auth::noop::NoopStartupHandler; -use pgwire::api::auth::StartupHandler; -use pgwire::api::query::{ - ExtendedQueryHandler, PlaceholderExtendedQueryHandler, SimpleQueryHandler, -}; +use pgwire::api::copy::NoopCopyHandler; +use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::MakeHandler; -use pgwire::api::{ClientInfo, StatelessMakeHandler, Type}; +use pgwire::api::{ClientInfo, NoopErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::tokio::process_socket; use std::fmt::Debug; @@ -83,29 +80,67 @@ pub struct FnckSQLBackend { inner: Arc>, } +impl FnckSQLBackend { + pub fn new(path: impl Into + Send) -> Result { + let database = DataBaseBuilder::path(path).build()?; + + Ok(FnckSQLBackend { + inner: Arc::new(database), + }) + } +} + pub struct SessionBackend { inner: Arc>, tx: Mutex>, } -impl MakeHandler for FnckSQLBackend { - type Handler = Arc; - - fn make(&self) -> Self::Handler { - Arc::new(SessionBackend { - inner: Arc::clone(&self.inner), +impl SessionBackend { + pub fn new(inner: Arc>) -> SessionBackend { + SessionBackend { + inner, tx: Mutex::new(None), - }) + } } } -impl FnckSQLBackend { - pub fn new(path: impl Into + Send) -> Result { - let database = DataBaseBuilder::path(path).build()?; +impl NoopStartupHandler for SessionBackend {} - Ok(FnckSQLBackend { - inner: Arc::new(database), - }) +struct CustomBackendFactory { + handler: Arc, +} + +impl CustomBackendFactory { + pub fn new(handler: Arc) -> CustomBackendFactory { + CustomBackendFactory { handler } + } +} + +impl PgWireServerHandlers for CustomBackendFactory { + type StartupHandler = SessionBackend; + type SimpleQueryHandler = SessionBackend; + type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; + type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; + + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + fn extended_query_handler(&self) -> Arc { + Arc::new(PlaceholderExtendedQueryHandler) + } + + fn startup_handler(&self) -> Arc { + self.handler.clone() + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) } } @@ -179,7 +214,10 @@ impl SimpleQueryHandler for SessionBackend { for tuple in iter.by_ref() { tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?); } - encode_tuples(iter.schema(), tuples)? + let schema = iter.schema().clone(); + iter.done() + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + encode_tuples(&schema, tuples)? } else { let mut iter = self .inner @@ -188,7 +226,10 @@ impl SimpleQueryHandler for SessionBackend { for tuple in iter.by_ref() { tuples.push(tuple.map_err(|e| PgWireError::ApiError(Box::new(e)))?); } - encode_tuples(iter.schema(), tuples)? + let schema = iter.schema().clone(); + iter.done() + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; + encode_tuples(&schema, tuples)? }; Ok(vec![Response::Query(response)]) } @@ -196,7 +237,7 @@ impl SimpleQueryHandler for SessionBackend { } } -fn encode_tuples<'a>(schema: &Schema, tuples: Vec) -> PgWireResult> { +fn encode_tuples<'a>(schema: &SchemaRef, tuples: Vec) -> PgWireResult> { if tuples.is_empty() { return Ok(QueryResponse::new(Arc::new(vec![]), stream::empty())); } @@ -268,7 +309,7 @@ fn into_pg_type(data_type: &LogicalType) -> PgWireResult { LogicalType::Date | LogicalType::DateTime => Type::DATE, LogicalType::Char(..) => Type::CHAR, LogicalType::Time => Type::TIME, - LogicalType::Decimal(_, _) => Type::FLOAT8, + LogicalType::Decimal(_, _) => Type::NUMERIC, _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(), @@ -318,17 +359,14 @@ async fn main() { ); let backend = FnckSQLBackend::new(args.path).unwrap(); - let processor = Arc::new(backend); - // We have not implemented extended query in this server, use placeholder instead - let placeholder = Arc::new(StatelessMakeHandler::new(Arc::new( - PlaceholderExtendedQueryHandler, - ))); - let authenticator = Arc::new(StatelessMakeHandler::new(Arc::new(NoopStartupHandler))); + let factory = Arc::new(CustomBackendFactory::new(Arc::new(SessionBackend::new( + backend.inner, + )))); let server_addr = format!("{}:{}", args.ip, args.port); let listener = TcpListener::bind(server_addr).await.unwrap(); tokio::select! { - res = server_run(processor, placeholder, authenticator, listener) => { + res = server_run(listener,factory) => { if let Err(err) = res { error!("[Listener][Failed To Accept]: {}", err); } @@ -337,32 +375,16 @@ async fn main() { } } -async fn server_run< - A: MakeHandler>, - Q: MakeHandler>, - EQ: MakeHandler>, ->( - processor: Arc, - placeholder: Arc, - authenticator: Arc, +async fn server_run( listener: TcpListener, + factory_ref: Arc, ) -> io::Result<()> { loop { let incoming_socket = listener.accept().await?; - let authenticator_ref = authenticator.make(); - let processor_ref = processor.make(); - let placeholder_ref = placeholder.make(); + let factory_ref = factory_ref.clone(); tokio::spawn(async move { - if let Err(err) = process_socket( - incoming_socket.0, - None, - authenticator_ref, - processor_ref, - placeholder_ref, - ) - .await - { + if let Err(err) = process_socket(incoming_socket.0, None, factory_ref).await { error!("Failed To Process: {}", err); } });