Skip to content

Commit

Permalink
feat(core): support query forward. (#487)
Browse files Browse the repository at this point in the history
* refactor: extract session.rs

* feat: support query forward.

* update tests.

* ci: cli test use env DATABEND_PORT.
  • Loading branch information
youngsofun authored Oct 22, 2024
1 parent 826c8b0 commit b996aef
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 62 deletions.
3 changes: 2 additions & 1 deletion cli/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CARGO_TARGET_DIR=${CARGO_TARGET_DIR:-./target}
DATABEND_USER=${DATABEND_USER:-root}
DATABEND_PASSWORD=${DATABEND_PASSWORD:-}
DATABEND_HOST=${DATABEND_HOST:-localhost}
DATABEND_PORT=${DATABEND_PORT:-8000}

TEST_HANDLER=$1

Expand All @@ -32,7 +33,7 @@ case $TEST_HANDLER in
;;
"http")
echo "==> Testing REST API handler"
export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:8000/?sslmode=disable&presign=on"
export BENDSQL_DSN="databend+http://${DATABEND_USER}:${DATABEND_PASSWORD}@${DATABEND_HOST}:${DATABEND_PORT}/?sslmode=disable&presign=on"
;;
*)
echo "Usage: $0 [flight|http]"
Expand Down
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] }
tokio = { version = "1.34", features = ["macros"] }
tokio-retry = "0.3"
tokio-util = { version = "0.7", features = ["io-util"] }
parking_lot = "0.12.3"
url = { version = "2.5", default-features = false }
uuid = { version = "1.6", features = ["v4"] }

Expand Down
45 changes: 38 additions & 7 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ use url::Url;

use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth};
use crate::presign::{presign_upload_to_stage, PresignMode, PresignedResponse, Reader};
use crate::session::SessionState;
use crate::stage::StageLocation;
use crate::{
error::{Error, Result},
request::{PaginationConfig, QueryRequest, SessionState, StageAttachmentConfig},
request::{PaginationConfig, QueryRequest, StageAttachmentConfig},
response::{QueryError, QueryResponse},
};

const HEADER_QUERY_ID: &str = "X-DATABEND-QUERY-ID";
const HEADER_TENANT: &str = "X-DATABEND-TENANT";
const HEADER_STICKY_NODE: &str = "X-DATABEND-STICKY-NODE";
const HEADER_WAREHOUSE: &str = "X-DATABEND-WAREHOUSE";
const HEADER_STAGE_NAME: &str = "X-DATABEND-STAGE-NAME";
const HEADER_ROUTE_HINT: &str = "X-DATABEND-ROUTE-HINT";
Expand Down Expand Up @@ -76,6 +78,7 @@ pub struct APIClient {
tls_ca_file: Option<String>,

presign: PresignMode,
last_node_id: Arc<parking_lot::Mutex<Option<String>>>,
}

impl APIClient {
Expand Down Expand Up @@ -283,6 +286,13 @@ impl APIClient {
}
}

pub fn set_last_node_id(&self, node_id: String) {
*self.last_node_id.lock() = Some(node_id)
}
pub fn last_node_id(&self) -> Option<String> {
self.last_node_id.lock().clone()
}

