Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add options to output registered entity summary #3028

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
23 changes: 23 additions & 0 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@
help="Skip errors during registration. This is useful when registering multiple packages and you want to skip "
"errors for some packages.",
)
@click.option(
"--summary-format",
"-f",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default=None,
help="Set output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.",
)
@click.option(
"--summary-dir",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=None,
help="Directory to save registration summary. Uses current working directory if not specified.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider validating summary format and directory

Consider adding validation for summary-format and summary-dir options. When summary-format is specified but summary-dir is not, the summary may not be saved correctly. Consider making summary-dir required when summary-format is provided.

Code suggestion
Check the AI-generated fix before applying
Suggested change
"--summary-format",
"-f",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default=None,
help="Set output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.",
)
@click.option(
"--summary-dir",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=None,
help="Directory to save registration summary. Uses current working directory if not specified.",
)
"--summary-format",
"-f",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default=None,
help="Set output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.",
callback=validate_summary_options,
)
"--summary-dir",
required=False,
type=click.Path(dir_okay=True, file_okay=False, writable=True, resolve_path=True),
default=None,
help="Directory to save registration summary. Uses current working directory if not specified.",
callback=validate_summary_options,
)

Code Review Run #9a3edb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
@click.pass_context
def register(
Expand All @@ -162,12 +177,15 @@ def register(
activate_launchplans: bool,
env: typing.Optional[typing.Dict[str, str]],
skip_errors: bool,
summary_format: typing.Optional[str],
summary_dir: typing.Optional[str],
):
"""
see help
"""
# Set the relevant copy option if non_fast is set, this enables the individual file listing behavior
# that the copy flag uses.

if non_fast:
click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow")
if "--copy" in sys.argv:
Expand Down Expand Up @@ -195,6 +213,9 @@ def register(
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
)

if summary_dir is not None and summary_format is None:
raise click.UsageError("--summary-format is a required parameter when --summary-dir is specified")

# Use extra images in the config file if that file exists
config_file = ctx.obj.get(constants.CTX_CONFIG_FILE)
if config_file:
Expand Down Expand Up @@ -225,6 +246,8 @@ def register(
package_or_module=package_or_module,
remote=remote,
env=env,
summary_format=summary_format,
summary_dir=summary_dir,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
Expand Down
50 changes: 46 additions & 4 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import functools
import json
import os
import tarfile
import tempfile
import typing
from pathlib import Path

import click
import yaml
from rich import print as rprint

from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
Expand Down Expand Up @@ -251,6 +253,8 @@ def register(
remote: FlyteRemote,
copy_style: CopyFileDetection,
env: typing.Optional[typing.Dict[str, str]],
summary_format: typing.Optional[str],
summary_dir: typing.Optional[str],
dry_run: bool = False,
activate_launchplans: bool = False,
skip_errors: bool = False,
Expand Down Expand Up @@ -333,6 +337,14 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
is_lp = True
else:
og_id = cp_entity.template.id

result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider enhancing registration result information

Consider adding more detailed status information in the result dictionary. The current status field only captures high-level states ('skipped', 'success', 'failed'). Additional fields like error_message and timestamp could provide more context for debugging and monitoring.

Code suggestion
Check the AI-generated fix before applying
Suggested change
result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
"timestamp": datetime.datetime.now().isoformat(),
"error_message": "",
"details": {}
}

Code Review Run #9a3edb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


try:
if not dry_run:
try:
Expand All @@ -350,30 +362,60 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
print_registration_status(
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message
)
result["status"] = "success"

except Exception as e:
if not skip_errors:
raise e
print_registration_status(og_id, success=False)
result["status"] = "failed"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it fails, what will the values of other keys be?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values of the other keys(id, type, version) are pre-computed before registration. Thus, the values will not be empty even if the registration fails.

The values are from:

result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
where og_id is the id of the entity's template / entity itself.

else:
print_registration_status(og_id, dry_run=True)
except RegistrationSkipped:
print_registration_status(og_id, success=False)
result["status"] = "skipped"

return result

async def _register(entities: typing.List[task.TaskSpec]):
loop = asyncio.get_running_loop()
tasks = []
for entity in entities:
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity)))
await asyncio.gather(*tasks)
return
results = await asyncio.gather(*tasks)
return results

# concurrent register
cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities))
asyncio.run(_register(cp_task_entities))
task_results = asyncio.run(_register(cp_task_entities))
# serial register
cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities))
other_results = []
for entity in cp_other_entities:
_raw_register(entity)
other_results.append(_raw_register(entity))

all_results = task_results + other_results

click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")

if summary_format is not None:
supported_format = ["json", "yaml"]
if summary_format not in supported_format:
raise ValueError(f"Unsupported file format: {summary_format}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using set for format checking

Consider using a set for supported_format instead of a list since we're only checking membership. This would provide O(1) lookup instead of O(n). Could be defined as supported_format = {'json', 'yaml'}.

Code suggestion
Check the AI-generated fix before applying
Suggested change
supported_format = ["json", "yaml"]
if summary_format not in supported_format:
raise ValueError(f"Unsupported file format: {summary_format}")
if summary_format not in {"json", "yaml"}:
raise ValueError(f"Unsupported file format: {summary_format}")

Code Review Run #9a3edb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


if summary_dir is not None:
os.makedirs(summary_dir, exist_ok=True)
else:
summary_dir = os.getcwd()

summary_file = f"registration_summary.{summary_format}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that the other registration may overwrite the summary? I think we could name the file with other unique names (ie. tmp file, py file name, wf name, version number)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! The summary will be overwritten if the registration is rerun.

summary_path = os.path.join(summary_dir, summary_file)

if summary_format == "json":
with open(summary_path, "w") as f:
json.dump(all_results, f)
elif summary_format == "yaml":
with open(summary_path, "w") as f:
yaml.dump(all_results, f)

click.secho(f"Registration summary written to: {summary_path}", fg="green")
73 changes: 73 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import subprocess
import json

import mock
import pytest
Expand Down Expand Up @@ -163,3 +164,75 @@ def test_non_fast_register_require_version(mock_client, mock_remote):
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"])
assert result.exit_code == 1
shutil.rmtree("core3")


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_summary_dir_without_format(mock_client, mock_remote):
mock_remote._client = mock_client
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"

runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core4", exist_ok=True)
with open(os.path.join("core4", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(pyflyte.main, ["register", "--summary-dir", "summaries", "core4"])
assert result.exit_code == 2
print(result.output)
paullongtan marked this conversation as resolved.
Show resolved Hide resolved
assert "--summary-format is a required parameter when --summary-dir is specified" in result.output

shutil.rmtree("core4")


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_registrated_summary_json(mock_client, mock_remote):
ctx = FlyteContextManager.current_context()
mock_remote._client = mock_client
mock_remote.return_value.context = ctx
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core5", exist_ok=True)
os.makedirs("summaries", exist_ok=True)
with open(os.path.join("core5", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()

result = runner.invoke(
pyflyte.main,
["register", "--summary-format", "json", "--summary-dir", "summaries", "core5"]
)
assert result.exit_code == 0

summary_path = os.path.join("summaries", "registration_summary.json")
assert os.path.exists(summary_path)

with open(summary_path) as f:
summary_data = json.load(f)

assert isinstance(summary_data, list)
assert len(summary_data) > 0
for entry in summary_data:
assert "id" in entry
assert "type" in entry
assert "version" in entry
assert "status" in entry

# Ensure cleanup happens even if test fails
if os.path.exists("core5"):
shutil.rmtree("core5")
if os.path.exists("summaries"):
shutil.rmtree("summaries")
Loading