Skip to content

Commit

Permalink
fix: enhance rulebook validation during project import (#702)
Browse files Browse the repository at this point in the history
  • Loading branch information
bzwei authored Mar 6, 2024
1 parent 126bdfe commit 84dc11c
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/aap_eda/services/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .imports import ProjectImportService
from .imports import ProjectImportError, ProjectImportService

__all__ = ("ProjectImportService",)
__all__ = ("ProjectImportService", "ProjectImportError")
57 changes: 43 additions & 14 deletions src/aap_eda/services/project/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import yaml
from django.core import exceptions
from django.db import utils

from aap_eda.core import models
from aap_eda.core.types import StrPath
Expand All @@ -46,27 +45,39 @@ class ProjectImportError(Exception):
pass


class MalformedError(Exception):
pass


def _project_import_wrapper(
func: Callable[[ProjectImportService, models.Project], None]
):
@wraps(func)
def wrapper(self: ProjectImportService, project: models.Project):
project.import_state = models.Project.ImportState.RUNNING
project.save()
project.save(update_fields=["import_state"])
error = None
try:
func(self, project)
project.import_state = models.Project.ImportState.COMPLETED
project.save()
except (utils.IntegrityError, exceptions.ObjectDoesNotExist):
logger.exception("Object may have been deleted")
except Exception as e:
project.import_state = models.Project.ImportState.FAILED
project.import_error = str(e)
error = e
finally:
try:
project.import_state = models.Project.ImportState.FAILED
project.import_error = str(e)
project.save()
project.save(update_fields=["import_state", "import_error"])
except exceptions.ObjectDoesNotExist:
logger.exception("Project may have been deleted")
raise
raise ProjectImportError(
"Project may have been deleted"
) from error
else:
if error and isinstance(error, ProjectImportError):
raise error
elif error:
raise ProjectImportError(
f"Failed to import the project: {str(error)}"
) from error

return wrapper

Expand Down Expand Up @@ -224,7 +235,10 @@ def _try_load_rulebook(
logger.warning("Invalid YAML file %s: %s", rulebook_path, exc)
return None

if not self._is_rulebook_file(content):
try:
self._validate_rulebook_file(content)
except MalformedError as exc:
logger.warning("Malformed rulebook %s: %s", rulebook_path, exc)
return None

relpath = os.path.relpath(rulebook_path, rulebooks_dir)
Expand All @@ -234,7 +248,22 @@ def _try_load_rulebook(
content=content,
)

def _is_rulebook_file(self, data: Any) -> bool:
def _validate_rulebook_file(self, data: Any) -> None:
if not isinstance(data, list):
return False
return all("rules" in entry for entry in data)
raise MalformedError("rulebook must contain a list of rulesets")
required_keys = ["name", "condition", "action"]
for ruleset in data:
if "rules" not in ruleset:
raise MalformedError("no rules in a ruleset")
rules = ruleset["rules"]
if not isinstance(rules, list):
raise MalformedError("ruleset must contain a list of rules")
for rule in rules:
if not all(key in rule for key in required_keys):
raise MalformedError(
f"ruleset must contain {required_keys}"
)
if not all(rule.get(key) is not None for key in required_keys):
raise MalformedError(
f"rule's {required_keys} must have non empty values"
)
12 changes: 9 additions & 3 deletions src/aap_eda/tasks/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from aap_eda.core import models
from aap_eda.core.tasking import get_queue, job, unique_enqueue
from aap_eda.services.project import ProjectImportService
from aap_eda.services.project import ProjectImportError, ProjectImportService

logger = logging.getLogger(__name__)
PROJECT_TASKS_QUEUE = "default"
Expand All @@ -27,7 +27,10 @@ def import_project(project_id: int):
logger.info(f"Task started: Import project ( {project_id=} )")

project = models.Project.objects.get(pk=project_id)
ProjectImportService().import_project(project)
try:
ProjectImportService().import_project(project)
except ProjectImportError as e:
logger.exception(e)

logger.info(f"Task complete: Import project ( project_id={project.id} )")

Expand All @@ -37,7 +40,10 @@ def sync_project(project_id: int):
logger.info(f"Task started: Sync project ( {project_id=} )")

project = models.Project.objects.get(pk=project_id)
ProjectImportService().sync_project(project)
try:
ProjectImportService().sync_project(project)
except ProjectImportError as e:
logger.exception(e)

logger.info(f"Task complete: Sync project ( project_id={project.id} )")

Expand Down
20 changes: 20 additions & 0 deletions tests/integration/services/data/project-05-import.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"name": "good.yml",
"rulesets": [
{
"name": "Hello Events",
"rules": [
{
"name": "Say Hello",
"action": {
"run_playbook": {
"name": "ansible.eda.hello"
}
}
}
]
}
]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
condition: event.i == 1
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
condition: event.i == 1
action:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules: bad
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
name: Say Hello
condition: event.i == 1
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- condition: event.i == 1
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name:
condition: event.i == 1
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
condition:
action:
run_playbook:
name: ansible.eda.hello
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
condition: event.i == 1
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
- name: Hello Events
hosts: all
sources:
- ansible.eda.range:
limit: 5
rules:
- name: Say Hello
condition: event.i == 1
action:
run_playbook:
name: ansible.eda.hello
...
45 changes: 45 additions & 0 deletions tests/integration/services/test_project_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import re
import shutil
Expand Down Expand Up @@ -295,6 +296,50 @@ def clone_project(_url, path, *_args, **_kwargs):
storage_save_patch.assert_not_called()


@pytest.mark.django_db
def test_project_import_with_invalid_rulebooks(
storage_save_patch, service_tempdir_patch, caplog
):
def clone_project(_url, path, *_args, **_kwargs):
src = DATA_DIR / "project-05"
shutil.copytree(src, path, symlinks=False)
return repo_mock

repo_mock = mock.Mock(name="GitRepository()")
repo_mock.rev_parse.return_value = (
"adc83b19e793491b1c6ea0fd8b46cd9f32e592fc"
)

git_mock = mock.Mock(name="GitRepository", spec=GitRepository)
git_mock.clone.side_effect = clone_project

project = models.Project.objects.create(
name="test-project-04", url="https://git.example.com/repo.git"
)

logger = logging.getLogger("aap_eda")
propagate = logger.propagate
logger.propagate = True
caplog.set_level(logging.WARNING)
try:
service = ProjectImportService(git_cls=git_mock)
service.import_project(project)
finally:
logger.propagate = propagate

assert project.git_hash == "adc83b19e793491b1c6ea0fd8b46cd9f32e592fc"
assert project.import_state == models.Project.ImportState.COMPLETED
assert caplog.text.count("WARNING") == 10

rulebooks = list(project.rulebook_set.order_by("name"))
assert len(rulebooks) == 1

with open(DATA_DIR / "project-05-import.json") as fp:
expected_rulebooks = json.load(fp)
for rulebook, expected in zip(rulebooks, expected_rulebooks):
assert_rulebook_is_valid(rulebook, expected)


def assert_rulebook_is_valid(rulebook: models.Rulebook, expected: dict):
assert rulebook.name == expected["name"]

Expand Down

0 comments on commit 84dc11c

Please sign in to comment.