Skip to content

Commit 41c17d6

Browse files
oualibmail4umar
andauthored
Update sql_magic.py (#780)
* Update sql_magic.py - correcting sql_magic bugs * correcting problem with store procedures * fix black * Update sql_magic.py * black * Update sql_magic.py --------- Co-authored-by: umar <46414488+mail4umar@users.noreply.github.com>
1 parent b592d53 commit 41c17d6

File tree

2 files changed

+115
-15
lines changed

2 files changed

+115
-15
lines changed

verticapy/_utils/_sql/_check.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ def is_dql(query: str) -> bool:
3535
query = erase_comment(query)
3636
for idx, q in enumerate(query):
3737
if q not in (" ", "("):
38-
result = query[idx:].lower().startswith(("select ", "with "))
38+
result = (
39+
query[idx:]
40+
.lower()
41+
.startswith(
42+
(
43+
"select ",
44+
"with ",
45+
)
46+
)
47+
)
48+
break
49+
return result
50+
51+
52+
def is_procedure(query: str) -> bool:
53+
"""
54+
Returns True if the input SQL query
55+
is a procedure.
56+
"""
57+
result = False
58+
query = clean_query(query)
59+
query = erase_comment(query)
60+
for idx, q in enumerate(query):
61+
if q not in (" ", "("):
62+
result = (
63+
query[idx:]
64+
.lower()
65+
.startswith(
66+
(
67+
"create procedure ",
68+
"create or replace procedure ",
69+
)
70+
)
71+
)
3972
break
4073
return result

verticapy/jupyter/extensions/sql_magic.py

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@
3434
import verticapy._config.config as conf
3535
from verticapy._utils._object import create_new_vdf
3636
from verticapy._utils._sql._collect import save_verticapy_logs
37+
from verticapy._utils._sql._check import is_procedure
3738
from verticapy._utils._sql._dblink import replace_external_queries
3839
from verticapy._utils._sql._format import (
3940
clean_query,
4041
replace_vars_in_query,
4142
)
4243
from verticapy._utils._sql._sys import _executeSQL
44+
from verticapy.connection import current_cursor
4345
from verticapy.connection.global_connection import get_global_connection
4446
from verticapy.errors import QueryError
4547

@@ -48,6 +50,42 @@
4850
if TYPE_CHECKING:
4951
from verticapy.core.vdataframe.base import vDataFrame
5052

53+
SPECIAL_WORDS = (
54+
# ML Algos
55+
"ARIMA",
56+
"AUTOREGRESSOR",
57+
"BALANCE",
58+
"BISECTING_KMEANS",
59+
"CROSS_VALIDATE",
60+
"DETECT_OUTLIERS",
61+
"IFOREST",
62+
"IMPUTE",
63+
"KMEANS",
64+
"KPROTOTYPES",
65+
"LINEAR_REG",
66+
"LOGISTIC_REG",
67+
"MOVING_AVERAGE",
68+
"NAIVE_BAYES",
69+
"NORMALIZE",
70+
"NORMALIZE_FIT",
71+
"ONE_HOT_ENCODER_FIT",
72+
"PCA",
73+
"POISSON_REG",
74+
"RF_CLASSIFIER",
75+
"RF_REGRESSOR",
76+
"SVD",
77+
"SVM_CLASSIFIER",
78+
"SVM_REGRESSOR",
79+
"XGB_CLASSIFIER",
80+
"XGB_REGRESSOR",
81+
# ML Management
82+
"CHANGE_MODEL_STATUS",
83+
"EXPORT_MODELS",
84+
"IMPORT_MODELS",
85+
"REGISTER_MODEL",
86+
"UPGRADE_MODEL",
87+
)
88+
5189

5290
@save_verticapy_logs
5391
@needs_local_scope
@@ -743,6 +781,12 @@ def sql_magic(
743781
elif "-c" in options:
744782
queries = options["-c"]
745783

784+
# Case when it is a procedure
785+
if is_procedure(queries):
786+
current_cursor().execute(queries)
787+
print("CREATE")
788+
return
789+
746790
# Cleaning the Query
747791
queries = clean_query(queries)
748792
queries = replace_vars_in_query(queries, locals()["local_ns"])
@@ -816,11 +860,14 @@ def sql_magic(
816860
for i in range(n):
817861
query = queries[i]
818862

819-
if query.split(" ")[0]:
820-
query_type = query.split(" ")[0].upper().replace("(", "")
863+
query_words = query.split(" ")
821864

865+
idx = 0 if query_words[0] else 1
866+
query_type = query_words[idx].upper().replace("(", "")
867+
if len(query_words) > 1:
868+
query_subtype = query_words[idx + 1].upper()
822869
else:
823-
query_type = query.split(" ")[1].upper().replace("(", "")
870+
query_subtype = "UNDEFINED"
824871

825872
if len(query_type) > 1 and query_type.startswith(("/*", "--")):
826873
query_type = "undefined"
@@ -843,7 +890,7 @@ def sql_magic(
843890

844891
elif (i < n - 1) or (
845892
(i == n - 1)
846-
and (query_type.lower() not in ("select", "with", "undefined"))
893+
and (query_type.lower() not in ("select", "show", "with", "undefined"))
847894
):
848895
error = ""
849896

@@ -869,25 +916,45 @@ def sql_magic(
869916
else:
870917
error = ""
871918

872-
try:
919+
if query_type.lower() in ("show",):
920+
final_result = _executeSQL(
921+
query, method="fetchall", print_time_sql=False
922+
)
923+
columns = [d.name for d in current_cursor().description]
873924
result = create_new_vdf(
874-
query,
875-
_is_sql_magic=True,
925+
final_result,
926+
usecols=columns,
876927
)
877-
result._vars["sql_magic_result"] = True
878-
# Display parameters
879-
if "-nrows" in options:
880-
result._vars["max_rows"] = options["-nrows"]
881-
if "-ncols" in options:
882-
result._vars["max_columns"] = options["-ncols"]
928+
continue
883929

884-
except:
930+
is_vdf = False
931+
if not (query_subtype.upper().startswith(SPECIAL_WORDS)):
932+
try:
933+
result = create_new_vdf(
934+
query,
935+
_is_sql_magic=True,
936+
)
937+
result._vars["sql_magic_result"] = True
938+
# Display parameters
939+
if "-nrows" in options:
940+
result._vars["max_rows"] = options["-nrows"]
941+
if "-ncols" in options:
942+
result._vars["max_columns"] = options["-ncols"]
943+
is_vdf = True
944+
except:
945+
pass # we could not create a vDataFrame out of the query.
946+
947+
if not (is_vdf):
885948
try:
886949
final_result = _executeSQL(
887950
query, method="fetchfirstelem", print_time_sql=False
888951
)
889952
if final_result and conf.get_option("print_info"):
890953
print(final_result)
954+
elif (
955+
query_subtype.upper().startswith(SPECIAL_WORDS)
956+
) and conf.get_option("print_info"):
957+
print(query_subtype.upper())
891958
elif conf.get_option("print_info"):
892959
print(query_type)
893960

0 commit comments

Comments
 (0)