Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
sp1thas committed Nov 26, 2023
1 parent d4c48e0 commit 123dd2f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 65 deletions.
6 changes: 4 additions & 2 deletions kaggle_provider/example_dags/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ def kaggle_workflow():

# $ kaggle c list --sort-by prize -v
competitions_list_op = KaggleOperator(
task_id="competition_list",
command="c",
subcommand="list",
optional_arguments={"sort-by": "prize", "v": True},
optional_arguments={"sort-by": "prize"},
)

# $ kaggle d list --sort-by votes -m
datasets_list_op = KaggleOperator(
task_id="dataset_list",
command="d",
subcommand="list",
optional_arguments={"sort-by": "votes", "m": True},
optional_arguments={"sort-by": "votes"},
)

competitions_list_op >> datasets_list_op
Expand Down
85 changes: 39 additions & 46 deletions kaggle_provider/hooks/kaggle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

from typing import Any, Tuple, Dict, Optional, Union
import os.path
import traceback
from typing import Any, Tuple, Optional, Union

import sh
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection

from kaggle_provider.utils.credentials import CredentialsTemporaryFile


class KaggleHook(BaseHook):
Expand All @@ -11,63 +17,52 @@ class KaggleHook(BaseHook):
:param kaggle_conn_id: connection that has the kaggle authentication credentials.
:type kaggle_conn_id: str
:param kaggle_bin_path: Kaggle binary path.
:type kaggle_bin_path: str
"""

conn_name_attr = "kaggle_conn_id"
default_conn_name = "kaggle_default"
conn_type = "kaggle"
hook_name = "Kaggle"
cred: Dict[str, str]

@staticmethod
def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import (
BS3PasswordFieldWidget,
BS3TextFieldWidget,
)
from flask_babel import lazy_gettext
from wtforms import PasswordField, StringField

return {
"user": StringField(lazy_gettext("User"), widget=BS3TextFieldWidget()),
"key": PasswordField(lazy_gettext("Key"), widget=BS3PasswordFieldWidget()),
}

@staticmethod
def get_ui_field_behaviour() -> dict:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour"""
import json

return {
"hidden_fields": [],
"relabeling": {},
"placeholders": {
"extra": json.dumps(
{
"example_parameter": "parameter",
},
indent=4,
),
"user": "HeirFlough",
"key": "mY53cr3tk3y!",
},
"hidden_fields": ["host", "schema", "port", "login", "password"],
}

def __init__(
self,
kaggle_conn_id: str = default_conn_name,
kaggle_bin_path: str | None = None,
) -> None:
super().__init__()
self.kaggle_conn_id = kaggle_conn_id
self.kaggle_bin_path = kaggle_bin_path or self._get_kaggle_bin()

def get_conn(self) -> Dict[str, str]:
if self.kaggle_bin_path and not os.path.exists(self.kaggle_bin_path):
raise RuntimeError(f"{self.kaggle_bin_path} does not exist")

self.command = sh.Command(self.kaggle_bin_path)

@staticmethod
def _get_kaggle_bin() -> Optional[str]:
potential_paths = (
os.path.join(os.getenv("HOME", ""), ".local", "bin", "kaggle"),
)
for p_path in potential_paths:
if os.path.exists(p_path):
return p_path
raise RuntimeError("kaggle binary can not be found")

def get_conn(self) -> Connection:
"""
Returns kaggle credentials.
Returns kaggle connection.
"""
conn = self.get_connection(self.kaggle_conn_id)

return {"KAGGLE_USER": conn.user, "KAGGLE_KEY": conn.key}
return self.get_connection(self.kaggle_conn_id)

def run(
self,
Expand All @@ -85,26 +80,24 @@ def run(
:param optional_arguments: additional kaggle command optional arguments
:type optional_arguments: dict
"""
from sh import kaggle

self.creds = self.get_conn()

command_base = []
if command:
command_base.append(command)
if subcommand:
command_base.append(subcommand)

command = kaggle.bake(*command_base, **optional_arguments, _env=self.creds)
command = self.command.bake(*command_base, **optional_arguments) # type: ignore

self.log.info(f"Running: f{str(command)}") # type: ignore
with CredentialsTemporaryFile(connection=self.get_conn()):
stdout = command() # type: ignore

return command() # type: ignore
self.log.info(f"\n{stdout}\n") # type: ignore
return stdout

def test_connection(self) -> Tuple[bool, str]:
"""Test a connection"""
try:
self.run(command=None, subcommand=None, v=True)
self.run(command="config", subcommand="view")
return True, "Connection successfully tested"
except Exception as e:
return False, str(e)
return False, "\n".join(traceback.format_exception(e))
29 changes: 12 additions & 17 deletions kaggle_provider/operators/kaggle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@ class KaggleOperator(BaseOperator):
"""
Calls Kaggle CLI.
:param kaggle_conn_id: connection to run the operator with
:type kaggle_conn_id: str
:param command: The Kaggle command. (templated)
:type command: str
:param subcommand: The Kaggle subcommand. (templated)
:type subcommand: str
:param optional_arguments: The Kaggle optional arguments. (templated)
:type optional_arguments: a dictionary of key/value pairs
:param kaggle_conn_id: connection to run the operator with
:type kaggle_conn_id: str
:param kaggle_bin_path: kaggle binary path
:type kaggle_bin_path: str
"""

# Specify the arguments that are allowed to parse with jinja templating
template_fields = [
"command",
"subcommand",
"optional_arguments",
]
template_fields = ["command", "subcommand", "optional_arguments", "kaggle_bin_path"]
template_fields_renderers = {"optional_arguments": "py"}
template_ext = ()
ui_color = "#f4a460"
Expand All @@ -42,24 +40,21 @@ def __init__(
subcommand: str | None = None,
optional_arguments: Dict[str, Union[str, bool]] | None = None,
kaggle_conn_id: str = KaggleHook.default_conn_name,
kaggle_bin_path: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.kaggle_conn_id = kaggle_conn_id
self.command = command
self.subcommand = subcommand
self.optional_arguments = optional_arguments or {}
if kwargs.get("xcom_push") is not None:
raise AirflowException(
"'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead"
)
self.kaggle_bin_path = kaggle_bin_path

def execute(self, context: Context) -> Any:
hook = KaggleHook(kaggle_conn_id=self.kaggle_conn_id)
def execute(self, context: Context) -> str:
hook = KaggleHook(
kaggle_conn_id=self.kaggle_conn_id, kaggle_bin_path=self.kaggle_bin_path
)

self.log.info("Call Kaggle CLI")
output = hook.run(
return hook.run(
command=self.command, subcommand=self.subcommand, **self.optional_arguments
)

return output
25 changes: 25 additions & 0 deletions kaggle_provider/utils/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os

from airflow.models.connection import Connection


class CredentialsTemporaryFile:
def __init__(self, connection: Connection):
self.connection = connection
self.folder_path = os.path.join(os.environ.get("HOME", ""), ".kaggle")
os.makedirs(self.folder_path, exist_ok=True)
self.file_path = os.path.join(self.folder_path, "kaggle.json")
self.extra = self.connection.extra_dejson
if not self.extra.get("username"):
raise ValueError("username is missing")
if not self.extra.get("key"):
raise ValueError("key is missing")

def __enter__(self):
self.f = open(self.file_path, "w")
self.f.write(self.connection.extra)
self.f.close()
os.chmod(self.file_path, 0o600)

def __exit__(self, exc_type, exc_val, exc_tb):
os.remove(self.file_path)

0 comments on commit 123dd2f

Please sign in to comment.