Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dialect selector #391

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inject_vars():
"""Inject arbitrary data into all templates."""
return dict(
all_rules=config.VALID_RULES,
all_dialects=config.VALID_DIALECTS,
all_dialects=list(config.VALID_DIALECTS.values()),
sqlfluff_version=config.SQLFLUFF_VERSION,
)

Expand Down
2 changes: 1 addition & 1 deletion src/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

SQLFLUFF_VERSION = sqlfluff.__version__

VALID_DIALECTS = tuple(d.name for d in sqlfluff.list_dialects())
VALID_DIALECTS = {d.label: d.name for d in sqlfluff.list_dialects()}

# dict mapping string rule names to descriptions
VALID_RULES = {r.code: r.description for r in sqlfluff.list_rules()}
22 changes: 19 additions & 3 deletions src/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from flask import Blueprint, redirect, render_template, request, url_for
from sqlfluff.api import fix, lint
from .config import VALID_DIALECTS

bp = Blueprint("routes", __name__)

Expand Down Expand Up @@ -43,10 +44,25 @@ def fluff_results():
sql = sql_decode(request.args["sql"]).strip()
sql = "\n".join(sql.splitlines()) + "\n"

# dialect must be a dialect label for `load_raw_dialect`. VALID_DIALECTS is a
# dictionary of dialect labels to dialect names. If we have a name, we need to
# get the label.
#
# However, the frontend logic runs on dialect names, so we need to convert the
# label back to a name for the frontend.
dialect = request.args["dialect"]
if dialect in VALID_DIALECTS.values():
dialect_name = dialect
dialect_label = next(
label for label, name in VALID_DIALECTS.items() if name == dialect
)
else:
dialect_label = dialect
dialect_name = VALID_DIALECTS[dialect]

try:
linted = lint(sql, dialect=dialect)
fixed_sql = fix(sql, dialect=dialect)
linted = lint(sql, dialect=dialect_label)
fixed_sql = fix(sql, dialect=dialect_label)
except RuntimeError as e:
linted = [
{
Expand All @@ -61,7 +77,7 @@ def fluff_results():
"index.html",
results=True,
sql=sql,
dialect=dialect,
dialect=dialect_name,
lint_errors=linted,
fixed_sql=fixed_sql,
)
19 changes: 16 additions & 3 deletions test/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,28 @@ def test_post_redirect(client):
assert rv.status_code == 302 and "/fluffed?sql" in rv.headers["location"]


def test_results_no_errors(client):
"""Test that the results is good to go when there is no error."""
@pytest.mark.parametrize("dialect", ["sparksql", "Apache Spark SQL"])
def test_results_no_errors(client, dialect):
"""Test that the results is good to go when there is no error.

Parameterized dialect asserts that either the formatted name or label can be used
as the dialect parameter.
"""
sql_encoded = sql_encode("select * from table")
rv = client.get("/fluffed", query_string=f"""dialect=ansi&sql={sql_encoded}""")
rv = client.get("/fluffed", query_string=f"""dialect={dialect}&sql={sql_encoded}""")
html = rv.data.decode().lower()
assert "sqlfluff online" in html
assert "fixed sql" in html
assert "select * from table" in html

# Test that the dialect is correctly selected in the results page.
selected_dialect = (
BeautifulSoup(html, "html.parser")
.find("select", {"id": "sql_dialect"})
.find("option", {"selected": "selected"})
)
assert selected_dialect.text.strip() == "apache spark sql"


def test_results_some_errors(client):
"""Test that the results is good to go with one obvious error."""
Expand Down