Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use of sdk and unittesting #6

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ jobs:
- uses: actions/setup-python@v4
- run: python3 -m pip install --upgrade pip
- run: pip install -e '.[test]'
- run: pytest
- env:
KAGGLE_JSON: ${{ secrets.KAGGLE_JSON }}
run: pytest tests/
62 changes: 51 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ This is the main operator that can be used to execute any kaggle cli command:
```python
from kaggle_provider.operators.kaggle import KaggleOperator

list_competitions_op = KaggleOperator(
command='competitions',
subcommand='list',
optional_arguments={'m': True},
)
list_competitions_op = KaggleOperator(task_id='foo', command='competitions_list', op_kwargs={'sort_by': 'prize'})
```

### Hooks
Expand All @@ -71,10 +67,54 @@ in your custom operator too.
```python
from kaggle_provider.hooks.kaggle import KaggleHook

hook = KaggleHook(kaggle_conn_id='kaggle_default')
hook.run(
command='datasets',
subcommand='list',
m=True,
)
hook = KaggleHook()
hook.run('datasets_list', sort_by="votes", user="sp1thas")
```


### Available commands

- `competitions_list`
- `competition_submit`
- `competition_submissions`
- `competition_list_files`
- `competition_download_file`
- `competition_download_files`
- `competition_leaderboard_download`
- `competition_leaderboard_view`
- `dataset_list`
- `dataset_metadata_prep`
- `dataset_metadata_update`
- `dataset_metadata`
- `dataset_list_files`
- `dataset_status`
- `dataset_download_file`
- `dataset_download_files`
- `dataset_create_version`
- `dataset_initialize`
- `dataset_create_new`
- `download_file`
- `kernels_list`
- `kernels_initialize`
- `kernels_push`
- `kernels_pull`
- `kernels_output`
- `kernels_status`
- `model_get`
- `model_list`
- `model_initialize`
- `model_create_new`
- `model_delete`
- `model_update`
- `model_instance_get`
- `model_instance_initialize`
- `model_instance_create`
- `model_instance_delete`
- `model_instance_update`
- `model_instance_version_create`
- `model_instance_version_download`
- `model_instance_version_delete`
- `download_needed`

Details regarding the command arguments can be found in the corresponding method docstring of this
[module](https://github.com/Kaggle/kaggle-api/blob/main/kaggle/api/kaggle_api_extended.py)
32 changes: 32 additions & 0 deletions kaggle_provider/_utils/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from airflow.models.connection import Connection

KAGGLE_USERNAME = "KAGGLE_USERNAME"
KAGGLE_KEY = "KAGGLE_KEY"


class TemporaryCredentials:
def __init__(self, connection: Connection):
self.connection = connection
self.extra = self.connection.extra_dejson
if not self.extra.get("username"):
raise ValueError("username is missing")
if not self.extra.get("key"):
raise ValueError("key is missing")
self.user_before = os.getenv(KAGGLE_USERNAME)
self.key_before = os.getenv(KAGGLE_KEY)

def __enter__(self):
os.environ[KAGGLE_USERNAME] = self.extra["username"]
os.environ[KAGGLE_KEY] = self.extra["key"]

def __exit__(self, exc_type, exc_val, exc_tb):
if self.user_before:
os.environ[KAGGLE_USERNAME] = self.user_before
else:
os.environ.pop(KAGGLE_USERNAME, None)
if self.key_before:
os.environ[KAGGLE_KEY] = self.key_before
else:
os.environ.pop(KAGGLE_KEY, None)
11 changes: 11 additions & 0 deletions kaggle_provider/_utils/encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import datetime
import json


class DefaultEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, (datetime.datetime, datetime.date)):
return o.isoformat()
if hasattr(o, "__dict__"):
return o.__dict__
return super(DefaultEncoder, self).default(o)
30 changes: 7 additions & 23 deletions kaggle_provider/example_dags/kaggle.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,25 @@
from pendulum import datetime

from airflow.decorators import dag
from pendulum import datetime

from kaggle_provider.operators.kaggle import KaggleOperator


@dag(
start_date=datetime(2023, 1, 1),
schedule=None,
default_args={"kaggle_conn_id": "kaggle_default"},
tags=["kaggle"],
)
def kaggle_workflow():
"""
### Kaggle DAG

Showcases the kaggle provider package's operator.

To run this example, create a kaggle connection with:
- id: kaggle_default
- type: kaggle
"""

# $ kaggle c list --sort-by prize -v
competitions_list_op = KaggleOperator(
task_id="competition_list",
command="c",
subcommand="list",
optional_arguments={"sort-by": "prize"},
task_id="competitions_list",
command="competitions_list",
op_kwargs={"sort_by": "prize"},
)

# $ kaggle d list --sort-by votes -m
datasets_list_op = KaggleOperator(
task_id="dataset_list",
command="d",
subcommand="list",
optional_arguments={"sort-by": "votes"},
task_id="datasets_list",
command="datasets_list",
op_kwargs={"sort_by": "votes", "user": "sp1thas"},
)

competitions_list_op >> datasets_list_op
Expand Down
71 changes: 26 additions & 45 deletions kaggle_provider/hooks/kaggle.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import os.path
import json
import traceback
from typing import Any, Tuple, Optional, Union
from typing import Any, Tuple

import sh
from airflow.hooks.base import BaseHook
from airflow.models.connection import Connection

from kaggle_provider.utils.credentials import CredentialsTemporaryFile
from kaggle_provider._utils.credentials import TemporaryCredentials
from kaggle_provider._utils.encoder import DefaultEncoder


