diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index d03aa7e..dba9243 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install pypa/setuptools run: >- python -m diff --git a/.github/workflows/test_main.yml b/.github/workflows/test_main.yml new file mode 100644 index 0000000..6fe9af2 --- /dev/null +++ b/.github/workflows/test_main.yml @@ -0,0 +1,29 @@ +# This workflow will install Python dependencies and run tests with a single version of Python. + +name: Test main branch + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +permissions: + contents: read + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[dev] + - name: Test with pytest + run: | + pytest diff --git a/setup.py b/setup.py index bd00815..e580122 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,9 @@ import sys from setuptools import setup, find_packages -if sys.version_info >= (3, 11, 0): +if sys.version_info >= (3, 12, 0): # due to current pytorch limitations - print('Required python version <= 3.11.0') + print('Required python version <= 3.12.0') sys.exit(-1) @@ -29,23 +29,25 @@ def read(fname): "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", ], - entry_points = { + entry_points={ 'console_scripts': ['textlsp=textLSP.cli:main'], }, install_requires=[ - 'pygls==1.0.0', - 'lsprotocol==2022.0.0a9', + 'pygls==1.1.2', + 'lsprotocol==2023.0.0b1', 'language-tool-python==2.7.1', - 'tree_sitter==0.20.1', - 'gitpython==3.1.29', + 'tree_sitter==0.20.4', + 'gitpython==3.1.40', 'appdirs==1.4.4', - 'torch==1.13.1', - 'openai==0.26.4', - 'transformers==4.25.1', + 'torch==2.1.0', + 'openai==1.2.4', + 'transformers==4.35.1', + 'sortedcontainers==2.4.0', ], extras_require={ 'dev': [ - 'pytest', + 'pytest==7.4.3', + 'python-lsp-jsonrpc==1.1.2', ] }, ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysers/analyser_test.py b/tests/analysers/analyser_test.py new file mode 100644 index 0000000..aaa4c40 --- /dev/null +++ b/tests/analysers/analyser_test.py @@ -0,0 +1,508 @@ +import pytest + +from threading import Event +from lsprotocol.types import ( + DidOpenTextDocumentParams, + TextDocumentItem, + DidChangeTextDocumentParams, + VersionedTextDocumentIdentifier, + TextDocumentContentChangeEvent_Type1, + Range, + Position, + DidSaveTextDocumentParams, + TextDocumentIdentifier, + CodeActionParams, + CodeActionContext, + Diagnostic, +) + +from tests.lsp_test_client import session, utils + + +@pytest.mark.parametrize('text,edit,exp', [ + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0), + ), + '\n', + False + ), + Range( + start=Position(line=1, character=10), + end=Position(line=1, character=18), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=0, character=0), + end=Position(line=0, character=0), + ), + '\n\n\n', + True + ), + Range( + start=Position(line=4, character=10), + end=Position(line=4, character=18), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=0, character=0), + end=Position(line=1, character=0), + ), + '', + True + ), + Range( + start=Position(line=0, character=10), + end=Position(line=0, character=18), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=1, character=23), + end=Position(line=1, character=23), + ), + '\n', + False + ), + Range( + start=Position(line=1, character=10), + end=Position(line=1, character=18), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=1, character=33), + end=Position(line=1, character=33), + ), + ' too', + False + ), + Range( + start=Position(line=1, character=10), + end=Position(line=1, character=18), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=1, character=4), + end=Position(line=1, character=4), + ), + ' word', + True + ), + Range( + start=Position(line=1, character=15), + end=Position(line=1, character=23), + ), + ), + ( + 'This is a sentence.\n' + 'This is a sAntence with an error.\n' + 'And another sentence.', + ( + Range( + start=Position(line=1, character=4), + end=Position(line=1, character=4), + ), + '\n', + True + ), + Range( + start=Position(line=2, character=5), + end=Position(line=2, character=13), + ), + ), +]) +def test_line_shifts(text, edit, exp, json_converter, langtool_ls_onsave): + done = Event() + diag_lst = list() + + langtool_ls_onsave.set_notification_callback( + session.PUBLISH_DIAGNOSTICS, + utils.get_notification_handler( + event=done, + results=diag_lst + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.txt', + language_id='txt', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + assert done.wait(30) + done.clear() + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=1, + uri='dummy.txt', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + edit[0], + edit[1], + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + ret = done.wait(1) + done.clear() + + # no diagnostics notification if none has changed + assert ret == edit[2] + if edit[2]: + assert len(diag_lst) == 2 + else: + assert len(diag_lst) == 1 + + res = diag_lst[-1]['diagnostics'][0]['range'] + assert res == json_converter.unstructure(exp) + + diag = diag_lst[-1]['diagnostics'][0] + diag = Diagnostic( + range=Range( + start=Position(**res['start']), + end=Position(**res['end']), + ), + message=diag['message'], + ) + code_action_params = CodeActionParams( + TextDocumentIdentifier('dummy.txt'), + exp, + CodeActionContext([diag]), + ) + actions_lst = langtool_ls_onsave.text_document_code_action( + json_converter.unstructure(code_action_params) + ) + assert len(actions_lst) == 1 + res = actions_lst[-1]['diagnostics'][0]['range'] + assert res == json_converter.unstructure(exp) + + +@pytest.mark.parametrize('text,edit,exp', [ + ( + 'Introduction\n' + '\n' + 'This is a sentence.\n' + 'This is another.\n' + '\n' + 'Thes is bold.', + ( + Range( + start=Position(line=1, character=0), + end=Position(line=1, character=0), + ), + '\n\n', + ), + Range( + start=Position(line=7, character=0), + end=Position(line=7, character=7), + ), + ), +]) +def test_diagnostics_bug1(text, edit, exp, json_converter, langtool_ls_onsave): + done = Event() + results = list() + + langtool_ls_onsave.set_notification_callback( + session.PUBLISH_DIAGNOSTICS, + utils.get_notification_handler( + event=done, + results=results + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.txt', + language_id='txt', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + assert done.wait(30) + done.clear() + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=1, + uri='dummy.txt', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + edit[0], + edit[1], + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + assert done.wait(30) + done.clear() + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.txt' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert done.wait(30) + done.clear() + + res = results[-1]['diagnostics'][0]['range'] + assert res == json_converter.unstructure(exp) + + +def test_diagnostics_bug2(json_converter, langtool_ls_onsave): + text = ('\\documentclass[11pt]{article}\n' + + '\\begin{document}\n' + + 'o\n' + + '\\section{Thes}\n' + + '\n' + + 'This is a sentence.\n' + + '\n' + + '\\end{document}') + + done = Event() + results = list() + + langtool_ls_onsave.set_notification_callback( + session.PUBLISH_DIAGNOSTICS, + utils.get_notification_handler( + event=done, + results=results + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.tex', + language_id='tex', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + assert done.wait(30) + done.clear() + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=1, + uri='dummy.tex', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + Range( + start=Position(line=2, character=0), + end=Position(line=3, character=0), + ), + '', + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + assert done.wait(30) + done.clear() + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.tex' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert done.wait(30) + done.clear() + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=2, + uri='dummy.tex', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + Range( + start=Position(line=1, character=16), + end=Position(line=2, character=0), + ), + '\no\n', + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + assert done.wait(30) + done.clear() + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.tex' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert done.wait(30) + done.clear() + + exp_lst = [ + Range( + start=Position(line=2, character=0), + end=Position(line=2, character=1), + ), + Range( + start=Position(line=3, character=9), + end=Position(line=3, character=13), + ), + ] + res_lst = results[-1]['diagnostics'] + assert len(res_lst) == len(exp_lst) + for exp, res in zip(exp_lst, res_lst): + assert res['range'] == json_converter.unstructure(exp) + + +def test_diagnostics_bug3(json_converter, langtool_ls_onsave): + text = ('Thiiiis is paragraph one.\n' + '\n' + '\n' + '\n' + 'Sentence one. Sentence two.\n') + + done = Event() + results = list() + + langtool_ls_onsave.set_notification_callback( + session.PUBLISH_DIAGNOSTICS, + utils.get_notification_handler( + event=done, + results=results + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.md', + language_id='md', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + assert done.wait(30) + done.clear() + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=1, + uri='dummy.md', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0) + ), + text='A' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=1), + end=Position(line=2, character=1) + ), + text='s' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=2), + end=Position(line=2, character=2) + ), + text='d' + ), + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + assert not done.wait(10) + done.clear() + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.md' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert done.wait(30) + done.clear() + + exp_lst = [ + Range( + start=Position(line=0, character=0), + end=Position(line=0, character=7), + ), + Range( + start=Position(line=2, character=0), + end=Position(line=2, character=3), + ), + ] + res_lst = results[-1]['diagnostics'] + assert len(res_lst) == len(exp_lst) + for exp, res in zip(exp_lst, res_lst): + assert res['range'] == json_converter.unstructure(exp) diff --git a/tests/analysers/languagetool_test.py b/tests/analysers/languagetool_test.py new file mode 100644 index 0000000..1525bcb --- /dev/null +++ b/tests/analysers/languagetool_test.py @@ -0,0 +1,180 @@ +from threading import Event +from lsprotocol.types import ( + DidOpenTextDocumentParams, + TextDocumentItem, + DidChangeTextDocumentParams, + VersionedTextDocumentIdentifier, + TextDocumentContentChangeEvent_Type1, + Range, + Position, + DidSaveTextDocumentParams, + TextDocumentIdentifier, +) + +from tests.lsp_test_client import session, utils + + +def test_bug1(json_converter, langtool_ls_onsave): + text = ('\\documentclass[11pt]{article}\n' + + '\\begin{document}\n' + + '\n' + + '\\section{Introduction}\n' + + '\n' + + 'This is a sentence.\n' + + '\n' + + '\\end{document}') + + done = Event() + results = list() + + langtool_ls_onsave.set_notification_callback( + session.WINDOW_SHOW_MESSAGE, + utils.get_notification_handler( + event=done, + results=results + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.tex', + language_id='tex', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=1, + uri='dummy.tex', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + Range( + start=Position(line=5, character=19), + end=Position(line=6, character=0), + ), + '\nThis is a sentence.\n', + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=2, + uri='dummy.tex', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + Range( + start=Position(line=6, character=19), + end=Position(line=7, character=0), + ), + '\nThis is a sentence.\n', + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.tex' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert not done.wait(20) + done.clear() + + +def test_bug2(json_converter, langtool_ls_onsave): + text = ( + 'This is a sentence.\n' + + 'This is a sentence.\n' + + 'This is a sentence.\n' + ) + + done = Event() + results = list() + + langtool_ls_onsave.set_notification_callback( + session.WINDOW_SHOW_MESSAGE, + utils.get_notification_handler( + event=done, + results=results + ), + ) + + open_params = DidOpenTextDocumentParams( + TextDocumentItem( + uri='dummy.txt', + language_id='txt', + version=1, + text=text, + ) + ) + + langtool_ls_onsave.notify_did_open( + json_converter.unstructure(open_params) + ) + + for i, edit_range in enumerate([ + # Last two sentences deleted as done by nvim + Range( + start=Position(line=0, character=19), + end=Position(line=0, character=19), + ), + Range( + start=Position(line=1, character=0), + end=Position(line=2, character=0), + ), + Range( + start=Position(line=1, character=0), + end=Position(line=1, character=19), + ), + Range( + start=Position(line=0, character=19), + end=Position(line=0, character=19), + ), + Range( + start=Position(line=1, character=0), + end=Position(line=2, character=0), + ), + ], 1): + change_params = DidChangeTextDocumentParams( + text_document=VersionedTextDocumentIdentifier( + version=i, + uri='dummy.txt', + ), + content_changes=[ + TextDocumentContentChangeEvent_Type1( + edit_range, + '', + ) + ] + ) + langtool_ls_onsave.notify_did_change( + json_converter.unstructure(change_params) + ) + + save_params = DidSaveTextDocumentParams( + text_document=TextDocumentIdentifier( + 'dummy.txt' + ) + ) + langtool_ls_onsave.notify_did_save( + json_converter.unstructure(save_params) + ) + assert not done.wait(20) + done.clear() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..834100a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,66 @@ +import pytest +import copy + +from pygls.protocol import default_converter + +from tests.lsp_test_client import session, defaults + + +@pytest.fixture +def json_converter(): + return default_converter() + + +@pytest.fixture +def simple_server(): + with session.LspSession() as lsp_session: + lsp_session.initialize() + yield lsp_session + + +@pytest.fixture +def langtool_ls(): + init_params = copy.deepcopy(defaults.VSCODE_DEFAULT_INITIALIZE) + init_params["initializationOptions"] = { + 'textLSP': { + 'analysers': { + 'languagetool': { + 'enabled': True, + 'check_text': { + 'on_open': True, + 'on_save': True, + 'on_change': True, + } + } + } + } + } + + with session.LspSession() as lsp_session: + lsp_session.initialize(init_params) + + yield lsp_session + + +@pytest.fixture +def langtool_ls_onsave(): + init_params = copy.deepcopy(defaults.VSCODE_DEFAULT_INITIALIZE) + init_params["initializationOptions"] = { + 'textLSP': { + 'analysers': { + 'languagetool': { + 'enabled': True, + 'check_text': { + 'on_open': True, + 'on_save': True, + 'on_change': False, + } + } + } + } + } + + with session.LspSession() as lsp_session: + lsp_session.initialize(init_params) + + yield lsp_session diff --git a/tests/documents/document_test.py b/tests/documents/document_test.py index 927d70b..8ba05db 100644 --- a/tests/documents/document_test.py +++ b/tests/documents/document_test.py @@ -350,7 +350,7 @@ def test_get_sentence_at_offset(content, offset, length, exp): ), ], [ - Interval(169, 1), + Interval(168, 1), ], ), ( @@ -409,12 +409,124 @@ def test_get_sentence_at_offset(content, offset, length, exp): Interval(171, 1), ], ), + ( + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=0, + character=19, + ), + end=Position( + line=1, + character=0, + ), + ), + text='\nThis is a sentence.\n', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=1, + character=19, + ), + end=Position( + line=2, + character=0, + ), + ), + text='\nThis is a sentence.\n', + ), + ], + [ + Interval(19, 20), + Interval(39, 21), + ], + ), + ( + 'This is a sentence.\n' + 'This is a sentence.\n' + 'This is a sentence.\n', + [ + # Last two sentences deleted as done by nvim + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=0, + character=19, + ), + end=Position( + line=0, + character=19, + ), + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=1, + character=0, + ), + end=Position( + line=2, + character=0, + ), + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=1, + character=0, + ), + end=Position( + line=1, + character=19, + ), + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=0, + character=19, + ), + end=Position( + line=0, + character=19, + ), + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=1, + character=0, + ), + end=Position( + line=2, + character=0, + ), + ), + text='', + ), + ], + [ + Interval(18, 1), + ], + ), ]) def test_updates(content, edits, exp): doc = BaseDocument('DUMMY_URL', content) tracker = ChangeTracker(doc, True) for edit in edits: - tracker.update_document(edit) + doc.apply_change(edit) + tracker.update_document(edit, doc) assert tracker.get_changes() == exp diff --git a/tests/documents/latex_test.py b/tests/documents/latex_test.py index bcd80c2..8c688ac 100644 --- a/tests/documents/latex_test.py +++ b/tests/documents/latex_test.py @@ -1,4 +1,6 @@ import pytest +import time +import logging from lsprotocol.types import ( Position, @@ -570,12 +572,503 @@ def test_get_paragraphs_at_range(content, range, exp): Interval(35, 1), ], ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=6, + character=0, + ), + end=Position( + line=7, + character=0, + ), + ), + text='\n\\end{document}\n', + ), + ], + [ + Interval(33, 16), + ], + ), ]) -def test_updates(content, edits, exp): +def test_change_tracker(content, edits, exp): doc = LatexDocument('DUMMY_URL', content) tracker = ChangeTracker(doc, True) for edit in edits: - tracker.update_document(edit) + doc.apply_change(edit) + tracker.update_document(edit, doc) assert tracker.get_changes() == exp + + +@pytest.mark.parametrize('content,changes,exp,offset_test,position_test', [ + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + + 'This is a sentence.\n'*2 + + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + # add 'o' to Introduction + range=Range( + start=Position( + line=3, + character=13, + ), + end=Position( + line=3, + character=13, + ), + ), + text='o', + ), + ], + 'Introoduction\n' + '\n' + + ' '.join(['This is a sentence.']*2) + + '\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + + 'This is a sentence.\n'*2 + + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + # delete 'o' from Introduction + range=Range( + start=Position( + line=3, + character=13, + ), + end=Position( + line=3, + character=14, + ), + ), + text='', + ), + ], + 'Intrduction\n' + '\n' + + ' '.join(['This is a sentence.']*2) + + '\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'An initial sentence.\n' + '\n' + '\\section{Introduction}\n' + '\n' + + 'This is a sentence.\n'*2 + + '\n' + '\\section{Conclusions}\n' + '\n' + 'A final sentence.\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + # replace the word initial + range=Range( + start=Position( + line=5, + character=3, + ), + end=Position( + line=5, + character=10, + ), + ), + text='\n\naaaaaaa', + ), + ], + 'Introduction\n' + '\n' + 'An\n' + '\n' + 'aaaaaaa sentence.\n' + '\n' + 'Introduction\n' + '\n' + + ' '.join(['This is a sentence.']*2) + + '\n\n' + 'Conclusions\n' + '\n' + 'A final sentence.\n', + ( + -16, + 'final', + Range( + start=Position(16, 2), + end=Position(16, 6), + ), + ), + ( + Position( + line=16, + character=2, + ), + 'final', + ), + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence. \\section{Inline} FooBar\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=5, + character=2, + ), + end=Position( + line=5, + character=2, + ), + ), + text='oooooo', + ), + ], + 'Introduction\n' + '\n' + 'Thoooooois is a sentence.\n' + '\n' + 'Inline\n' + '\n' + 'FooBar\n', + None, + ( + Position( + line=5, + character=43, + ), + 'FooBar', + ), + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=6, + character=0, + ), + end=Position( + line=6, + character=0, + ), + ), + text='o', + ), + ], + 'Introduction\n' + '\n' + + 'This is a sentence. o\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\n' + '\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=2, + character=0, + ), + end=Position( + line=2, + character=0, + ), + ), + text='o', + ), + ], + 'o\n' + '\n' + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + 'o\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=2, + character=0, + ), + end=Position( + line=3, + character=0, + ), + ), + text='', + ), + ], + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + ( + 0, + 'Introduction', + Range( + start=Position(2, 9), + end=Position(2, 20), + ), + ), + ( + Position( + line=2, + character=9, + ), + 'Introduction', + ), + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + # delete last character: '.' + range=Range( + start=Position( + line=5, + character=18, + ), + end=Position( + line=5, + character=19, + ), + ), + text='', + ), + ], + 'Introduction\n' + '\n' + + 'This is a sentence\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}\n' + '\n', + [ + TextDocumentContentChangeEvent_Type1( + # delete last character: '.' + range=Range( + start=Position( + line=8, + character=0, + ), + end=Position( + line=9, + character=0, + ), + ), + text='', + ), + ], + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=6, + character=0, + ), + end=Position( + line=7, + character=0, + ), + ), + text='\n\\end{document}\n', + ), + ], + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + '\\section{Introduction}\n' + '\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=1, + character=16, + ), + end=Position( + line=2, + character=0, + ), + ), + text='\no\n', + ), + ], + 'o\n' + '\n' + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + None, + None, + ), + ( + '\\documentclass[11pt]{article}\n' + '\\begin{document}\n' + 'A sentence.\n' + 'Introduction\n' + 'This is a sentence.\n' + '\n' + '\\end{document}', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=3, + character=0, + ), + end=Position( + line=3, + character=0, + ), + ), + text='\\section{', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=3, + character=21, + ), + end=Position( + line=3, + character=21, + ), + ), + text='}', + ), + ], + 'A sentence.\n' + '\n' + 'Introduction\n' + '\n' + + 'This is a sentence.\n', + None, + None, + ), +]) +def test_edits(content, changes, exp, offset_test, position_test): + doc = LatexDocument('DUMMY_URL', content) + doc.cleaned_source + start = time.time() + for change in changes: + doc.apply_change(change) + assert doc.cleaned_source == exp + logging.warning(time.time() - start) + + if offset_test is not None: + offset = offset_test[0] + if offset < 0: + offset = len(exp) + offset + assert doc.text_at_offset(offset, len(offset_test[1]), True) == offset_test[1] + if len(offset_test) > 2: + assert doc.range_at_offset(offset, len(offset_test[1]), True) == offset_test[2] + if position_test is not None: + offset = doc.offset_at_position(position_test[0], True) + assert doc.text_at_offset(offset, len(position_test[1]), True) == position_test[1] diff --git a/tests/documents/markdown_test.py b/tests/documents/markdown_test.py index 1b9d86a..1dd8216 100644 --- a/tests/documents/markdown_test.py +++ b/tests/documents/markdown_test.py @@ -1,6 +1,11 @@ import pytest from textLSP.documents.markdown.markdown import MarkDownDocument +from lsprotocol.types import ( + Position, + Range, + TextDocumentContentChangeEvent_Type1 +) @pytest.mark.parametrize('src,clean', [ @@ -103,3 +108,501 @@ def test_highlight(src, offset, exp): res += lines[pos_range.end.line][:pos_range.end.character+1] assert res == exp + + +@pytest.mark.parametrize('content,changes,exp,offset_test,position_test', [ + ( + 'This is a sentence.', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position( + line=0, + character=0, + ), + end=Position( + line=0, + character=4, + ), + ), + text='That', + ), + ], + 'That is a sentence.\n', + None, + None, + ), + ( + # Based on a bug, as done by in nvim + 'This is a sentence. This is another.\n' + '\n' + 'This is a new paragraph.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=19), + end=Position(line=0, character=36) + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=2, character=0) + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=1, character=24) + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=19), + end=Position(line=0, character=19) + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=2, character=0) + ), + text='', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=19), + end=Position(line=1, character=0) + ), + text='\n\n', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=1, character=0) + ), + text='\n', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0) + ), + text='A', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=1), + end=Position(line=2, character=1) + ), + text='s', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=2), + end=Position(line=2, character=2) + ), + text='d', + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=3), + end=Position(line=2, character=3) + ), + text='f', + ), + ], + 'This is a sentence.\n' + '\n' + 'Asdf\n', + None, + None, + ), + ( + # Based on a bug in nvim + 'This is paragraph one.\n' + '\n' + 'Sentence one. Sentence two.\n' + '\n' + 'Sentence three.\n' + '\n' + '# Header\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=13), + end=Position(line=2, character=27), + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=3, character=0), + end=Position(line=4, character=0), + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=3, character=0), + end=Position(line=3, character=15), + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=13), + end=Position(line=2, character=13), + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=3, character=0), + end=Position(line=4, character=0), + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=13), + end=Position(line=2, character=13), + ), + text=' Sentence two.\n\nSentence three.' + ), + ], + 'This is paragraph one.\n' + '\n' + 'Sentence one. Sentence two.\n' + '\n' + 'Sentence three.\n' + '\n' + 'Header\n', + None, + None, + ), + ( + 'This is paragraph one.\n' + '\n' + '\n' + 'Sentence one. Sentence two.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=2, character=0), + ), + text='\n\n', + ), + ], + 'This is paragraph one.\n' + '\n' + 'Sentence one. Sentence two.\n', + None, + None, + ), + ( + 'This is paragraph one.\n' + '\n' + '\n' + '\n' + 'Sentence one. Sentence two.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0) + ), + text='A' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=1), + end=Position(line=2, character=1) + ), + text='s' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=2), + end=Position(line=2, character=2) + ), + text='d' + ), + ], + 'This is paragraph one.\n' + '\n' + 'Asd\n' + '\n' + 'Sentence one. Sentence two.\n', + None, + None, + ), + ( + 'This is paragraph one.\n' + '\n' + 'Sentence one. Sentence two.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=22), + end=Position(line=0, character=22) + ), + text=' ' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=22), + end=Position(line=0, character=23) + ), + text='\n' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=0), + end=Position(line=1, character=0) + ), + text='A' + ), + ], + 'This is paragraph one. A\n' + '\n' + 'Sentence one. Sentence two.\n', + None, + None, + ), + ( + 'This is a sentence.\n' + '\n' + 'Header\n' + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0) + ), + text='#' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=1), + end=Position(line=2, character=1) + ), + text=' ' + ), + ], + 'This is a sentence.\n' + '\n' + 'Header\n' + '\n' + 'This is a sentence.\n', + None, + None, + ), + ( + 'Header\n' + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=0), + end=Position(line=0, character=0) + ), + text='#' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=1), + end=Position(line=0, character=1) + ), + text=' ' + ), + ], + 'Header\n' + '\n' + 'This is a sentence.\n', + None, + None, + ), + ( + 'This is a sentence.\n' + '\n' + '# Header\n' + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=1), + end=Position(line=2, character=2) + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=1) + ), + text='' + ), + ], + 'This is a sentence.\n' + '\n' + 'Header This is a sentence.\n', + None, + None, + ), + ( + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=0), + end=Position(line=1, character=0) + ), + text='' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=0), + end=Position(line=0, character=0) + ), + text='This is a sentence.' + ), + ], + 'This is a sentence.\n', + None, + None, + ), + ( + '* This is point one.\n' + '* This is point two.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=0, character=0), + end=Position(line=0, character=0) + ), + text='* This is point one.\n' + ), + ], + 'This is point one.\n' + '\n' + 'This is point one.\n' + '\n' + 'This is point two.\n', + None, + None, + ), + ( + 'This is a sentence.\n' + 'A\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=1), + end=Position(line=1, character=1) + ), + text='B' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=2), + end=Position(line=1, character=2) + ), + text=' ' + ), + ], + 'This is a sentence. AB\n', + None, + None, + ), + ( + 'This is a sentence.\n' + 'A\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=1), + end=Position(line=1, character=1) + ), + text=' ' + ), + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=1, character=2), + end=Position(line=1, character=2) + ), + text=' ' + ), + ], + 'This is a sentence. A\n', + None, + None, + ), + ( + 'This is a sentence.\n' + '\n' + ' This will be an unparsed part.\n' + '\n' + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=0) + ), + text=' ' + ), + ], + 'This is a sentence.\n' + '\n' + 'This is a sentence.\n', + None, + None, + ), + ( + 'This is a sentence.\n' + '\n' + ' This will be a parsed part.\n' + '\n' + 'This is a sentence.\n', + [ + TextDocumentContentChangeEvent_Type1( + range=Range( + start=Position(line=2, character=0), + end=Position(line=2, character=1) + ), + text='' + ), + ], + 'This is a sentence.\n' + '\n' + 'This will be a parsed part.\n' + '\n' + 'This is a sentence.\n', + None, + None, + ), +]) +def test_edits(content, changes, exp, offset_test, position_test): + doc = MarkDownDocument('DUMMY_URL', content) + doc.cleaned_source + for change in changes: + doc.apply_change(change) + assert doc.cleaned_source == exp + + if offset_test is not None: + offset = offset_test[0] + if offset < 0: + offset = len(exp) + offset + assert doc.text_at_offset(offset, len(offset_test[1]), True) == offset_test[1] + if len(offset_test) > 2: + assert doc.range_at_offset(offset, len(offset_test[1]), True) == offset_test[2] + if position_test is not None: + offset = doc.offset_at_position(position_test[0], True) + assert doc.text_at_offset(offset, len(position_test[1]), True) == position_test[1] diff --git a/tests/lsp_test_client/__init__.py b/tests/lsp_test_client/__init__.py new file mode 100644 index 0000000..3f3eae1 --- /dev/null +++ b/tests/lsp_test_client/__init__.py @@ -0,0 +1,10 @@ +# Taken from: https://github.com/pappasam/jedi-language-server +"""Test client main module.""" + +import py + +from .utils import as_uri + +TEST_ROOT = py.path.local(__file__) / ".." +PROJECT_ROOT = TEST_ROOT / ".." / ".." +PROJECT_URI = as_uri(PROJECT_ROOT) diff --git a/tests/lsp_test_client/defaults.py b/tests/lsp_test_client/defaults.py new file mode 100644 index 0000000..529e52a --- /dev/null +++ b/tests/lsp_test_client/defaults.py @@ -0,0 +1,214 @@ +"""Default values for lsp test client.""" +import os + +import tests.lsp_test_client as lsp_client + +VSCODE_DEFAULT_INITIALIZE = { + "processId": os.getpid(), # pylint: disable=no-member + "clientInfo": {"name": "vscode", "version": "1.45.0"}, + "rootPath": str(lsp_client.PROJECT_ROOT), + "rootUri": lsp_client.PROJECT_URI, + "capabilities": { + "workspace": { + "applyEdit": True, + "workspaceEdit": { + "documentChanges": True, + "resourceOperations": ["create", "rename", "delete"], + "failureHandling": "textOnlyTransactional", + }, + "didChangeConfiguration": {"dynamicRegistration": True}, + "didChangeWatchedFiles": {"dynamicRegistration": True}, + "symbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + "tagSupport": {"valueSet": [1]}, + }, + "executeCommand": {"dynamicRegistration": True}, + "configuration": True, + "workspaceFolders": True, + }, + "textDocument": { + "publishDiagnostics": { + "relatedInformation": True, + "versionSupport": False, + "tagSupport": {"valueSet": [1, 2]}, + "complexDiagnosticCodeSupport": True, + }, + "synchronization": { + "dynamicRegistration": True, + "willSave": True, + "willSaveWaitUntil": True, + "didSave": True, + }, + "completion": { + "dynamicRegistration": True, + "contextSupport": True, + "completionItem": { + "snippetSupport": True, + "commitCharactersSupport": True, + "documentationFormat": ["markdown", "plaintext"], + "deprecatedSupport": True, + "preselectSupport": True, + "tagSupport": {"valueSet": [1]}, + "insertReplaceSupport": True, + }, + "completionItemKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + ] + }, + }, + "hover": { + "dynamicRegistration": True, + "contentFormat": ["markdown", "plaintext"], + }, + "signatureHelp": { + "dynamicRegistration": True, + "signatureInformation": { + "documentationFormat": ["markdown", "plaintext"], + "parameterInformation": {"labelOffsetSupport": True}, + }, + "contextSupport": True, + }, + "definition": {"dynamicRegistration": True, "linkSupport": True}, + "references": {"dynamicRegistration": True}, + "documentHighlight": {"dynamicRegistration": True}, + "documentSymbol": { + "dynamicRegistration": True, + "symbolKind": { + "valueSet": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + ] + }, + "hierarchicalDocumentSymbolSupport": True, + "tagSupport": {"valueSet": [1]}, + }, + "codeAction": { + "dynamicRegistration": True, + "isPreferredSupport": True, + "codeActionLiteralSupport": { + "codeActionKind": { + "valueSet": [ + "", + "quickfix", + "refactor", + "refactor.extract", + "refactor.inline", + "refactor.rewrite", + "source", + "source.organizeImports", + ] + } + }, + }, + "codeLens": {"dynamicRegistration": True}, + "formatting": {"dynamicRegistration": True}, + "rangeFormatting": {"dynamicRegistration": True}, + "onTypeFormatting": {"dynamicRegistration": True}, + "rename": {"dynamicRegistration": True, "prepareSupport": True}, + "documentLink": { + "dynamicRegistration": True, + "tooltipSupport": True, + }, + "typeDefinition": { + "dynamicRegistration": True, + "linkSupport": True, + }, + "implementation": { + "dynamicRegistration": True, + "linkSupport": True, + }, + "colorProvider": {"dynamicRegistration": True}, + "foldingRange": { + "dynamicRegistration": True, + "rangeLimit": 5000, + "lineFoldingOnly": True, + }, + "declaration": {"dynamicRegistration": True, "linkSupport": True}, + "selectionRange": {"dynamicRegistration": True}, + }, + "window": {"workDoneProgress": True}, + }, + "trace": "verbose", + "workspaceFolders": [{"uri": lsp_client.PROJECT_URI, "name": "textLSP"}], + "initializationOptions": { + }, +} diff --git a/tests/lsp_test_client/lsp_run.py b/tests/lsp_test_client/lsp_run.py new file mode 100644 index 0000000..d3c1b91 --- /dev/null +++ b/tests/lsp_test_client/lsp_run.py @@ -0,0 +1,7 @@ +"""Run Language Server for Test.""" + +import sys + +from textLSP.cli import main + +sys.exit(main()) diff --git a/tests/lsp_test_client/session.py b/tests/lsp_test_client/session.py new file mode 100644 index 0000000..2d09421 --- /dev/null +++ b/tests/lsp_test_client/session.py @@ -0,0 +1,279 @@ +"""Provides LSP session helpers for testing.""" + +import os +import subprocess +import sys +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Event + +from pylsp_jsonrpc.dispatchers import MethodDispatcher +from pylsp_jsonrpc.endpoint import Endpoint +from pylsp_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter + +from tests.lsp_test_client import defaults + +LSP_EXIT_TIMEOUT = 5000 + + +PUBLISH_DIAGNOSTICS = "textDocument/publishDiagnostics" +WINDOW_LOG_MESSAGE = "window/logMessage" +WINDOW_SHOW_MESSAGE = "window/showMessage" +WINDOW_WORK_DONE_PROGRESS_CREATE = "window/workDoneProgress/create" + +# pylint: disable=no-member + + +class LspSession(MethodDispatcher): + """Send and Receive messages over LSP as a test LS Client.""" + + def __init__(self, cwd=None): + self.cwd = cwd if cwd else os.getcwd() + # pylint: disable=consider-using-with + self._thread_pool = ThreadPoolExecutor() + self._sub = None + self._writer = None + self._reader = None + self._endpoint = None + self._notification_callbacks = {} + + def __enter__(self): + """Context manager entrypoint. + + shell=True needed for pytest-cov to work in subprocess. + """ + # pylint: disable=consider-using-with + self._sub = subprocess.Popen( + [ + sys.executable, + os.path.join(os.path.dirname(__file__), "lsp_run.py"), + ], + stdout=subprocess.PIPE, + stdin=subprocess.PIPE, + bufsize=0, + cwd=self.cwd, + env=os.environ, + shell="WITH_COVERAGE" in os.environ, + ) + + self._writer = JsonRpcStreamWriter( + os.fdopen(self._sub.stdin.fileno(), "wb") + ) + self._reader = JsonRpcStreamReader( + os.fdopen(self._sub.stdout.fileno(), "rb") + ) + + dispatcher = { + PUBLISH_DIAGNOSTICS: self._publish_diagnostics, + WINDOW_SHOW_MESSAGE: self._window_show_message, + WINDOW_LOG_MESSAGE: self._window_log_message, + WINDOW_WORK_DONE_PROGRESS_CREATE: self._window_work_done_progress_create, + } + self._endpoint = Endpoint(dispatcher, self._writer.write) + self._thread_pool.submit(self._reader.listen, self._endpoint.consume) + return self + + def __exit__(self, typ, value, _tb): + self.shutdown(True) + try: + self._sub.terminate() + except Exception: # pylint:disable=broad-except + pass + self._endpoint.shutdown() + self._thread_pool.shutdown() + + def initialize( + self, + initialize_params=None, + process_server_capabilities=None, + ): + """Sends the initialize request to LSP server.""" + server_initialized = Event() + + def _after_initialize(fut): + if process_server_capabilities: + process_server_capabilities(fut.result()) + self.initialized() + server_initialized.set() + + self._send_request( + "initialize", + params=( + initialize_params + if initialize_params is not None + else defaults.VSCODE_DEFAULT_INITIALIZE + ), + handle_response=_after_initialize, + ) + + server_initialized.wait() + + def initialized(self, initialized_params=None): + """Sends the initialized notification to LSP server.""" + if initialized_params is None: + initialized_params = {} + self._endpoint.notify("initialized", initialized_params) + + def shutdown(self, should_exit, exit_timeout=LSP_EXIT_TIMEOUT): + """Sends the shutdown request to LSP server.""" + + def _after_shutdown(_): + if should_exit: + self.exit_lsp(exit_timeout) + + self._send_request("shutdown", handle_response=_after_shutdown) + + def exit_lsp(self, exit_timeout=LSP_EXIT_TIMEOUT): + """Handles LSP server process exit.""" + self._endpoint.notify("exit") + assert self._sub.wait(exit_timeout) == 0 + + def text_document_completion(self, completion_params): + """Sends text document completion request to LSP server.""" + fut = self._send_request( + "textDocument/completion", params=completion_params + ) + return fut.result() + + def text_document_rename(self, rename_params): + """Sends text document rename request to LSP server.""" + fut = self._send_request("textDocument/rename", params=rename_params) + return fut.result() + + def text_document_code_action(self, code_action_params): + """Sends text document code action request to LSP server.""" + fut = self._send_request( + "textDocument/codeAction", params=code_action_params + ) + return fut.result() + + def text_document_hover(self, hover_params): + """Sends text document hover request to LSP server.""" + fut = self._send_request("textDocument/hover", params=hover_params) + return fut.result() + + def text_document_signature_help(self, signature_help_params): + """Sends text document hover request to LSP server.""" + fut = self._send_request( + "textDocument/signatureHelp", params=signature_help_params + ) + return fut.result() + + def text_document_definition(self, definition_params): + """Sends text document defintion request to LSP server.""" + fut = self._send_request( + "textDocument/definition", params=definition_params + ) + return fut.result() + + def text_document_symbol(self, document_symbol_params): + """Sends text document symbol request to LSP server.""" + fut = self._send_request( + "textDocument/documentSymbol", params=document_symbol_params + ) + return fut.result() + + def text_document_highlight(self, document_highlight_params): + """Sends text document highlight request to LSP server.""" + fut = self._send_request( + "textDocument/documentHighlight", params=document_highlight_params + ) + return fut.result() + + def text_document_references(self, references_params): + """Sends text document references request to LSP server.""" + fut = self._send_request( + "textDocument/references", params=references_params + ) + return fut.result() + + def workspace_symbol(self, workspace_symbol_params): + """Sends workspace symbol request to LSP server.""" + fut = self._send_request( + "workspace/symbol", params=workspace_symbol_params + ) + return fut.result() + + def completion_item_resolve(self, resolve_params): + """Sends completion item resolve request to LSP server.""" + fut = self._send_request( + "completionItem/resolve", params=resolve_params + ) + return fut.result() + + def notify_did_change(self, did_change_params): + """Sends did change notification to LSP Server.""" + self._send_notification( + "textDocument/didChange", params=did_change_params + ) + + def notify_did_save(self, did_save_params): + """Sends did save notification to LSP Server.""" + self._send_notification("textDocument/didSave", params=did_save_params) + + def notify_did_open(self, did_open_params): + """Sends did open notification to LSP Server.""" + self._send_notification("textDocument/didOpen", params=did_open_params) + + def set_notification_callback(self, notification_name, callback): + """Set custom LS notification handler.""" + self._notification_callbacks[notification_name] = callback + + def get_notification_callback(self, notification_name): + """Gets callback if set or default callback for a given LS + notification.""" + try: + return self._notification_callbacks[notification_name] + except KeyError: + + def _default_handler(_params): + """Default notification handler.""" + + return _default_handler + + def _publish_diagnostics(self, publish_diagnostics_params): + """Internal handler for text document publish diagnostics.""" + return self._handle_notification( + PUBLISH_DIAGNOSTICS, publish_diagnostics_params + ) + + def _window_log_message(self, window_log_message_params): + """Internal handler for window log message.""" + return self._handle_notification( + WINDOW_LOG_MESSAGE, window_log_message_params + ) + + def _window_show_message(self, window_show_message_params): + """Internal handler for window show message.""" + return self._handle_notification( + WINDOW_SHOW_MESSAGE, window_show_message_params + ) + + def _window_work_done_progress_create(self, window_progress_params): + """Internal handler for window/workDoneProgress/create""" + return self._handle_notification( + WINDOW_WORK_DONE_PROGRESS_CREATE, window_progress_params + ) + + def _handle_notification(self, notification_name, params): + """Internal handler for notifications.""" + fut = Future() + + def _handler(): + callback = self.get_notification_callback(notification_name) + callback(params) + fut.set_result(None) + + self._thread_pool.submit(_handler) + return fut + + def _send_request( + self, name, params=None, handle_response=lambda f: f.done() + ): + """Sends {name} request to the LSP server.""" + fut = self._endpoint.request(name, params) + fut.add_done_callback(handle_response) + return fut + + def _send_notification(self, name, params=None): + """Sends {name} notification to the LSP server.""" + self._endpoint.notify(name, params) diff --git a/tests/lsp_test_client/utils.py b/tests/lsp_test_client/utils.py new file mode 100644 index 0000000..a32c8a8 --- /dev/null +++ b/tests/lsp_test_client/utils.py @@ -0,0 +1,31 @@ +"""Provides LSP client side utilities for easier testing.""" + +import pathlib +import platform +import functools + +import py + +# pylint: disable=no-member + + +def normalizecase(path: str) -> str: + """Fixes 'file' uri or path case for easier testing in windows.""" + if platform.system() == "Windows": + return path.lower() + return path + + +def as_uri(path: py.path.local) -> str: + """Return 'file' uri as string.""" + return normalizecase(pathlib.Path(path).as_uri()) + + +def handle_notification(params, event, results=None): + if results is not None: + results.append(params) + event.set() + + +def get_notification_handler(*args, **kwargs): + return functools.partial(handle_notification, *args, **kwargs) diff --git a/textLSP/analysers/analyser.py b/textLSP/analysers/analyser.py index 1a6b0f3..77d1071 100644 --- a/textLSP/analysers/analyser.py +++ b/textLSP/analysers/analyser.py @@ -1,4 +1,4 @@ -import bisect +import logging from typing import List, Optional from pygls.server import LanguageServer @@ -21,14 +21,21 @@ TextEdit, Command, VersionedTextDocumentIdentifier, - MessageType, CompletionParams, CompletionList, ) from ..documents.document import BaseDocument, ChangeTracker from ..utils import merge_dicts -from ..types import Interval, TextLSPCodeActionKind, ProgressBar +from ..types import ( + Interval, + TextLSPCodeActionKind, + ProgressBar, + PositionDict, +) + + +logger = logging.getLogger(__name__) class Analyser(): @@ -75,98 +82,196 @@ def did_open(self, params: DidOpenTextDocumentParams): def _did_change(self, doc: Document, changes: List[Interval]): raise NotImplementedError() - def _get_line_shifts(self, params: DidChangeTextDocumentParams) -> List: - res = list() + def _handle_line_shifts(self, params: DidChangeTextDocumentParams): + # FIXME: this method is very complex, try to make it easier to read + should_update_diagnostics = False + doc = self.get_document(params) + + val = 0 + accumulative_shifts = list() + # handling inline shifts and building a list of line shifts for later for change in params.content_changes: if type(change) == TextDocumentContentChangeEvent_Type2: continue + if change.range.start != change.range.end: + tmp_range = Range( + start=Position( + line=change.range.start.line-val, + character=change.range.start.character, + ), + end=Position( + line=change.range.end.line-val, + character=change.range.start.character, + ), + ) + num = self.remove_code_items_at_range(doc, tmp_range, (True, False)) + should_update_diagnostics = should_update_diagnostics or num > 0 + + change_text_len = len(change.text) line_diff = change.range.end.line - change.range.start.line diff = change.text.count('\n') - line_diff - if diff != 0: - res.append((change.range.start.line, diff)) + if diff == 0: + in_line_diff = change.range.start.character - change.range.end.character + in_line_diff += change_text_len + if in_line_diff != 0: + # if only edits in a given line, let's shift the items + # in the line + next_pos = Position( + line=change.range.start.line+1, + character=0, + ) - return res + for diag in list( + self._diagnostics_dict[doc.uri].irange_values( + minimum=change.range.end, + maximum=next_pos, + inclusive=(True, False) + ) + ): + item_range = diag.range + diag.range = Range( + start=Position( + line=item_range.start.line, + character=item_range.start.character+in_line_diff + ), + end=Position( + line=item_range.end.line, + character=item_range.end.character + + (in_line_diff if item_range.start.line == + item_range.end.line else 0) + ) + ) + self._diagnostics_dict[doc.uri].update( + item_range.start, + diag.range.start, + diag + ) + should_update_diagnostics = True + + for action in list( + self._code_actions_dict[doc.uri].irange_values( + minimum=change.range.end, + maximum=next_pos, + inclusive=(True, False) + ) + ): + item_range = action.edit.document_changes[0].edits[0].range + action.edit.document_changes[0].edits[0].range = Range( + start=Position( + line=item_range.start.line, + character=item_range.start.character+in_line_diff + ), + end=Position( + line=item_range.end.line, + character=item_range.end.character + + (in_line_diff if item_range.start.line == + item_range.end.line else 0) + ) + ) + self._code_actions_dict[doc.uri].update( + item_range.start, + action.edit.document_changes[0].edits[0].range.start, + action + ) + else: + # There is a line shift: diff > 0 + val += diff + accumulative_shifts.append((change.range.start, val, change)) + pos = doc.last_position(True) + pos = Position( + line=pos.line - (accumulative_shifts[-1][1] if len(accumulative_shifts) else 0) + 1, + character=0 + ) + accumulative_shifts.append((pos, val)) - def _handle_line_shifts(self, doc: BaseDocument, line_shifts: List): - """ - params: line_shifts: List of tuples (line, shift) should be sorted - """ - if len(line_shifts) == 0: - return + if len(accumulative_shifts) == 0: + return should_update_diagnostics - val = 0 - bisect_lst = [line_shifts[0][0]] - accumulative_shifts = [(line_shifts[0][0], 0)] - for shift in line_shifts: - val += shift[1] - accumulative_shifts.append((shift[0]+1, val)) - bisect_lst.append(shift[0]+1) - num_shifts = len(accumulative_shifts) - - # TODO extract to function - # diagnostics - diagnostics = list() - for diag in self._diagnostics_dict[doc.uri]: - range = diag.range - idx = bisect.bisect_left(bisect_lst, range.start.line) - idx = min(idx, num_shifts-1) + # handling line shifts ############################################ + for idx in range(len(accumulative_shifts)-1): + pos = accumulative_shifts[idx][0] + next_pos = accumulative_shifts[idx+1][0] shift = accumulative_shifts[idx][1] - if shift != 0: - if range.start.line + shift < 0: - continue + for diag in list( + self._diagnostics_dict[doc.uri].irange_values( + minimum=pos, + maximum=next_pos, + inclusive=(True, False) + ) + ): + item_range = diag.range + char_shift = 0 + if item_range.start.line == pos.line: + char_shift = item_range.start.character - \ + (pos.character + len(accumulative_shifts[idx][2].text)) diag.range = Range( start=Position( - line=range.start.line + shift, - character=range.start.character + line=item_range.start.line + shift, + character=item_range.start.character - char_shift ), end=Position( - line=range.end.line + shift, - character=range.end.character + line=item_range.end.line + shift, + character=item_range.end.character - + (char_shift if item_range.start.line == + item_range.end.line else 0) ) ) - diagnostics.append(diag) - self._diagnostics_dict[doc.uri] = diagnostics - - # code actions - code_actions = list() - for action in self._code_actions_dict[doc.uri]: - range = action.edit.document_changes[0].edits[0].range - idx = bisect.bisect_left(bisect_lst, range.start.line) - idx = min(idx, num_shifts-1) - shift = accumulative_shifts[idx][1] + self._diagnostics_dict[doc.uri].update( + item_range.start, + diag.range.start, + diag + ) + should_update_diagnostics = True - if shift != 0: - if range.start.line + shift < 0: - continue + for action in list( + self._code_actions_dict[doc.uri].irange_values( + minimum=pos, + maximum=next_pos, + inclusive=(True, False) + ) + ): + item_range = action.edit.document_changes[0].edits[0].range + char_shift = 0 + if item_range.start.line == pos.line: + char_shift = item_range.start.character - \ + (pos.character + len(accumulative_shifts[idx][2].text)) action.edit.document_changes[0].edits[0].range = Range( start=Position( - line=range.start.line + shift, - character=range.start.character + line=item_range.start.line + shift, + character=item_range.start.character - char_shift ), end=Position( - line=range.end.line + shift, - character=range.end.character + line=item_range.end.line + shift, + character=item_range.end.character - + (char_shift if item_range.start.line == + item_range.end.line else 0) ) ) - code_actions.append(action) - self._code_actions_dict[doc.uri] = code_actions + self._code_actions_dict[doc.uri].update( + item_range.start, + action.edit.document_changes[0].edits[0].range.start, + action + ) + + return should_update_diagnostics def _remove_overflown_code_items(self, doc: BaseDocument): last_position = doc.last_position(True) - self._diagnostics_dict[doc.uri] = [ - diag - for diag in self._diagnostics_dict[doc.uri] - if diag.range.start <= last_position - ] + self._diagnostics_dict[doc.uri].remove_from(last_position, False) + self._code_actions_dict[doc.uri].remove_from(last_position, False) - self._code_actions_dict[doc.uri] = [ - action - for action in self._code_actions_dict[doc.uri] - if action.edit.document_changes[0].edits[0].range.start <= last_position - ] + def _handle_shifts(self, params: DidChangeTextDocumentParams): + """ + Handlines line shifts and position shifts within lines + """ + doc = self.get_document(params) + should_update_diagnostics = self._handle_line_shifts(params) + self._remove_overflown_code_items(doc) + + return should_update_diagnostics def _update_single_code_action(self, action: CodeAction, doc: BaseDocument): # update document version @@ -187,13 +292,9 @@ def _update_code_actions(self, doc: BaseDocument): doc, ) - def did_change(self, params: DidChangeTextDocumentParams): - # TODO handle shifts within lines - line_shifts = self._get_line_shifts(params) doc = self.get_document(params) - self._handle_line_shifts(doc, line_shifts) - self._remove_overflown_code_items(doc) + should_update_diagnostics = self._handle_shifts(params) self._update_code_actions(doc) if self.should_run_on(Analyser.CONFIGURATION_CHECK_ON_CHANGE): @@ -203,18 +304,18 @@ def did_change(self, params: DidChangeTextDocumentParams): ) else: changes = self._content_change_dict[doc.uri].get_changes() + self._content_change_dict[doc.uri] = ChangeTracker(doc, True) with ProgressBar( self.language_server, f'{self.name} checking', token=self._progressbar_token ): self._did_change(doc, changes) - self._content_change_dict[doc.uri] = ChangeTracker(doc, True) - elif len(line_shifts) > 0: + elif should_update_diagnostics: self.language_server.publish_stored_diagnostics(doc) def update_document(self, doc: Document, change: TextDocumentContentChangeEvent): - self._content_change_dict[doc.uri].update_document(change) + self._content_change_dict[doc.uri].update_document(change, doc) def did_save(self, params: DidSaveTextDocumentParams): if self.should_run_on(Analyser.CONFIGURATION_CHECK_ON_SAVE): @@ -227,13 +328,13 @@ def did_save(self, params: DidSaveTextDocumentParams): ) else: changes = self._content_change_dict[doc.uri].get_changes() + self._content_change_dict[doc.uri] = ChangeTracker(doc, True) with ProgressBar( self.language_server, f'{self.name} checking', token=self._progressbar_token ): self._did_change(doc, changes) - self._content_change_dict[doc.uri] = ChangeTracker(doc, True) def _did_close(self, doc: Document): pass @@ -273,31 +374,24 @@ def should_run_on(self, event: str) -> bool: ) def init_diagnostics(self, doc: Document): - self._diagnostics_dict[doc.uri] = list() + self._diagnostics_dict[doc.uri] = PositionDict() def get_diagnostics(self, doc: Document): - return self._diagnostics_dict.get(doc.uri, list()) + return self._diagnostics_dict.get(doc.uri, PositionDict()) def add_diagnostics(self, doc: Document, diagnostics: List[Diagnostic]): - self._diagnostics_dict[doc.uri] += diagnostics + for diag in diagnostics: + self._diagnostics_dict[doc.uri].add(diag.range.start, diag) self.language_server.publish_stored_diagnostics(doc) - def remove_code_items_at_rage(self, doc: Document, pos_range: Range): - diagnostics = list() - for diag in self.get_diagnostics(doc): - if diag.range.end < pos_range.start or diag.range.start > pos_range.end: - diagnostics.append(diag) - self._diagnostics_dict[doc.uri] = diagnostics - - code_actions = list() - for action in self._code_actions_dict[doc.uri]: - range = action.edit.document_changes[0].edits[0].range - if range.end < pos_range.start or range.start > pos_range.end: - code_actions.append(action) - self._code_actions_dict[doc.uri] = code_actions + def remove_code_items_at_range(self, doc: Document, pos_range: Range, inclusive=(True, True)): + num = 0 + num += self._diagnostics_dict[doc.uri].remove_between(pos_range, inclusive) + num += self._code_actions_dict[doc.uri].remove_between(pos_range, inclusive) + return num def init_code_actions(self, doc: Document): - self._code_actions_dict[doc.uri] = list() + self._code_actions_dict[doc.uri] = PositionDict() def get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction]]: doc = self.get_document(params) @@ -306,11 +400,12 @@ def get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction # TODO make this faster? res = [ action - for action in self._code_actions_dict[doc.uri] + for action in self._code_actions_dict[doc.uri].irange_values(maximum=range.start) if ( ( - action.edit.document_changes[0].edits[0].range.start <= range.start - and action.edit.document_changes[0].edits[0].range.end >= range.end + # action.edit.document_changes[0].edits[0].range.start <= range.start + # and + action.edit.document_changes[0].edits[0].range.end >= range.end ) # if it's not reachable by the cursor or ( @@ -379,7 +474,11 @@ def get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction return res def add_code_actions(self, doc: Document, actions: List[CodeAction]): - self._code_actions_dict[doc.uri] += actions + for action in actions: + self._code_actions_dict[doc.uri].add( + action.edit.document_changes[0].edits[0].range.start, + action, + ) @staticmethod def build_single_suggestion_action( diff --git a/textLSP/analysers/gramformer/gramformer.py b/textLSP/analysers/gramformer/gramformer.py index 3b1354c..c28035b 100644 --- a/textLSP/analysers/gramformer/gramformer.py +++ b/textLSP/analysers/gramformer/gramformer.py @@ -138,7 +138,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): for change in changes: paragraph = doc.paragraph_at_offset( change.start, - min_length=change.length, + min_offset=change.start + change.length-1, cleaned=True, ) if paragraph in checked: @@ -149,7 +149,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): paragraph.length, True ) - self.remove_code_items_at_rage(doc, pos_range) + self.remove_code_items_at_range(doc, pos_range) diags, actions = self._analyse_sentences( doc.text_at_offset( diff --git a/textLSP/analysers/grammarbot/grammarbot.py b/textLSP/analysers/grammarbot/grammarbot.py index 335bc17..91a9389 100644 --- a/textLSP/analysers/grammarbot/grammarbot.py +++ b/textLSP/analysers/grammarbot/grammarbot.py @@ -107,7 +107,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): for change in changes: paragraph = doc.paragraph_at_offset( change.start, - min_length=change.length, + min_offset=change.start + change.length-1, cleaned=True, ) if paragraph in checked: @@ -119,7 +119,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): paragraph.length, True ) - self.remove_code_items_at_rage(doc, pos_range) + self.remove_code_items_at_range(doc, pos_range) paragraph_text = doc.text_at_offset(paragraph.start, paragraph.length) text += paragraph_text diff --git a/textLSP/analysers/handler.py b/textLSP/analysers/handler.py index ea37167..6622815 100644 --- a/textLSP/analysers/handler.py +++ b/textLSP/analysers/handler.py @@ -72,15 +72,37 @@ def update_settings(self, settings): if name not in self.analysers: analyser.close() + def shutdown(self): + for analyser in self.analysers.values(): + analyser.close() + def get_diagnostics(self, doc: Document): - return [analyser.get_diagnostics(doc) for analyser in self.analysers.values()] + try: + return [ + analyser.get_diagnostics(doc) + for analyser in self.analysers.values() + ] + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) + return [] def get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction]]: res = list() - for analyser in self.analysers.values(): - tmp_lst = analyser.get_code_actions(params) - if tmp_lst is not None and len(tmp_lst) > 0: - res.extend(tmp_lst) + try: + for analyser in self.analysers.values(): + tmp_lst = analyser.get_code_actions(params) + if tmp_lst is not None and len(tmp_lst) > 0: + res.extend(tmp_lst) + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) return res if len(res) > 0 else None @@ -96,7 +118,16 @@ async def _submit_task(self, function, *args, **kwargs): if len(functions) == 0: return - await asyncio.wait(functions) + done, pending = await asyncio.wait(functions) + for task in done: + try: + task.result() + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) async def _did_open( self, @@ -206,6 +237,12 @@ async def command_analyse(self, *args): str(f'{analyser_name}: {e}'), MessageType.Error, ) + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) else: await self._submit_task(self._command_analyse, args) @@ -217,7 +254,14 @@ async def command_custom_command(self, *args): ext_command = f'command_{command}' if hasattr(analyser, ext_command): - getattr(analyser, ext_command)(**args) + try: + getattr(analyser, ext_command)(**args) + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) else: self.language_server.show_message( str(f'No custom command supported by {analyser}: {command}'), @@ -230,10 +274,17 @@ def update_document(self, doc: Document, change: TextDocumentContentChangeEvent) def get_completions(self, params: Optional[CompletionParams] = None) -> CompletionList: comp_lst = list() - for _, analyser in self.analysers.items(): - tmp = analyser.get_completions(params) - if tmp is not None and len(tmp) > 0: - comp_lst.extend(tmp) + try: + for _, analyser in self.analysers.items(): + tmp = analyser.get_completions(params) + if tmp is not None and len(tmp) > 0: + comp_lst.extend(tmp) + except Exception as e: + self.language_server.show_message( + str('Server error. See log for details.'), + MessageType.Error, + ) + logger.exception(str(e)) return CompletionList( is_incomplete=False, diff --git a/textLSP/analysers/hf_checker/hf_checker.py b/textLSP/analysers/hf_checker/hf_checker.py index 27d5197..9736e95 100644 --- a/textLSP/analysers/hf_checker/hf_checker.py +++ b/textLSP/analysers/hf_checker/hf_checker.py @@ -140,7 +140,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): for change in changes: paragraph = doc.paragraph_at_offset( change.start, - min_length=change.length, + min_offset=change.start + change.length-1, cleaned=True, ) if paragraph in checked: @@ -151,7 +151,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): paragraph.length, True ) - self.remove_code_items_at_rage(doc, pos_range) + self.remove_code_items_at_range(doc, pos_range) diags, actions = self._analyse_lines( doc.text_at_offset( diff --git a/textLSP/analysers/languagetool/languagetool.py b/textLSP/analysers/languagetool/languagetool.py index 7188d07..c83b0dc 100644 --- a/textLSP/analysers/languagetool/languagetool.py +++ b/textLSP/analysers/languagetool/languagetool.py @@ -82,7 +82,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): for change in changes: paragraph = doc.paragraph_at_offset( change.start, - min_length=change.length, + min_offset=change.start + change.length-1, cleaned=True, ) if paragraph in checked: @@ -127,7 +127,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): end_sent.start-paragraph.start-1 + end_sent.length, True ) - self.remove_code_items_at_rage(doc, pos_range) + self.remove_code_items_at_range(doc, pos_range) diags, actions = self._analyse( doc.text_at_offset( @@ -171,6 +171,9 @@ def close(self): tool.close() self.tool = dict() + def __del__(self): + self.close() + def _get_mapped_language(self, language): return LANGUAGE_MAP[language] diff --git a/textLSP/analysers/openai/openai.py b/textLSP/analysers/openai/openai.py index 0207c44..4c22656 100644 --- a/textLSP/analysers/openai/openai.py +++ b/textLSP/analysers/openai/openai.py @@ -1,6 +1,6 @@ import logging import openai -from openai.error import OpenAIError +from openai import OpenAI, APIError from typing import List, Tuple, Optional from lsprotocol.types import ( @@ -54,19 +54,20 @@ def __init__(self, language_server: LanguageServer, config: dict, name: str): super().__init__(language_server, config, name) if self.CONFIGURATION_API_KEY not in self.config: raise ConfigurationError(f'Reqired parameter: {name}.{self.CONFIGURATION_API_KEY}') - openai.api_key = self.config[self.CONFIGURATION_API_KEY] + self._client = OpenAI(api_key=self.config[self.CONFIGURATION_API_KEY]) def _edit(self, text) -> List[TokenDiff]: try: - res = openai.Edit.create( + # res = openai.Edit.create( + res = self._client.edits.create( model=self.config.get(self.CONFIGURATION_EDIT_MODEL, self.SETTINGS_DEFAULT_EDIT_MODEL), instruction=self.config.get(self.CONFIGURATION_EDIT_INSTRUCTION, self.SETTINGS_DEFAULT_EDIT_INSTRUCTION), input=text, temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE), ) if len(res.choices) > 0: - return TokenDiff.token_level_diff(text, res.choices[0]['text'].strip()) - except OpenAIError as e: + return TokenDiff.token_level_diff(text, res.choices[0].text.strip()) + except APIError as e: self.language_server.show_message( str(e), MessageType.Error, @@ -76,7 +77,7 @@ def _edit(self, text) -> List[TokenDiff]: def _generate(self, text) -> Optional[str]: try: - res = openai.Completion.create( + res = self._client.completions.create( model=self.config.get(self.CONFIGURATION_MODEL, self.SETTINGS_DEFAULT_MODEL), prompt=text, temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE), @@ -84,8 +85,8 @@ def _generate(self, text) -> Optional[str]: ) if len(res.choices) > 0: - return res.choices[0]['text'].strip() - except OpenAIError as e: + return res.choices[0].text.strip() + except APIError as e: self.language_server.show_message( str(e), MessageType.Error, @@ -150,7 +151,7 @@ def _did_open(self, doc: BaseDocument): diagnostics = list() code_actions = list() checked = set() - for paragraph in doc.paragraphs_at_offset(0, len(doc.cleaned_source), True): + for paragraph in doc.paragraphs_at_offset(0, len(doc.cleaned_source), cleaned=True): diags, actions = self._handle_paragraph(doc, paragraph) diagnostics.extend(diags) code_actions.extend(actions) @@ -166,7 +167,7 @@ def _did_change(self, doc: BaseDocument, changes: List[Interval]): for change in changes: paragraph = doc.paragraph_at_offset( change.start, - min_length=change.length, + min_offset=change.start + change.length-1, cleaned=True, ) if paragraph in checked: @@ -189,7 +190,7 @@ def _handle_paragraph(self, doc: BaseDocument, paragraph: Interval): paragraph.length, True ) - self.remove_code_items_at_rage(doc, pos_range) + self.remove_code_items_at_range(doc, pos_range) diags, actions = self._analyse( doc.text_at_offset( @@ -265,7 +266,10 @@ def get_code_actions(self, params: CodeActionParams) -> Optional[List[CodeAction if params.range.start != params.range.end: return res - line = doc.lines[params.range.start.line].strip() + if len(doc.lines) > 0: + line = doc.lines[params.range.start.line].strip() + else: + line = '' magic = self.config.get(self.CONFIGURATION_PROMPT_MAGIC, self.SETTINGS_DEFAULT_PROMPT_MAGIC) if magic in line: if res is None: diff --git a/textLSP/documents/document.py b/textLSP/documents/document.py index f238d89..c1f0433 100644 --- a/textLSP/documents/document.py +++ b/textLSP/documents/document.py @@ -1,29 +1,36 @@ import logging import tempfile +import sys +import copy from typing import Optional, Generator, List, Dict from dataclasses import dataclass +from itertools import chain from lsprotocol.types import ( Range, Position, TextDocumentContentChangeEvent, + TextDocumentContentChangeEvent_Type1, TextDocumentContentChangeEvent_Type2, ) -from pygls.workspace import Document, position_from_utf16 +from pygls.workspace import TextDocument +from pygls.workspace.position_codec import PositionCodec from tree_sitter import Language, Parser, Tree, Node from ..utils import get_class, synchronized, git_clone, get_user_cache from ..types import ( + OffsetPositionInterval, OffsetPositionIntervalList, Interval ) from .. import documents logger = logging.getLogger(__name__) +_codec = PositionCodec() -class BaseDocument(Document): +class BaseDocument(TextDocument): def __init__(self, *args, config: Dict = None, **kwargs): super().__init__(*args, **kwargs) if config is None: @@ -97,7 +104,7 @@ def range_at_offset(self, offset: int, length: int, cleaned=False) -> Range: def offset_at_position(self, position: Position, cleaned=False) -> int: # doesn't really matter lines = self.cleaned_lines if cleaned else self.lines - pos = position_from_utf16(lines, position) + pos = _codec.position_from_client_units(lines, position) row, col = pos.line, pos.character return col + sum(len(line) for line in lines[:row]) @@ -134,17 +141,19 @@ def sentence_at_offset(self, offset: int, min_length=0, cleaned=False) -> Interv return Interval(start_idx, end_idx-start_idx+1) - def paragraph_at_offset(self, offset: int, min_length=0, cleaned=False) -> Interval: + def paragraph_at_offset(self, offset: int, min_length=0, min_offset=0, cleaned=False) -> Interval: """ + Returns the last paragraph if offset is over the content length. returns (start_offset, length) """ - start_idx = offset - end_idx = offset source = self.cleaned_source if cleaned else self.source len_source = len(source) + start_idx = offset assert start_idx >= 0 - assert end_idx < len_source + if start_idx >= len_source: + start_idx = len_source - 1 + end_idx = start_idx while ( start_idx >= 0 @@ -165,7 +174,7 @@ def paragraph_at_offset(self, offset: int, min_length=0, cleaned=False) -> Inter ): end_idx += 1 - if end_idx < len_source-1 and end_idx-start_idx+1 < min_length: + if end_idx < len_source-1 and (end_idx-start_idx+1 < min_length or end_idx <= min_offset): end_idx += 1 else: break @@ -178,12 +187,12 @@ def paragraph_at_position(self, position: Position, cleaned=False) -> Interval: return None return self.paragraph_at_offset(offset, cleaned=cleaned) - def paragraphs_at_offset(self, offset: int, min_length=0, cleaned=False) -> List[Interval]: + def paragraphs_at_offset(self, offset: int, min_length=0, min_offset=0, cleaned=False) -> List[Interval]: res = list() - doc_lenght = len(self.cleaned_source if cleaned else self.source) + doc_length = len(self.cleaned_source if cleaned else self.source) length = 0 - while offset < doc_lenght and (length < min_length or length == 0): + while offset < doc_length and (length < min_length or offset <= min_offset or length == 0): paragraph = self.paragraph_at_offset(offset, cleaned=cleaned) res.append(paragraph) @@ -242,8 +251,8 @@ def _clean_source(self): raise NotImplementedError() def apply_change(self, change: TextDocumentContentChangeEvent) -> None: - super().apply_change(change) self._cleaned_source = None + super().apply_change(change) def position_at_offset(self, offset: int, cleaned=False) -> Position: if not cleaned: @@ -329,6 +338,8 @@ class TreeSitterDocument(CleanableDocument): def __init__(self, language_name, grammar_url, branch, *args, **kwargs): super().__init__(*args, **kwargs) + ####################################################################### + # Do not deepcopy these self._language = self.get_language(language_name, grammar_url, branch) self._parser = self.get_parser( language_name, @@ -336,8 +347,23 @@ def __init__(self, language_name, grammar_url, branch, *args, **kwargs): branch, self._language ) + self._tree = None + self._query = self._build_query() + ####################################################################### + self._text_intervals = None + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k not in {'_language', '_parser', '_tree', '_query'}: + setattr(result, k, copy.deepcopy(v, memo)) + else: + setattr(result, k, v) + return result + @classmethod def build_library(cls, name, url, branch=None) -> None: with tempfile.TemporaryDirectory() as tmpdir: @@ -371,15 +397,25 @@ def get_parser(cls, name=None, url=None, branch=None, language=None) -> Parser: parser.set_language(language) return parser + def _build_query(self): + raise NotImplementedError() + def _parse_source(self): return self._parser.parse(bytes(self.source, 'utf-8')) - def _clean_source(self): - tree = self._parse_source() + @property + def tree(self) -> Tree: + if self._tree is None: + self._tree = self._parse_source() + return self._tree + + def _clean_source(self, change: TextDocumentContentChangeEvent_Type1 = None): self._text_intervals = OffsetPositionIntervalList() offset = 0 - for node in self._iterate_text_nodes(tree): + start_point = (0, 0) + end_point = (sys.maxsize, sys.maxsize) + for node in self._iterate_text_nodes(self.tree, start_point, end_point): node_len = len(node) self._text_intervals.add_interval_values( offset, @@ -394,9 +430,530 @@ def _clean_source(self): self._cleaned_source = ''.join(self._text_intervals.values) - def _iterate_text_nodes(self, tree: Tree) -> Generator[TextNode, None, None]: + def _iterate_text_nodes( + self, + tree: Tree, + start_point, + end_point, + ) -> Generator[TextNode, None, None]: raise NotImplementedError() + def _get_edit_positions(self, change): + lines = self.lines + change_range = change.range + change_range = _codec.range_from_client_units(lines, change_range) + start_line = change_range.start.line + start_col = change_range.start.character + end_line = change_range.end.line + end_col = change_range.end.character + len_lines = len(lines) + if len_lines == 0: + start_byte = 0 + end_byte = 0 + else: + if end_line >= len(lines): + # this could happen eg when the last line is deleted + end_line = len(lines) - 1 + end_col = len(lines[end_line]) - 1 + + start_byte = len(bytes( + ''.join( + lines[:start_line] + [lines[start_line][:start_col]] + ), + 'utf-8', + )) + end_byte = len(bytes( + ''.join( + lines[:end_line] + [lines[end_line][:end_col]] + ), + 'utf-8', + )) + text_bytes = len(bytes(change.text, 'utf-8')) + + if end_byte - start_byte == 0: + # INSERT + old_end_byte = start_byte + new_end_byte = start_byte + text_bytes + start_point = (start_line, start_col) + old_end_point = start_point + new_lines = change.text.count('\n') + new_end_point = ( + start_line + new_lines, + (start_col + text_bytes) if new_lines == 0 else len(bytes( + change.text.split('\n')[-1], + 'utf-8' + )), + ) + elif text_bytes == 0: + # DELETE + old_end_byte = end_byte + new_end_byte = start_byte + start_point = (start_line, start_col) + old_end_point = (end_line, end_col) + new_end_point = start_point + else: + # REPLACE + old_end_byte = end_byte + new_end_byte = start_byte + text_bytes + start_point = (start_line, start_col) + old_end_point = (end_line, end_col) + + new_lines = change.text.count('\n') + deleted_lines = end_line - start_line + if new_lines == 0 and deleted_lines == 0: + new_end_line = end_line + new_end_col = end_col + text_bytes - (end_col - start_col) + elif new_lines > 0 and deleted_lines == 0: + new_end_line = end_line + new_lines + new_end_col = len(bytes(change.text.split('\n')[-1], 'utf-8')) + elif new_lines == 0 and deleted_lines > 0: + new_end_line = end_line - deleted_lines + new_end_col = end_col + text_bytes - (end_col - start_col) + else: + new_end_line = end_line + new_lines - deleted_lines + new_end_col = len(bytes(change.text.split('\n')[-1], 'utf-8')) + + new_end_point = ( + new_end_line, + new_end_col, + ) + + return ( + start_line, + start_col, + end_line, + end_col, + start_byte, + old_end_byte, + new_end_byte, + text_bytes, + start_point, + old_end_point, + new_end_point, + ) + + def _get_node_and_iterator_for_edit( + self, + start_point, + old_end_point, + new_end_point, + last_changed_point, + old_tree_first_node_new_end_point, + old_tree_end_point, + ): + sp = start_point + if len(self._text_intervals) > 0: + old_first_interval_end_point = ( + self._text_intervals.get_interval(0).position_range.end.line, + self._text_intervals.get_interval(0).position_range.end.character + ) + else: + old_first_interval_end_point = (0, 0) + if start_point < old_first_interval_end_point: + # there's new content at the beginning, we need to parse the next + # subtree as well, since there are no necesary whitespace tokens in + # the current text_intervals + tmp_point = old_tree_first_node_new_end_point + else: + tmp_point = (0, 0) + # last_changed_point is needed to handle subtrees being broken into + # multiple ones + ep = max(tmp_point, new_end_point, last_changed_point) + + if start_point > old_tree_end_point: + # edit at the end of the file + # need to extend the range to include the last node since there + # might be relevant content (e.g. multiple newlines) that was + # ignored since it was at the end + if old_end_point[1] > 0: + sp = (old_tree_end_point[0], max(0, old_tree_end_point[1]-1)) + else: + sp = (max(0, old_tree_end_point[0]-1), 0) + + node_iter = self._iterate_text_nodes(self.tree, sp, ep) + node = next(node_iter) + while node.text == '\n' and node.start_point == (0, 1) and node.end_point == (0, 1): + # empty tree is selected + assert next(node_iter, None) is None + if sp > (0, 0): + sp = (max(0, sp[0]-1), 0) + else: + node.start_point = start_point + node.end_point = start_point + break + + node_iter = self._iterate_text_nodes(self.tree, sp, ep) + node = next(node_iter) + + return node, chain([node], node_iter) + + def _get_intervals_before_edit( + self, + node, + ): + # offset = 0 + for interval_idx in range(len(self._text_intervals)): + interval = self._text_intervals.get_interval(interval_idx) + if interval.value == '\n' and interval.position_range.start == interval.position_range.end: + # newline added by parser but not in source + interval_end = (interval.position_range.end.line+1, 0) + if interval_end >= node.start_point: + # FIXME This is very messy. Handling these dummy newlines + # should be refactored. + interval.value = ' ' + # offset += len(interval.value) + # text_intervals.add_interval(interval) + yield interval + break + else: + interval_end = ( + interval.position_range.end.line, + interval.position_range.end.character, + ) + if interval_end >= node.start_point: + break + + # offset += len(interval.value) + # text_intervals.add_interval(interval) + yield interval + + def _get_edited_intervals_and_last_node( + self, + node_iter, + offset, + ): + tmp_intvals = list() + last_new_node = None + tmp_node = None + for node in node_iter: + node_len = len(node) + tmp_intvals.append( + OffsetPositionInterval( + offset_interval=Interval( + start=offset, + length=node_len + ), + position_range=Range( + start=Position( + line=node.start_point[0], + character=node.start_point[1], + ), + end=Position( + line=node.end_point[0], + character=node.end_point[1], + ), + ), + value=node.text, + ) + ) + offset += node_len + last_new_node = tmp_node + tmp_node = node + + return tmp_intvals, last_new_node + + def _get_idx_after_edited_tree( + self, + old_end_point, + new_end_point, + text_bytes, + last_new_end_point, + last_changed_point + ): + # we take the max since none parseable content could have been + # added at the end + last_new_end_point = max(last_changed_point, last_new_end_point) + + row_diff = new_end_point[0] - old_end_point[0] + if last_new_end_point[0] < new_end_point[0]: + # parse ended before the edit, happens when non parseable + # part is edited or all content was deleted + last_new_end_point = ( + max(old_end_point, new_end_point)[0], + max(old_end_point, new_end_point)[1] + 1 + ) + elif last_new_end_point[0] > new_end_point[0]: + # parse ended in a later line as the edit, i.e. its + # position is only affected by line shift + last_new_end_point = ( + last_new_end_point[0] - row_diff, + last_new_end_point[1] + 1 + ) + elif row_diff == 0: + # the parse ended in the line of the edit + last_new_end_point = ( + last_new_end_point[0], + last_new_end_point[1] - (new_end_point[1] - old_end_point[1]) + text_bytes + 1 + ) + elif row_diff > 0: + # the edit was in the line of the last node which is now + # shifted + last_new_end_point = ( + last_new_end_point[0] - row_diff, + old_end_point[1] + last_new_end_point[1] - new_end_point[1] + 1 + ) + else: + # the edit was in the line of the last node which is now + # shifted + last_new_end_point = ( + last_new_end_point[0] - row_diff, + new_end_point[1] + last_new_end_point[1] - old_end_point[1] + 1 + ) + + last_idx = self._text_intervals.get_idx_at_position( + Position( + line=max(0, last_new_end_point[0]), + character=max(0, last_new_end_point[1]) + ), + strict=False, + ) + return last_idx + + def _handle_intervals_after_edit_shifted( + self, + last_idx, + start_col, + end_line, + end_col, + old_end_point, + new_end_point, + text_bytes, + offset, + text_intervals, + ): + while last_idx > 1: + interval = self._text_intervals.get_interval(last_idx-1) + if (interval.value != '\n' or interval.position_range.start != + interval.position_range.end): + # not dummy newline + break + last_idx -= 1 + + row_diff = new_end_point[0] - old_end_point[0] + col_diff = text_bytes - (end_col - start_col) + for interval_idx in range(last_idx, len(self._text_intervals)): + interval = self._text_intervals.get_interval(interval_idx) + if ( + len(text_intervals) == 0 + and interval.value.count('\n') > 0 + and interval.value.strip() == '' + ): + continue + node_len = len(interval.value) + if interval.position_range.start.line > end_line: + start_line_offset = row_diff + start_char_offset = 0 + end_line_offset = row_diff + end_char_offset = 0 + elif (interval.position_range.start.line == end_line + and interval.position_range.start.character >= end_col): + start_line_offset = row_diff + start_char_offset = col_diff + end_line_offset = row_diff + if interval.position_range.end.line > interval.position_range.start.line: + end_char_offset = 0 + else: + end_char_offset = col_diff + else: + # These are the special newlines which are not in the source + # but added by the parser to separate paragraphs + assert (interval.value == '\n' and interval.position_range.start == + interval.position_range.end) + last_interval_range = text_intervals.get_interval(-1).position_range + interval_range = interval.position_range + # we need to set start and end position to the same value + # which is the same line as the last item in text_intervals + # and one column to the right + end_line_offset = last_interval_range.end.line - interval_range.end.line + start_line_offset = interval_range.end.line - interval_range.start.line + end_line_offset + end_char_offset = last_interval_range.end.character - interval_range.end.character + 1 + start_char_offset = interval_range.end.character - interval_range.start.character + end_char_offset + + text_intervals.add_interval_values( + offset, + offset+node_len-1, + interval.position_range.start.line + start_line_offset, + interval.position_range.start.character + start_char_offset, + interval.position_range.end.line + end_line_offset, + interval.position_range.end.character + end_char_offset, + interval.value, + ) + offset += node_len + + def _build_updated_text_intervals( + self, + start_line, + start_col, + end_line, + end_col, + start_byte, + old_end_byte, + new_end_byte, + start_point, + old_end_point, + new_end_point, + text_bytes, + last_changed_point, + old_tree_first_node_new_end_point, + old_tree_end_point, + ): + text_intervals = OffsetPositionIntervalList() + + # get first edited node and iterator for all edited nodes + node, node_iter = self._get_node_and_iterator_for_edit( + start_point, + old_end_point, + new_end_point, + last_changed_point, + old_tree_first_node_new_end_point, + old_tree_end_point, + ) + + # copy the text intervals up to the start of the change + for interval in self._get_intervals_before_edit(node): + text_intervals.add_interval(interval) + + if len(text_intervals) > 0: + offset = interval.offset_interval.start + interval.offset_interval.length + else: + offset = 0 + + # handle the nodes that were in the edited subtree + new_intervals, last_new_node = self._get_edited_intervals_and_last_node( + node_iter, + offset, + ) + if last_new_node is None: + return None + + for interval in new_intervals[:-1]: + # there's always a newline return at the end of the file which + # is not needed if we are not really at the end of the file yet + # text_intervals.add_interval_values(*interval) + text_intervals.add_interval(interval) + offset = interval.offset_interval.start + interval.offset_interval.length + + # add remaining intervals shifted + last_new_end_point = last_new_node.end_point + last_idx = self._get_idx_after_edited_tree( + old_end_point, + new_end_point, + text_bytes, + last_new_end_point, + last_changed_point + ) + if last_idx+1 >= len(self._text_intervals): + # we are actully at the end of the file so add the final newline + text_intervals.add_interval(new_intervals[-1]) + else: + self._handle_intervals_after_edit_shifted( + last_idx, + start_col, + end_line, + end_col, + old_end_point, + new_end_point, + text_bytes, + offset, + text_intervals, + ) + + return text_intervals + + def _apply_incremental_change(self, change: TextDocumentContentChangeEvent_Type1) -> None: + """Apply an ``Incremental`` text change to the document""" + if self._tree is None: + super()._apply_incremental_change(change) + return + + tree = self.tree + ( + start_line, + start_col, + end_line, + end_col, + start_byte, + old_end_byte, + new_end_byte, + text_bytes, + start_point, + old_end_point, + new_end_point, + ) = self._get_edit_positions(change) + + # bookkeeping for later source cleaning + capture = self._query.captures( + tree.root_node, + ) + if len(capture) == 0: + old_tree_first_node = None + old_tree_end_point = None + else: + old_tree_first_node = capture[0][0] + old_tree_end_point = capture[-1][0].end_point + + tree.edit( + start_byte=start_byte, + old_end_byte=old_end_byte, + new_end_byte=new_end_byte, + start_point=start_point, + old_end_point=old_end_point, + new_end_point=new_end_point, + ) + super()._apply_incremental_change(change) + new_source = bytes(self.source, 'utf-8') + self._tree = self._parser.parse( + new_source, + tree + ) + + if old_tree_first_node is not None: + old_tree_first_node.edit( + start_byte=start_byte, + old_end_byte=old_end_byte, + new_end_byte=new_end_byte, + start_point=start_point, + old_end_point=old_end_point, + new_end_point=new_end_point, + ) + old_tree_first_node_new_end_point = old_tree_first_node.end_point + else: + old_tree_first_node_new_end_point = None + + last_changed_point = (-1, -1) + for change in tree.changed_ranges(self.tree): + last_changed_point = max(last_changed_point, change.end_point) + + if old_tree_end_point is not None: + # rebuild the cleaned source + text_intervals = self._build_updated_text_intervals( + start_line, + start_col, + end_line, + end_col, + start_byte, + old_end_byte, + new_end_byte, + start_point, + old_end_point, + new_end_point, + text_bytes, + last_changed_point, + old_tree_first_node_new_end_point, + old_tree_end_point, + ) + + if text_intervals is not None: + self._text_intervals = text_intervals + self._cleaned_source = ''.join(self._text_intervals.values) + else: + self._clean_source() + + def _apply_full_change(self, change: TextDocumentContentChangeEvent) -> None: + """Apply a ``Full`` text change to the document.""" + super()._apply_full_change(change) + self._tree = None + def position_at_offset(self, offset: int, cleaned=False) -> Position: if not cleaned: return super().position_at_offset(offset, cleaned) @@ -517,7 +1074,7 @@ def get_document( version: Optional[int] = None, language_id: Optional[str] = None, sync_kind=None, - ) -> Document: + ) -> TextDocument: try: type = DocumentTypeFactory.get_file_type(language_id) cls = get_class( @@ -548,7 +1105,8 @@ def get_document( class ChangeTracker(): def __init__(self, doc: BaseDocument, cleaned=False): - self.document = doc + self.document = None + self._set_document(doc) self.cleaned = cleaned length = len(doc.cleaned_source) if cleaned else len(doc.source) # list of tuples (span_length, was_changed) @@ -556,7 +1114,15 @@ def __init__(self, doc: BaseDocument, cleaned=False): self._items = [(length, False)] self.full_document_change = False - def update_document(self, change: TextDocumentContentChangeEvent): + def _set_document(self, doc: BaseDocument): + # XXX not too memory efficient + self.document = copy.deepcopy(doc) + + def update_document( + self, + change: TextDocumentContentChangeEvent, + updated_doc: BaseDocument + ): if self.full_document_change: return @@ -578,24 +1144,50 @@ def update_document(self, change: TextDocumentContentChangeEvent): item_idx, item_offset = self._get_offset_idx(start_offset) change_length = len(change.text) range_length = end_offset-start_offset - start_offset = start_offset - item_offset + relative_start_offset = start_offset - item_offset + + if relative_start_offset > 0: + # add item from the beginning of the item to the start of the change + new_lst.append((relative_start_offset, self._items[item_idx][1])) - if start_offset > 0: - new_lst.append((start_offset, self._items[item_idx][1])) + if start_offset == end_offset and change_length == 0: + # nothing to do (I'm not sure what this is) + self._set_document(updated_doc) + return + + if change_length == 0: + # deletion + new_lst.append((0, True)) - if change_length >= range_length: - effective_change_length = change_length + tmp_item = ( + self._items[item_idx][0]-relative_start_offset-range_length, + self._items[item_idx][1] + ) + if tmp_item[0] != 0: + new_lst.append(tmp_item) + elif range_length == 0: + # insertion + new_lst.append((change_length, True)) + + tmp_item = ( + self._items[item_idx][0]-relative_start_offset, + self._items[item_idx][1] + ) + if tmp_item[0] > 0: + new_lst.append(tmp_item) else: - effective_change_length = change_length-range_length - effective_change_length = max(effective_change_length, -1*start_offset) - new_lst.append((effective_change_length, True)) + # replacement + new_lst.append((change_length, True)) - new_lst.append(( - self._items[item_idx][0]-start_offset-range_length, - self._items[item_idx][1] - )) + tmp_item = ( + self._items[item_idx][0]-relative_start_offset-(change_length-range_length), + self._items[item_idx][1] + ) + if tmp_item[0] > 0: + new_lst.append(tmp_item) self._replace_at(item_idx, new_lst) + self._set_document(updated_doc) def _get_offset_idx(self, offset): pos = 0 @@ -625,6 +1217,7 @@ def get_changes(self) -> List[Interval]: return [Interval(0, doc_length)] res = list() + seen = set() pos = 0 for item in self._items: if item[1]: @@ -635,7 +1228,20 @@ def get_changes(self) -> List[Interval]: length = min(length*-1, doc_length-pos) else: position = pos - res.append(Interval(position, length)) + + if position >= doc_length: + position = doc_length-1 + length = 0 + + if length == 0 and position > 0: + position -= 1 + length = 1 + + intv = Interval(position, length) + + if intv not in seen: + res.append(intv) + seen.add(intv) pos += max(0, item[0]) return res diff --git a/textLSP/documents/latex/latex.py b/textLSP/documents/latex/latex.py index ecca371..129c215 100644 --- a/textLSP/documents/latex/latex.py +++ b/textLSP/documents/latex/latex.py @@ -13,6 +13,7 @@ class LatexDocument(TreeSitterDocument): CURLY_GROUP = 'curly_group' ENUM_ITEM = 'enum_item' GENERIC_ENVIRONMENT = 'generic_environment' + ERROR = 'ERROR' # content in syntex error, e.g. missing closing environment NODE_CONTENT = 'content' NODE_NEWLINE_BEFORE_AFTER = 'newline_before_after' @@ -24,6 +25,7 @@ class LatexDocument(TreeSitterDocument): CURLY_GROUP, ENUM_ITEM, GENERIC_ENVIRONMENT, + ERROR, } NEWLINE_BEFORE_AFTER_CURLY_PARENT = { @@ -44,7 +46,6 @@ def __init__(self, *args, **kwargs): *args, **kwargs, ) - self._query = self._build_query() def _build_query(self): query_str = '' @@ -59,13 +60,18 @@ def _build_query(self): return self._language.query(query_str) - def _iterate_text_nodes(self, tree: Tree) -> Generator[TextNode, None, None]: + def _iterate_text_nodes( + self, + tree: Tree, + start_point, + end_point, + ) -> Generator[TextNode, None, None]: lines = tree.text.decode('utf-8').split('\n') last_sent = None new_lines_after = list() - for node in self._query.captures(tree.root_node): + for node in self._query.captures(tree.root_node, start_point=start_point, end_point=end_point): # Check if we need some newlines after previous elements while len(new_lines_after) > 0: if node[0].start_point > new_lines_after[0]: diff --git a/textLSP/documents/markdown/markdown.py b/textLSP/documents/markdown/markdown.py index a029174..16af8e9 100644 --- a/textLSP/documents/markdown/markdown.py +++ b/textLSP/documents/markdown/markdown.py @@ -55,7 +55,6 @@ def __init__(self, *args, **kwargs): *args, **kwargs, ) - self._query = self._build_query() def _build_query(self): query_str = '' @@ -71,13 +70,28 @@ def _build_query(self): return self._language.query(query_str) - def _iterate_text_nodes(self, tree: Tree) -> Generator[TextNode, None, None]: + def _iterate_text_nodes( + self, + tree: Tree, + start_point, + end_point, + ) -> Generator[TextNode, None, None]: lines = tree.text.decode('utf-8').split('\n') last_sent = None new_lines_after = list() - for node in self._query.captures(tree.root_node): + if start_point == end_point: + # FIXME This is a weird issue, it seems that in some cases nothing + # is selected if the interval is empty, but not in all cases. See + # markdown_text.py test_edits() where first two characters of + # '# Header' is removed + end_point = ( + end_point[0], + end_point[1] + 1 + ) + + for node in self._query.captures(tree.root_node, start_point=start_point, end_point=end_point): # Check if we need some newlines after previous elements while len(new_lines_after) > 0: if node[0].start_point > new_lines_after[0]: diff --git a/textLSP/documents/org/org.py b/textLSP/documents/org/org.py index 957562c..d52fa0c 100644 --- a/textLSP/documents/org/org.py +++ b/textLSP/documents/org/org.py @@ -45,7 +45,6 @@ def __init__(self, *args, **kwargs): *args, **kwargs, ) - self._query = self._build_query() keywords = self.config.setdefault( self.CONFIGURATION_TODO_KEYWORDS, self.DEFAULT_TODO_KEYWORDS, @@ -70,13 +69,18 @@ def _build_query(self): return self._language.query(query_str) - def _iterate_text_nodes(self, tree: Tree) -> Generator[TextNode, None, None]: + def _iterate_text_nodes( + self, + tree: Tree, + start_point, + end_point, + ) -> Generator[TextNode, None, None]: lines = tree.text.decode('utf-8').split('\n') last_sent = None new_lines_after = list() - for node in self._query.captures(tree.root_node): + for node in self._query.captures(tree.root_node, start_point=start_point, end_point=end_point): # Check if we need some newlines after previous elements while len(new_lines_after) > 0: if node[0].start_point > new_lines_after[0]: diff --git a/textLSP/server.py b/textLSP/server.py index 520ee89..506a2eb 100644 --- a/textLSP/server.py +++ b/textLSP/server.py @@ -13,6 +13,7 @@ WORKSPACE_DID_CHANGE_CONFIGURATION, INITIALIZE, TEXT_DOCUMENT_COMPLETION, + SHUTDOWN, ) from lsprotocol.types import ( DidOpenTextDocumentParams, @@ -29,6 +30,7 @@ CompletionList, CompletionOptions, CompletionParams, + ShutdownRequest, ) from .workspace import TextLSPWorkspace from .utils import merge_dicts, get_textlsp_version @@ -45,7 +47,7 @@ def __init__(self, *args, **kwargs): @lsp_method(INITIALIZE) def lsp_initialize(self, params: InitializeParams) -> InitializeResult: result = super().lsp_initialize(params) - self.workspace = TextLSPWorkspace.workspace2textlspworkspace( + self._workspace = TextLSPWorkspace.workspace2textlspworkspace( self.workspace, self._server.analyser_handler, self._server.settings, @@ -120,6 +122,11 @@ def publish_stored_diagnostics(self, doc: Document): diagnostics.extend(lst) self.publish_diagnostics(doc.uri, diagnostics) + def shutdown(self): + logger.warning('TextLSP shutting down!') + self.analyser_handler.shutdown() + super().shutdown() + SERVER = TextLSPLanguageServer( name='textLSP', @@ -148,6 +155,11 @@ async def did_close(ls: TextLSPLanguageServer, params: DidCloseTextDocumentParam await ls.analyser_handler.did_close(params) +@SERVER.feature(SHUTDOWN) +def shutdown(ls: TextLSPLanguageServer, params: ShutdownRequest): + ls.shutdown() + + @SERVER.feature(WORKSPACE_DID_CHANGE_CONFIGURATION) def did_change_configuration(ls: TextLSPLanguageServer, params: DidChangeConfigurationParams): ls.update_settings(params.settings) diff --git a/textLSP/types.py b/textLSP/types.py index fa99ea2..96da06e 100644 --- a/textLSP/types.py +++ b/textLSP/types.py @@ -6,6 +6,7 @@ from typing import Optional, Any, List from dataclasses import dataclass +from sortedcontainers import SortedDict from lsprotocol.types import ( Position, @@ -16,6 +17,8 @@ WorkDoneProgressEnd, ) +from .utils import position_to_tuple + TEXT_PASSAGE_PATTERN = re.compile('[.?!] |\\n') LINE_PATTERN = re.compile('\\n') @@ -83,7 +86,7 @@ def add_interval_values( def add_interval(self, interval: OffsetPositionInterval): self.add_interval_values( interval.offset_interval.start, - interval.offset_interval.end, + interval.offset_interval.start + interval.offset_interval.length - 1, interval.position_range.start.line, interval.position_range.start.character, interval.position_range.end.line, @@ -196,7 +199,10 @@ def get_idx_at_position(self, position: Position, strict=True) -> int: if self._position_start_character[idx] <= position.character <= self._position_end_character[idx]: return idx - if position.character < self._position_start_character[idx]: + if ( + position.line < self._position_start_line[idx] or + position.character < self._position_start_character[idx] + ): return None if strict else idx return None if strict else min(idx+1, length-1) @@ -211,6 +217,88 @@ def get_interval_at_position(self, position: Position, strict=True) -> OffsetPos return self.get_interval(idx) +class PositionDict(): + + def __init__(self): + self._positions = SortedDict() + + def add(self, position: Position, item): + position = position_to_tuple(position) + self._positions[position] = item + + def get(self, position: Position): + position = position_to_tuple(position) + return self._positions[position] + + def pop(self, position: Position): + position = position_to_tuple(position) + return self._positions.popitem(position) + + def update(self, old_position: Position, new_position: Position = None, + new_value=None): + assert new_position is not None or new_value is not None, ' new_position' + ' or new_value should be specified.' + + old_position = position_to_tuple(old_position) + new_position = position_to_tuple(new_position) + if new_position is None: + self._positions[old_position] = new_value + return + + if new_value is None: + new_value = self._positions.popitem(old_position) + else: + del self._positions[old_position] + + self._positions[new_position] = new_value + + def remove(self, position: Position): + position = position_to_tuple(position) + del self._positions[position] + + def remove_from(self, position: Position, inclusive=True): + position = position_to_tuple(position) + num = 0 + for key in list(self._positions.irange( + minimum=position, + inclusive=(inclusive, False) + )): + del self._positions[key] + num += 1 + + return num + + def remove_between(self, range: Range, inclusive=(True, True)): + minimum = position_to_tuple(range.start) + maximum = position_to_tuple(range.end) + num = 0 + for key in list(self._positions.irange( + minimum=minimum, + maximum=maximum, + inclusive=inclusive, + )): + del self._positions[key] + num += 1 + + return num + + def irange(self, minimum: Position = None, maximum: Position = None, *args, + **kwargs): + if minimum is not None: + minimum = position_to_tuple(minimum) + if maximum is not None: + maximum = position_to_tuple(maximum) + + return self._positions.irange(minimum, maximum, *args, **kwargs) + + def irange_values(self, *args, **kwargs): + for key in self.irange(*args, **kwargs): + yield self._positions[key] + + def __iter__(self): + return iter(self._positions.values()) + + @enum.unique class TextLSPCodeActionKind(str, enum.Enum): AcceptSuggestion = CodeActionKind.QuickFix + '.accept_suggestion' diff --git a/textLSP/utils.py b/textLSP/utils.py index bca0186..63a4d93 100644 --- a/textLSP/utils.py +++ b/textLSP/utils.py @@ -1,13 +1,14 @@ import sys import importlib import inspect -import pkg_resources import re +from importlib.metadata import version from functools import wraps from threading import RLock from git import Repo from appdirs import user_cache_dir +from lsprotocol.types import Position def merge_dicts(dict1, dict2): @@ -71,7 +72,7 @@ def get_textlsp_name(): def get_textlsp_version(): - pkg_resources.require(get_textlsp_name())[0].version + return version(get_textlsp_name()) def get_user_cache(app_name=None): @@ -99,3 +100,30 @@ def batch_text(text: str, pattern: re.Pattern, max_size: int, min_size: int = 0) if sidx <= text_len: yield text[sidx:text_len] + + +def position_to_tuple(position: Position): + return (position.line, position.character) + + +def traverse_tree(tree): + cursor = tree.walk() + + reached_root = False + while reached_root: + yield cursor.node + + if cursor.goto_first_child(): + continue + + if cursor.goto_next_sibling(): + continue + + retracing = True + while retracing: + if not cursor.goto_parent(): + retracing = False + reached_root = True + + if cursor.goto_next_sibling(): + retracing = False diff --git a/textLSP/workspace.py b/textLSP/workspace.py index 8d5003a..9dddbf5 100644 --- a/textLSP/workspace.py +++ b/textLSP/workspace.py @@ -6,7 +6,7 @@ TextDocumentContentChangeEvent, VersionedTextDocumentIdentifier, ) -from pygls.workspace import Workspace, Document +from pygls.workspace import Workspace, TextDocument from .documents.document import DocumentTypeFactory from .analysers.handler import AnalyserHandler @@ -21,13 +21,13 @@ def __init__(self, analyser_handler: AnalyserHandler, settings: Dict, *args, **k self.analyser_handler = analyser_handler self.settings = settings - def _create_document( + def _create_text_document( self, doc_uri: str, source: Optional[str] = None, version: Optional[int] = None, language_id: Optional[str] = None, - ) -> Document: + ) -> TextDocument: return DocumentTypeFactory.get_document( doc_uri=doc_uri, config=self.settings, @@ -59,9 +59,11 @@ def update_settings(self, settings): self.settings = merge_dicts(self.settings, settings) - def update_document(self, - text_doc: VersionedTextDocumentIdentifier, - change: TextDocumentContentChangeEvent): - doc = self._docs[text_doc.uri] + def update_text_document( + self, + text_doc: VersionedTextDocumentIdentifier, + change: TextDocumentContentChangeEvent + ): + doc = self._text_documents[text_doc.uri] self.analyser_handler.update_document(doc, change) - super().update_document(text_doc, change) + super().update_text_document(text_doc, change)