diff --git a/changelog.md b/changelog.md index 1ce607da..c3309963 100644 --- a/changelog.md +++ b/changelog.md @@ -8,6 +8,7 @@ Features * Place exact-leading completions first. * Allow history file location to be configured. * Make destructive-warning keywords configurable. +* Smarter fuzzy completion matches. Bug Fixes diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 1a051cd3..264331ac 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -931,7 +931,7 @@ def reset_completions(self) -> None: @staticmethod def find_matches( - text: str, + orig_text: str, collection: Collection, start_only: bool = False, fuzzy: bool = True, @@ -950,24 +950,53 @@ def find_matches( yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ - last = last_word(text, include="most_punctuations") + last = last_word(orig_text, include="most_punctuations") text = last.lower() + # unicode support not possible without adding the regex dependency + case_change_pat = re.compile("(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])") completions = [] if fuzzy: - regex = ".*?".join(map(re.escape, text)) + regex = ".{0,3}?".join(map(re.escape, text)) pat = re.compile(f'({regex})') + under_words_text = [x for x in text.split('_') if x] + case_words_text = re.split(case_change_pat, text) + for item in collection: r = pat.search(item.lower()) if r: - completions.append((len(r.group()), r.start(), item)) + completions.append(item) + continue + + under_words_item = [x for x in item.lower().split('_') if x] + occurrences = 0 + for elt_word in under_words_text: + for elt_item in under_words_item: + if elt_item.startswith(elt_word): + occurrences += 1 + break + if occurrences >= len(under_words_text): + completions.append(item) + continue + + case_words_item = re.split(case_change_pat, item.lower()) + occurrences = 0 + for elt_word in case_words_text: + for elt_item in case_words_item: + if elt_item.startswith(elt_word): + occurrences += 1 + break + if occurrences >= len(case_words_text): + completions.append(item) + continue + else: match_end_limit = len(text) if start_only else None for item in collection: match_point = item.lower().find(text, 0, match_end_limit) if match_point >= 0: - completions.append((len(text), match_point, item)) + completions.append(item) if casing == "auto": casing = "lower" if last and last[-1].islower() else "upper" @@ -977,14 +1006,14 @@ def apply_case(kw: str) -> str: return kw.upper() return kw.lower() - def exact_leading_key(item: tuple[int, int, str], text): - if text and item[2].lower().startswith(text): - return -1000 + len(item[2]) + def exact_leading_key(item: str, text: str): + if text and item.lower().startswith(text): + return -1000 + len(item) return 0 completions = sorted(completions, key=lambda item: exact_leading_key(item, text)) - return (Completion(z if casing is None else apply_case(z), -len(text)) for x, y, z in completions) + return (Completion(x if casing is None else apply_case(x), -len(text)) for x in completions) def get_completions( self, diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 7e213e70..c6b0953c 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -13,6 +13,11 @@ "orders": ["id", "ordered_date", "status"], "select": ["id", "insert", "ABC"], "réveillé": ["id", "insert", "ABC"], + "time_zone": ["Time_zone_id"], + "time_zone_leap_second": ["Time_zone_id"], + "time_zone_name": ["Time_zone_id"], + "time_zone_transition": ["Time_zone_id"], + "time_zone_transition_type": ["Time_zone_id"], } @@ -66,51 +71,12 @@ def test_select_keyword_completion(completer, complete_event): assert list(result) == [ Completion(text='SELECT', start_position=-3), Completion(text='SERIAL', start_position=-3), - Completion(text='GET_MASTER_PUBLIC_KEY', start_position=-3), - Completion(text='GET_SOURCE_PUBLIC_KEY', start_position=-3), - Completion(text='MASTER_COMPRESSION_ALGORITHMS', start_position=-3), - Completion(text='MASTER_DELAY', start_position=-3), Completion(text='MASTER_LOG_FILE', start_position=-3), Completion(text='MASTER_LOG_POS', start_position=-3), - Completion(text='MASTER_PUBLIC_KEY_PATH', start_position=-3), - Completion(text='MASTER_SSL', start_position=-3), - Completion(text='MASTER_SSL_CA', start_position=-3), - Completion(text='MASTER_SSL_CAPATH', start_position=-3), - Completion(text='MASTER_SSL_CERT', start_position=-3), - Completion(text='MASTER_SSL_CIPHER', start_position=-3), - Completion(text='MASTER_SSL_CRL', start_position=-3), - Completion(text='MASTER_SSL_CRLPATH', start_position=-3), - Completion(text='MASTER_SSL_KEY', start_position=-3), - Completion(text='MASTER_SSL_VERIFY_SERVER_CERT', start_position=-3), Completion(text='MASTER_TLS_CIPHERSUITES', start_position=-3), Completion(text='MASTER_TLS_VERSION', start_position=-3), - Completion(text='MASTER_ZSTD_COMPRESSION_LEVEL', start_position=-3), Completion(text='SCHEDULE', start_position=-3), - Completion(text='SECONDARY_LOAD', start_position=-3), - Completion(text='SECONDARY_UNLOAD', start_position=-3), Completion(text='SERIALIZABLE', start_position=-3), - Completion(text='SOURCE_COMPRESSION_ALGORITHMS', start_position=-3), - Completion(text='SOURCE_CONNECTION_AUTO_FAILOVER', start_position=-3), - Completion(text='SOURCE_DELAY', start_position=-3), - Completion(text='SOURCE_LOG_FILE', start_position=-3), - Completion(text='SOURCE_LOG_POS', start_position=-3), - Completion(text='SOURCE_PUBLIC_KEY_PATH', start_position=-3), - Completion(text='SOURCE_SSL', start_position=-3), - Completion(text='SOURCE_SSL_CA', start_position=-3), - Completion(text='SOURCE_SSL_CAPATH', start_position=-3), - Completion(text='SOURCE_SSL_CERT', start_position=-3), - Completion(text='SOURCE_SSL_CIPHER', start_position=-3), - Completion(text='SOURCE_SSL_CRL', start_position=-3), - Completion(text='SOURCE_SSL_CRLPATH', start_position=-3), - Completion(text='SOURCE_SSL_KEY', start_position=-3), - Completion(text='SOURCE_SSL_VERIFY_SERVER_CERT', start_position=-3), - Completion(text='SOURCE_TLS_CIPHERSUITES', start_position=-3), - Completion(text='SOURCE_TLS_VERSION', start_position=-3), - Completion(text='SOURCE_ZSTD_COMPRESSION_LEVEL', start_position=-3), - Completion(text='SQL_BIG_RESULT', start_position=-3), - Completion(text='SQL_BUFFER_RESULT', start_position=-3), - Completion(text='SQL_SMALL_RESULT', start_position=-3), - Completion(text='STATS_AUTO_RECALC', start_position=-3), ] @@ -130,6 +96,11 @@ def test_table_completion(completer, complete_event): Completion(text="orders", start_position=0), Completion(text="`select`", start_position=0), Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), ] @@ -191,7 +162,6 @@ def test_function_name_completion(completer, complete_event): Completion(text='DECIMAL', start_position=-2), Completion(text='SMALLINT', start_position=-2), Completion(text='TIMESTAMP', start_position=-2), - Completion(text='ASSIGN_GTIDS_TO_ANONYMOUS_TRANSACTIONS', start_position=-2), Completion(text='COLUMN_FORMAT', start_position=-2), Completion(text='COLUMN_NAME', start_position=-2), Completion(text='COMPACT', start_position=-2), @@ -211,10 +181,7 @@ def test_function_name_completion(completer, complete_event): Completion(text='SCHEMA', start_position=-2), Completion(text='SCHEMA_NAME', start_position=-2), Completion(text='SCHEMAS', start_position=-2), - Completion(text='SOURCE_COMPRESSION_ALGORITHMS', start_position=-2), - Completion(text='SQL_AFTER_MTS_GAPS', start_position=-2), Completion(text='SQL_SMALL_RESULT', start_position=-2), - Completion(text='STATS_SAMPLE_PAGES', start_position=-2), Completion(text='TEMPORARY', start_position=-2), Completion(text='TEMPTABLE', start_position=-2), Completion(text='TERMINATED', start_position=-2), @@ -428,6 +395,33 @@ def test_table_names_after_from(completer, complete_event): Completion(text="orders", start_position=0), Completion(text="`select`", start_position=0), Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + ] + + +def test_table_names_leading_partial(completer, complete_event): + text = "SELECT * FROM time_zone" + position = len("SELECT * FROM time_zone") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="time_zone", start_position=-9), + Completion(text="time_zone_name", start_position=-9), + Completion(text="time_zone_transition", start_position=-9), + Completion(text="time_zone_leap_second", start_position=-9), + Completion(text="time_zone_transition_type", start_position=-9), + ] + + +def test_table_names_inter_partial(completer, complete_event): + text = "SELECT * FROM time_leap" + position = len("SELECT * FROM time_leap") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="time_zone_leap_second", start_position=-9), ]