diff --git a/tests/utils_test.py b/tests/utils_test.py index 4ee5d03..58a58e7 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -116,8 +116,8 @@ def test_batch_text(src, exp, max, min): types.TokenDiff( types.TokenDiff.INSERT, '', - 'example', - 5, + ' example', + 4, 0 ), ], @@ -128,10 +128,10 @@ def test_batch_text(src, exp, max, min): [ types.TokenDiff( types.TokenDiff.DELETE, - 'example', + ' example', '', - 5, - 7 + 4, + 8 ), ], ), @@ -141,15 +141,15 @@ def test_batch_text(src, exp, max, min): [ types.TokenDiff( types.TokenDiff.DELETE, - 'example', + ' example', '', - 5, - 7 + 4, + 8 ), types.TokenDiff( types.TokenDiff.INSERT, '', - 'good', + 'good ', # XXX: the position of space seems to be a bit inconsistent, before or after 18, 0 ), @@ -167,6 +167,40 @@ def test_batch_text(src, exp, max, min): 'This is a sentence of 47 characters. ', [], ), + ( + 'This is a sentence.\n' + '\n' + 'This is a new paragraph.\n', + 'This is a sentence.\n' + '\n' + 'This is the new paragraph.\n', + [ + types.TokenDiff( + types.TokenDiff.REPLACE, + 'a', + 'the', + 29, + 1 + ), + ], + ), + ( + 'This is a sentence.\n' + '\n' + 'This is a new paragraph.\n', + 'This is a sentence.\n' + '\n' + 'That this is a new paragraph.\n', + [ + types.TokenDiff( + types.TokenDiff.REPLACE, + 'This', + 'That this', + 21, + 4 + ), + ], + ), ]) def test_token_diff(s1, s2, exp): res = types.TokenDiff.token_level_diff(s1, s2) diff --git a/textLSP/analysers/openai/openai.py b/textLSP/analysers/openai/openai.py index 4f11126..93fbf10 100644 --- a/textLSP/analysers/openai/openai.py +++ b/textLSP/analysers/openai/openai.py @@ -85,8 +85,12 @@ def _edit(self, text) -> List[TokenDiff]: model=self.config.get(self.CONFIGURATION_MODEL, self.SETTINGS_DEFAULT_MODEL), temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE), ) + logger.debug(f"Response: {res}") + if len(res.choices) > 0: - return TokenDiff.token_level_diff(text, res.choices[0].message.content.strip()) + # the API escapes special characters such as newlines + res_text = res.choices[0].message.content.strip().encode().decode("unicode_escape") + return TokenDiff.token_level_diff(text, res_text) return [] @@ -98,9 +102,11 @@ def _generate(self, text) -> Optional[str]: temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE), max_tokens=self.config.get(self.CONFIGURATION_MAX_TOKEN, self.SETTINGS_DEFAULT_MAX_TOKEN), ) + logger.debug(f"Response: {res}") if len(res.choices) > 0: - return res.choices[0].message.content.strip() + # the API escapes special characters such as newlines + return res.choices[0].message.content.strip().encode().decode("unicode_escape") return None @@ -108,6 +114,9 @@ def _analyse(self, text, doc, offset=0) -> Tuple[List[Diagnostic], List[CodeActi diagnostics = list() code_actions = list() + # we don not want trailing whitespace + text = text.rstrip() + try: edits = self._edit(text) except APIError as e: diff --git a/textLSP/types.py b/textLSP/types.py index 96da06e..5e5e98a 100644 --- a/textLSP/types.py +++ b/textLSP/types.py @@ -317,19 +317,23 @@ class TokenDiff(): offset: int length: int + @staticmethod + def _split(text): + return [item for item in re.split("(\s)", text) if item != ""] + @staticmethod def token_level_diff(text1, text2) -> List: - tokens1 = text1.split() - tokens2 = text2.split() + tokens1 = TokenDiff._split(text1) + tokens2 = TokenDiff._split(text2) diff = difflib.SequenceMatcher(None, tokens1, tokens2) return [ TokenDiff( type=item[0], - old_token=' '.join(tokens1[item[1]:item[2]]), - new_token=' '.join(tokens2[item[3]:item[4]]), - offset=0 if item[1] == 0 else len(' '.join(tokens1[:item[1]]))+1, - length=len(' '.join(tokens1[item[1]:item[2]])), + old_token=''.join(tokens1[item[1]:item[2]]), + new_token=''.join(tokens2[item[3]:item[4]]), + offset=0 if item[1] == 0 else len(''.join(tokens1[:item[1]])), + length=len(''.join(tokens1[item[1]:item[2]])), ) for item in diff.get_opcodes() if item[0] != 'equal'