pub fn handle_warnings(&self, resp: &QueryResponse) {
if let Some(warnings) = &resp.warnings {
for w in warnings {
Expand All @@ -297,12 +307,18 @@ impl APIClient {
self.route_hint.next();
}
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);
let req = QueryRequest::new(sql)
.with_pagination(self.make_pagination())
.with_session(Some(session_state));
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;
let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -344,7 +360,12 @@ impl APIClient {
Ok(result)
}

pub async fn query_page(&self, query_id: &str, next_uri: &str) -> Result<QueryResponse> {
pub async fn query_page(
&self,
query_id: &str,
next_uri: &str,
node_id: &str,
) -> Result<QueryResponse> {
info!("query page: {}", next_uri);
let endpoint = self.endpoint.join(next_uri)?;
let headers = self.make_headers(query_id).await?;
Expand All @@ -354,6 +375,7 @@ impl APIClient {
builder = self.auth.wrap(builder).await?;
builder
.headers(headers.clone())
.header(HEADER_STICKY_NODE, node_id)
.timeout(self.page_request_timeout)
.send()
.await
Expand Down Expand Up @@ -410,12 +432,14 @@ impl APIClient {

pub async fn wait_for_query(&self, resp: QueryResponse) -> Result<QueryResponse> {
info!("wait for query: {}", resp.id);
let node_id = resp.node_id.clone();
self.set_last_node_id(node_id.clone());
if let Some(next_uri) = &resp.next_uri {
let schema = resp.schema;
let mut data = resp.data;
let mut resp = self.query_page(&resp.id, next_uri).await?;
let mut resp = self.query_page(&resp.id, next_uri, &node_id).await?;
while let Some(next_uri) = &resp.next_uri {
resp = self.query_page(&resp.id, next_uri).await?;
resp = self.query_page(&resp.id, next_uri, &node_id).await?;
data.append(&mut resp.data);
}
resp.schema = schema;
Expand Down Expand Up @@ -487,6 +511,8 @@ impl APIClient {
sql, file_format_options, copy_options
);
let session_state = self.session_state().await;
let need_sticky = session_state.need_sticky.unwrap_or(false);

let stage_attachment = Some(StageAttachmentConfig {
location: stage,
file_format_options: Some(file_format_options),
Expand All @@ -498,8 +524,12 @@ impl APIClient {
.with_stage_attachment(stage_attachment);
let endpoint = self.endpoint.join("v1/query")?;
let query_id = self.gen_query_id();
let headers = self.make_headers(&query_id).await?;

let mut headers = self.make_headers(&query_id).await?;
if need_sticky {
if let Some(node_id) = self.last_node_id() {
headers.insert(HEADER_STICKY_NODE, node_id.parse()?);
}
}
let mut builder = self.cli.post(endpoint.clone()).json(&req);
builder = self.auth.wrap(builder).await?;
let mut resp = builder.headers(headers.clone()).send().await?;
Expand Down Expand Up @@ -626,6 +656,7 @@ impl Default for APIClient {
tls_ca_file: None,
presign: PresignMode::Auto,
route_hint: Arc::new(RouteHintGenerator::new()),
last_node_id: Arc::new(Default::default()),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod error;
pub mod presign;
pub mod request;
pub mod response;
pub mod session;
pub mod stage;

pub use client::APIClient;
51 changes: 7 additions & 44 deletions core/src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;

use crate::session::SessionState;
use serde::{Deserialize, Serialize};

#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
pub struct ServerInfo {
pub id: String,
pub start_time: String,
}
#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct SessionState {
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub settings: Option<BTreeMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secondary_roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub txn_state: Option<String>,

// hide fields of no interest (but need to send back to server in next query)
#[serde(flatten)]
additional_fields: HashMap<String, serde_json::Value>,
}

impl SessionState {
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
self.settings = settings;
self
}

pub fn with_database(mut self, database: Option<String>) -> Self {
self.database = database;
self
}

pub fn with_role(mut self, role: Option<String>) -> Self {
self.role = role;
self
}
}

#[derive(Serialize, Debug)]
pub struct QueryRequest<'a> {
Expand Down Expand Up @@ -122,14 +90,9 @@ mod test {
#[test]
fn build_request() -> Result<()> {
let req = QueryRequest::new("select 1")
.with_session(Some(SessionState {
database: Some("default".to_string()),
settings: Some(BTreeMap::new()),
role: None,
secondary_roles: None,
txn_state: None,
additional_fields: Default::default(),
}))
.with_session(Some(
SessionState::default().with_database(Some("default".to_string())),
))
.with_pagination(Some(PaginationConfig {
wait_time_secs: Some(1),
max_rows_in_buffer: Some(1),
Expand All @@ -142,7 +105,7 @@ mod test {
}));
assert_eq!(
serde_json::to_string(&req)?,
r#"{"session":{"database":"default","settings":{}},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"#
r#"{"session":{"database":"default"},"sql":"select 1","pagination":{"wait_time_secs":1,"max_rows_in_buffer":1,"max_rows_per_page":1},"stage_attachment":{"location":"@~/my_location"}}"#
);
Ok(())
}
Expand Down
3 changes: 2 additions & 1 deletion core/src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

use serde::Deserialize;

use crate::request::SessionState;
use crate::session::SessionState;

#[derive(Deserialize, Debug)]
pub struct QueryError {
Expand Down Expand Up @@ -55,6 +55,7 @@ pub struct SchemaField {
#[derive(Deserialize, Debug)]
pub struct QueryResponse {
pub id: String,
pub node_id: String,
pub session_id: Option<String>,
pub session: Option<SessionState>,
pub schema: Vec<SchemaField>,
Expand Down
53 changes: 53 additions & 0 deletions core/src/session.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};

#[derive(Deserialize, Serialize, Debug, Default, Clone)]
pub struct SessionState {
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub settings: Option<BTreeMap<String, String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub secondary_roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub txn_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub need_sticky: Option<bool>,

// hide fields of no interest (but need to send back to server in next query)
#[serde(flatten)]
additional_fields: HashMap<String, serde_json::Value>,
}

impl SessionState {
pub fn with_settings(mut self, settings: Option<BTreeMap<String, String>>) -> Self {
self.settings = settings;
self
}

pub fn with_database(mut self, database: Option<String>) -> Self {
self.database = database;
self
}

pub fn with_role(mut self, role: Option<String>) -> Self {
self.role = role;
self
}
}
29 changes: 20 additions & 9 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ impl Connection for RestAPIConnection {
async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
let node_id = resp.node_id.clone();
while let Some(next_uri) = resp.next_uri {
resp = self.client.query_page(&resp.id, &next_uri).await?;
resp = self
.client
.query_page(&resp.id, &next_uri, &node_id)
.await?;
}
Ok(resp.stats.progresses.write_progress.rows as i64)
}
Expand Down Expand Up @@ -201,14 +205,19 @@ impl<'o> RestAPIConnection {
Ok(Self { client })
}

async fn wait_for_schema(&self, pre: QueryResponse) -> Result<QueryResponse> {
if !pre.data.is_empty() || !pre.schema.is_empty() {
return Ok(pre);
async fn wait_for_schema(&self, resp: QueryResponse) -> Result<QueryResponse> {
if !resp.data.is_empty() || !resp.schema.is_empty() {
return Ok(resp);
}
let mut result = pre;
// preserve schema since it is no included in the final response
let node_id = resp.node_id.clone();
self.client.set_last_node_id(node_id.clone());
let mut result = resp;
// preserve schema since it is not included in the final response
while let Some(next_uri) = result.next_uri {
result = self.client.query_page(&result.id, &next_uri).await?;
result = self
.client
.query_page(&result.id, &next_uri, &node_id)
.await?;
if !result.data.is_empty() || !result.schema.is_empty() {
break;
}
Expand Down Expand Up @@ -240,6 +249,7 @@ pub struct RestAPIRows {
data: VecDeque<Vec<Option<String>>>,
stats: Option<ServerStats>,
query_id: String,
node_id: String,
next_uri: Option<String>,
next_page: Option<PageFut>,
}
Expand All @@ -250,6 +260,7 @@ impl RestAPIRows {
let rows = Self {
client,
query_id: resp.id,
node_id: resp.node_id,
next_uri: resp.next_uri,
schema: Arc::new(schema.clone()),
data: resp.data.into(),
Expand Down Expand Up @@ -278,7 +289,6 @@ impl Stream for RestAPIRows {
if self.schema.fields().is_empty() {
self.schema = Arc::new(resp.schema.try_into()?);
}
self.query_id = resp.id;
self.next_uri = resp.next_uri;
self.next_page = None;
self.stats = Some(ServerStats::from(resp.stats));
Expand All @@ -295,9 +305,10 @@ impl Stream for RestAPIRows {
let client = self.client.clone();
let next_uri = next_uri.clone();
let query_id = self.query_id.clone();
let node_id = self.node_id.clone();
self.next_page = Some(Box::pin(async move {
client
.query_page(&query_id, &next_uri)
.query_page(&query_id, &next_uri, &node_id)
.await
.map_err(|e| e.into())
}));
Expand Down

0 comments on commit b996aef

Please sign in to comment.