Skip to content

Commit

Permalink
feat: add core support for database_schema block
Browse files Browse the repository at this point in the history
  • Loading branch information
fontanierh committed Nov 17, 2023
1 parent d888aa0 commit 9adbdc4
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 96 deletions.
6 changes: 5 additions & 1 deletion core/src/blocks/block.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::blocks::{
browser::Browser, chat::Chat, code::Code, curl::Curl, data::Data, data_source::DataSource,
end::End, input::Input, llm::LLM, map::Map, r#while::While, reduce::Reduce, search::Search,
database_schema::DatabaseSchema, end::End, input::Input, llm::LLM, map::Map, r#while::While,
reduce::Reduce, search::Search,
};
use crate::data_sources::qdrant::QdrantClients;
use crate::project::Project;
Expand Down Expand Up @@ -73,6 +74,7 @@ pub enum BlockType {
Browser,
While,
End,
DatabaseSchema,
}

impl ToString for BlockType {
Expand All @@ -91,6 +93,7 @@ impl ToString for BlockType {
BlockType::Browser => String::from("browser"),
BlockType::While => String::from("while"),
BlockType::End => String::from("end"),
BlockType::DatabaseSchema => String::from("database_schema"),
}
}
}
Expand Down Expand Up @@ -192,6 +195,7 @@ pub fn parse_block(t: BlockType, block_pair: Pair<Rule>) -> Result<Box<dyn Block
BlockType::Browser => Ok(Box::new(Browser::parse(block_pair)?)),
BlockType::While => Ok(Box::new(While::parse(block_pair)?)),
BlockType::End => Ok(Box::new(End::parse(block_pair)?)),
BlockType::DatabaseSchema => Ok(Box::new(DatabaseSchema::parse(block_pair)?)),
}
}

Expand Down
97 changes: 2 additions & 95 deletions core/src/blocks/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
use crate::blocks::block::{
parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
};
use crate::blocks::helpers::get_data_source_project;
use crate::data_sources::data_source::{Document, SearchFilter};
use crate::project::Project;
use crate::Rule;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use hyper::header;
use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request};
use hyper_tls::HttpsConnector;
use pest::iterators::Pair;
use serde_json::Value;
use std::collections::HashMap;
use std::io::prelude::*;
use tokio::sync::mpsc::UnboundedSender;
use url::Url;
use urlencoding::encode;

#[derive(Clone)]
pub struct DataSource {
Expand Down Expand Up @@ -88,94 +82,7 @@ impl DataSource {
) -> Result<Vec<Document>> {
let data_source_project = match workspace_id {
Some(workspace_id) => {
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
None => Err(anyhow!(
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
is set in `data_source` block config"
))?,
Some(v) => v.clone(),
};
let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
Ok(key) => key,
Err(_) => Err(anyhow!(
"Environment variable `DUST_REGISTRY_SECRET` is not set."
))?,
};
let front_api = match std::env::var("DUST_FRONT_API") {
Ok(key) => key,
Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?,
};

let url = format!(
"{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}",
front_api.as_str(),
encode(&workspace_id),
encode(&data_source_id),
);
let parsed_url = Url::parse(url.as_str())?;

let mut req = Request::builder().method(Method::GET).uri(url.as_str());

{
let headers = match req.headers_mut() {
Some(h) => h,
None => Err(anyhow!("Invalid URL: {}", url.as_str()))?,
};
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_bytes(
format!("Bearer {}", registry_secret.as_str()).as_bytes(),
)?,
);
headers.insert(
header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?,
header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?,
);
}
let req = req.body(Body::empty())?;

let res = match parsed_url.scheme() {
"https" => {
let https = HttpsConnector::new();
let cli = Client::builder().build::<_, hyper::Body>(https);
cli.request(req).await?
}
"http" => {
let cli = Client::new();
cli.request(req).await?
}
_ => Err(anyhow!(
"Only the `http` and `https` schemes are authorized."
))?,
};

let status = res.status();
if status != StatusCode::OK {
Err(anyhow!(
"Failed to retrieve DataSource `{} > {}`",
workspace_id,
data_source_id,
))?;
}

let body = hyper::body::aggregate(res).await?;
let mut b: Vec<u8> = vec![];
body.reader().read_to_end(&mut b)?;

let response_body = String::from_utf8_lossy(&b).into_owned();

let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
Ok(body) => body,
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
};

match body.get("project_id") {
Some(Value::Number(p)) => match p.as_i64() {
Some(p) => Project::new_from_id(p),
None => Err(anyhow!("Failed to parse registry response"))?,
},
_ => Err(anyhow!("Failed to parse registry response"))?,
}
get_data_source_project(&workspace_id, &data_source_id, env).await?
}
None => env.project.clone(),
};
Expand Down
126 changes: 126 additions & 0 deletions core/src/blocks/database_schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use crate::blocks::block::{Block, BlockResult, BlockType, Env};
use crate::Rule;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::future::try_join_all;
use pest::iterators::Pair;
use serde_json::Value;
use tokio::sync::mpsc::UnboundedSender;

use super::helpers::get_data_source_project;

#[derive(Clone)]
pub struct DatabaseSchema {}

impl DatabaseSchema {
pub fn parse(_block_pair: Pair<Rule>) -> Result<Self> {
Ok(DatabaseSchema {})
}
}

