Skip to content

Commit 7300e78

Browse files
committed
feat: add core support for database_schema block
1 parent d888aa0 commit 7300e78

File tree

6 files changed

+240
-96
lines changed

6 files changed

+240
-96
lines changed

core/src/blocks/block.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::blocks::{
22
browser::Browser, chat::Chat, code::Code, curl::Curl, data::Data, data_source::DataSource,
3-
end::End, input::Input, llm::LLM, map::Map, r#while::While, reduce::Reduce, search::Search,
3+
database_schema::DatabaseSchema, end::End, input::Input, llm::LLM, map::Map, r#while::While,
4+
reduce::Reduce, search::Search,
45
};
56
use crate::data_sources::qdrant::QdrantClients;
67
use crate::project::Project;
@@ -73,6 +74,7 @@ pub enum BlockType {
7374
Browser,
7475
While,
7576
End,
77+
DatabaseSchema,
7678
}
7779

7880
impl ToString for BlockType {
@@ -91,6 +93,7 @@ impl ToString for BlockType {
9193
BlockType::Browser => String::from("browser"),
9294
BlockType::While => String::from("while"),
9395
BlockType::End => String::from("end"),
96+
BlockType::DatabaseSchema => String::from("database_schema"),
9497
}
9598
}
9699
}
@@ -192,6 +195,7 @@ pub fn parse_block(t: BlockType, block_pair: Pair<Rule>) -> Result<Box<dyn Block
192195
BlockType::Browser => Ok(Box::new(Browser::parse(block_pair)?)),
193196
BlockType::While => Ok(Box::new(While::parse(block_pair)?)),
194197
BlockType::End => Ok(Box::new(End::parse(block_pair)?)),
198+
BlockType::DatabaseSchema => Ok(Box::new(DatabaseSchema::parse(block_pair)?)),
195199
}
196200
}
197201

core/src/blocks/data_source.rs

Lines changed: 2 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,15 @@
11
use crate::blocks::block::{
22
parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
33
};
4+
use crate::blocks::helpers::get_data_source_project;
45
use crate::data_sources::data_source::{Document, SearchFilter};
5-
use crate::project::Project;
66
use crate::Rule;
77
use anyhow::{anyhow, Result};
88
use async_trait::async_trait;
9-
use hyper::header;
10-
use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request};
11-
use hyper_tls::HttpsConnector;
129
use pest::iterators::Pair;
1310
use serde_json::Value;
1411
use std::collections::HashMap;
15-
use std::io::prelude::*;
1612
use tokio::sync::mpsc::UnboundedSender;
17-
use url::Url;
18-
use urlencoding::encode;
1913

2014
#[derive(Clone)]
2115
pub struct DataSource {
@@ -88,94 +82,7 @@ impl DataSource {
8882
) -> Result<Vec<Document>> {
8983
let data_source_project = match workspace_id {
9084
Some(workspace_id) => {
91-
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
92-
None => Err(anyhow!(
93-
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
94-
is set in `data_source` block config"
95-
))?,
96-
Some(v) => v.clone(),
97-
};
98-
let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
99-
Ok(key) => key,
100-
Err(_) => Err(anyhow!(
101-
"Environment variable `DUST_REGISTRY_SECRET` is not set."
102-
))?,
103-
};
104-
let front_api = match std::env::var("DUST_FRONT_API") {
105-
Ok(key) => key,
106-
Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?,
107-
};
108-
109-
let url = format!(
110-
"{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}",
111-
front_api.as_str(),
112-
encode(&workspace_id),
113-
encode(&data_source_id),
114-
);
115-
let parsed_url = Url::parse(url.as_str())?;
116-
117-
let mut req = Request::builder().method(Method::GET).uri(url.as_str());
118-
119-
{
120-
let headers = match req.headers_mut() {
121-
Some(h) => h,
122-
None => Err(anyhow!("Invalid URL: {}", url.as_str()))?,
123-
};
124-
headers.insert(
125-
header::AUTHORIZATION,
126-
header::HeaderValue::from_bytes(
127-
format!("Bearer {}", registry_secret.as_str()).as_bytes(),
128-
)?,
129-
);
130-
headers.insert(
131-
header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?,
132-
header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?,
133-
);
134-
}
135-
let req = req.body(Body::empty())?;
136-
137-
let res = match parsed_url.scheme() {
138-
"https" => {
139-
let https = HttpsConnector::new();
140-
let cli = Client::builder().build::<_, hyper::Body>(https);
141-
cli.request(req).await?
142-
}
143-
"http" => {
144-
let cli = Client::new();
145-
cli.request(req).await?
146-
}
147-
_ => Err(anyhow!(
148-
"Only the `http` and `https` schemes are authorized."
149-
))?,
150-
};
151-
152-
let status = res.status();
153-
if status != StatusCode::OK {
154-
Err(anyhow!(
155-
"Failed to retrieve DataSource `{} > {}`",
156-
workspace_id,
157-
data_source_id,
158-
))?;
159-
}
160-
161-
let body = hyper::body::aggregate(res).await?;
162-
let mut b: Vec<u8> = vec![];
163-
body.reader().read_to_end(&mut b)?;
164-
165-
let response_body = String::from_utf8_lossy(&b).into_owned();
166-
167-
let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
168-
Ok(body) => body,
169-
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
170-
};
171-
172-
match body.get("project_id") {
173-
Some(Value::Number(p)) => match p.as_i64() {
174-
Some(p) => Project::new_from_id(p),
175-
None => Err(anyhow!("Failed to parse registry response"))?,
176-
},
177-
_ => Err(anyhow!("Failed to parse registry response"))?,
178-
}
85+
get_data_source_project(&workspace_id, &data_source_id, env).await?
17986
}
18087
None => env.project.clone(),
18188
};

