Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support setting name for each client #344

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bindings/nodejs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ napi = { version = "2.14", default-features = false, features = [
"chrono_date",
] }
napi-derive = "2.14"
once_cell = "1.18"
tokio-stream = "0.1"

[build-dependencies]
Expand Down
9 changes: 8 additions & 1 deletion bindings/nodejs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ extern crate napi_derive;

use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
use napi::bindgen_prelude::*;
use once_cell::sync::Lazy;
use tokio_stream::StreamExt;

static VERSION: Lazy<String> = Lazy::new(|| {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
version.to_string()
});

#[napi]
pub struct Client(databend_driver::Client);

Expand Down Expand Up @@ -279,7 +285,8 @@ impl Client {
/// Create a new databend client with a given DSN.
#[napi(constructor)]
pub fn new(dsn: String) -> Self {
let client = databend_driver::Client::new(dsn);
let name = format!("databend-driver-nodejs/{}", VERSION.as_str());
let client = databend_driver::Client::new(dsn).with_name(name);
Self(client)
}

Expand Down
1 change: 1 addition & 0 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ doc = false

[dependencies]
databend-driver = { workspace = true, features = ["rustls", "flight-sql"] }
once_cell = "1.18"
pyo3 = { version = "0.20", features = ["abi3-py37"] }
pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] }
tokio = "1.34"
Expand Down
5 changes: 3 additions & 2 deletions bindings/python/src/asyncio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use pyo3::prelude::*;
use pyo3_asyncio::tokio::future_into_py;

use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats};
use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION};

#[pyclass(module = "databend_driver")]
pub struct AsyncDatabendClient(databend_driver::Client);
Expand All @@ -25,7 +25,8 @@ impl AsyncDatabendClient {
#[new]
#[pyo3(signature = (dsn))]
pub fn new(dsn: String) -> PyResult<Self> {
let client = databend_driver::Client::new(dsn);
let name = format!("databend-driver-python/{}", VERSION.as_str());
let client = databend_driver::Client::new(dsn).with_name(name);
Ok(Self(client))
}

Expand Down
5 changes: 3 additions & 2 deletions bindings/python/src/blocking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use pyo3::prelude::*;

use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats};
use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION};

#[pyclass(module = "databend_driver")]
pub struct BlockingDatabendClient(databend_driver::Client);
Expand All @@ -24,7 +24,8 @@ impl BlockingDatabendClient {
#[new]
#[pyo3(signature = (dsn))]
pub fn new(dsn: String) -> PyResult<Self> {
let client = databend_driver::Client::new(dsn);
let name = format!("databend-driver-python/{}", VERSION.as_str());
let client = databend_driver::Client::new(dsn).with_name(name);
Ok(Self(client))
}

Expand Down
6 changes: 6 additions & 0 deletions bindings/python/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

use std::sync::Arc;

use once_cell::sync::Lazy;
use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList, PyTuple};
use pyo3_asyncio::tokio::future_into_py;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;

pub static VERSION: Lazy<String> = Lazy::new(|| {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
version.to_string()
});

pub struct Value(databend_driver::Value);

