Skip to content

Commit

Permalink
fix: incorrect newline handling in text diff
Browse files Browse the repository at this point in the history
  • Loading branch information
hangyav committed May 3, 2024
1 parent 9baf7ed commit 37a0f43
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 17 deletions.
52 changes: 43 additions & 9 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def test_batch_text(src, exp, max, min):
types.TokenDiff(
types.TokenDiff.INSERT,
'',
'example',
5,
' example',
4,
0
),
],
Expand All @@ -128,10 +128,10 @@ def test_batch_text(src, exp, max, min):
[
types.TokenDiff(
types.TokenDiff.DELETE,
'example',
' example',
'',
5,
7
4,
8
),
],
),
Expand All @@ -141,15 +141,15 @@ def test_batch_text(src, exp, max, min):
[
types.TokenDiff(
types.TokenDiff.DELETE,
'example',
' example',
'',
5,
7
4,
8
),
types.TokenDiff(
types.TokenDiff.INSERT,
'',
'good',
'good ', # XXX: the position of space seems to be a bit inconsistent, before or after
18,
0
),
Expand All @@ -167,6 +167,40 @@ def test_batch_text(src, exp, max, min):
'This is a sentence of 47 characters. ',
[],
),
(
'This is a sentence.\n'
'\n'
'This is a new paragraph.\n',
'This is a sentence.\n'
'\n'
'This is the new paragraph.\n',
[
types.TokenDiff(
types.TokenDiff.REPLACE,
'a',
'the',
29,
1
),
],
),
(
'This is a sentence.\n'
'\n'
'This is a new paragraph.\n',
'This is a sentence.\n'
'\n'
'That this is a new paragraph.\n',
[
types.TokenDiff(
types.TokenDiff.REPLACE,
'This',
'That this',
21,
4
),
],
),
])
def test_token_diff(s1, s2, exp):
res = types.TokenDiff.token_level_diff(s1, s2)
Expand Down
13 changes: 11 additions & 2 deletions textLSP/analysers/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ def _edit(self, text) -> List[TokenDiff]:
model=self.config.get(self.CONFIGURATION_MODEL, self.SETTINGS_DEFAULT_MODEL),
temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE),
)
logger.debug(f"Response: {res}")

if len(res.choices) > 0:
return TokenDiff.token_level_diff(text, res.choices[0].message.content.strip())
# the API escapes special characters such as newlines
res_text = res.choices[0].message.content.strip().encode().decode("unicode_escape")
return TokenDiff.token_level_diff(text, res_text)

return []

Expand All @@ -98,16 +102,21 @@ def _generate(self, text) -> Optional[str]:
temperature=self.config.get(self.CONFIGURATION_TEMPERATURE, self.SETTINGS_DEFAULT_TEMPERATURE),
max_tokens=self.config.get(self.CONFIGURATION_MAX_TOKEN, self.SETTINGS_DEFAULT_MAX_TOKEN),
)
logger.debug(f"Response: {res}")

if len(res.choices) > 0:
return res.choices[0].message.content.strip()
# the API escapes special characters such as newlines
return res.choices[0].message.content.strip().encode().decode("unicode_escape")

return None

def _analyse(self, text, doc, offset=0) -> Tuple[List[Diagnostic], List[CodeAction]]:
diagnostics = list()
code_actions = list()

# we don not want trailing whitespace
text = text.rstrip()

try:
edits = self._edit(text)
except APIError as e:
Expand Down
16 changes: 10 additions & 6 deletions textLSP/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,19 +317,23 @@ class TokenDiff():
offset: int
length: int

@staticmethod
def _split(text):
return [item for item in re.split("(\s)", text) if item != ""]

@staticmethod
def token_level_diff(text1, text2) -> List:
tokens1 = text1.split()
tokens2 = text2.split()
tokens1 = TokenDiff._split(text1)
tokens2 = TokenDiff._split(text2)
diff = difflib.SequenceMatcher(None, tokens1, tokens2)

return [
TokenDiff(
type=item[0],
old_token=' '.join(tokens1[item[1]:item[2]]),
new_token=' '.join(tokens2[item[3]:item[4]]),
offset=0 if item[1] == 0 else len(' '.join(tokens1[:item[1]]))+1,
length=len(' '.join(tokens1[item[1]:item[2]])),
old_token=''.join(tokens1[item[1]:item[2]]),
new_token=''.join(tokens2[item[3]:item[4]]),
offset=0 if item[1] == 0 else len(''.join(tokens1[:item[1]])),
length=len(''.join(tokens1[item[1]:item[2]])),
)
for item in diff.get_opcodes()
if item[0] != 'equal'
Expand Down

0 comments on commit 37a0f43

Please sign in to comment.