Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ notifications:
before_install:
- pip install codecov
- mkdir bad_tests
- wget https://hg.python.org/cpython/archive/tip.tar.bz2/Lib/test/ -O bad_tests/test.tar.bz2
- wget https://hg.python.org/cpython/archive/3.6.tar.bz2/Lib/test/ -O bad_tests/test.tar.bz2
- tar -xjf bad_tests/test.tar.bz2 -C bad_tests/
- mv bad_tests/cpython-*/Lib/test/badsyntax_future* . -v
- rm -r bad_tests/
Expand Down
83 changes: 70 additions & 13 deletions flake8_future_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,61 @@
except ImportError as e:
argparse = e

from ast import NodeVisitor, Str, Module, parse
import ast

__version__ = '0.4.3'


class FutureImportVisitor(NodeVisitor):
class FutureImportVisitor(ast.NodeVisitor):

def __init__(self):
super(FutureImportVisitor, self).__init__()
self.future_imports = []

self._uses_code = False
self._uses_print = False
self._uses_division = False
self._uses_import = False
self._uses_str_literals = False
self._uses_generators = False
self._uses_with = False

def _is_print(self, node):
# python 2
if hasattr(ast, 'Print') and isinstance(node, ast.Print):
return True

# python 3
if isinstance(node, ast.Call) and \
isinstance(node.func, ast.Name) and \
node.func.id == 'print':
return True

return False

def visit_ImportFrom(self, node):
if node.module == '__future__':
self.future_imports += [node]

def visit_Expr(self, node):
if not isinstance(node.value, Str) or node.value.col_offset != 0:
self._uses_code = True
else:
self._uses_import = True

def generic_visit(self, node):
if not isinstance(node, Module):
if not isinstance(node, ast.Module):
self._uses_code = True

if isinstance(node, ast.Str):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for absolute correctness, unicode_literals should only be required when there are string literals that are not prefixed ... i.e. if a module uses b'..' throughout, unicode_literals has no effect.

This may also need ast.FormattedValue and ast.JoinedStr, on Python 3.6 -- i havent tested these.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be possible using tokenize. Iterate over each string token and store it in a dictionary using the position and then you could here get that token from the dictionary (key = (node.lineno, node.col_offset)) and check the prefixes.

For a similar a script (flake8-string-format) I have already implemented a system like this to detect if a string is a raw string. And the main issue (and why I haven't released the other script already) is that it might be that there is no filename but that it's read from stdin.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the b'..' is taken care of by py3 as it has ast.Bytes. iirc, tokenize isnt necessary here for py2; just need to check the type of node.s .

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no difference between '…' and b'…' in Python 2 (unless unicode_literals is used):

>>> ast.parse("b'bytes'\n'string'").body[0].value.s
'bytes'
>>> type(ast.parse("b'bytes'\n'string'").body[0].value.s)
<type 'str'>
>>> ast.parse("b'bytes'\n'string'").body[1].value.s
'string'
>>> type(ast.parse("b'bytes'\n'string'").body[1].value.s)
<type 'str'>

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But with Flake8 3.x plugins can directly get the tokens, so it'd be possible to avoid any special “stdin” handling or such.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Maybe that part of the problem should be a separate patch after this one.

self._uses_str_literals = True
elif self._is_print(node):
self._uses_print = True
elif isinstance(node, ast.Div):
self._uses_division = True
elif isinstance(node, ast.Import):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also needs ast.ImportFrom

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is handled above, in visit_ImportFrom

self._uses_import = True
elif isinstance(node, ast.With):
self._uses_with = True
elif isinstance(node, ast.Yield):
self._uses_generators = True

super(FutureImportVisitor, self).generic_visit(node)

@property
Expand Down Expand Up @@ -94,6 +126,7 @@ class FutureImportChecker(Flake8Argparse):
name = 'flake8-future-import'
require_code = True
min_version = False
require_used = False

def __init__(self, tree, filename):
self.tree = tree
Expand All @@ -106,6 +139,8 @@ def add_arguments(cls, parser):
parser.add_argument('--min-version', default=False,
help='The minimum version supported so that it can '
'ignore mandatory and non-existent features')
parser.add_argument('--require-used', action='store_true',
help='Only alert when relevant features are used')

@classmethod
def parse_options(cls, options):
Expand All @@ -122,6 +157,7 @@ def parse_options(cls, options):
'like "A.B.C"'.format(options.min_version))
min_version += (0, ) * (max(3 - len(min_version), 0))
cls.min_version = min_version
cls.require_used = options.require_used

def _generate_error(self, future_import, lineno, present):
feature = FEATURES.get(future_import)
Expand Down Expand Up @@ -156,10 +192,31 @@ def run(self):
yield err
present.add(alias.name)
for name in FEATURES:
if name not in present:
err = self._generate_error(name, 1, False)
if err:
yield err
if name in present:
continue

if self.require_used:
if name == 'print_function' and not visitor._uses_print:
continue

if name == 'division' and not visitor._uses_division:
continue

if name == 'absolute_import' and not visitor._uses_import:
continue

if name == 'unicode_literals' and not visitor._uses_str_literals:
continue

if name == 'generators' and not visitor._uses_generators:
continue

if name == 'with_statement' and not visitor._uses_with:
continue

err = self._generate_error(name, 1, False)
if err:
yield err


