55import os
66import re
77import sqlalchemy
8+ import sqlite3
89import sqlparse
910import sys
1011import termcolor
@@ -32,12 +33,25 @@ def __init__(self, url, **kwargs):
3233 if not os .path .isfile (matches .group (1 )):
3334 raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
3435
35- # Create engine, raising exception if back end's module not installed
36- self .engine = sqlalchemy .create_engine (url , ** kwargs )
36+ # Remember foreign_keys and remove it from kwargs
37+ foreign_keys = kwargs .pop ("foreign_keys" , False )
38+
39+ # Create engine, raising exception if back end's module not installed
40+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
41+
42+ # Enable foreign key constraints
43+ if foreign_keys :
44+ sqlalchemy .event .listen (self .engine , "connect" , _connect )
45+ else :
46+
47+ # Create engine, raising exception if back end's module not installed
48+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
49+
3750
3851 # Log statements to standard error
3952 logging .basicConfig (level = logging .DEBUG )
4053 self .logger = logging .getLogger ("cs50" )
54+ disabled = self .logger .disabled
4155
4256 # Test database
4357 try :
@@ -48,7 +62,7 @@ def __init__(self, url, **kwargs):
4862 e .__cause__ = None
4963 raise e
5064 else :
51- self .logger .disabled = False
65+ self .logger .disabled = disabled
5266
5367 def _parse (self , e ):
5468 """Parses an exception, returns its message."""
@@ -133,6 +147,8 @@ def process(value):
133147 return process (value )
134148
135149 # Allow only one statement at a time
150+ # SQLite does not support executing many statements
151+ # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
136152 if len (sqlparse .split (text )) > 1 :
137153 raise RuntimeError ("too many statements at once" )
138154
@@ -211,3 +227,16 @@ def process(value):
211227 else :
212228 self .logger .debug (termcolor .colored (log , "green" ))
213229 return ret
230+
231+
232+ # http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
233+ def _connect (dbapi_connection , connection_record ):
234+ """Enables foreign key support."""
235+
236+ # Ensure backend is sqlite
237+ if type (dbapi_connection ) is sqlite3 .Connection :
238+ cursor = dbapi_connection .cursor ()
239+
240+ # Respect foreign key constraints by default
241+ cursor .execute ("PRAGMA foreign_keys=ON" )
242+ cursor .close ()
0 commit comments