class KaggleHook(BaseHook):
Expand All @@ -17,8 +17,6 @@ class KaggleHook(BaseHook):

:param kaggle_conn_id: connection that has the kaggle authentication credentials.
:type kaggle_conn_id: str
:param kaggle_bin_path: Kaggle binary path.
:type kaggle_bin_path: str
"""

conn_name_attr = "kaggle_conn_id"
Expand All @@ -31,32 +29,17 @@ def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Returns custom field behaviour"""

return {
"hidden_fields": ["host", "schema", "port", "login", "password"],
"hidden_fields": ["port", "password", "login", "schema", "host"],
"relabeling": {},
"placeholders": {},
}

def __init__(
self,
kaggle_conn_id: str = default_conn_name,
kaggle_bin_path: str | None = None,
) -> None:
super().__init__()
self.kaggle_conn_id = kaggle_conn_id
self.kaggle_bin_path = kaggle_bin_path or self._get_kaggle_bin()

if self.kaggle_bin_path and not os.path.exists(self.kaggle_bin_path):
raise RuntimeError(f"{self.kaggle_bin_path} does not exist")

self.command = sh.Command(self.kaggle_bin_path)

@staticmethod
def _get_kaggle_bin() -> Optional[str]:
potential_paths = (
os.path.join(os.getenv("HOME", ""), ".local", "bin", "kaggle"),
)
for p_path in potential_paths:
if os.path.exists(p_path):
return p_path
raise RuntimeError("kaggle binary can not be found")

def get_conn(self) -> Connection:
"""
Expand All @@ -66,38 +49,36 @@ def get_conn(self) -> Connection:

def run(
self,
command: Optional[str] = None,
subcommand: Optional[str] = None,
**optional_arguments: Union[str, bool],
) -> str:
command: str,
*args: Any,
**kwargs: Any,
) -> Any:
"""
Performs the kaggle command

:param command: kaggle command
:type command: str
:param subcommand: kaggle subcommand
:type subcommand: str
:param optional_arguments: additional kaggle command optional arguments
:type optional_arguments: dict
:param args: required positional kaggle command arguments
:type kwargs: optional keyword kaggle command arguments
"""
command_base = []
if command:
command_base.append(command)
if subcommand:
command_base.append(subcommand)
with TemporaryCredentials(connection=self.get_conn()):
import kaggle

command = self.command.bake(*command_base, **optional_arguments) # type: ignore
try:
clb = getattr(kaggle.api, command)
except AttributeError as e:
raise ValueError(f"Unknown command: {command}") from e

with CredentialsTemporaryFile(connection=self.get_conn()):
stdout = command() # type: ignore
response = clb(*args or (), **kwargs or {})

self.log.info(f"\n{stdout}\n") # type: ignore
return stdout
return json.loads(json.dumps(response, cls=DefaultEncoder))

def test_connection(self) -> Tuple[bool, str]:
"""Test a connection"""
try:
self.run(command="config", subcommand="view")
return True, "Connection successfully tested"
with TemporaryCredentials(connection=self.get_conn()):
import kaggle

return True, "Connection successfully tested"
except Exception as e:
return False, "\n".join(traceback.format_exception(e))
return False, "\n".join(traceback.format_exception(e)) # type: ignore
51 changes: 20 additions & 31 deletions kaggle_provider/operators/kaggle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Union
import pprint
from typing import TYPE_CHECKING, Any, Collection, Mapping

from airflow.models import BaseOperator

Expand All @@ -12,48 +13,36 @@

class KaggleOperator(BaseOperator):
"""
Calls Kaggle CLI.

:param command: The Kaggle command. (templated)
:type command: str
:param subcommand: The Kaggle subcommand. (templated)
:type subcommand: str
:param optional_arguments: The Kaggle optional arguments. (templated)
:type optional_arguments: a dictionary of key/value pairs
:param command: kaggle command.
:param op_args: Required positional arguments. (templated)
:param op_kwargs: Optional keyword arguments. (templated)
:param kaggle_conn_id: connection to run the operator with
:type kaggle_conn_id: str
:param kaggle_bin_path: kaggle binary path
:type kaggle_bin_path: str
"""

# Specify the arguments that are allowed to parse with jinja templating
template_fields = ["command", "subcommand", "optional_arguments", "kaggle_bin_path"]
template_fields_renderers = {"optional_arguments": "py"}
template_fields = ["command", "op_args", "op_kwargs"]
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
template_ext = ()
ui_color = "#20beff"

def __init__(
self,
*,
command: str | None = None,
subcommand: str | None = None,
optional_arguments: Dict[str, Union[str, bool]] | None = None,
command: str,
op_args: Collection[Any] | None = None,
op_kwargs: Mapping[str, Any] | None = None,
kaggle_conn_id: str = KaggleHook.default_conn_name,
kaggle_bin_path: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.kaggle_conn_id = kaggle_conn_id
self.command = command
self.subcommand = subcommand
self.optional_arguments = optional_arguments or {}
self.kaggle_bin_path = kaggle_bin_path

def execute(self, context: Context) -> str:
hook = KaggleHook(
kaggle_conn_id=self.kaggle_conn_id, kaggle_bin_path=self.kaggle_bin_path
)

return hook.run(
command=self.command, subcommand=self.subcommand, **self.optional_arguments
)
self.op_args = op_args
self.op_kwargs = op_kwargs
self.kaggle_conn_id = kaggle_conn_id

def execute(self, context: Context) -> Any:
hook = KaggleHook(kaggle_conn_id=self.kaggle_conn_id)

response = hook.run(self.command, *self.op_args or (), **self.op_kwargs or {})
self.log.info(pprint.pformat(response))
return response
Loading
Loading