def main(args):
Expand Down Expand Up @@ -199,12 +256,12 @@ def main(args):
has_errors = False
for filename in args.files:
with open(filename, 'rb') as f:
tree = parse(f.read(), filename=filename, mode='exec')
tree = ast.parse(f.read(), filename=filename, mode='exec')
for line, char, msg, checker in FutureImportChecker(tree,
filename).run():
if msg[:4] not in ignored:
has_errors = True
print('{0}:{1}:{2}: {3}'.format(filename, line, char, msg))
print('{0}:{1}:{2}: {3}'.format(filename, line, char + 1, msg))
return has_errors


Expand Down
85 changes: 65 additions & 20 deletions test_flake8_future_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import sys
import tempfile

from distutils.version import StrictVersion

if sys.version_info < (2, 7):
import unittest2 as unittest
else:
Expand Down Expand Up @@ -87,13 +85,12 @@ def iterator(self, checker):
self.assertEqual(char, 0)
self.assertIs(origin, flake8_future_import.FutureImportChecker)

def reverse_parse(self, lines, expected_offset, tmp_file=None):
def reverse_parse(self, lines, tmp_file=None):
for line in lines:
match = re.match(r'([^:]+):(\d+):(\d+): (.*)', line)
yield int(match.group(2)), match.group(4)
match = re.match(r'([^:]+):(\d+):1: (.*)', line)
yield int(match.group(2)), match.group(3)
if tmp_file is not None:
self.assertEqual(match.group(1), tmp_file)
self.assertEqual(int(match.group(3)), expected_offset)


class SimpleImportTestCase(TestCaseBase):
Expand Down Expand Up @@ -173,7 +170,7 @@ def run_main(self, *imported):
flake8_future_import.main([tmp_file])
finally:
os.remove(tmp_file)
self.run_test(self.reverse_parse(self.messages, 0), imported)
self.run_test(self.reverse_parse(self.messages), imported)

def test_main(self):
self.run_main()
Expand Down Expand Up @@ -264,14 +261,6 @@ def setUpClass(cls):
else:
raise unittest.SkipTest('The plugin is not installed and '
'TEST_FLAKE8_INSTALL not set')
# Determine version of installed flake8 package
for dist in pip.utils.get_installed_distributions(False):
if dist.key == 'flake8':
version = StrictVersion(dist.version)
cls.expected_offset = 0 if version.version[0] >= 3 else 1
break
else:
raise ValueError('Unable to find Flake8 installation')
super(Flake8TestCase, cls).setUpClass()

@classmethod
Expand Down Expand Up @@ -300,11 +289,8 @@ def run_flake8(self, *imported):
os.close(handle)
os.remove(tmp_file)
self.assertFalse(data_err)
self.run_test(
self.reverse_parse(data_out.decode('utf8').splitlines(),
self.expected_offset,
tmp_file),
imported)
self.run_test(self.reverse_parse(data_out.decode('utf8').splitlines(), tmp_file),
imported)
self.assertEqual(p.returncode, 1)

def test_flake8(self):
Expand Down Expand Up @@ -363,5 +349,64 @@ class TestFeatures(TestCaseBase):
"""Verify that the features are up to date."""


class FeatureDetectionTestCase(TestCaseBase):

ALWAYS_MISSING = frozenset(('generator_stop', 'nested_scopes'))

def check_code(self, code):
tree = ast.parse(code)
checker = flake8_future_import.FutureImportChecker(tree, 'fn')
checker.require_used = True
iterator = self.iterator(checker)
return self.check_result(iterator)

def assert_errors(self, code, missing=None, forbidden=None):
missing = missing or set()
forbidden = forbidden or set()

found_missing, found_forbidden, _ = self.check_code(code)

self.assertEqual(missing, found_missing)
self.assertEqual(forbidden, found_forbidden)

def test_no_code(self):
self.assert_errors('')
self.assert_errors('# comment only')

def test_simple_statement(self):
self.assert_errors('1+1', missing=self.ALWAYS_MISSING)

def test_print_function(self):
self.assert_errors('print(foo)', self.ALWAYS_MISSING | set(['print_function']))

def test_unicode_literals(self):
expected_missing = self.ALWAYS_MISSING | set(['unicode_literals'])
self.assert_errors('"foo"', expected_missing)
self.assert_errors('u"foo"', expected_missing)
self.assert_errors('r"foo"', expected_missing)
self.assert_errors('fn("foo")', expected_missing)

def test_division(self):
# not division
self.assert_errors('a % b', self.ALWAYS_MISSING)

expected_missing = self.ALWAYS_MISSING | set(['division'])
self.assert_errors('1 / 0', expected_missing)
self.assert_errors('1 / 2 / 1', expected_missing)
self.assert_errors('a /= b', expected_missing)
self.assert_errors('fn(3 / 2)', expected_missing)

def test_absolute_import(self):
expected_missing = self.ALWAYS_MISSING | set(['absolute_import'])
self.assert_errors('import foo\npass', expected_missing)
self.assert_errors('from foo import bar\npass', expected_missing)

def test_with_statement(self):
self.assert_errors('with foo: foo()', self.ALWAYS_MISSING | set(['with_statement']))

def test_generators(self):
self.assert_errors('def foo(): yield', self.ALWAYS_MISSING | set(['generators']))


if __name__ == '__main__':
unittest.main()