Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 96 additions & 5 deletions astropy/utils/introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import inspect
import types
import importlib
import re
import warnings
from distutils.version import LooseVersion


Expand Down Expand Up @@ -88,6 +90,68 @@ def resolve_name(name, *additional_parts):
return ret


_TAG_RULES = (
('post', 2),
('dev', 1),
('rc', -1),
('beta', -2),
('b', -2),
('alpha', -3),
('a', -3),
)

_TAG_PATTERNS = {
name: re.compile(rf'[-_.]*{name}(\d*)$') for name, _ in _TAG_RULES
}

_TAG_SENTINEL = 0


def _normalize_for_loose_version(version_string):
"""Normalize a version string for ``LooseVersion`` comparisons."""
text = str(version_string or '').strip()

if '!' in text:
text = text.split('!', 1)[1]

if '+' in text:
text = text.split('+', 1)[0]

match = re.match(r'^(\d+(?:\.\d+)*)', text)
release_str = match.group(1) if match else ''
remainder = text[len(release_str):]

release_parts = [int(part) for part in release_str.split('.') if part]
if not release_parts:
release_parts = [0]
while len(release_parts) > 1 and release_parts[-1] == 0:
release_parts.pop()
release_parts = [value + 1 for value in release_parts]

tag_code = 0
tag_number = 0

tail = remainder.lower()
for name, code in _TAG_RULES:
match = _TAG_PATTERNS[name].search(tail)
if match:
digits = match.group(1)
tag_number = int(digits) if digits else 0
tag_code = code
break

normalized_code = tag_code + 3
components = release_parts + [_TAG_SENTINEL, normalized_code, tag_number]
return '.'.join(str(value) for value in components)


def _requires_normalization(value):
lower = str(value or '').lower()
if '!' in lower or '+' in lower:
return True
return any(name in lower for name, _ in _TAG_RULES)


def minversion(module, version, inclusive=True, version_path='__version__'):
"""
Returns `True` if the specified Python module satisfies a minimum version
Expand Down Expand Up @@ -139,10 +203,34 @@ def minversion(module, version, inclusive=True, version_path='__version__'):
else:
have_version = resolve_name(module.__name__, version_path)

if inclusive:
return LooseVersion(have_version) >= LooseVersion(version)
else:
return LooseVersion(have_version) > LooseVersion(version)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
left = LooseVersion(have_version)
right = LooseVersion(version)

comparator = (lambda a, b: a >= b) if inclusive else (lambda a, b: a > b)

needs_normalized = (
_requires_normalization(have_version) or
_requires_normalization(version)
)

try:
direct_result = comparator(left, right)
except TypeError:
needs_normalized = True
direct_result = None

if needs_normalized:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
normalized_have = LooseVersion(
_normalize_for_loose_version(have_version))
normalized_need = LooseVersion(
_normalize_for_loose_version(version))
return comparator(normalized_have, normalized_need)

return direct_result


def find_current_module(depth=1, finddiff=False):
Expand Down Expand Up @@ -313,7 +401,10 @@ def find_mod_objs(modname, onlylocals=False):
if onlylocals:
if onlylocals is True:
onlylocals = [modname]
valids = [any(fqn.startswith(nm) for nm in onlylocals) for fqn in fqnames]
valids = [
any(fqn.startswith(nm) for nm in onlylocals)
for fqn in fqnames
]
localnames = [e for i, e in enumerate(localnames) if valids[i]]
fqnames = [e for i, e in enumerate(fqnames) if valids[i]]
objs = [e for i, e in enumerate(objs) if valids[i]]
Expand Down
51 changes: 46 additions & 5 deletions astropy/utils/tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

# namedtuple is needed for find_mod_objs so it can have a non-local module
from collections import namedtuple
from types import ModuleType

import pytest

from .. import introspection
from ..introspection import (find_current_module, find_mod_objs,
isinstancemethod, minversion)
minversion)


def test_pkg_finder():
Expand All @@ -34,7 +35,8 @@ def test_find_current_mod():

assert find_current_module(0, True).__name__ == thismodnm
assert find_current_module(0, [introspection]).__name__ == thismodnm
assert find_current_module(0, ['astropy.utils.introspection']).__name__ == thismodnm
assert find_current_module(
0, ['astropy.utils.introspection']).__name__ == thismodnm

with pytest.raises(ImportError):
find_current_module(0, ['faddfdsasewrweriopunjlfiurrhujnkflgwhu'])
Expand Down Expand Up @@ -63,13 +65,52 @@ def test_find_mod_objs():
assert namedtuple not in objs


def _module_with_version(version):
module = ModuleType(str("test_module"))
module.__version__ = version
return module


def test_minversion():
from types import ModuleType
test_module = ModuleType(str("test_module"))
test_module.__version__ = '0.12.2'
test_module = _module_with_version('0.12.2')
good_versions = ['0.12', '0.12.1', '0.12.0.dev']
bad_versions = ['1', '1.2rc1']
for version in good_versions:
assert minversion(test_module, version)
for version in bad_versions:
assert not minversion(test_module, version)


def test_minversion_dev_comparisons():
assert minversion(_module_with_version('1.14.3'), '1.14dev')
assert not minversion(_module_with_version('1.14'), '1.14dev')


def test_minversion_multi_component_dev_ordering():
lower = _module_with_version('1.2.3.4.dev1')
higher_version = '1.2.3.5.dev1'
assert not minversion(lower, higher_version)
assert minversion(_module_with_version(higher_version), '1.2.3.4.dev1')


@pytest.mark.parametrize(
'lower,higher',
[
('1.14a1', '1.14b1'),
('1.14b1', '1.14rc1'),
('1.14rc1', '1.14'),
]
)
def test_minversion_prerelease_order(lower, higher):
assert not minversion(_module_with_version(lower), higher)
assert minversion(_module_with_version(higher), lower)


def test_minversion_post_release_order():
assert not minversion(_module_with_version('1.14'), '1.14.post1')
assert minversion(_module_with_version('1.14.1'), '1.14.post1')


def test_minversion_ignores_epoch_and_local():
assert minversion(_module_with_version('2!1.14.0+g123'), '1.14')
assert minversion(_module_with_version('1.14'), '2!1.14.0+local')