diff --git a/src/pyastgrep/color.py b/src/pyastgrep/color.py index 7581751..32db769 100644 --- a/src/pyastgrep/color.py +++ b/src/pyastgrep/color.py @@ -83,11 +83,11 @@ def color_match(self, match: Match) -> str: # So, we only color matches if they start and end on the same line. ast_node = match.ast_node - if match.position.lineno == ast_node.lineno == ast_node.end_lineno: + if match.position.lineno == ast_node.lineno == ast_node.end_lineno: # type: ignore [attr-defined] raw_line = match.matching_line - before = raw_line[0 : ast_node.col_offset] - matched = raw_line[ast_node.col_offset : ast_node.end_col_offset] - after = raw_line[ast_node.end_col_offset :] + before = raw_line[0 : ast_node.col_offset] # type: ignore [attr-defined] + matched = raw_line[ast_node.col_offset : ast_node.end_col_offset] # type: ignore [attr-defined] + after = raw_line[ast_node.end_col_offset :] # type: ignore [attr-defined] return f"{before}{self.match_color}{matched}{Styles.END}{after}" else: diff --git a/src/pyastgrep/context.py b/src/pyastgrep/context.py index 88742d8..3a0e6af 100644 --- a/src/pyastgrep/context.py +++ b/src/pyastgrep/context.py @@ -31,12 +31,13 @@ class StatementContext: def get_context_lines_for_result(self, result: Match) -> tuple[int, int]: result_node = result.ast_node statement_node = ast_utils.get_ast_statement_node(result_node) - first_line = statement_node.lineno + first_line = statement_node.lineno # type: ignore [attr-defined] if hasattr(statement_node, "decorator_list"): - first_line = min((first_line, *(n.lineno for n in statement_node.decorator_list))) - before_context = result_node.lineno - first_line - if isinstance(statement_node.end_lineno, int): - after_context = statement_node.end_lineno - result_node.lineno + decorator_list = statement_node.decorator_list + first_line = min((first_line, *(n.lineno for n in decorator_list))) + before_context = result_node.lineno - first_line # type: ignore [attr-defined] + if isinstance(statement_node.end_lineno, int): # type: ignore [attr-defined] + after_context = statement_node.end_lineno - result_node.lineno # type: ignore [attr-defined] else: after_context = 0 return (before_context, after_context)