Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/malls/ts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
)


COMMENTS_QUERY = Query(
MAL_LANGUAGE,
"""
((comment) @comment_node)*
""",
)


def find_variable_query(variable_name: str):
query = Query(
MAL_LANGUAGE,
Expand Down Expand Up @@ -1335,3 +1343,57 @@ def find_meta_comment_function(
# terminate if there are no more parents
if node is None:
return []


def find_comments_function(
node: Node, symbol: str, document_uri: str = "", storage: dict = {}
) -> list:
"""
Given a node and a symbol, this function will find the comments
associated to that symbol. Comments are considered to be associated
with a symbol if they appear in consecutive lines above the symbol.

E.g.:
// but not this

// this as well
// this comment is connected
let myVar = ...
// technically this *could* be connected to above but its a potentially complex scenario
// so only count it towards the node below, if there is one


The easiest way to do this is to find all comments in the file and only keep those
which appear in consecutive lines above the symbol.
"""

start_row = node.start_point.row

# find comments
captures = run_query(storage[document_uri].tree.root_node, COMMENTS_QUERY)
if not captures:
return [] # there are no comments

# sort captures by row
sorted_comments = sorted(
filter(lambda item: item.start_point.row < start_row, captures["comment_node"]),
key=lambda item: item.start_point.row,
)

comments = [sorted_comments[0].text]
previous_row = sorted_comments[0].end_point.row

for comment_node in sorted_comments[1:]:
current_row = comment_node.start_point.row

# if the comment is in a consecutive row,
# we keep it
if current_row == previous_row + 1:
comments.append(comment_node.text)
previous_row = current_row # update row
else:
# otherwise, restart the count
comments = [comment_node.text]
previous_row = comment_node.end_point.row

return comments if previous_row == start_row - 1 else []
57 changes: 57 additions & 0 deletions tests/fixtures/mal/find_comments_for_symbol_function.mal
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#id: "org.mal-lang.testAnalyzer"
#version:"0.0.0"

// should not appear
/*
* SHOULD NOT APPEAR
*/

// category comment
category Example {

// should not appear

// asset comment 1
// asset comment 2
asset Asset1
{
| compromise
-> b.compromise
}

/*
* SHOULD NOT APPEAR
*/

/*
* MULTI-LINE COMMENT
*/
// followed by single comment
asset Asset2
{
| compromise
// attack_step comment
& attack
}

/*
* SHOULD NOT APPEAR
*/
// should not appear

// asset3 comment
asset
Asset3 { }

asset
// asset4 comment
Asset4 { }
}
associations
{
// should not appear
// should not appear

// association comment
Asset1 [a] * <-- L --> * [c] Asset2
}
60 changes: 60 additions & 0 deletions tests/unit/test_find_comments_for_symbol_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pathlib import Path

import pytest
import tree_sitter_mal as ts_mal
from tree_sitter import Language, Parser

from malls.lsp.classes import Document
from malls.lsp.utils import recursive_parsing
from malls.ts.utils import INCLUDED_FILES_QUERY, find_comments_function, run_query

MAL_LANGUAGE = Language(ts_mal.language())
PARSER = Parser(MAL_LANGUAGE)
FILE_PATH = str(Path(__file__).parent.parent.resolve()) + "/fixtures/mal/"

parameters = [
((9, 13), [b"// category comment"]),
((15, 12), [b"// asset comment 1", b"// asset comment 2"]),
((29, 12), [b"/* \n * MULTI-LINE COMMENT\n */", b"// followed by single comment"]),
((33, 12), [b"// attack_step comment"]),
((43, 8), []),
((47, 6), [b"// asset4 comment"]),
((55, 21), [b"// association comment"]),
]


@pytest.mark.parametrize(
"point,comments",
parameters,
)
def test_find_comments_for_symbol_function(mal_find_comments_for_symbol_function, point, comments):
# build the storage (mimicks the file parsing in the server)
storage = {}

doc_uri = FILE_PATH + "find_comments_for_symbol_function.mal"
source_encoded = mal_find_comments_for_symbol_function.read()
tree = PARSER.parse(source_encoded)

storage[doc_uri] = Document(tree, source_encoded, doc_uri)

# obtain the included files
root_node = tree.root_node

captures = run_query(root_node, INCLUDED_FILES_QUERY)
if "file_name" in captures:
recursive_parsing(FILE_PATH, captures["file_name"], storage, doc_uri, [])

###################################

# get the node
cursor = tree.walk()
while cursor.goto_first_child_for_point(point) is not None:
continue

# confirm it's an identifier
assert cursor.node.type == "identifier"

# we use sets to ensure order does not matter
returned_comments = find_comments_function(cursor.node, cursor.node.text, doc_uri, storage)

assert set(returned_comments) == set(comments)
Loading