Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support context recognition for injected languages #388

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 64 additions & 55 deletions lua/treesitter-context/context.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,13 @@ local get_lang = vim.treesitter.language.get_lang or require('nvim-treesitter.pa
--- @diagnostic disable-next-line:deprecated
local get_query = vim.treesitter.query.get or vim.treesitter.query.get_query

---@param bufnr integer
---@param row integer
---@param col integer
---@return TSNode?
local function get_node(bufnr, row, col)
local root_tree = vim.treesitter.get_parser(bufnr)
if not root_tree then
return
end

return root_tree:named_node_for_range({ row, col, row, col + 1 })
end
--- @param langtree LanguageTree
--- @param range Range4
--- @return TSNode[]?
local function get_parent_nodes(langtree, range)
local tree = langtree:tree_for_range(range, { ignore_injections = true })
local n = tree:root():named_descendant_for_range(unpack(range))

--- @param node TSNode
--- @return TSNode[]
local function get_parent_nodes(node)
local n = node --- @type TSNode?
local ret = {} --- @type TSNode[]
while n do
ret[#ret + 1] = n
Expand Down Expand Up @@ -108,12 +98,9 @@ local context_range = cache.memoize(function(node, query)
end
end, hash_node)

---@param bufnr integer
---@param lang string
---@return Query?
local function get_context_query(bufnr)
--- @type string
local lang = assert(get_lang(vim.bo[bufnr].filetype))

local function get_context_query(lang)
local ok, query = pcall(get_query, lang, 'context')

if not ok then
Expand Down Expand Up @@ -182,6 +169,29 @@ end

local M = {}

---@param bufnr integer
---@param row integer
---@param col integer
---@return LanguageTree[]
local function get_parent_langtrees(bufnr, range)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work properly for nested injections? it looks to me this will only return at most 2 langtrees. The root and the most inner.

Copy link
Contributor Author

@kwaszczuk kwaszczuk Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, I misunderstood language_for_range implementation. I changed the code to manually traverse children langtrees, please take a look.

To test the nested injection, I came up with the markdown test case replacing HTML one. Context both for <html> and <script> displays properly when I run vim manually, but in tests injecting language in Markdown seems to be broken (look at test job for the reference). Do you have any idea what could be the problem here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lewis6991 It took me a while, but I finally get it to work.

Looks like tests are not performing full initialization of nvim-treesitter plugin, but just require'nvim-treesitter.configs'.setup { ... }. Because of that some modules are not initialized, including nvim-treesitter.query_predicates which contains set-lang-from-info-string! directive configuration required by Markdown queries.

My straightforward solution for that is to explicitly initialize the plugin via require'nvim-treesitter'.setup(). I don't think it is the greatest solution, as we end up invoking setup() twice (just on different modules), which seems weird, but I couldn't find any other way that would not induce setup duplication.

local root_tree = vim.treesitter.get_parser(bufnr)
if not root_tree then
return {}
end

local parent_langtrees = {root_tree}

while true do
child_langtree = parent_langtrees[#parent_langtrees]:language_for_range(range)
kwaszczuk marked this conversation as resolved.
Show resolved Hide resolved
if child_langtree == parent_langtrees[#parent_langtrees] then
break
end
parent_langtrees[#parent_langtrees + 1] = child_langtree
end

return parent_langtrees
end

--- @param bufnr integer
--- @param winid integer
--- @return Range4[]?, string[]?
Expand All @@ -196,12 +206,6 @@ function M.get(bufnr, winid)
return
end

local query = get_context_query(bufnr)

if not query then
return
end

local top_row = fn.line('w0', winid) - 1

--- @type integer, integer
Expand All @@ -220,40 +224,45 @@ function M.get(bufnr, winid)

for offset = 0, max_lines do
local node_row = row + offset

local node = get_node(bufnr, node_row, offset == 0 and col or 0)
if not node then
return
end

local parents = get_parent_nodes(node)
local col0 = offset == 0 and col or 0
local range = { node_row, col0, node_row, col0 + 1 }

context_ranges = {}
context_lines = {}
contexts_height = 0

for i = #parents, 1, -1 do
local parent = parents[i]
local parent_start_row = parent:range()

local contexts_end_row = top_row + math.min(max_lines, contexts_height)
-- Only process the parent if it is not in view.
if parent_start_row < contexts_end_row then
local range0 = context_range(parent, query)
if range0 then
local range, lines = get_text_for_range(range0)

local last_context = context_ranges[#context_ranges]
if last_context and parent_start_row == last_context[1] then
-- If there are multiple contexts on the same row, then prefer the inner
contexts_height = contexts_height - util.get_range_height(last_context)
context_ranges[#context_ranges] = nil
context_lines[#context_lines] = nil
end
local parent_trees = get_parent_langtrees(bufnr, range)
for i = 1, #parent_trees, 1 do
local langtree = parent_trees[i]
local query = get_context_query(langtree:lang())
if not query then
return
end

contexts_height = contexts_height + util.get_range_height(range)
context_ranges[#context_ranges + 1] = range
context_lines[#context_lines + 1] = lines
local parents = get_parent_nodes(langtree, range)
for i = #parents, 1, -1 do
local parent = parents[i]
local parent_start_row = parent:range()

local contexts_end_row = top_row + math.min(max_lines, contexts_height)
-- Only process the parent if it is not in view.
if parent_start_row < contexts_end_row then
local range0 = context_range(parent, query)
if range0 then
local range, lines = get_text_for_range(range0)

local last_context = context_ranges[#context_ranges]
if last_context and parent_start_row == last_context[1] then
-- If there are multiple contexts on the same row, then prefer the inner
contexts_height = contexts_height - util.get_range_height(last_context)
context_ranges[#context_ranges] = nil
context_lines[#context_lines] = nil
end

contexts_height = contexts_height + util.get_range_height(range)
context_ranges[#context_ranges + 1] = range
context_lines[#context_lines + 1] = lines
end
end
end
end
Expand Down
14 changes: 14 additions & 0 deletions test/test.html
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@
var b = 2;
function test() {
let test = "asdasd";



if test != "" {









}
}

var c = a + b;
Expand Down
50 changes: 49 additions & 1 deletion test/ts_context_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ describe('ts_context', function()
"lua",
"rust",
"cpp",
"typescript"
"typescript",
"html",
"javascript",
},
sync_install = true,
}
Expand Down Expand Up @@ -329,6 +331,52 @@ describe('ts_context', function()
]]}
end)

it('html', function()
cmd('edit test/test.html')
exec_lua [[vim.treesitter.start()]]

feed'100<C-e>'
screen:expect{grid=[[
{14:<html}{2: }{14:lang}{1:=}{10:"en"}{14:>}{2: }|
{2: }{14:<body>}{2: }|
{2: }{14:<ul>}{2: }|
{2: }{14:<li>}{2: }|
|
^ |
|*5
{15:</li>} |
{15:<li></li>} |
{15:</ul>} |
{15:</body>} |
|
]]}

feed'31<C-e>'
screen:expect{grid=[[
{14:<html}{2: }{14:lang}{1:=}{10:"en"}{14:>}{2: }|
{2: }{14:<script>}{2: }|
{2: }{1:function}{2: }{3:test}{14:()}{2: }{14:{}{2: }|
|
|
^ |
{4:if} {5:test} {4:!=} {11:""} {15:{} |
|*9
]]}

feed'4<C-e>'
screen:expect{grid=[[
{14:<html}{2: }{14:lang}{1:=}{10:"en"}{14:>}{2: }|
{2: }{14:<script>}{2: }|
{2: }{1:function}{2: }{3:test}{14:()}{2: }{14:{}{2: }|
{2: }{1:if}{2: }{3:test}{2: }{1:!=}{2: }{10:""}{2: }{14:{}{2: }|
|
^ |
|*6
{15:}} |
{15:}} |
|*2
]]}
end)
end)

end)
Loading