diff --git a/core/src/blocks/block.rs b/core/src/blocks/block.rs index b147ec08b97a6..a266e182a2db8 100644 --- a/core/src/blocks/block.rs +++ b/core/src/blocks/block.rs @@ -1,6 +1,7 @@ use crate::blocks::{ browser::Browser, chat::Chat, code::Code, curl::Curl, data::Data, data_source::DataSource, - end::End, input::Input, llm::LLM, map::Map, r#while::While, reduce::Reduce, search::Search, + database_schema::DatabaseSchema, end::End, input::Input, llm::LLM, map::Map, r#while::While, + reduce::Reduce, search::Search, }; use crate::data_sources::qdrant::QdrantClients; use crate::project::Project; @@ -73,6 +74,7 @@ pub enum BlockType { Browser, While, End, + DatabaseSchema, } impl ToString for BlockType { @@ -91,6 +93,7 @@ impl ToString for BlockType { BlockType::Browser => String::from("browser"), BlockType::While => String::from("while"), BlockType::End => String::from("end"), + BlockType::DatabaseSchema => String::from("database_schema"), } } } @@ -192,6 +195,7 @@ pub fn parse_block(t: BlockType, block_pair: Pair) -> Result Ok(Box::new(Browser::parse(block_pair)?)), BlockType::While => Ok(Box::new(While::parse(block_pair)?)), BlockType::End => Ok(Box::new(End::parse(block_pair)?)), + BlockType::DatabaseSchema => Ok(Box::new(DatabaseSchema::parse(block_pair)?)), } } diff --git a/core/src/blocks/data_source.rs b/core/src/blocks/data_source.rs index 2cd4a3b1d15e0..d9b4afae76321 100644 --- a/core/src/blocks/data_source.rs +++ b/core/src/blocks/data_source.rs @@ -1,21 +1,15 @@ use crate::blocks::block::{ parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env, }; +use crate::blocks::helpers::get_data_source_project; use crate::data_sources::data_source::{Document, SearchFilter}; -use crate::project::Project; use crate::Rule; use anyhow::{anyhow, Result}; use async_trait::async_trait; -use hyper::header; -use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request}; -use hyper_tls::HttpsConnector; use pest::iterators::Pair; use serde_json::Value; use std::collections::HashMap; -use std::io::prelude::*; use tokio::sync::mpsc::UnboundedSender; -use url::Url; -use urlencoding::encode; #[derive(Clone)] pub struct DataSource { @@ -88,94 +82,7 @@ impl DataSource { ) -> Result> { let data_source_project = match workspace_id { Some(workspace_id) => { - let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") { - None => Err(anyhow!( - "DUST_WORKSPACE_ID credentials missing, but `workspace_id` \ - is set in `data_source` block config" - ))?, - Some(v) => v.clone(), - }; - let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") { - Ok(key) => key, - Err(_) => Err(anyhow!( - "Environment variable `DUST_REGISTRY_SECRET` is not set." - ))?, - }; - let front_api = match std::env::var("DUST_FRONT_API") { - Ok(key) => key, - Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?, - }; - - let url = format!( - "{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}", - front_api.as_str(), - encode(&workspace_id), - encode(&data_source_id), - ); - let parsed_url = Url::parse(url.as_str())?; - - let mut req = Request::builder().method(Method::GET).uri(url.as_str()); - - { - let headers = match req.headers_mut() { - Some(h) => h, - None => Err(anyhow!("Invalid URL: {}", url.as_str()))?, - }; - headers.insert( - header::AUTHORIZATION, - header::HeaderValue::from_bytes( - format!("Bearer {}", registry_secret.as_str()).as_bytes(), - )?, - ); - headers.insert( - header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?, - header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?, - ); - } - let req = req.body(Body::empty())?; - - let res = match parsed_url.scheme() { - "https" => { - let https = HttpsConnector::new(); - let cli = Client::builder().build::<_, hyper::Body>(https); - cli.request(req).await? - } - "http" => { - let cli = Client::new(); - cli.request(req).await? - } - _ => Err(anyhow!( - "Only the `http` and `https` schemes are authorized." - ))?, - }; - - let status = res.status(); - if status != StatusCode::OK { - Err(anyhow!( - "Failed to retrieve DataSource `{} > {}`", - workspace_id, - data_source_id, - ))?; - } - - let body = hyper::body::aggregate(res).await?; - let mut b: Vec = vec![]; - body.reader().read_to_end(&mut b)?; - - let response_body = String::from_utf8_lossy(&b).into_owned(); - - let body = match serde_json::from_str::(&response_body) { - Ok(body) => body, - Err(_) => Err(anyhow!("Failed to parse registry response"))?, - }; - - match body.get("project_id") { - Some(Value::Number(p)) => match p.as_i64() { - Some(p) => Project::new_from_id(p), - None => Err(anyhow!("Failed to parse registry response"))?, - }, - _ => Err(anyhow!("Failed to parse registry response"))?, - } + get_data_source_project(&workspace_id, &data_source_id, env).await? } None => env.project.clone(), }; diff --git a/core/src/blocks/database_schema.rs b/core/src/blocks/database_schema.rs new file mode 100644 index 0000000000000..c9fc61cdb45bf --- /dev/null +++ b/core/src/blocks/database_schema.rs @@ -0,0 +1,126 @@ +use crate::blocks::block::{Block, BlockResult, BlockType, Env}; +use crate::Rule; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use futures::future::try_join_all; +use pest::iterators::Pair; +use serde_json::Value; +use tokio::sync::mpsc::UnboundedSender; + +use super::helpers::get_data_source_project; + +#[derive(Clone)] +pub struct DatabaseSchema {} + +impl DatabaseSchema { + pub fn parse(_block_pair: Pair) -> Result { + Ok(DatabaseSchema {}) + } +} + +#[async_trait] +impl Block for DatabaseSchema { + fn block_type(&self) -> BlockType { + BlockType::DatabaseSchema + } + + fn inner_hash(&self) -> String { + let mut hasher = blake3::Hasher::new(); + hasher.update("database_schema".as_bytes()); + format!("{}", hasher.finalize().to_hex()) + } + + async fn execute( + &self, + name: &str, + env: &Env, + _event_sender: Option>, + ) -> Result { + let config = env.config.config_for_block(name); + + // TODO: finish error message + let err_msg = format!( + "Invalid or missing `databases` in configuration for \ + `database_schema` block `{}` expecting `{{ \"databases\": \ + [ {{ \"workspace_id\": ..., \"data_source_id\": ..., \"database_id\": ... }}, ... ] }}`", + name + ); + + let databases = match config { + Some(v) => match v.get("databases") { + Some(Value::Array(a)) => a + .iter() + .map(|v| { + let workspace_id = match v.get("workspace_id") { + Some(Value::String(s)) => s, + _ => Err(anyhow!(err_msg.clone()))?, + }; + let data_source_id = match v.get("data_source_id") { + Some(Value::String(s)) => s, + _ => Err(anyhow!(err_msg.clone()))?, + }; + let database_id = match v.get("database_id") { + Some(Value::String(s)) => s, + _ => Err(anyhow!(err_msg.clone()))?, + }; + + Ok((workspace_id, data_source_id, database_id)) + }) + .collect::>>(), + _ => Err(anyhow!(err_msg)), + }, + None => Err(anyhow!(err_msg)), + }?; + + let schemas = try_join_all(databases.iter().map( + |(workspace_id, data_source_id, database_id)| { + get_database_schema(workspace_id, data_source_id, database_id, env) + }, + )) + .await?; + + Ok(BlockResult { + value: serde_json::to_value(schemas)?, + meta: None, + }) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +async fn get_database_schema( + workspace_id: &String, + data_source_id: &String, + database_id: &String, + env: &Env, +) -> Result { + let project = get_data_source_project(workspace_id, data_source_id, env).await?; + let database = match env + .store + .load_database(&project, data_source_id, database_id) + .await? + { + Some(d) => d, + None => Err(anyhow!( + "Database `{}` not found in data source `{}`", + database_id, + data_source_id + ))?, + }; + + match database.get_schema(&project, env.store.clone()).await { + Ok(s) => Ok(s), + Err(e) => Err(anyhow!( + "Error getting schema for database `{}` in data source `{}`: {}", + database_id, + data_source_id, + e + )), + } +} diff --git a/core/src/blocks/helpers.rs b/core/src/blocks/helpers.rs new file mode 100644 index 0000000000000..97983aafa6690 --- /dev/null +++ b/core/src/blocks/helpers.rs @@ -0,0 +1,105 @@ +use super::block::Env; +use crate::project::Project; +use anyhow::{anyhow, Result}; +use hyper::header; +use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request}; +use hyper_tls::HttpsConnector; +use serde_json::Value; +use std::io::prelude::*; +use url::Url; +use urlencoding::encode; + +pub async fn get_data_source_project( + workspace_id: &String, + data_source_id: &String, + env: &Env, +) -> Result { + let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") { + None => Err(anyhow!( + "DUST_WORKSPACE_ID credentials missing, but `workspace_id` \ + is set in `data_source` block config" + ))?, + Some(v) => v.clone(), + }; + let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") { + Ok(key) => key, + Err(_) => Err(anyhow!( + "Environment variable `DUST_REGISTRY_SECRET` is not set." + ))?, + }; + let front_api = match std::env::var("DUST_FRONT_API") { + Ok(key) => key, + Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?, + }; + + let url = format!( + "{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}", + front_api.as_str(), + encode(&workspace_id), + encode(&data_source_id), + ); + let parsed_url = Url::parse(url.as_str())?; + + let mut req = Request::builder().method(Method::GET).uri(url.as_str()); + + { + let headers = match req.headers_mut() { + Some(h) => h, + None => Err(anyhow!("Invalid URL: {}", url.as_str()))?, + }; + headers.insert( + header::AUTHORIZATION, + header::HeaderValue::from_bytes( + format!("Bearer {}", registry_secret.as_str()).as_bytes(), + )?, + ); + headers.insert( + header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?, + header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?, + ); + } + let req = req.body(Body::empty())?; + + let res = match parsed_url.scheme() { + "https" => { + let https = HttpsConnector::new(); + let cli = Client::builder().build::<_, hyper::Body>(https); + cli.request(req).await? + } + "http" => { + let cli = Client::new(); + cli.request(req).await? + } + _ => Err(anyhow!( + "Only the `http` and `https` schemes are authorized." + ))?, + }; + + let status = res.status(); + if status != StatusCode::OK { + Err(anyhow!( + "Failed to retrieve DataSource `{} > {}`", + workspace_id, + data_source_id, + ))?; + } + + let body = hyper::body::aggregate(res).await?; + let mut b: Vec = vec![]; + body.reader().read_to_end(&mut b)?; + + let response_body = String::from_utf8_lossy(&b).into_owned(); + + let body = match serde_json::from_str::(&response_body) { + Ok(body) => body, + Err(_) => Err(anyhow!("Failed to parse registry response"))?, + }; + + match body.get("project_id") { + Some(Value::Number(p)) => match p.as_i64() { + Some(p) => Ok(Project::new_from_id(p)), + None => Err(anyhow!("Failed to parse registry response")), + }, + _ => Err(anyhow!("Failed to parse registry response")), + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index f1a7d0987fe59..2d23f50cdb43f 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -48,7 +48,9 @@ pub mod blocks { pub mod curl; pub mod data; pub mod data_source; + pub mod database_schema; pub mod end; + pub mod helpers; pub mod input; pub mod llm; pub mod map; diff --git a/core/src/run.rs b/core/src/run.rs index e23d93bc3f37a..7ad888e5fb1b9 100644 --- a/core/src/run.rs +++ b/core/src/run.rs @@ -56,6 +56,7 @@ impl RunConfig { BlockType::Browser => 8, BlockType::While => 64, BlockType::End => 64, + BlockType::DatabaseSchema => 8, } } }