diff --git a/packages/kestrel_core/pyproject.toml b/packages/kestrel_core/pyproject.toml index 7af9ecd4..16769686 100644 --- a/packages/kestrel_core/pyproject.toml +++ b/packages/kestrel_core/pyproject.toml @@ -30,14 +30,14 @@ classifiers = [ ] dependencies = [ - "typeguard>=4.1.5", + "typeguard>=4.3.0", "pyyaml>=6.0.1", - "lark>=1.1.7", - "pandas>=2.0.3", - "pyarrow>=13.0.0", - "mashumaro>=3.10", - "networkx>=3.1", # networkx==3.2.1 only for Python>=3.9 - "SQLAlchemy>=2.0.23", + "lark>=1.1.9", + "pandas>=2.0.3", # any higher version drops Python 3.8 support + "pyarrow>=17.0.0", + "mashumaro>=3.13.1", + "networkx>=3.1", # any higher version drops Python 3.8 support + "SQLAlchemy>=2.0.31", ] [project.optional-dependencies] diff --git a/packages/kestrel_core/src/kestrel/analytics/interface.py b/packages/kestrel_core/src/kestrel/analytics/interface.py index 27bc3fa0..fe3cf578 100644 --- a/packages/kestrel_core/src/kestrel/analytics/interface.py +++ b/packages/kestrel_core/src/kestrel/analytics/interface.py @@ -107,7 +107,7 @@ def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5) import traceback from contextlib import AbstractContextManager from importlib.util import module_from_spec, spec_from_file_location -from typing import Any, Iterable, Mapping, MutableMapping, Optional +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional from uuid import UUID from kestrel.analytics.config import get_profile, load_profiles @@ -119,7 +119,7 @@ def my_analytic(df: pd.DataFrame, x: int = 0, y: float = 0.5) InvalidAnalyticsInterfaceImplementation, InvalidAnalyticsOutput, ) -from kestrel.interface import AbstractInterface +from kestrel.interface import AnalyticsInterface from kestrel.ir.graph import IRGraphEvaluable from kestrel.ir.instructions import ( Analytic, @@ -157,7 +157,7 @@ def run(self, config: dict) -> DataFrame: return df -class PythonAnalyticsInterface(AbstractInterface): +class PythonAnalyticsInterface(AnalyticsInterface): def __init__( self, serialized_cache_catalog: Optional[str] = None, @@ -171,6 +171,9 @@ def __init__( def schemes() -> Iterable[str]: return ["python"] + def get_datasources(self) -> List[str]: + return list(self.config) + def get_storage_of_datasource(self, datasource: str) -> str: return "local" diff --git a/packages/kestrel_core/src/kestrel/cache/base.py b/packages/kestrel_core/src/kestrel/cache/base.py index ecb6dc31..9e818373 100644 --- a/packages/kestrel_core/src/kestrel/cache/base.py +++ b/packages/kestrel_core/src/kestrel/cache/base.py @@ -1,26 +1,24 @@ from __future__ import annotations from abc import abstractmethod -from typing import Iterable, MutableMapping +from typing import Iterable, List, MutableMapping from uuid import UUID from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER, CACHE_STORAGE_IDENTIFIER -from kestrel.interface import AbstractInterface +from kestrel.interface import DatasourceInterface from pandas import DataFrame -class AbstractCache(AbstractInterface, MutableMapping): - """Base class for Kestrel cache - - Additional @abstractmethod from AbstractInterface: - - - evaluate_graph() - """ +class AbstractCache(DatasourceInterface, MutableMapping): + """Base class for Kestrel cache""" @staticmethod def schemes() -> Iterable[str]: return [CACHE_INTERFACE_IDENTIFIER] + def get_datasources(self) -> List[str]: + return [] + def get_storage_of_datasource(self, datasource: str) -> str: return CACHE_STORAGE_IDENTIFIER diff --git a/packages/kestrel_core/src/kestrel/config/relations/event.csv b/packages/kestrel_core/src/kestrel/config/relations/event.csv index 854c2567..fd2de0ef 100644 --- a/packages/kestrel_core/src/kestrel/config/relations/event.csv +++ b/packages/kestrel_core/src/kestrel/config/relations/event.csv @@ -7,3 +7,4 @@ network_endpoint,RESPONDED,dst_endpoint reg_key,RESPONDED,reg_key reg_value,RESPONDED,reg_value user,ORIGINATED,actor.user +endpoint,RESPONDED,device diff --git a/packages/kestrel_core/src/kestrel/config/utils.py b/packages/kestrel_core/src/kestrel/config/utils.py index 28de1e1b..29477237 100644 --- a/packages/kestrel_core/src/kestrel/config/utils.py +++ b/packages/kestrel_core/src/kestrel/config/utils.py @@ -1,7 +1,7 @@ import logging import os from pathlib import Path -from typing import Mapping, Union +from typing import List, Mapping, Union import pandas import yaml @@ -17,6 +17,9 @@ CONFIG_PATH_DEFAULT = CONFIG_DIR_DEFAULT / "kestrel.yaml" CONFIG_PATH_ENV_VAR = "KESTREL_CONFIG" # override CONFIG_PATH_DEFAULT if provided +relations = [] +entity_types = [] + _logger = logging.getLogger(__name__) @@ -105,3 +108,33 @@ def load_relation_configs(table_name: str) -> pandas.DataFrame: except: raise InvalidKestrelRelationTable(filepaths[0]) return table + + +@typechecked +def get_all_relations() -> List[str]: + global relations + if not relations: + _relations = set() + for filepath in list_folder_files( + "kestrel.config", "relations", extension="csv" + ): + table = pandas.read_csv(filepath) + _relations |= set(table["Relation"].to_list()) + relations = list(_relations) + return relations + + +@typechecked +def get_all_entity_types() -> List[str]: + global entity_types + if not entity_types: + _entity_types = {"event"} + for filepath in list_folder_files( + "kestrel.config", "relations", extension="csv" + ): + table = pandas.read_csv(filepath) + for typecol in ("InputType", "OutputType"): + if typecol in table: + _entity_types |= set(table[typecol].to_list()) + entity_types = list(_entity_types) + return entity_types diff --git a/packages/kestrel_core/src/kestrel/frontend/completor.py b/packages/kestrel_core/src/kestrel/frontend/completor.py new file mode 100644 index 00000000..bfe575c2 --- /dev/null +++ b/packages/kestrel_core/src/kestrel/frontend/completor.py @@ -0,0 +1,229 @@ +import logging +import re +from datetime import datetime +from typing import Callable, Iterable, List, Tuple + +import lark +from kestrel.config.utils import get_all_entity_types, get_all_relations +from kestrel.frontend.parser import get_keywords, parse_without_transform +from kestrel.interface import InterfaceManager +from kestrel.interface.base import AnalyticsInterface, DatasourceInterface +from kestrel.utils import timefmt +from typeguard import typechecked + +_logger = logging.getLogger(__name__) + +ISO_TS_RE = re.compile(r"\d{4}(-\d{2}(-\d{2}(T\d{2}(:\d{2}(:\d{2}Z?)?)?)?)?)?") + + +@typechecked +def do_complete( + code: str, + cursor_pos: int, + itf_manager: InterfaceManager, + varnames: List[str], +) -> Iterable[str]: + _logger.debug("auto_complete function starts...") + + # do not care code after cursor position in the current version + line = code[:cursor_pos] + _logger.debug(f"line to auto-complete: {line}") + + # if the last char is a space, `line_to_parse = line` + # otherwise, exclude the last token in `line_to_parse` to prompt the expected token + last_word_prefix, line_to_parse = _split_last_token(line) + _logger.debug(f"last word prefix: {last_word_prefix}") + _logger.debug(f"line to parse: {line_to_parse}") + + try: + ast = parse_without_transform(line_to_parse) + + except lark.exceptions.UnexpectedCharacters as e: + suggestions = ["% illegal char in huntflow %"] + _logger.debug(f"illegal character in `line_to_parse`, err: {str(e)}") + + except lark.exceptions.UnexpectedEOF as e: + suggestions = ["% EOF auto-complete internal error, report to developers %"] + # https://github.com/lark-parser/lark/issues/791 + # Lark updates may break this, check if it is the case + # no need to use KestrelInternalError; not to break huntflow execution + _logger.debug(f"Lark with LALR should not give this error: {str(e)}") + + except lark.exceptions.UnexpectedToken as e: + error_token = e.token + expected_tokens = e.accepts or e.expected + expected_values = [] + keywords = set(get_keywords(False)) + relations = get_all_relations() + entity_types = get_all_entity_types() + for token in expected_tokens: + _logger.debug("token: %s", token) + if token == "VARIABLE": + expected_values.extend(varnames) + elif token == "ISOTIMESTAMP": + if last_word_prefix: + if last_word_prefix.startswith("t'"): + ts_prefix = last_word_prefix[2:] + ts_complete = _do_complete_timestamp(ts_prefix) + exp_value = "t'" + ts_complete + "'" + else: + exp_value = _do_complete_timestamp(last_word_prefix) + else: + exp_value = timefmt(datetime.now()) + expected_values.append(exp_value) + elif token == "DATASRC_SIMPLE": + _logger.debug("auto-complete data source") + expected_values.extend( + _do_complete_interface( + last_word_prefix, + itf_manager.schemes(DatasourceInterface), + itf_manager.list_datasources_from_scheme, + ) + ) + elif token == "ANALYTICS_SIMPLE": + _logger.debug("auto-complete analytics") + expected_values.extend( + _do_complete_interface( + last_word_prefix, + itf_manager.schemes(AnalyticsInterface), + itf_manager.list_datasources_from_scheme, + ) + ) + elif token == "ENTITY_TYPE": + expected_values.extend(entity_types) + elif token == "PROJECT_FIELD": # not precise + expected_values.extend(entity_types) + elif token == "RELATION": + expected_values.extend(relations) + elif token == "REVERSED": + expected_values.append("BY") + elif token == "EQUAL": + expected_values.append("=") + elif token == "ATTRIBUTE": + # TODO: attribute completion + # https://github.com/opencybersecurityalliance/kestrel-lang/issues/79 + _logger.debug(f"TODO: ATTRIBUTE COMPLETION") + elif token == "ENTITY_ATTRIBUTE_PATH": + # TODO: attribute completion + # https://github.com/opencybersecurityalliance/kestrel-lang/issues/79 + _logger.debug(f"TODO: ATTRIBUTE COMPLETION") + elif token == "COMMA": + expected_values.append(",") + elif token in keywords: + if last_word_prefix and last_word_prefix.islower(): + token = token.lower() + expected_values.append(token) + else: + # token not handled + continue + expected_values = sorted(expected_values) + _logger.debug(f"expected values: {expected_values}") + + # turn `expected_values` into `suggestions` + _p = last_word_prefix + _e = expected_values + suggestions = [t[len(_p) :] for t in _e if t.startswith(_p)] if _p else _e + suggestions = [x for x in set(suggestions) if x] + _logger.debug(f"suggestions: {suggestions}") + + else: + suggestions = [] + + # handle optional components + if ast: + stmt = ast.children[-1].children[0] + cmd = ( + stmt.children[1].data.value + if stmt.data.value == "assignment" + else stmt.data.value + ) + if cmd == "disp": + for clause in ("attr_clause", "limit_clause", "offset_clause"): + if not list(stmt.find_data(clause)): + suggestions.append("ATTR") + elif cmd in ("expression", "find") and not list( + stmt.find_data("where_clause") + ): + suggestions.append("WHERE") + elif cmd in ("get", "find") and not list(stmt.find_data("timerange")): + suggestions.append("START") + elif cmd == "apply" and not list(stmt.find_data("args")): + suggestions.append("WITH") + + suggestions = [x for x in set(suggestions) if x] + _p = last_word_prefix + suggestions = ( + [t[len(_p) :] for t in suggestions if t.startswith(_p)] + if _p + else suggestions + ) + _logger.debug(f"suggestions from optional components: {suggestions}") + + return suggestions + + +@typechecked +def _end_with_blank_or_comma(s: str) -> bool: + return s[-1] in [" ", "\t", "\n", "\r", "\f", "\v", ","] if s else True + + +@typechecked +def _split_last_token(s: str) -> Tuple[str, str]: + last = "" + if not _end_with_blank_or_comma(s): + while not _end_with_blank_or_comma(s): + last = s[-1] + last + s = s[:-1] + return last, s + + +@typechecked +def _do_complete_timestamp(ts_prefix: str) -> str: + valid_ts_formats = [ + "%Y", + "%Y-%m", + "%Y-%m-%d", + "%Y-%m-%dT%H", + "%Y-%m-%dT%H:%M", + "%Y-%m-%dT%H:%M:%S", + ] + matched = ISO_TS_RE.match(ts_prefix) + if matched: + for ts_format in valid_ts_formats: + _logger.debug(f"Match timestamp {ts_prefix} with format {ts_format}") + try: + ts = datetime.strptime(matched.group(), ts_format) + except: + _logger.debug(f"Timestamp match failed") + else: + ts_complete = timefmt(ts) + _logger.debug(f"Timestamp completed: {ts_complete}") + break + else: + ts_complete = "% TS auto-complete internal error, report to developers %" + # no need to use KestrelInternalError; not to break huntflow execution + _logger.debug( + f"TS auto-complete internal error: `valid_ts_formats` is incomplete" + ) + else: + ts_complete = "% illegal ISO 8601 timestamp prefix %" + _logger.debug(f"illegal ISO 8601 timestamp prefix: {ts_prefix}") + return ts_complete + + +@typechecked +def _do_complete_interface( + last_word_prefix: str, + schemes: Iterable[str], + list_names_from_scheme: Callable, +) -> Iterable[str]: + if last_word_prefix and "://" in last_word_prefix: + scheme, _ = last_word_prefix.split("://") + if scheme in schemes: + names = list_names_from_scheme(scheme) + paths = [scheme + "://" + name for name in names] + _logger.debug(f"auto-complete interface {scheme}: {paths}") + expected_values = paths + else: + expected_values = [scheme + "://" for scheme in schemes] + return expected_values diff --git a/packages/kestrel_core/src/kestrel/frontend/parser.py b/packages/kestrel_core/src/kestrel/frontend/parser.py index df829494..5cc0ce7c 100644 --- a/packages/kestrel_core/src/kestrel/frontend/parser.py +++ b/packages/kestrel_core/src/kestrel/frontend/parser.py @@ -4,14 +4,14 @@ from itertools import chain from typing import Iterable +import lark import yaml -from kestrel.config.utils import load_relation_configs +from kestrel.config.utils import get_all_relations, load_relation_configs from kestrel.frontend.compile import _KestrelT from kestrel.ir.graph import IRGraph from kestrel.ir.instructions import Return from kestrel.mapping.data_model import reverse_mapping from kestrel.utils import list_folder_files, load_data_file -from lark import Lark from pandas import DataFrame from typeguard import typechecked @@ -51,13 +51,12 @@ def get_frontend_mapping(submodule: str, do_reverse_mapping: bool = False) -> di @typechecked -def get_keywords(): - # TODO: this Kestrel1 code needs to be updated +def get_keywords(including_relations: bool = True): grammar = load_data_file("kestrel.frontend", "kestrel.lark") - parser = Lark(grammar, parser="lalr") + parser = lark.Lark(grammar, parser="lalr") alphabet_patterns = filter(lambda x: x.pattern.value.isalnum(), parser.terminals) - # keywords = [x.pattern.value for x in alphabet_patterns] + all_relations - keywords = [x.pattern.value for x in alphabet_patterns] + all_relations = get_all_relations() + keywords = [x.pattern.value for x in alphabet_patterns] + all_relations keywords_lower = map(lambda x: x.lower(), keywords) keywords_upper = map(lambda x: x.upper(), keywords) keywords_comprehensive = list(chain(keywords_lower, keywords_upper)) @@ -78,7 +77,7 @@ def parse_kestrel_and_update_irgraph( Returns: List of Return instructions in the current code block """ - lp = Lark( + lp = lark.Lark( load_data_file("kestrel.frontend", "kestrel.lark"), parser="lalr", transformer=_KestrelT( @@ -91,3 +90,15 @@ def parse_kestrel_and_update_irgraph( ), ) return lp.parse(stmts) + + +@typechecked +def parse_without_transform( + stmts: str, +) -> lark.tree.Tree: + """Parse Kestrel code block and not transform; for syntax error check""" + lp = lark.Lark( + load_data_file("kestrel.frontend", "kestrel.lark"), + parser="lalr", + ) + return lp.parse(stmts) diff --git a/packages/kestrel_core/src/kestrel/interface/__init__.py b/packages/kestrel_core/src/kestrel/interface/__init__.py index 3c4b25e5..39e6c843 100644 --- a/packages/kestrel_core/src/kestrel/interface/__init__.py +++ b/packages/kestrel_core/src/kestrel/interface/__init__.py @@ -1,2 +1,6 @@ -from kestrel.interface.base import AbstractInterface +from kestrel.interface.base import ( + AbstractInterface, + AnalyticsInterface, + DatasourceInterface, +) from kestrel.interface.manager import InterfaceManager diff --git a/packages/kestrel_core/src/kestrel/interface/base.py b/packages/kestrel_core/src/kestrel/interface/base.py index e6a15a31..3300ff96 100644 --- a/packages/kestrel_core/src/kestrel/interface/base.py +++ b/packages/kestrel_core/src/kestrel/interface/base.py @@ -1,6 +1,6 @@ import json from abc import ABC, abstractmethod -from typing import Any, Iterable, Mapping, MutableMapping, Optional +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional from uuid import UUID from kestrel.display import GraphletExplanation @@ -55,6 +55,11 @@ def schemes() -> Iterable[str]: """ ... + @abstractmethod + def get_datasources(self) -> List[str]: + """Get the list of datasource names registered at this interface""" + ... + @abstractmethod def get_storage_of_datasource(self, datasource: str) -> str: """Get the storage name of a given datasource""" @@ -130,3 +135,9 @@ def explain_graph( def cache_catalog_to_json(self) -> str: """Serialize the cache catalog to a JSON string""" return json.dumps(self.cache_catalog) + + +class DatasourceInterface(AbstractInterface): ... + + +class AnalyticsInterface(AbstractInterface): ... diff --git a/packages/kestrel_core/src/kestrel/interface/manager.py b/packages/kestrel_core/src/kestrel/interface/manager.py index adce331c..ba991723 100644 --- a/packages/kestrel_core/src/kestrel/interface/manager.py +++ b/packages/kestrel_core/src/kestrel/interface/manager.py @@ -44,7 +44,10 @@ def __getitem__(self, scheme: str) -> AbstractInterface: raise InterfaceNotFound(f"no interface loaded for scheme {scheme}") def __iter__(self) -> Iterable[str]: - return itertools.chain(*[i.schemes() for i in self.interfaces]) + return filter( + lambda x: x != CACHE_INTERFACE_IDENTIFIER, + itertools.chain(*[i.schemes() for i in self.interfaces]), + ) def __len__(self) -> int: return sum(1 for _ in iter(self)) @@ -64,6 +67,17 @@ def del_cache(self): self.interfaces.remove(cache) del cache + def schemes(self, interface_type: type) -> Iterable[str]: + return filter( + lambda x: x != CACHE_INTERFACE_IDENTIFIER, + itertools.chain( + *[i.schemes() for i in self.interfaces if isinstance(i, interface_type)] + ), + ) + + def list_datasources_from_scheme(self, scheme: str) -> Iterable[str]: + return self[scheme].get_datasources() + def _load_interface_classes(): interface_clss = [] @@ -95,7 +109,7 @@ def _list_interface_pkg_names(): def _is_class(cls): - return lambda obj: inspect.isclass(obj) and obj.__bases__[0] == cls + return lambda obj: inspect.isclass(obj) and issubclass(obj, cls) @typechecked diff --git a/packages/kestrel_core/src/kestrel/session.py b/packages/kestrel_core/src/kestrel/session.py index 48b8c69e..8c11ce99 100644 --- a/packages/kestrel_core/src/kestrel/session.py +++ b/packages/kestrel_core/src/kestrel/session.py @@ -9,6 +9,7 @@ from kestrel.config.internal import CACHE_INTERFACE_IDENTIFIER from kestrel.display import Display, GraphExplanation from kestrel.exceptions import InstructionNotFound +from kestrel.frontend.completor import do_complete from kestrel.frontend.parser import parse_kestrel_and_update_irgraph from kestrel.interface import InterfaceManager from kestrel.ir.graph import IRGraph @@ -136,7 +137,7 @@ def evaluate_instruction(self, ins: Instruction) -> Display: if iid == ins.id: return display - def do_complete(self, huntflow_block: str, cursor_pos: int): + def do_complete(self, huntflow_block: str, cursor_pos: int) -> Iterable[str]: """Kestrel code auto-completion. Parameters: @@ -146,7 +147,12 @@ def do_complete(self, huntflow_block: str, cursor_pos: int): Returns: A list of suggested strings to complete the code """ - raise NotImplementedError() + return do_complete( + huntflow_block, + cursor_pos, + self.interface_manager, + [v.name for v in self.irgraph.get_variables()], + ) def close(self): """Explicitly close the session. diff --git a/packages/kestrel_core/src/kestrel/utils.py b/packages/kestrel_core/src/kestrel/utils.py index b6213e56..b9238d8f 100644 --- a/packages/kestrel_core/src/kestrel/utils.py +++ b/packages/kestrel_core/src/kestrel/utils.py @@ -1,5 +1,6 @@ import collections.abc import os +from datetime import datetime from importlib import resources from pathlib import Path from pkgutil import get_data @@ -8,6 +9,8 @@ from kestrel.__future__ import is_python_older_than_minor_version from typeguard import typechecked +TIME_FMT = "%Y-%m-%dT%H:%M:%S.%f" + @typechecked def load_data_file(package_name: str, file_name: str) -> str: @@ -76,3 +79,20 @@ def update_nested_dict(dict_old: Mapping, dict_new: Optional[Mapping]) -> Mappin else: dict_old[k] = v return dict_old + + +@typechecked +def timefmt(t: datetime, prec: int = 3) -> str: + """Format Python datetime `t` in RFC 3339-format + + Ported from firepit.timestamp + """ + val = t.strftime(TIME_FMT) + parts = val.split(".") + if len(parts) > 1: + l = len(parts[0]) + digits = parts[1] + num_digits = len(digits) + if num_digits: + l += min(num_digits, prec) + 1 + return val[:l] + "Z" diff --git a/packages/kestrel_interface_opensearch/pyproject.toml b/packages/kestrel_interface_opensearch/pyproject.toml index 6270f6d0..2c0fe636 100644 --- a/packages/kestrel_interface_opensearch/pyproject.toml +++ b/packages/kestrel_interface_opensearch/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ dependencies = [ "kestrel_core>=2.0.0", - "opensearch-py>=2.4.2", + "opensearch-py>=2.6.0", ] [project.urls] diff --git a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py index 0c151c3e..6a68b5dd 100644 --- a/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py +++ b/packages/kestrel_interface_opensearch/src/kestrel_interface_opensearch/interface.py @@ -1,10 +1,10 @@ import logging -from typing import Any, Iterable, Mapping, MutableMapping, Optional, Tuple +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional, Tuple from uuid import UUID from kestrel.display import GraphletExplanation from kestrel.exceptions import DataSourceError, InvalidDataSource -from kestrel.interface import AbstractInterface +from kestrel.interface import DatasourceInterface from kestrel.ir.graph import IRGraphEvaluable from kestrel.ir.instructions import ( DataSource, @@ -73,7 +73,7 @@ def read_sql(sql: str, conn: OpenSearch, dmm: Optional[dict] = None) -> DataFram return concat(dfs) -class OpenSearchInterface(AbstractInterface): +class OpenSearchInterface(DatasourceInterface): def __init__( self, serialized_cache_catalog: Optional[str] = None, @@ -98,6 +98,9 @@ def __init__( def schemes() -> Iterable[str]: return ["opensearch"] + def get_datasources(self) -> List[str]: + return list(self.config.datasources) + def get_storage_of_datasource(datasource: str) -> str: """Get the storage name of a given datasource""" if datasource not in self.config.datasources: diff --git a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py index f415688c..4ee2b574 100644 --- a/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py +++ b/packages/kestrel_interface_sqlalchemy/src/kestrel_interface_sqlalchemy/interface.py @@ -1,12 +1,12 @@ import logging from functools import reduce -from typing import Any, Iterable, Mapping, MutableMapping, Optional +from typing import Any, Iterable, List, Mapping, MutableMapping, Optional from uuid import UUID import sqlalchemy from kestrel.display import GraphletExplanation, NativeQuery from kestrel.exceptions import InvalidDataSource, SourceNotFound -from kestrel.interface import AbstractInterface +from kestrel.interface import DatasourceInterface from kestrel.interface.codegen.sql import ingest_dataframe_to_temp_table from kestrel.interface.codegen.utils import variable_attributes_to_dataframe from kestrel.ir.graph import IRGraphEvaluable @@ -34,7 +34,7 @@ @typechecked -class SQLAlchemyInterface(AbstractInterface): +class SQLAlchemyInterface(DatasourceInterface): def __init__( self, serialized_cache_catalog: Optional[str] = None, @@ -60,6 +60,9 @@ def __init__( def schemes() -> Iterable[str]: return ["sqlalchemy"] + def get_datasources(self) -> List[str]: + return list(self.config.datasources) + def get_storage_of_datasource(self, datasource: str) -> str: """Get the storage name of a given datasource""" if datasource not in self.config.datasources: diff --git a/packages/kestrel_jupyter/pyproject.toml b/packages/kestrel_jupyter/pyproject.toml index 3cc31435..d97c115b 100644 --- a/packages/kestrel_jupyter/pyproject.toml +++ b/packages/kestrel_jupyter/pyproject.toml @@ -31,9 +31,9 @@ dependencies = [ "jupyterlab", "jupyter_client", "nbclassic", - "sqlparse==0.4.4", - "pygments==2.17.2", - "matplotlib==3.8.3", + "sqlparse==0.5.1", + "pygments==2.18.0", + "matplotlib==3.9.1", ] [project.optional-dependencies]