22import decimal
33import importlib
44import logging
5+ import os
56import re
67import sqlalchemy
78import sqlparse
89import sys
10+ import termcolor
911import warnings
1012
1113
@@ -22,12 +24,52 @@ def __init__(self, url, **kwargs):
2224 http://docs.sqlalchemy.org/en/latest/dialects/index.html
2325 """
2426
25- # log statements to standard error
27+ # Require that file already exist for SQLite
28+ matches = re .search (r"^sqlite:///(.+)$" , url )
29+ if matches :
30+ if not os .path .exists (matches .group (1 )):
31+ raise RuntimeError ("does not exist: {}" .format (matches .group (1 )))
32+ if not os .path .isfile (matches .group (1 )):
33+ raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
34+
35+ # Create engine, raising exception if back end's module not installed
36+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
37+
38+ # Log statements to standard error
2639 logging .basicConfig (level = logging .DEBUG )
2740 self .logger = logging .getLogger ("cs50" )
2841
29- # create engine, raising exception if back end's module not installed
30- self .engine = sqlalchemy .create_engine (url , ** kwargs )
42+ # Test database
43+ try :
44+ self .logger .disabled = True
45+ self .execute ("SELECT 1" )
46+ except sqlalchemy .exc .OperationalError as e :
47+ e = RuntimeError (self ._parse (e ))
48+ e .__cause__ = None
49+ raise e
50+ else :
51+ self .logger .disabled = False
52+
53+ def _parse (self , e ):
54+ """Parses an exception, returns its message."""
55+
56+ # MySQL
57+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
58+ if matches :
59+ return matches .group (1 )
60+
61+ # PostgreSQL
62+ matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
63+ if matches :
64+ return matches .group (1 )
65+
66+ # SQLite
67+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
68+ if matches :
69+ return matches .group (1 )
70+
71+ # Default
72+ return str (e )
3173
3274 def execute (self , text , ** params ):
3375 """
@@ -81,77 +123,91 @@ def process(value):
81123 elif isinstance (value , sqlalchemy .sql .elements .Null ):
82124 return sqlalchemy .types .NullType ().literal_processor (dialect )(value )
83125
84- # unsupported value
126+ # Unsupported value
85127 raise RuntimeError ("unsupported value" )
86128
87- # process value(s), separating with commas as needed
129+ # Process value(s), separating with commas as needed
88130 if type (value ) is list :
89131 return ", " .join ([process (v ) for v in value ])
90132 else :
91133 return process (value )
92134
93- # allow only one statement at a time
135+ # Allow only one statement at a time
94136 if len (sqlparse .split (text )) > 1 :
95137 raise RuntimeError ("too many statements at once" )
96138
97- # raise exceptions for warnings
139+ # Raise exceptions for warnings
98140 warnings .filterwarnings ("error" )
99141
100- # prepare , execute statement
142+ # Prepare , execute statement
101143 try :
102144
103- # construct a new TextClause clause
145+ # Construct a new TextClause clause
104146 statement = sqlalchemy .text (text )
105147
106- # iterate over parameters
148+ # Iterate over parameters
107149 for key , value in params .items ():
108150
109- # translate None to NULL
151+ # Translate None to NULL
110152 if value is None :
111153 value = sqlalchemy .sql .null ()
112154
113- # bind parameters before statement reaches database, so that bound parameters appear in exceptions
155+ # Bind parameters before statement reaches database, so that bound parameters appear in exceptions
114156 # http://docs.sqlalchemy.org/en/latest/core/sqlelement.html#sqlalchemy.sql.expression.text
115157 statement = statement .bindparams (sqlalchemy .bindparam (
116158 key , value = value , type_ = UserDefinedType ()))
117159
118- # stringify bound parameters
160+ # Stringify bound parameters
119161 # http://docs.sqlalchemy.org/en/latest/faq/sqlexpressions.html#how-do-i-render-sql-expressions-as-strings-possibly-with-bound-parameters-inlined
120162 statement = str (statement .compile (compile_kwargs = {"literal_binds" : True }))
121163
122- # execute statement
123- result = self . engine . execute (statement )
164+ # Statement for logging
165+ log = re . sub ( r"\n\s*" , " " , sqlparse . format (statement , reindent = True ) )
124166
125- # log statement
126- self .logger . debug ( re . sub ( r"\n\s*" , " " , sqlparse . format ( statement , reindent = True )) )
167+ # Execute statement
168+ result = self .engine . execute ( statement )
127169
128- # if SELECT (or INSERT with RETURNING), return result set as list of dict objects
170+ # If SELECT (or INSERT with RETURNING), return result set as list of dict objects
129171 if re .search (r"^\s*SELECT" , statement , re .I ):
130172
131- # coerce any decimal.Decimal objects to float objects
173+ # Coerce any decimal.Decimal objects to float objects
132174 # https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
133175 rows = [dict (row ) for row in result .fetchall ()]
134176 for row in rows :
135177 for column in row :
136178 if isinstance (row [column ], decimal .Decimal ):
137179 row [column ] = float (row [column ])
138- return rows
180+ ret = rows
139181
140- # if INSERT, return primary key value for a newly inserted row
182+ # If INSERT, return primary key value for a newly inserted row
141183 elif re .search (r"^\s*INSERT" , statement , re .I ):
142184 if self .engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
143185 result = self .engine .execute (sqlalchemy .text ("SELECT LASTVAL()" ))
144- return result .first ()[0 ]
186+ ret = result .first ()[0 ]
145187 else :
146- return result .lastrowid
188+ ret = result .lastrowid
147189
148- # if DELETE or UPDATE, return number of rows matched
190+ # If DELETE or UPDATE, return number of rows matched
149191 elif re .search (r"^\s*(?:DELETE|UPDATE)" , statement , re .I ):
150- return result .rowcount
192+ ret = result .rowcount
151193
152- # if some other statement, return True unless exception
153- return True
194+ # If some other statement, return True unless exception
195+ else :
196+ ret = True
154197
155- # if constraint violated, return None
198+ # If constraint violated, return None
156199 except sqlalchemy .exc .IntegrityError :
200+ self .logger .debug (termcolor .colored (log , "yellow" ))
157201 return None
202+
203+ # If user errror
204+ except sqlalchemy .exc .OperationalError as e :
205+ self .logger .debug (termcolor .colored (log , "red" ))
206+ e = RuntimeError (self ._parse (e ))
207+ e .__cause__ = None
208+ raise e
209+
210+ # Return value
211+ else :
212+ self .logger .debug (termcolor .colored (log , "green" ))
213+ return ret
0 commit comments