Skip to content

Commit

Permalink
version bump to 0.35; add extract relations
Browse files Browse the repository at this point in the history
  • Loading branch information
wseaton committed Jul 3, 2023
1 parent 1c23624 commit edb19fc
Show file tree
Hide file tree
Showing 9 changed files with 334 additions and 290 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.33
current_version = 0.1.35
commit = True
tag = True

Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# 0.1.35

- Added `extract_relations` function to assist in extracting table references from the AST in Rust.
10 changes: 5 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "sqloxide"
version = "0.1.33"
version = "0.1.35"
authors = ["Will Eaton <me@wseaton.com>"]
edition = "2018"

Expand All @@ -9,12 +9,12 @@ name = "sqloxide"
crate-type = ["cdylib"]

[dependencies]
pythonize = "0.17"
pythonize = "0.19"

[dependencies.pyo3]
version = "0.17.1"
version = "0.19.0"
features = ["extension-module"]

[dependencies.sqlparser]
version = "0.33.0"
features = ["json_example"]
version = "0.35.0"
features = ["serde", "visitor"]
527 changes: 253 additions & 274 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sqloxide"
version = "0.1.33"
version = "0.1.35"
repository = "https://github.com/wseaton/sqloxide"
license = "MIT"
description = "Python bindings for sqlparser-rs"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup_kwargs = {
"name": "sqloxide",
"version": "0.1.33",
"version": "0.1.35",
"description": "Python bindings for sqlparser-rs",
"long_description": open("readme.md").read(),
"long_description_content_type": "text/markdown",
Expand Down
4 changes: 2 additions & 2 deletions sqloxide/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .sqloxide import parse_sql
from .sqloxide import parse_sql, extract_relations

__all__ = ["parse_sql"]
__all__ = ["parse_sql", "extract_relations"]
50 changes: 49 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pythonize::PythonizeError;
use sqlparser::ast::Statement;

use core::ops::ControlFlow;
use pythonize::pythonize;

use sqlparser::ast::visit_relations;
use sqlparser::dialect::*;
use sqlparser::parser::Parser;

Expand Down Expand Up @@ -65,8 +68,53 @@ fn parse_sql(py: Python, sql: &str, dialect: &str) -> PyResult<PyObject> {
Ok(output)
}

///
/// Function to extract relations from a parsed query.
/// Returns a nested list of relations, one list per query statement.
///
/// Example:
/// ```python
/// from sqloxide import parse_sql, extract_relations
///
/// sql = "SELECT * FROM table1 JOIN table2 ON table1.id = table2.id"
/// parsed_query = parse_sql(sql, "generic")
/// relations = extract_relations(parsed_query)
/// print(relations)
/// ```
///
#[pyfunction]
#[pyo3(text_signature = "(parsed_query)")]
fn extract_relations(py: Python, parsed_query: &PyAny) -> PyResult<PyObject> {
let parse_result: Result<Vec<Statement>, PythonizeError> = pythonize::depythonize(parsed_query);

let mut relations = Vec::new();

match parse_result {
Ok(statements) => {
for statement in statements {
visit_relations(&statement, |relation| {
relations.push(relation.clone());
ControlFlow::<()>::Continue(())
});
}
}
Err(_e) => {
let msg = _e.to_string();
return Err(PyValueError::new_err(format!(
"Query serialization failed.\n\t{}",
msg
)));
}
};

let output = pythonize(py, &relations).expect("Internal python deserialization failed.");

Ok(output)
}

#[pymodule]
fn sqloxide(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(parse_sql, m)?)?;
m.add_function(wrap_pyfunction!(extract_relations, m)?)?;
Ok(())
}
24 changes: 19 additions & 5 deletions tests/test_sqloxide.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from sqloxide import parse_sql
from sqloxide import parse_sql, extract_relations


def test_parse_sql():

sql = """
SELECT employee.first_name, employee.last_name,
call.start_time, call.end_time, call_outcome.outcome_text
Expand All @@ -23,9 +22,24 @@ def test_parse_sql():


def test_throw_exception():

sql = """
SELECT $# as 1;
"""
with pytest.raises(ValueError, match=r"Query parsing failed.\n\tsql parser error: .+"):
ast = parse_sql(sql=sql, dialect="ansi")[0]
with pytest.raises(
ValueError, match=r"Query parsing failed.\n\tsql parser error: .+"
):
_ast = parse_sql(sql=sql, dialect="ansi")[0]


def test_extract_relations():
sql = """
SELECT employee.first_name, employee.last_name,
call.start_time, call.end_time, call_outcome.outcome_text
FROM employee
INNER JOIN call ON call.employee_id = employee.id
INNER JOIN call_outcome ON call.call_outcome_id = call_outcome.id
ORDER BY call.start_time ASC;
"""

ast = parse_sql(sql=sql, dialect="ansi")
print(extract_relations(parsed_query=ast))

0 comments on commit edb19fc

Please sign in to comment.