diff --git a/src/malls/ts/utils.py b/src/malls/ts/utils.py index a1f5aea..5969101 100644 --- a/src/malls/ts/utils.py +++ b/src/malls/ts/utils.py @@ -105,6 +105,14 @@ ) +COMMENTS_QUERY = Query( + MAL_LANGUAGE, + """ + ((comment) @comment_node)* + """, +) + + def find_variable_query(variable_name: str): query = Query( MAL_LANGUAGE, @@ -1335,3 +1343,57 @@ def find_meta_comment_function( # terminate if there are no more parents if node is None: return [] + + +def find_comments_function( + node: Node, symbol: str, document_uri: str = "", storage: dict = {} +) -> list: + """ + Given a node and a symbol, this function will find the comments + associated to that symbol. Comments are considered to be associated + with a symbol if they appear in consecutive lines above the symbol. + + E.g.: + // but not this + + // this as well + // this comment is connected + let myVar = ... + // technically this *could* be connected to above but its a potentially complex scenario + // so only count it towards the node below, if there is one + + + The easiest way to do this is to find all comments in the file and only keep those + which appear in consecutive lines above the symbol. + """ + + start_row = node.start_point.row + + # find comments + captures = run_query(storage[document_uri].tree.root_node, COMMENTS_QUERY) + if not captures: + return [] # there are no comments + + # sort captures by row + sorted_comments = sorted( + filter(lambda item: item.start_point.row < start_row, captures["comment_node"]), + key=lambda item: item.start_point.row, + ) + + comments = [sorted_comments[0].text] + previous_row = sorted_comments[0].end_point.row + + for comment_node in sorted_comments[1:]: + current_row = comment_node.start_point.row + + # if the comment is in a consecutive row, + # we keep it + if current_row == previous_row + 1: + comments.append(comment_node.text) + previous_row = current_row # update row + else: + # otherwise, restart the count + comments = [comment_node.text] + previous_row = comment_node.end_point.row + + return comments if previous_row == start_row - 1 else [] diff --git a/tests/fixtures/mal/find_comments_for_symbol_function.mal b/tests/fixtures/mal/find_comments_for_symbol_function.mal new file mode 100644 index 0000000..f1c4fea --- /dev/null +++ b/tests/fixtures/mal/find_comments_for_symbol_function.mal @@ -0,0 +1,57 @@ +#id: "org.mal-lang.testAnalyzer" +#version:"0.0.0" + +// should not appear +/* + * SHOULD NOT APPEAR + */ + +// category comment +category Example { + + // should not appear + + // asset comment 1 + // asset comment 2 + asset Asset1 + { + | compromise + -> b.compromise + } + + /* + * SHOULD NOT APPEAR + */ + + /* + * MULTI-LINE COMMENT + */ + // followed by single comment + asset Asset2 + { + | compromise + // attack_step comment + & attack + } + + /* + * SHOULD NOT APPEAR + */ + // should not appear + + // asset3 comment + asset + Asset3 { } + + asset + // asset4 comment + Asset4 { } +} +associations +{ + // should not appear + // should not appear + + // association comment + Asset1 [a] * <-- L --> * [c] Asset2 +} diff --git a/tests/unit/test_find_comments_for_symbol_function.py b/tests/unit/test_find_comments_for_symbol_function.py new file mode 100644 index 0000000..f5b86fb --- /dev/null +++ b/tests/unit/test_find_comments_for_symbol_function.py @@ -0,0 +1,60 @@ +from pathlib import Path + +import pytest +import tree_sitter_mal as ts_mal +from tree_sitter import Language, Parser + +from malls.lsp.classes import Document +from malls.lsp.utils import recursive_parsing +from malls.ts.utils import INCLUDED_FILES_QUERY, find_comments_function, run_query + +MAL_LANGUAGE = Language(ts_mal.language()) +PARSER = Parser(MAL_LANGUAGE) +FILE_PATH = str(Path(__file__).parent.parent.resolve()) + "/fixtures/mal/" + +parameters = [ + ((9, 13), [b"// category comment"]), + ((15, 12), [b"// asset comment 1", b"// asset comment 2"]), + ((29, 12), [b"/* \n * MULTI-LINE COMMENT\n */", b"// followed by single comment"]), + ((33, 12), [b"// attack_step comment"]), + ((43, 8), []), + ((47, 6), [b"// asset4 comment"]), + ((55, 21), [b"// association comment"]), +] + + +@pytest.mark.parametrize( + "point,comments", + parameters, +) +def test_find_comments_for_symbol_function(mal_find_comments_for_symbol_function, point, comments): + # build the storage (mimicks the file parsing in the server) + storage = {} + + doc_uri = FILE_PATH + "find_comments_for_symbol_function.mal" + source_encoded = mal_find_comments_for_symbol_function.read() + tree = PARSER.parse(source_encoded) + + storage[doc_uri] = Document(tree, source_encoded, doc_uri) + + # obtain the included files + root_node = tree.root_node + + captures = run_query(root_node, INCLUDED_FILES_QUERY) + if "file_name" in captures: + recursive_parsing(FILE_PATH, captures["file_name"], storage, doc_uri, []) + + ################################### + + # get the node + cursor = tree.walk() + while cursor.goto_first_child_for_point(point) is not None: + continue + + # confirm it's an identifier + assert cursor.node.type == "identifier" + + # we use sets to ensure order does not matter + returned_comments = find_comments_function(cursor.node, cursor.node.text, doc_uri, storage) + + assert set(returned_comments) == set(comments)