Skip to content

Commit

Permalink
Class statement support (#22)
Browse files Browse the repository at this point in the history
fix: #18
  • Loading branch information
zhongjiajie authored Oct 12, 2023
1 parent 44fa7ad commit 143e9ba
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 12 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# under the License.

"""The script for setting up stmdency."""
from __future__ import annotations

import logging
import os
from distutils.dir_util import remove_tree
from typing import List

from setuptools import Command, setup

Expand All @@ -30,7 +31,7 @@ class CleanCommand(Command):
"""Command to clean up python api before setup by running `python setup.py pre_clean`."""

description = "Clean up project root"
user_options: List[str] = []
user_options: list[str] = []
clean_list = [
"build",
"htmlcov",
Expand Down
17 changes: 15 additions & 2 deletions src/stmdency/visitors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import libcst as cst
import libcst.matchers as m
from libcst import Assign, FunctionDef, Import, ImportFrom
from libcst import Assign, ClassDef, FunctionDef, Import, ImportFrom

from stmdency.models.node import StmdencyNode
from stmdency.visitors.assign import AssignVisitor
Expand All @@ -22,7 +22,7 @@ class BaseVisitor(cst.CSTVisitor):

stack: dict[str, StmdencyNode] = field(default_factory=dict)
# Add scope to determine if the node is in the same scope
scope: dict[cst.CSTNode] = field(default_factory=set)
scope: set[cst.CSTNode] = field(default_factory=set)

def handle_import(self, node: Import | ImportFrom) -> None:
"""Handle `import` / `from xx import xxx` statement and parse/add to stack."""
Expand All @@ -44,6 +44,19 @@ def visit_ImportFrom(self, node: ImportFrom) -> bool | None:
self.handle_import(node)
return True

def visit_ClassDef(self, node: ClassDef) -> bool | None:
"""Handle class definition, pass to ClassDefVisitor and add scope.
the reason add scope is to skip the visit_Assign in current class
"""
self.scope.add(node)
self.stack.update([(node.name.value, StmdencyNode(node=node))])
return True

def leave_ClassDef(self, original_node: ClassDef) -> None:
"""Remove class definition in scope."""
self.scope.remove(original_node)

def visit_FunctionDef(self, node: FunctionDef) -> bool | None:
"""Handle function definition, pass to FunctionVisitor and add scope.
Expand Down
4 changes: 2 additions & 2 deletions tests/extractor/test_assign.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from __future__ import annotations

import pytest

Expand Down Expand Up @@ -81,5 +81,5 @@


@pytest.mark.parametrize("name, source, expects", assign_cases)
def test_assign(name: str, source: str, expects: Dict[str, str]) -> None:
def test_assign(name: str, source: str, expects: dict[str, str]) -> None:
assert_extract(name, source, expects)
4 changes: 2 additions & 2 deletions tests/extractor/test_function.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from __future__ import annotations

import pytest

Expand Down Expand Up @@ -296,5 +296,5 @@ def foo():


@pytest.mark.parametrize("name, source, expects", func_cases)
def test_func(name: str, source: str, expects: Dict[str, str]) -> None:
def test_func(name: str, source: str, expects: dict[str, str]) -> None:
assert_extract(name, source, expects)
4 changes: 2 additions & 2 deletions tests/extractor/test_import.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict
from __future__ import annotations

import pytest

Expand Down Expand Up @@ -124,5 +124,5 @@ def bar():


@pytest.mark.parametrize("name, source, expects", import_cases)
def test_import(name: str, source: str, expects: Dict[str, str]) -> None:
def test_import(name: str, source: str, expects: dict[str, str]) -> None:
assert_extract(name, source, expects)
34 changes: 34 additions & 0 deletions tests/extractor/test_module_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

import pytest

from tests.testing import assert_extract

import_cases = [
(
"simple class with init",
"""
class Foo:
def __init__(self, arg1):
self.arg1 = arg1
def foo():
f = Foo(arg1=1)
print(f)
""",
{
"foo": """\
class Foo:
def __init__(self, arg1):
self.arg1 = arg1\n\n
def foo():
f = Foo(arg1=1)
print(f)
""",
},
),
]


@pytest.mark.parametrize("name, source, expects", import_cases)
def test_import(name: str, source: str, expects: dict[str, str]) -> None:
assert_extract(name, source, expects)
5 changes: 3 additions & 2 deletions tests/testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import textwrap
from typing import Dict

from stmdency.extractor import Extractor


def assert_extract(name: str, source: str, expects: Dict[str, str]) -> None:
def assert_extract(name: str, source: str, expects: dict[str, str]) -> None:
wrap_source = textwrap.dedent(source)
extractor = Extractor(source=wrap_source)
for expect in expects:
Expand Down

0 comments on commit 143e9ba

Please sign in to comment.