From 123dd2fa7e9f57e9235946f6818c95ac7f9965d8 Mon Sep 17 00:00:00 2001 From: Panagiotis Simakis Date: Mon, 27 Nov 2023 00:16:28 +0200 Subject: [PATCH] fix typo --- kaggle_provider/example_dags/kaggle.py | 6 +- kaggle_provider/hooks/kaggle.py | 85 ++++++++++++-------------- kaggle_provider/operators/kaggle.py | 29 ++++----- kaggle_provider/utils/credentials.py | 25 ++++++++ 4 files changed, 80 insertions(+), 65 deletions(-) create mode 100644 kaggle_provider/utils/credentials.py diff --git a/kaggle_provider/example_dags/kaggle.py b/kaggle_provider/example_dags/kaggle.py index 50c9233..b7a8a0b 100644 --- a/kaggle_provider/example_dags/kaggle.py +++ b/kaggle_provider/example_dags/kaggle.py @@ -26,16 +26,18 @@ def kaggle_workflow(): # $ kaggle c list --sort-by prize -v competitions_list_op = KaggleOperator( + task_id="competition_list", command="c", subcommand="list", - optional_arguments={"sort-by": "prize", "v": True}, + optional_arguments={"sort-by": "prize"}, ) # $ kaggle d list --sort-by votes -m datasets_list_op = KaggleOperator( + task_id="dataset_list", command="d", subcommand="list", - optional_arguments={"sort-by": "votes", "m": True}, + optional_arguments={"sort-by": "votes"}, ) competitions_list_op >> datasets_list_op diff --git a/kaggle_provider/hooks/kaggle.py b/kaggle_provider/hooks/kaggle.py index 7b0e7ae..fa6bf26 100644 --- a/kaggle_provider/hooks/kaggle.py +++ b/kaggle_provider/hooks/kaggle.py @@ -1,8 +1,14 @@ from __future__ import annotations -from typing import Any, Tuple, Dict, Optional, Union +import os.path +import traceback +from typing import Any, Tuple, Optional, Union +import sh from airflow.hooks.base import BaseHook +from airflow.models.connection import Connection + +from kaggle_provider.utils.credentials import CredentialsTemporaryFile class KaggleHook(BaseHook): @@ -11,63 +17,52 @@ class KaggleHook(BaseHook): :param kaggle_conn_id: connection that has the kaggle authentication credentials. :type kaggle_conn_id: str + :param kaggle_bin_path: Kaggle binary path. + :type kaggle_bin_path: str """ conn_name_attr = "kaggle_conn_id" default_conn_name = "kaggle_default" conn_type = "kaggle" hook_name = "Kaggle" - cred: Dict[str, str] - - @staticmethod - def get_connection_form_widgets() -> dict[str, Any]: - """Returns connection widgets to add to connection form""" - from flask_appbuilder.fieldwidgets import ( - BS3PasswordFieldWidget, - BS3TextFieldWidget, - ) - from flask_babel import lazy_gettext - from wtforms import PasswordField, StringField - return { - "user": StringField(lazy_gettext("User"), widget=BS3TextFieldWidget()), - "key": PasswordField(lazy_gettext("Key"), widget=BS3PasswordFieldWidget()), - } - - @staticmethod - def get_ui_field_behaviour() -> dict: + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: """Returns custom field behaviour""" - import json return { - "hidden_fields": [], - "relabeling": {}, - "placeholders": { - "extra": json.dumps( - { - "example_parameter": "parameter", - }, - indent=4, - ), - "user": "HeirFlough", - "key": "mY53cr3tk3y!", - }, + "hidden_fields": ["host", "schema", "port", "login", "password"], } def __init__( self, kaggle_conn_id: str = default_conn_name, + kaggle_bin_path: str | None = None, ) -> None: super().__init__() self.kaggle_conn_id = kaggle_conn_id + self.kaggle_bin_path = kaggle_bin_path or self._get_kaggle_bin() - def get_conn(self) -> Dict[str, str]: + if self.kaggle_bin_path and not os.path.exists(self.kaggle_bin_path): + raise RuntimeError(f"{self.kaggle_bin_path} does not exist") + + self.command = sh.Command(self.kaggle_bin_path) + + @staticmethod + def _get_kaggle_bin() -> Optional[str]: + potential_paths = ( + os.path.join(os.getenv("HOME", ""), ".local", "bin", "kaggle"), + ) + for p_path in potential_paths: + if os.path.exists(p_path): + return p_path + raise RuntimeError("kaggle binary can not be found") + + def get_conn(self) -> Connection: """ - Returns kaggle credentials. + Returns kaggle connection. """ - conn = self.get_connection(self.kaggle_conn_id) - - return {"KAGGLE_USER": conn.user, "KAGGLE_KEY": conn.key} + return self.get_connection(self.kaggle_conn_id) def run( self, @@ -85,26 +80,24 @@ def run( :param optional_arguments: additional kaggle command optional arguments :type optional_arguments: dict """ - from sh import kaggle - - self.creds = self.get_conn() - command_base = [] if command: command_base.append(command) if subcommand: command_base.append(subcommand) - command = kaggle.bake(*command_base, **optional_arguments, _env=self.creds) + command = self.command.bake(*command_base, **optional_arguments) # type: ignore - self.log.info(f"Running: f{str(command)}") # type: ignore + with CredentialsTemporaryFile(connection=self.get_conn()): + stdout = command() # type: ignore - return command() # type: ignore + self.log.info(f"\n{stdout}\n") # type: ignore + return stdout def test_connection(self) -> Tuple[bool, str]: """Test a connection""" try: - self.run(command=None, subcommand=None, v=True) + self.run(command="config", subcommand="view") return True, "Connection successfully tested" except Exception as e: - return False, str(e) + return False, "\n".join(traceback.format_exception(e)) diff --git a/kaggle_provider/operators/kaggle.py b/kaggle_provider/operators/kaggle.py index c7886c7..06f082c 100644 --- a/kaggle_provider/operators/kaggle.py +++ b/kaggle_provider/operators/kaggle.py @@ -15,22 +15,20 @@ class KaggleOperator(BaseOperator): """ Calls Kaggle CLI. - :param kaggle_conn_id: connection to run the operator with - :type kaggle_conn_id: str :param command: The Kaggle command. (templated) :type command: str :param subcommand: The Kaggle subcommand. (templated) :type subcommand: str :param optional_arguments: The Kaggle optional arguments. (templated) :type optional_arguments: a dictionary of key/value pairs + :param kaggle_conn_id: connection to run the operator with + :type kaggle_conn_id: str + :param kaggle_bin_path: kaggle binary path + :type kaggle_bin_path: str """ # Specify the arguments that are allowed to parse with jinja templating - template_fields = [ - "command", - "subcommand", - "optional_arguments", - ] + template_fields = ["command", "subcommand", "optional_arguments", "kaggle_bin_path"] template_fields_renderers = {"optional_arguments": "py"} template_ext = () ui_color = "#f4a460" @@ -42,6 +40,7 @@ def __init__( subcommand: str | None = None, optional_arguments: Dict[str, Union[str, bool]] | None = None, kaggle_conn_id: str = KaggleHook.default_conn_name, + kaggle_bin_path: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -49,17 +48,13 @@ def __init__( self.command = command self.subcommand = subcommand self.optional_arguments = optional_arguments or {} - if kwargs.get("xcom_push") is not None: - raise AirflowException( - "'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead" - ) + self.kaggle_bin_path = kaggle_bin_path - def execute(self, context: Context) -> Any: - hook = KaggleHook(kaggle_conn_id=self.kaggle_conn_id) + def execute(self, context: Context) -> str: + hook = KaggleHook( + kaggle_conn_id=self.kaggle_conn_id, kaggle_bin_path=self.kaggle_bin_path + ) - self.log.info("Call Kaggle CLI") - output = hook.run( + return hook.run( command=self.command, subcommand=self.subcommand, **self.optional_arguments ) - - return output diff --git a/kaggle_provider/utils/credentials.py b/kaggle_provider/utils/credentials.py new file mode 100644 index 0000000..b45d5f0 --- /dev/null +++ b/kaggle_provider/utils/credentials.py @@ -0,0 +1,25 @@ +import os + +from airflow.models.connection import Connection + + +class CredentialsTemporaryFile: + def __init__(self, connection: Connection): + self.connection = connection + self.folder_path = os.path.join(os.environ.get("HOME", ""), ".kaggle") + os.makedirs(self.folder_path, exist_ok=True) + self.file_path = os.path.join(self.folder_path, "kaggle.json") + self.extra = self.connection.extra_dejson + if not self.extra.get("username"): + raise ValueError("username is missing") + if not self.extra.get("key"): + raise ValueError("key is missing") + + def __enter__(self): + self.f = open(self.file_path, "w") + self.f.write(self.connection.extra) + self.f.close() + os.chmod(self.file_path, 0o600) + + def __exit__(self, exc_type, exc_val, exc_tb): + os.remove(self.file_path)