impl IntoPy<PyObject> for Value {
Expand Down
30 changes: 25 additions & 5 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use anyhow::Result;
use chrono::NaiveDateTime;
use databend_driver::ServerStats;
use databend_driver::{Client, Connection};
use once_cell::sync::Lazy;
use rustyline::config::Builder;
use rustyline::error::ReadlineError;
use rustyline::history::DefaultHistory;
Expand All @@ -40,6 +41,15 @@ use crate::VERSION;

static PROMPT_SQL: &str = "select name from system.tables union all select name from system.columns union all select name from system.databases union all select name from system.functions";

static VERSION_SHORT: Lazy<String> = Lazy::new(|| {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
let sha = option_env!("VERGEN_GIT_SHA").unwrap_or("dev");
match option_env!("BENDSQL_BUILD_INFO") {
Some(info) => format!("{}-{}", version, info),
None => format!("{}-{}", version, sha),
}
});

pub struct Session {
client: Client,
conn: Box<dyn Connection>,
Expand All @@ -54,16 +64,26 @@ pub struct Session {

impl Session {
pub async fn try_new(dsn: String, settings: Settings, is_repl: bool) -> Result<Self> {
let client = Client::new(dsn);
let client = Client::new(dsn).with_name(format!("bendsql/{}", VERSION_SHORT.as_str()));
let conn = client.get_conn().await?;
let info = conn.info().await;
let mut keywords = Vec::with_capacity(1024);
if is_repl {
println!("Welcome to BendSQL {}.", VERSION.as_str());
println!(
"Connecting to {}:{} as user {}.",
info.host, info.port, info.user
);
match info.warehouse {
Some(ref warehouse) => {
println!(
"Connecting to {}:{} with warehouse {} as user {}",
info.host, info.port, warehouse, info.user
);
}
None => {
println!(
"Connecting to {}:{} as user {}.",
info.host, info.port, info.user
);
}
}
let version = conn.version().await?;
println!("Connected to {}", version);
println!();
Expand Down
54 changes: 35 additions & 19 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ static VERSION: Lazy<String> = Lazy::new(|| {
#[derive(Clone)]
pub struct APIClient {
pub cli: HttpClient,
scheme: String,
endpoint: Url,
pub host: String,
pub port: u16,
Expand All @@ -72,7 +73,14 @@ pub struct APIClient {
}

impl APIClient {
pub async fn from_dsn(dsn: &str) -> Result<Self> {
pub async fn new(dsn: &str, name: Option<String>) -> Result<Self> {
let mut client = Self::from_dsn(dsn).await?;
client.build_client(name).await?;
client.check_presign().await?;
Ok(client)
}

async fn from_dsn(dsn: &str) -> Result<Self> {
let u = Url::parse(dsn)?;
let mut client = Self::default();
if let Some(host) = u.host_str() {
Expand Down Expand Up @@ -176,34 +184,40 @@ impl APIClient {
_ => unreachable!(),
},
};
client.scheme = scheme.to_string();

let mut cli_builder = HttpClient::builder()
.user_agent(format!("databend-client-rust/{}", VERSION.as_str()))
.pool_idle_timeout(Duration::from_secs(1));
#[cfg(any(feature = "rustls", feature = "native-tls"))]
if scheme == "https" {
if let Some(ref ca_file) = client.tls_ca_file {
let cert_pem = tokio::fs::read(ca_file).await?;
let cert = reqwest::Certificate::from_pem(&cert_pem)?;
cli_builder = cli_builder.add_root_certificate(cert);
}
}
client.cli = cli_builder.build()?;
client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?;

client.session_state = Arc::new(Mutex::new(
SessionState::default()
.with_settings(Some(session_settings))
.with_role(role)
.with_database(database),
));

client.init_presign().await?;

Ok(client)
}

async fn init_presign(&mut self) -> Result<()> {
async fn build_client(&mut self, name: Option<String>) -> Result<()> {
let ua = match name {
Some(n) => n,
None => format!("databend-client-rust/{}", VERSION.as_str()),
};
let mut cli_builder = HttpClient::builder()
.user_agent(ua)
.pool_idle_timeout(Duration::from_secs(1));
#[cfg(any(feature = "rustls", feature = "native-tls"))]
if self.scheme == "https" {
if let Some(ref ca_file) = self.tls_ca_file {
let cert_pem = tokio::fs::read(ca_file).await?;
let cert = reqwest::Certificate::from_pem(&cert_pem)?;
cli_builder = cli_builder.add_root_certificate(cert);
}
}
self.cli = cli_builder.build()?;
Ok(())
}

async fn check_presign(&mut self) -> Result<()> {
match self.presign {
PresignMode::Auto => {
if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") {
Expand All @@ -212,7 +226,7 @@ impl APIClient {
self.presign = PresignMode::Off;
}
}
PresignMode::Detect => match self.get_presigned_upload_url("~/.bendsql/check").await {
PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await {
Ok(_) => self.presign = PresignMode::On,
Err(e) => {
warn!("presign mode off with error detected: {}", e);
Expand Down Expand Up @@ -344,7 +358,8 @@ impl APIClient {
}
let resp: QueryResponse = resp.json().await?;
self.handle_session(&resp.session).await;
// TODO: duplicate warnings with start_query, maybe we should only print warnings on final response
// TODO: duplicate warnings with start_query,
// maybe we should only print warnings on final response
// self.handle_warnings(&resp);
match resp.error {
Some(err) => Err(Error::InvalidResponse(err)),
Expand Down Expand Up @@ -570,6 +585,7 @@ impl Default for APIClient {
fn default() -> Self {
Self {
cli: HttpClient::new(),
scheme: "http".to_string(),
endpoint: Url::parse("http://localhost:8080").unwrap(),
host: "localhost".to_string(),
port: 8000,
Expand Down
2 changes: 1 addition & 1 deletion core/tests/core/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::common::DEFAULT_DSN;
#[tokio::test]
async fn select_simple() {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = APIClient::from_dsn(dsn).await.unwrap();
let client = APIClient::new(dsn, None).await.unwrap();
let resp = client.start_query("select 15532").await.unwrap();
assert_eq!(resp.data, [["15532"]]);
}
4 changes: 2 additions & 2 deletions core/tests/core/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::common::DEFAULT_DSN;
async fn insert_with_stage(presign: bool) {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = if presign {
APIClient::from_dsn(&format!("{}&presign=on", dsn))
APIClient::new(&format!("{}&presign=on", dsn), None)
.await
.unwrap()
} else {
APIClient::from_dsn(&format!("{}&presign=off", dsn))
APIClient::new(&format!("{}&presign=off", dsn), None)
.await
.unwrap()
};
Expand Down
1 change: 1 addition & 0 deletions driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ csv = "1.3"
dyn-clone = "1.0"
glob = "0.3"
log = "0.4"
once_cell = "1.18"
percent-encoding = "2.3"
serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.34", features = ["macros"] }
Expand Down
19 changes: 16 additions & 3 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use std::sync::Arc;

use async_trait::async_trait;
use dyn_clone::DynClone;
use once_cell::sync::Lazy;
use tokio::io::AsyncRead;
use tokio_stream::StreamExt;
use url::Url;
Expand All @@ -34,26 +35,38 @@ use databend_sql::value::{NumberValue, Value};

use crate::rest_api::RestAPIConnection;

static VERSION: Lazy<String> = Lazy::new(|| {
let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown");
version.to_string()
});

#[derive(Clone)]
pub struct Client {
dsn: String,
name: String,
}

impl Client {
pub fn new(dsn: String) -> Self {
Self { dsn }
let name = format!("databend-driver-rust/{}", VERSION.as_str());
Self { dsn, name }
}

pub fn with_name(mut self, name: String) -> Self {
self.name = name;
self
}

pub async fn get_conn(&self) -> Result<Box<dyn Connection>> {
let u = Url::parse(&self.dsn)?;
match u.scheme() {
"databend" | "databend+http" | "databend+https" => {
let conn = RestAPIConnection::try_create(&self.dsn).await?;
let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?;
Ok(Box::new(conn))
}
#[cfg(feature = "flight-sql")]
"databend+flight" | "databend+grpc" => {
let conn = FlightSQLConnection::try_create(&self.dsn).await?;
let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?;
Ok(Box::new(conn))
}
_ => Err(Error::Parsing(format!(
Expand Down
7 changes: 4 additions & 3 deletions driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ impl Connection for FlightSQLConnection {
}

impl FlightSQLConnection {
pub async fn try_create(dsn: &str) -> Result<Self> {
let (args, endpoint) = Self::parse_dsn(dsn).await?;
pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
let (args, endpoint) = Self::parse_dsn(dsn, name).await?;
let channel = endpoint.connect_lazy();
let mut client = FlightSqlServiceClient::new(channel);
// enable progress
Expand Down Expand Up @@ -178,10 +178,11 @@ impl FlightSQLConnection {
Ok(())
}

async fn parse_dsn(dsn: &str) -> Result<(Args, Endpoint)> {
async fn parse_dsn(dsn: &str, name: String) -> Result<(Args, Endpoint)> {
let u = Url::parse(dsn)?;
let args = Args::from_url(&u)?;
let mut endpoint = Endpoint::new(args.uri.clone())?
.user_agent(name)?
.connect_timeout(args.connect_timeout)
.timeout(args.query_timeout)
.tcp_nodelay(args.tcp_nodelay)
Expand Down
4 changes: 2 additions & 2 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ impl Connection for RestAPIConnection {
}

impl<'o> RestAPIConnection {
pub async fn try_create(dsn: &str) -> Result<Self> {
let client = APIClient::from_dsn(dsn).await?;
pub async fn try_create(dsn: &str, name: String) -> Result<Self> {
let client = APIClient::new(dsn, Some(name)).await?;
Ok(Self { client })
}

Expand Down
Loading