#[async_trait]
impl Block for DatabaseSchema {
fn block_type(&self) -> BlockType {
BlockType::DatabaseSchema
}

fn inner_hash(&self) -> String {
let mut hasher = blake3::Hasher::new();
hasher.update("database_schema".as_bytes());
format!("{}", hasher.finalize().to_hex())
}

async fn execute(
&self,
name: &str,
env: &Env,
_event_sender: Option<UnboundedSender<Value>>,
) -> Result<BlockResult> {
let config = env.config.config_for_block(name);

// TODO: finish error message
let err_msg = format!(
"Invalid or missing `databases` in configuration for \
`database_schema` block `{}` expecting `{{ \"databases\": \
[ {{ \"workspace_id\": ..., \"data_source_id\": ..., \"database_id\": ... }}, ... ] }}`",
name
);

let databases = match config {
Some(v) => match v.get("databases") {
Some(Value::Array(a)) => a
.iter()
.map(|v| {
let workspace_id = match v.get("workspace_id") {
Some(Value::String(s)) => s,
_ => Err(anyhow!(err_msg.clone()))?,
};
let data_source_id = match v.get("data_source_id") {
Some(Value::String(s)) => s,
_ => Err(anyhow!(err_msg.clone()))?,
};
let database_id = match v.get("database_id") {
Some(Value::String(s)) => s,
_ => Err(anyhow!(err_msg.clone()))?,
};

Ok((workspace_id, data_source_id, database_id))
})
.collect::<Result<Vec<_>>>(),
_ => Err(anyhow!(err_msg)),
},
None => Err(anyhow!(err_msg)),
}?;

let schemas = try_join_all(databases.iter().map(
|(workspace_id, data_source_id, database_id)| {
get_database_schema(workspace_id, data_source_id, database_id, env)
},
))
.await?;

Ok(BlockResult {
value: serde_json::to_value(schemas)?,
meta: None,
})
}

fn clone_box(&self) -> Box<dyn Block + Sync + Send> {
Box::new(self.clone())
}

fn as_any(&self) -> &dyn std::any::Any {
self
}
}

async fn get_database_schema(
workspace_id: &String,
data_source_id: &String,
database_id: &String,
env: &Env,
) -> Result<crate::databases::database::DatabaseSchema> {
let project = get_data_source_project(workspace_id, data_source_id, env).await?;
let database = match env
.store
.load_database(&project, data_source_id, database_id)
.await?
{
Some(d) => d,
None => Err(anyhow!(
"Database `{}` not found in data source `{}`",
database_id,
data_source_id
))?,
};

match database.get_schema(&project, env.store.clone()).await {
Ok(s) => Ok(s),
Err(e) => Err(anyhow!(
"Error getting schema for database `{}` in data source `{}`: {}",
database_id,
data_source_id,
e
)),
}
}
105 changes: 105 additions & 0 deletions core/src/blocks/helpers.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use super::block::Env;
use crate::project::Project;
use anyhow::{anyhow, Result};
use hyper::header;
use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request};
use hyper_tls::HttpsConnector;
use serde_json::Value;
use std::io::prelude::*;
use url::Url;
use urlencoding::encode;

pub async fn get_data_source_project(
workspace_id: &String,
data_source_id: &String,
env: &Env,
) -> Result<Project> {
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
None => Err(anyhow!(
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
is set in `data_source` block config"
))?,
Some(v) => v.clone(),
};
let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
Ok(key) => key,
Err(_) => Err(anyhow!(
"Environment variable `DUST_REGISTRY_SECRET` is not set."
))?,
};
let front_api = match std::env::var("DUST_FRONT_API") {
Ok(key) => key,
Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?,
};

let url = format!(
"{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}",
front_api.as_str(),
encode(&workspace_id),
encode(&data_source_id),
);
let parsed_url = Url::parse(url.as_str())?;

let mut req = Request::builder().method(Method::GET).uri(url.as_str());

{
let headers = match req.headers_mut() {
Some(h) => h,
None => Err(anyhow!("Invalid URL: {}", url.as_str()))?,
};
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_bytes(
format!("Bearer {}", registry_secret.as_str()).as_bytes(),
)?,
);
headers.insert(
header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?,
header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?,
);
}
let req = req.body(Body::empty())?;

let res = match parsed_url.scheme() {
"https" => {
let https = HttpsConnector::new();
let cli = Client::builder().build::<_, hyper::Body>(https);
cli.request(req).await?
}
"http" => {
let cli = Client::new();
cli.request(req).await?
}
_ => Err(anyhow!(
"Only the `http` and `https` schemes are authorized."
))?,
};

let status = res.status();
if status != StatusCode::OK {
Err(anyhow!(
"Failed to retrieve DataSource `{} > {}`",
workspace_id,
data_source_id,
))?;
}

let body = hyper::body::aggregate(res).await?;
let mut b: Vec<u8> = vec![];
body.reader().read_to_end(&mut b)?;

let response_body = String::from_utf8_lossy(&b).into_owned();

let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
Ok(body) => body,
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
};

match body.get("project_id") {
Some(Value::Number(p)) => match p.as_i64() {
Some(p) => Ok(Project::new_from_id(p)),
None => Err(anyhow!("Failed to parse registry response")),
},
_ => Err(anyhow!("Failed to parse registry response")),
}
}
2 changes: 2 additions & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ pub mod blocks {
pub mod curl;
pub mod data;
pub mod data_source;
pub mod database_schema;
pub mod end;
pub mod helpers;
pub mod input;
pub mod llm;
pub mod map;
Expand Down
Loading

0 comments on commit 9adbdc4

Please sign in to comment.