Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 11, 2024
1 parent 17374e5 commit 9758aab
Showing 1 changed file with 57 additions and 29 deletions.
86 changes: 57 additions & 29 deletions tests/test_call_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def is_function_or_wrapped_function(obj):
return inspect.isfunction(unwrapped)


def class_class_invocations(cls):
def class_method_invocations(cls, method_name):
class_results = []
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
try:
Expand All @@ -87,7 +87,7 @@ def class_class_invocations(cls):
# Walk the AST to check for `__call__` invocations
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "__class__":
if node.func.attr == method_name:
class_results.append(name)
break
except Exception as e:
Expand All @@ -98,44 +98,52 @@ def class_class_invocations(cls):
class_results.remove("_define_instance")
except ValueError:
pass
try:
class_results.remove("_initialize_tsd_output")
except ValueError:
pass
return class_results


def subclass_class_invocations(base_class):
def subclass_method_invocations(base_class, method_name):
"""
Finds methods in subclasses of a base class where the `__class__` method is invoked.
Finds methods in subclasses of a base class where the `method_name` method is invoked.
Args:
base_class (type): The base class to inspect.
method_name: string
Returns:
dict: A dictionary with subclass names as keys and a list of method names invoking `__class__`.
dict: A dictionary with subclass names as keys and a list of method names invoking `method_name`.
"""
results = {}

cls_results = class_class_invocations(base_class)
cls_results = class_method_invocations(base_class, method_name)

if cls_results:
results[base_class.__name__] = cls_results

for subclass in base_class.__subclasses__():

subclass_results = class_class_invocations(subclass)
subclass_results = class_method_invocations(subclass, method_name)
if subclass_results:
results[subclass.__name__] = subclass_results

return results


def find_class_invocations_in_function(func):
def find_method_invocations_in_function(func, method_name):
"""
Checks if a function contains a call to `__class__`.
Checks if a function contains a call to `method_name`.
Args:
Parameters
----------
func (callable): The function to analyze.
method_name: the name of the method
Returns:
bool: True if `__class__` is invoked in the function, False otherwise.
Returns
-------
bool: True if `method_name` is invoked in the function, False otherwise.
"""
try:
# Get the source code of the function
Expand All @@ -144,41 +152,45 @@ def find_class_invocations_in_function(func):
tree = ast.parse(source)
# Walk the AST to check for `__call__` invocations
for node in ast.walk(tree):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute):
if node.func.attr == "__class__":
if isinstance(node, ast.Call) and isinstance(node.func, (ast.Attribute, ast.Name)):
name = getattr(node.func, "attr", getattr(node.func, "id", None))
if name == method_name:
return True
except Exception as e:
# Log the function that couldn't be analyzed
print(f"Could not analyze function {func}: {e}")
return False


def find_class_invocations_in_module_functions(module):
def find_method_invocations_in_module_functions(module, method_name):
"""
Recursively find all functions in a module that invoke `__class__`.
Recursively find all functions in a module that invoke `method_name`.
Args:
module (module): The module to inspect.
Parameters
----------
module (module): The module to inspect.
method_name (str): The name of the method to inspect.
Returns:
dict: A dictionary with module, class, or function names as keys and
a list of function/method names invoking `__class__`.
Returns
-------
dict: A dictionary with module, class, or function names as keys and
a list of function/method names invoking `method_name`.
"""
results_func = {}

# Inspect functions directly defined in the module
for name, func in inspect.getmembers(
module, predicate=is_function_or_wrapped_function
):
if find_class_invocations_in_function(func):
if find_method_invocations_in_function(func, method_name):
results_func[module.__name__ + f".{name}"] = name

# Recursively inspect submodules
if hasattr(module, "__path__"): # Only packages have a __path__
for submodule_info in pkgutil.iter_modules(module.__path__):
submodule_name = f"{module.__name__}.{submodule_info.name}"
submodule = importlib.import_module(submodule_name)
submodule_results = find_class_invocations_in_module_functions(submodule)
submodule_results = find_method_invocations_in_module_functions(submodule, method_name)
if submodule_results:
results_func.update(submodule_results)

Expand All @@ -190,7 +202,7 @@ def test_find_func():
current_module = sys.modules[__name__]

# Run the detection function
results = find_class_invocations_in_module_functions(current_module)
results = find_method_invocations_in_module_functions(current_module, "__class__")
expected_results = {
"tests.test_call_invocation.invalid_func": "invalid_func",
"tests.test_call_invocation.invalid_func_decorated": "invalid_func_decorated",
Expand All @@ -200,22 +212,38 @@ def test_find_func():

def test_find_class():
# Run the detection function
results = subclass_class_invocations(BaseClass)
results = subclass_method_invocations(BaseClass, "__class__")
expected_results = {"InvalidClass": ["method"]}
assert results == expected_results


def test_no_direct__class__invocation_in_base_subclasses():
results_func = find_class_invocations_in_module_functions(nap)
results_cls = subclass_class_invocations(nap.core.base_class._Base)
results_func = find_method_invocations_in_module_functions(nap, "__class__")
results_cls = subclass_method_invocations(nap.core.base_class._Base, "__class__")
if results_cls != {}:
raise ValueError(
f"Direct use of __class__ found in the following _Base objects and methods: {results_cls}. \n"
"Please, replace them with `_define_instance`."
"Please, replace them with `_define_instance` or `_initialize_tsd_output`."
)

if results_cls != {}:
raise ValueError(
f"Direct use of __class__ found in the following modules and functions: {results_func}. \n"
"Please, replace them with `_define_instance`."
"Please, replace them with `_define_instance` or `_initialize_tsd_output`."
)


def test_no_direct_get_cls_invocation_in_base_subclasses():
results_func = find_method_invocations_in_module_functions(nap, "_get_class")
results_cls = subclass_method_invocations(nap.core.base_class._Base, "_get_class")
if results_cls != {}:
raise ValueError(
f"Direct use of `_get_cls` found in the following _Base objects and methods: {results_cls}. \n"
"Please, replace them with `_initialize_tsd_output`."
)

if results_func != {'pynapple.core.time_series._initialize_tsd_output': '_initialize_tsd_output'}:
raise ValueError(
f"Direct use of _get_cls found in the following modules and functions: {results_func}. \n"
"Please, replace them with `_initialize_tsd_output`."
)

0 comments on commit 9758aab

Please sign in to comment.