core/src/blocks/database_schema.rs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use crate::blocks::block::{Block, BlockResult, BlockType, Env};
2+
use crate::Rule;
3+
use anyhow::{anyhow, Result};
4+
use async_trait::async_trait;
5+
use futures::future::try_join_all;
6+
use pest::iterators::Pair;
7+
use serde_json::Value;
8+
use tokio::sync::mpsc::UnboundedSender;
9+
10+
use super::helpers::get_data_source_project;
11+
12+
#[derive(Clone)]
13+
pub struct DatabaseSchema {}
14+
15+
impl DatabaseSchema {
16+
pub fn parse(_block_pair: Pair<Rule>) -> Result<Self> {
17+
Ok(DatabaseSchema {})
18+
}
19+
}
20+
21+
#[async_trait]
22+
impl Block for DatabaseSchema {
23+
fn block_type(&self) -> BlockType {
24+
BlockType::DatabaseSchema
25+
}
26+
27+
fn inner_hash(&self) -> String {
28+
let mut hasher = blake3::Hasher::new();
29+
hasher.update("database_schema".as_bytes());
30+
format!("{}", hasher.finalize().to_hex())
31+
}
32+
33+
async fn execute(
34+
&self,
35+
name: &str,
36+
env: &Env,
37+
_event_sender: Option<UnboundedSender<Value>>,
38+
) -> Result<BlockResult> {
39+
let config = env.config.config_for_block(name);
40+
41+
let err_msg = format!(
42+
"Invalid or missing `databases` in configuration for \
43+
`database_schema` block `{}` expecting `{{ \"databases\": \
44+
[ {{ \"workspace_id\": ..., \"data_source_id\": ..., \"database_id\": ... }}, ... ] }}`",
45+
name
46+
);
47+
48+
let databases = match config {
49+
Some(v) => match v.get("databases") {
50+
Some(Value::Array(a)) => a
51+
.iter()
52+
.map(|v| {
53+
let workspace_id = match v.get("workspace_id") {
54+
Some(Value::String(s)) => s,
55+
_ => Err(anyhow!(err_msg.clone()))?,
56+
};
57+
let data_source_id = match v.get("data_source_id") {
58+
Some(Value::String(s)) => s,
59+
_ => Err(anyhow!(err_msg.clone()))?,
60+
};
61+
let database_id = match v.get("database_id") {
62+
Some(Value::String(s)) => s,
63+
_ => Err(anyhow!(err_msg.clone()))?,
64+
};
65+
66+
Ok((workspace_id, data_source_id, database_id))
67+
})
68+
.collect::<Result<Vec<_>>>(),
69+
_ => Err(anyhow!(err_msg)),
70+
},
71+
None => Err(anyhow!(err_msg)),
72+
}?;
73+
74+
let schemas = try_join_all(databases.iter().map(
75+
|(workspace_id, data_source_id, database_id)| {
76+
get_database_schema(workspace_id, data_source_id, database_id, env)
77+
},
78+
))
79+
.await?;
80+
81+
Ok(BlockResult {
82+
value: serde_json::to_value(schemas)?,
83+
meta: None,
84+
})
85+
}
86+
87+
fn clone_box(&self) -> Box<dyn Block + Sync + Send> {
88+
Box::new(self.clone())
89+
}
90+
91+
fn as_any(&self) -> &dyn std::any::Any {
92+
self
93+
}
94+
}
95+
96+
async fn get_database_schema(
97+
workspace_id: &String,
98+
data_source_id: &String,
99+
database_id: &String,
100+
env: &Env,
101+
) -> Result<crate::databases::database::DatabaseSchema> {
102+
let project = get_data_source_project(workspace_id, data_source_id, env).await?;
103+
let database = match env
104+
.store
105+
.load_database(&project, data_source_id, database_id)
106+
.await?
107+
{
108+
Some(d) => d,
109+
None => Err(anyhow!(
110+
"Database `{}` not found in data source `{}`",
111+
database_id,
112+
data_source_id
113+
))?,
114+
};
115+
116+
match database.get_schema(&project, env.store.clone()).await {
117+
Ok(s) => Ok(s),
118+
Err(e) => Err(anyhow!(
119+
"Error getting schema for database `{}` in data source `{}`: {}",
120+
database_id,
121+
data_source_id,
122+
e
123+
)),
124+
}
125+
}

