diff --git a/README.md b/README.md index 88d810b..038db9a 100644 --- a/README.md +++ b/README.md @@ -24,8 +24,8 @@ errors of the SQL. - arch: sudo pacman -S postgresql-libs ###Getting PgSanity -PgSanity is available in the Python Package Index, so you can install it with either easy_install or pip. Here's [PgSanity's page on PyPI](http://pypi.python.org/pypi/pgsanity). -- sudo pip install pgsanity **or** sudo easy_install pgsanity +PgSanity is available in the Python Package Index, so you can install it with either easy\_install or pip. Here's [PgSanity's page on PyPI](http://pypi.python.org/pypi/pgsanity). +- sudo pip install pgsanity **or** sudo easy\_install pgsanity - If you don't have pip you can get it on ubuntu/debian by running: sudo apt-get install python-pip ##Usage diff --git a/pgsanity/pgsanity.py b/pgsanity/pgsanity.py index c620863..262bafb 100755 --- a/pgsanity/pgsanity.py +++ b/pgsanity/pgsanity.py @@ -2,6 +2,8 @@ from __future__ import print_function from __future__ import absolute_import +from chardet import detect +from codecs import BOM_UTF8 import argparse import sys @@ -13,6 +15,27 @@ def get_config(argv=sys.argv[1:]): parser.add_argument('files', nargs='*', default=None) return parser.parse_args(argv) +def remove_bom_if_exists(sql_string): + """ Take the entire SQL-payload of a file (or stream) and strip the BOM-table + if one was detected, returning it along with the detected encoding. + + sql_string -- string-representation of incoming character-data. Value + should be passed RAW, meaning BEFORE regular decoding take + place. Otherwise, BOM-detection may fail. + + Returns a BOM-free SQL-payload. + """ + encoding = detect(sql_string[:10000])["encoding"] # HACK + is_utf8 = encoding in ["UTF-8","UTF-8-SIG"] # * + bom_present = is_utf8 and sql_string.startswith(BOM_UTF8) # * + sql_string = sql_string[len(BOM_UTF8):] if bom_present else sql_string + return sql_string, encoding + # * The marked lines above are a tiny bit redundant given that 'UTF-8-SIG' + # simply means "UTF-8 file with a BOM-table". However, older versions of + # chardet don't support this, and will just detect 'UTF-8', leaving us to + # check for the BOM ourselves as we do above. The extra check is not + # harmful on systems that have a more recent chardet module. + def check_file(filename=None, show_filename=False): """ Check whether an input file is valid PostgreSQL. If no filename is @@ -27,8 +50,8 @@ def check_file(filename=None, show_filename=False): else: with sys.stdin as filelike: sql_string = sys.stdin.read() - - success, msg = check_string(sql_string) + sql_string, encoding = remove_bom_if_exists(sql_string) + success, msg = check_string(sql_string.decode(encoding)) # report results result = 0 diff --git a/pgsanity/sqlprep.py b/pgsanity/sqlprep.py index 5c7d95b..2157c74 100644 --- a/pgsanity/sqlprep.py +++ b/pgsanity/sqlprep.py @@ -1,4 +1,6 @@ import re +from collections import OrderedDict + try: from cStringIO import StringIO except ImportError: @@ -7,76 +9,150 @@ def prepare_sql(sql): results = StringIO() - in_statement = False - in_line_comment = False - in_block_comment = False - for (start, end, contents) in split_sql(sql): - precontents = None - start_str = None - - # decide where we are - if not in_statement and not in_line_comment and not in_block_comment: - # not currently in any block - if start != "--" and start != "/*" and len(contents.strip()) > 0: - # not starting a comment and there is contents - in_statement = True - precontents = "EXEC SQL " - - if start == "/*": - in_block_comment = True - elif start == "--" and not in_block_comment: - in_line_comment = True - if not in_statement: - start_str = "//" - - start_str = start_str or start or "" - precontents = precontents or "" - results.write(start_str + precontents + contents) - - if not in_line_comment and not in_block_comment and in_statement and end == ";": - in_statement = False - - if in_block_comment and end == "*/": - in_block_comment = False - - if in_line_comment and end == "\n": - in_line_comment = False + for current_sql_expression in split_sql(sql): + assert(current_sql_expression[-1] == ';') + results.write("EXEC SQL " + current_sql_expression) response = results.getvalue() results.close() return response +def get_processing_state(current_state, current_token): + """determine the current state of processing an SQL-string. + + current_state -- see 'States' further down in this dcostring + + current_token -- any character or character-pair which can prompt one + or more transitions in SQL-state (quote-marks, + comment-starting symbols, etc.) + NOTE: For both double-quote and single-quote + characters, the passed-in token should consist of the + initial quote character, plus the character which + immediately follows it, because it is not possible + to determine the next state without it. + + return: state symbol. + + States: + + _ -- the base state wherein SQL tokens, commands, and math and + other operators occur. This is the initial processesing state + in which the machine starts off + + /* -- block-comment state. In block-comments, no SQL actually + occurs, meaning special characters like quotes and semicolons + have no effect + + $$ -- extended-string state. In extended strings, all characters + are interpreted as string-data, meaning SQL-commands, + operators, etc. have no effect + + -- -- line-comment state. All characters are ignored and not + treated as SQL except for '\n', which is the only character + that prompts a transition out of this state + + ; -- the final state which indicates a single, complete + SQL-statement has just been completed + + ' -- single-quote state. In this state, no characters are treated + as SQL. The only transition away is "'" followed by any + character other than "'" + + '2 -- single-quote pre-exit state. Identical to the single-quote + state except that encountering a character other than "'" + causes the current single-quoted string to be closed + + " -- double-quote state. Similar in nature to the single-quote + state, except that possible transition away is intiated + by '"' instead of "'". + + "2 -- double-quote pre-exit state. Similar in nature to the single- + quote pre-exit state except that '"' prompts a return back to + the stable double-quote state, rather than "'" + """ + transitions = { + '_': { + 0: '_', '/*' : '/*', '--': '--', '$$': '$$', + "'": "'", '"': '"', ';': ';', "''": "'2", '""': '"2' + }, + "'": {0: "'", "'": "'2"}, + "'2": {0: "_", "'": "'", ';': ';'}, + '"': {0: '"', '"': '"2'}, + '"2': {0: '_', '"': '"', ';': ';'}, + '--': {0: '--', '\n':'_'}, + '/*': {0: '/*', '*/':'_'}, + '$$': {0: '$$', '$$': '_'}, + } + # ^ Above, transitions[current_state][0] represents the transition to take + # if no transition is explicitly defined for the passed-in character + if current_state not in transitions: + raise ValueError("Received an invalid state '{}'".format(current_state)) + if current_token in transitions[current_state]: + return transitions[current_state][current_token] + elif current_token[0] in transitions[current_state]: + # if we have a double-quote + peek character or a single-quote + char, + # transition using that + temp_state = transitions[current_state][current_token[0]] + return get_processing_state(temp_state,current_token[1]) # recurse + else: + return transitions[current_state][0] + +def get_token_gen(sql,tokens): + """ return a generator that indicates the position of each token in turn, + and the identity of that token + return: (token's integer position in string, token) + """ + peek_tokens = ["'",'"'] + position_dict = {} + search_position = 0 + for token in tokens: + position_dict[token] = sql.find(token,search_position) + while position_dict.values() != []: + si = sorted(position_dict.items(), key=lambda t: t[1]) + rval = si[0] + find_next = rval[0] + if rval[1]==-1: + del position_dict[rval[0]] + continue + elif rval[0] in peek_tokens and rval[1]+1 < len(sql): + find_next = rval[0] + rval = (rval[0]+sql[rval[1]+1],rval[1]) + yield rval + # if possible, replace the token just returned and advance the cursor + search_position = rval[1] + len(rval[0]) + position_dict[find_next] = sql.find(find_next,search_position) + def split_sql(sql): - """generate hunks of SQL that are between the bookends - return: tuple of beginning bookend, closing bookend, and contents - note: beginning & end of string are returned as None""" - bookends = ("\n", ";", "--", "/*", "*/") - last_bookend_found = None - start = 0 - - while start <= len(sql): - results = get_next_occurence(sql, start, bookends) - if results is None: - yield (last_bookend_found, None, sql[start:]) - start = len(sql) + 1 - else: - (end, bookend) = results - yield (last_bookend_found, bookend, sql[start:end]) - start = end + len(bookend) - last_bookend_found = bookend - -def get_next_occurence(haystack, offset, needles): - """find next occurence of one of the needles in the haystack - return: tuple of (index, needle found) - or: None if no needle was found""" - # make map of first char to full needle (only works if all needles - # have different first characters) - firstcharmap = dict([(n[0], n) for n in needles]) - firstchars = firstcharmap.keys() - while offset < len(haystack): - if haystack[offset] in firstchars: - possible_needle = firstcharmap[haystack[offset]] - if haystack[offset:offset + len(possible_needle)] == possible_needle: - return (offset, possible_needle) - offset += 1 - return None + """isolate complete SQL-statements from the passed-in string + return: the SQL-statements from the passed-in string, + separated into individual statements """ + if len(sql) == 0: + raise ValueError("Input appears to be empty.") + tokens = ['$$','*/','/*',';',"'",'"','--',"\n"] + # move through the tokens in order, appending SQL-chunks to current string + previous_state = '_' + current_state = '_' + current_sql_expression = '' + previous_position = 0 + for token, position in get_token_gen(sql,tokens): + current_state = get_processing_state(current_state,token) + # disard everything except for newlines if in line-comment state + if current_state != '--' and previous_state != '--': + current_sql_expression += sql[previous_position:position+len(token)] + elif current_state == '--' and previous_state != '--': + # if line-comment just started, add everything before it + current_sql_expression += sql[previous_position:position] + elif token=="\n": + current_sql_expression += token + if current_state == ';': + yield current_sql_expression + current_sql_expression = '' + current_state = '_' + previous_state = '_' + previous_position = position + len(token) + previous_state = current_state + current_sql_expression += sql[previous_position:].rstrip(';') + if current_sql_expression.strip(' ;'): + # unless only whitespace and semicolons left, return remaining characters + # between last ';' and EOF + yield current_sql_expression + ';' diff --git a/test/test_ecpg.py b/test/test_ecpg.py index 6d9c2c8..4f1da1b 100644 --- a/test/test_ecpg.py +++ b/test/test_ecpg.py @@ -14,6 +14,11 @@ def test_simple_failure(self): self.assertFalse(success) self.assertEqual('line 1: ERROR: unrecognized data type name "garbage"', msg) + def test_empty_sql_okay(self): + text = u"EXEC SQL ;" + (success, msg) = ecpg.check_syntax(text) + self.assertTrue(success) + def test_parse_error_simple(self): error = '/tmp/tmpLBKZo5.pgc:1: ERROR: unrecognized data type name "garbage"' expected = 'line 1: ERROR: unrecognized data type name "garbage"' diff --git a/test/test_pgsanity.py b/test/test_pgsanity.py index 8d06d43..45a252c 100644 --- a/test/test_pgsanity.py +++ b/test/test_pgsanity.py @@ -1,6 +1,7 @@ import unittest import tempfile import os +from codecs import BOM_UTF8 from pgsanity import pgsanity @@ -26,6 +27,25 @@ def test_check_invalid_string(self): self.assertFalse(success) self.assertEqual('line 1: ERROR: unrecognized data type name "garbage"', msg) + def test_check_invalid_string_2(self): + text = "SELECT '\n" + text += "-- this is not really a comment' AS c;\n" + text += "SELECT '\n" + text += "-- neither is this' AS c spam;" + + (success,msg) = pgsanity.check_string(text) + self.assertFalse(success) + self.assertEqual('line 4: ERROR: syntax error at or near "spam"', msg) + + def test_bom_gets_stripped(self): + bomless = "SELECT 'pining for the fjords';".encode('utf-8') + bomful = BOM_UTF8 + bomless + self.assertEqual(pgsanity.remove_bom_if_exists(bomful), bomless) + + def test_bom_removal_idempotence(self): + bomless = "SELET current_setting('parrot.status);".encode('utf-8') + self.assertEqual(bomless, pgsanity.remove_bom_if_exists(bomless)) + class TestPgSanityFiles(unittest.TestCase): def setUp(self): diff --git a/test/test_sqlprep.py b/test/test_sqlprep.py index 0fe5092..efd6df1 100644 --- a/test/test_sqlprep.py +++ b/test/test_sqlprep.py @@ -5,12 +5,12 @@ class TestSqlPrep(unittest.TestCase): def test_split_sql_nothing_interesting(self): text = "abcd123" - expected = [(None, None, "abcd123")] + expected = ["abcd123;"] self.assertEqual(expected, list(sqlprep.split_sql(text))) def test_split_sql_trailing_semicolon(self): text = "abcd123;" - expected = [(None, ";", "abcd123"), (";", None, '')] + expected = [text] self.assertEqual(expected, list(sqlprep.split_sql(text))) def test_split_sql_comment_between_statements(self): @@ -18,23 +18,14 @@ def test_split_sql_comment_between_statements(self): text += "--comment here\n" text += "select a from b;" - expected = [(None, ";", "select a from b"), - (";", "\n", ''), - ("\n", "--", ''), - ("--", "\n", 'comment here'), - ("\n", ";", 'select a from b'), - (";", None, '')] + expected = ["select a from b;","\n\nselect a from b;"] self.assertEqual(expected, list(sqlprep.split_sql(text))) def test_split_sql_inline_comment(self): text = "select a from b; --comment here\n" text += "select a from b;" - expected = [(None, ";", "select a from b"), - (";", "--", ' '), - ("--", "\n", 'comment here'), - ("\n", ";", 'select a from b'), - (";", None, '')] + expected = ["select a from b;", " \nselect a from b;"] self.assertEqual(expected, list(sqlprep.split_sql(text))) def test_handles_first_column_comment_between_statements(self): @@ -42,9 +33,8 @@ def test_handles_first_column_comment_between_statements(self): text += "--comment here\n" text += "blah blah;" - expected = "EXEC SQL blah blah;\n" - expected += "//comment here\n" - expected += "EXEC SQL blah blah;" + expected = "EXEC SQL blah blah;" + expected += "EXEC SQL \n\nblah blah;" self.assertEqual(expected, sqlprep.prepare_sql(text)) @@ -52,8 +42,8 @@ def test_handles_inline_comment_between_statements(self): text = "blah blah; --comment here\n" text += "blah blah;" - expected = "EXEC SQL blah blah; //comment here\n" - expected += "EXEC SQL blah blah;" + expected = "EXEC SQL blah blah;" + expected += "EXEC SQL \nblah blah;" self.assertEqual(expected, sqlprep.prepare_sql(text)) @@ -61,7 +51,8 @@ def test_does_not_mangle_inline_comment_within_statement(self): text = "blah blah--comment here\n" text += "blah blah" - expected = "EXEC SQL " + text + expected = "EXEC SQL blah blah\n" + expected += "blah blah;" self.assertEqual(expected, sqlprep.prepare_sql(text)) @@ -70,7 +61,10 @@ def test_does_not_mangle_first_column_comment_within_statement(self): text += "--comment here\n" text += "where c=3" - expected = "EXEC SQL " + text + expected = "select a from b\n" + expected += "\n" + expected += "where c=3;" + expected = "EXEC SQL " + expected self.assertEqual(expected, sqlprep.prepare_sql(text)) @@ -80,8 +74,8 @@ def test_prepend_exec_sql_to_simple_statements(self): self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_prepend_exec_sql_multiple_lines(self): - text1 = "create table control.myfavoritetable (id bigint);\n" - text2 = "create table control.myfavoritetable (id bigint);" + text1 = "create table control.myfavoritetable (id bigint);" + text2 = "\ncreate table control.myfavoritetable (id bigint);" expected = "EXEC SQL " + text1 + "EXEC SQL " + text2 self.assertEqual(expected, sqlprep.prepare_sql(text1 + text2)) @@ -112,32 +106,37 @@ def test_prepend_exec_sql_wrapped_trailing_sql(self): def test_comment_start_found_within_comment_within_statement(self): text = "select a from b --comment in comment --here\nwhere c=1;" - expected = "EXEC SQL select a from b --comment in comment --here\nwhere c=1;" + expected = "EXEC SQL select a from b \nwhere c=1;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_comment_start_found_within_comment_between_statements(self): text = "select a from b; --comment in comment --here\nselect c from d;" - expected = "EXEC SQL select a from b; //comment in comment //here\nEXEC SQL select c from d;" + expected = "EXEC SQL select a from b;EXEC SQL \nselect c from d;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_double_semicolon(self): text = "select a from b;;" - expected = "EXEC SQL select a from b;;" + expected = "EXEC SQL select a from b;EXEC SQL ;" + self.assertEqual(expected, sqlprep.prepare_sql(text)) + + def test_triple_semicolon(self): + text = "select a from b;;;" + expected = "EXEC SQL select a from b;EXEC SQL ;EXEC SQL ;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_semi_found_in_comment_at_end_of_line(self): text = "select a\nfrom b --semi in comment;\nwhere c=1;" - expected = "EXEC SQL select a\nfrom b --semi in comment;\nwhere c=1;" + expected = "EXEC SQL select a\nfrom b \nwhere c=1;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_handles_first_line_comment(self): text = "--comment on line 1\nselect a from b;" - expected = "//comment on line 1\nEXEC SQL select a from b;" + expected = "EXEC SQL \nselect a from b;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_handles_block_comment_on_last_line(self): text = "select a from b;\n/*\nselect c from d;\n*/" - expected = "EXEC SQL select a from b;\n/*\nselect c from d;\n*/" + expected = "EXEC SQL select a from b;EXEC SQL \n/*\nselect c from d;\n*/;" self.assertEqual(expected, sqlprep.prepare_sql(text)) def test_semi_found_in_block_comment(self): @@ -155,7 +154,12 @@ def test_opening_two_block_comments_only_requries_one_close(self): expected = "EXEC SQL select a\n/*\n/*\ncomment\n*/from b;EXEC SQL select c from d;" self.assertEqual(expected, sqlprep.prepare_sql(text)) -# TODO: -# semicolon followed by only whitespace / comments -# multiple semicolons in a row (legal?) -# line starts with semi and then has a statement + def test_trailing_whitespace_after_semicolon(self): + text = "select a from b; " + expected = "EXEC SQL select a from b;" + self.assertEqual(expected, sqlprep.prepare_sql(text)) + + def test_line_starts_with_semicolon(self): + text = ";select a from b;" + expected = "EXEC SQL ;EXEC SQL select a from b;" + self.assertEqual(expected, sqlprep.prepare_sql(text))