Skip to content

Commit 91a6d1a

Browse files
authored
feat: implement schema from rows (#2231)
1 parent 7d7891d commit 91a6d1a

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

core/src/databases/table_schema.rs

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
use std::collections::HashMap;
2+
3+
use anyhow::{anyhow, Result};
4+
5+
use serde::{Deserialize, Serialize};
6+
use serde_json::Value;
7+
8+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
9+
#[serde(rename_all = "lowercase")]
10+
pub enum TableSchemaFieldType {
11+
Int,
12+
Float,
13+
Text,
14+
Bool,
15+
}
16+
17+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
18+
pub struct TableSchema(HashMap<String, TableSchemaFieldType>);
19+
20+
impl TableSchema {
21+
pub fn from_rows(rows: Vec<Value>) -> Result<Self> {
22+
let mut schema = HashMap::new();
23+
24+
for (row_index, row) in rows.iter().enumerate() {
25+
let object = row
26+
.as_object()
27+
.ok_or_else(|| anyhow!("Row {} is not an object", row_index))?;
28+
29+
for (k, v) in object {
30+
if v.is_null() {
31+
continue;
32+
}
33+
34+
let value_type = match v {
35+
Value::Bool(_) => TableSchemaFieldType::Bool,
36+
Value::Number(x) => {
37+
if x.is_i64() {
38+
TableSchemaFieldType::Int
39+
} else {
40+
TableSchemaFieldType::Float
41+
}
42+
}
43+
Value::String(_) | Value::Object(_) | Value::Array(_) => {
44+
TableSchemaFieldType::Text
45+
}
46+
_ => unreachable!(),
47+
};
48+
49+
if let Some(existing_type) = schema.get(k) {
50+
if existing_type != &value_type {
51+
return Err(anyhow!(
52+
"Field {} has conflicting types on row {}: {:?} and {:?}",
53+
k,
54+
row_index,
55+
existing_type,
56+
value_type
57+
));
58+
}
59+
} else {
60+
schema.insert(k.clone(), value_type);
61+
}
62+
}
63+
}
64+
65+
Ok(Self(schema))
66+
}
67+
68+
pub fn to_sql_string(&self, table_name: &str) -> String {
69+
let mut create_table = format!("CREATE TABLE {} (", table_name);
70+
71+
for (name, field_type) in &self.0 {
72+
let sql_type = match field_type {
73+
TableSchemaFieldType::Int => "INT",
74+
TableSchemaFieldType::Float => "REAL",
75+
TableSchemaFieldType::Text => "TEXT",
76+
TableSchemaFieldType::Bool => "BOOLEAN",
77+
};
78+
79+
create_table.push_str(&format!("{} {}, ", name, sql_type));
80+
}
81+
82+
// Remove the trailing comma and space, then close the parentheses.
83+
let len = create_table.len();
84+
create_table.truncate(len - 2);
85+
create_table.push_str(");");
86+
87+
create_table
88+
}
89+
}
90+
91+
#[cfg(test)]
92+
mod tests {
93+
use super::*;
94+
use rusqlite::Connection;
95+
use serde_json::json;
96+
97+
#[test]
98+
fn test_table_schema_from_rows() -> Result<()> {
99+
let rows = vec![
100+
json!({
101+
"field1": 1,
102+
"field2": 1.2,
103+
"field3": "text",
104+
"field4": true,
105+
"field6": ["array", "elements"],
106+
"field7": {"key": "value"}
107+
}),
108+
json!({
109+
"field1": 2,
110+
"field2": 2.4,
111+
"field3": "more text",
112+
"field4": false,
113+
"field5": "not null anymore",
114+
"field6": ["more", "elements"],
115+
"field7": {"anotherKey": "anotherValue"}
116+
}),
117+
];
118+
119+
let schema = TableSchema::from_rows(rows)?;
120+
let expected_map: HashMap<String, TableSchemaFieldType> = [
121+
("field1", TableSchemaFieldType::Int),
122+
("field2", TableSchemaFieldType::Float),
123+
("field3", TableSchemaFieldType::Text),
124+
("field4", TableSchemaFieldType::Bool),
125+
("field5", TableSchemaFieldType::Text),
126+
("field6", TableSchemaFieldType::Text),
127+
("field7", TableSchemaFieldType::Text),
128+
]
129+
.iter()
130+
.map(|(field_id, field_type)| (field_id.to_string(), field_type.clone()))
131+
.collect();
132+
133+
let expected_schema = TableSchema(expected_map);
134+
135+
assert_eq!(schema, expected_schema);
136+
137+
Ok(())
138+
}
139+
140+
#[test]
141+
fn test_table_schema_from_rows_conflicting_types() {
142+
let rows = vec![
143+
json!({
144+
"field1": 1,
145+
"field2": 1.2,
146+
"field3": "text",
147+
"field4": true,
148+
"field6": ["array", "elements"],
149+
"field7": {"key": "value"}
150+
}),
151+
json!({
152+
"field1": 2,
153+
"field2": 2.4,
154+
"field3": "more text",
155+
"field4": "this was a bool before",
156+
"field5": "not null anymore",
157+
"field6": ["more", "elements"],
158+
"field7": {"anotherKey": "anotherValue"}
159+
}),
160+
json!({
161+
"field1": "now it's a text field",
162+
}),
163+
];
164+
165+
let schema = TableSchema::from_rows(rows);
166+
167+
assert!(
168+
schema.is_err(),
169+
"Schema should have failed due to conflicting types."
170+
);
171+
}
172+
173+
#[test]
174+
fn test_table_schema_from_empty_rows() {
175+
let rows = vec![];
176+
177+
let schema = TableSchema::from_rows(rows);
178+
179+
assert!(schema.is_ok(), "Schema from empty rows should be valid.");
180+
}
181+
182+
#[test]
183+
fn test_table_schema_to_string() -> Result<()> {
184+
let schema_map: HashMap<String, TableSchemaFieldType> = [
185+
("field1", TableSchemaFieldType::Int),
186+
("field2", TableSchemaFieldType::Float),
187+
("field3", TableSchemaFieldType::Text),
188+
("field4", TableSchemaFieldType::Bool),
189+
]
190+
.iter()
191+
.map(|(field_id, field_type)| (field_id.to_string(), field_type.clone()))
192+
.collect();
193+
194+
let schema = TableSchema(schema_map);
195+
196+
let sql = schema.to_sql_string("test_table");
197+
198+
println!("{}", sql);
199+
200+
let conn = Connection::open_in_memory()?;
201+
conn.execute(&sql, [])?;
202+
203+
let mut stmt = conn.prepare("PRAGMA table_info(test_table);")?;
204+
let rows = stmt.query_map([], |row| Ok((row.get(1)?, row.get(2)?)))?;
205+
206+
let mut actual_schema: HashMap<String, String> = HashMap::new();
207+
for row in rows {
208+
let (name, ty): (String, String) = row?;
209+
actual_schema.insert(name, ty);
210+
}
211+
212+
assert_eq!(actual_schema["field1"], "INT");
213+
assert_eq!(actual_schema["field2"], "REAL");
214+
assert_eq!(actual_schema["field3"], "TEXT");
215+
assert_eq!(actual_schema["field4"], "BOOLEAN");
216+
217+
Ok(())
218+
}
219+
}

core/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ pub mod data_sources {
1515
pub mod data_source;
1616
pub mod splitter;
1717
}
18+
pub mod databases {
19+
pub mod table_schema;
20+
}
1821
pub mod project;
1922
pub mod run;
2023
pub mod utils;

0 commit comments

Comments
 (0)