core/src/blocks/helpers.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
use super::block::Env;
2+
use crate::project::Project;
3+
use anyhow::{anyhow, Result};
4+
use hyper::header;
5+
use hyper::{body::Buf, http::StatusCode, Body, Client, Method, Request};
6+
use hyper_tls::HttpsConnector;
7+
use serde_json::Value;
8+
use std::io::prelude::*;
9+
use url::Url;
10+
use urlencoding::encode;
11+
12+
pub async fn get_data_source_project(
13+
workspace_id: &String,
14+
data_source_id: &String,
15+
env: &Env,
16+
) -> Result<Project> {
17+
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
18+
None => Err(anyhow!(
19+
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
20+
is set in `data_source` block config"
21+
))?,
22+
Some(v) => v.clone(),
23+
};
24+
let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
25+
Ok(key) => key,
26+
Err(_) => Err(anyhow!(
27+
"Environment variable `DUST_REGISTRY_SECRET` is not set."
28+
))?,
29+
};
30+
let front_api = match std::env::var("DUST_FRONT_API") {
31+
Ok(key) => key,
32+
Err(_) => Err(anyhow!("Environment variable `DUST_FRONT_API` is not set."))?,
33+
};
34+
35+
let url = format!(
36+
"{}/api/registry/data_sources/lookup?workspace_id={}&data_source_id={}",
37+
front_api.as_str(),
38+
encode(&workspace_id),
39+
encode(&data_source_id),
40+
);
41+
let parsed_url = Url::parse(url.as_str())?;
42+
43+
let mut req = Request::builder().method(Method::GET).uri(url.as_str());
44+
45+
{
46+
let headers = match req.headers_mut() {
47+
Some(h) => h,
48+
None => Err(anyhow!("Invalid URL: {}", url.as_str()))?,
49+
};
50+
headers.insert(
51+
header::AUTHORIZATION,
52+
header::HeaderValue::from_bytes(
53+
format!("Bearer {}", registry_secret.as_str()).as_bytes(),
54+
)?,
55+
);
56+
headers.insert(
57+
header::HeaderName::from_bytes("X-Dust-Workspace-Id".as_bytes())?,
58+
header::HeaderValue::from_bytes(dust_workspace_id.as_bytes())?,
59+
);
60+
}
61+
let req = req.body(Body::empty())?;
62+
63+
let res = match parsed_url.scheme() {
64+
"https" => {
65+
let https = HttpsConnector::new();
66+
let cli = Client::builder().build::<_, hyper::Body>(https);
67+
cli.request(req).await?
68+
}
69+
"http" => {
70+
let cli = Client::new();
71+
cli.request(req).await?
72+
}
73+
_ => Err(anyhow!(
74+
"Only the `http` and `https` schemes are authorized."
75+
))?,
76+
};
77+
78+
let status = res.status();
79+
if status != StatusCode::OK {
80+
Err(anyhow!(
81+
"Failed to retrieve DataSource `{} > {}`",
82+
workspace_id,
83+
data_source_id,
84+
))?;
85+
}
86+
87+
let body = hyper::body::aggregate(res).await?;
88+
let mut b: Vec<u8> = vec![];
89+
body.reader().read_to_end(&mut b)?;
90+
91+
let response_body = String::from_utf8_lossy(&b).into_owned();
92+
93+
let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
94+
Ok(body) => body,
95+
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
96+
};
97+
98+
match body.get("project_id") {
99+
Some(Value::Number(p)) => match p.as_i64() {
100+
Some(p) => Ok(Project::new_from_id(p)),
101+
None => Err(anyhow!("Failed to parse registry response")),
102+
},
103+
_ => Err(anyhow!("Failed to parse registry response")),
104+
}
105+
}

core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ pub mod blocks {
4848
pub mod curl;
4949
pub mod data;
5050
pub mod data_source;
51+
pub mod database_schema;
5152
pub mod end;
53+
pub mod helpers;
5254
pub mod input;
5355
pub mod llm;
5456
pub mod map;

0 commit comments

Comments
 (0)