diff --git a/src/aap_eda/api/serializers/event_stream.py b/src/aap_eda/api/serializers/event_stream.py index 8b7fc9136..39c0dca18 100644 --- a/src/aap_eda/api/serializers/event_stream.py +++ b/src/aap_eda/api/serializers/event_stream.py @@ -19,7 +19,6 @@ from django.conf import settings from django.core.validators import RegexValidator from rest_framework import serializers -from rest_framework.exceptions import ValidationError from aap_eda.api.constants import ( PG_NOTIFY_TEMPLATE_RULEBOOK_DATA, @@ -31,7 +30,11 @@ MissingEventStreamRulebookSource, ) from aap_eda.api.serializers.credential import CredentialSerializer -from aap_eda.api.serializers.utils import substitute_extra_vars, swap_sources +from aap_eda.api.serializers.utils import ( + YAMLSerializerField, + substitute_extra_vars, + swap_sources, +) from aap_eda.core import models, validators logger = logging.getLogger(__name__) @@ -96,29 +99,6 @@ def _updated_listener_ruleset(validated_data): return swap_sources(validated_data["rulebook_rulesets"], sources_info) -class YAMLSerializerField(serializers.Field): - """Serializer for YAML a superset of JSON.""" - - def to_internal_value(self, data) -> dict: - if data: - try: - parsed_args = yaml.safe_load(data) - except yaml.YAMLError: - raise ValidationError("Invalid YAML format for 'source_args'") - - if not isinstance(parsed_args, dict): - raise ValidationError( - "The 'source_args' field must be a YAML " - "object (dictionary)" - ) - - return parsed_args - return data - - def to_representation(self, value) -> str: - return yaml.dump(value) - - class EventStreamSerializer(serializers.ModelSerializer): decision_environment_id = serializers.IntegerField( validators=[validators.check_if_de_exists] diff --git a/src/aap_eda/api/serializers/rulebook.py b/src/aap_eda/api/serializers/rulebook.py index 6892867ed..14b635aa9 100644 --- a/src/aap_eda/api/serializers/rulebook.py +++ b/src/aap_eda/api/serializers/rulebook.py @@ -15,6 +15,7 @@ from drf_spectacular.utils import extend_schema_field from rest_framework import serializers +from aap_eda.api.serializers.utils import YAMLSerializerField from aap_eda.core import models @@ -254,8 +255,9 @@ class AuditEventSerializer(serializers.ModelSerializer): help_text="The received timestamp of the event", ) - payload = serializers.JSONField( + payload = YAMLSerializerField( required=False, + allow_null=True, help_text="The payload in the event", ) diff --git a/src/aap_eda/api/serializers/utils.py b/src/aap_eda/api/serializers/utils.py index 3d209d920..058e912c0 100644 --- a/src/aap_eda/api/serializers/utils.py +++ b/src/aap_eda/api/serializers/utils.py @@ -19,10 +19,34 @@ import yaml from django.conf import settings from jinja2.nativetypes import NativeTemplate +from rest_framework import serializers +from rest_framework.exceptions import ValidationError LOGGER = logging.getLogger(__name__) +class YAMLSerializerField(serializers.Field): + """Serializer for YAML a superset of JSON.""" + + def to_internal_value(self, data) -> dict: + if data: + try: + parsed_args = yaml.safe_load(data) + except yaml.YAMLError: + raise ValidationError("Invalid YAML format for input data") + + if not isinstance(parsed_args, dict): + raise ValidationError( + "The input field must be a YAML object (dictionary)" + ) + + return parsed_args + return data + + def to_representation(self, value) -> str: + return yaml.dump(value) + + def _render_string(value: str, context: dict) -> str: if "{{" in value and "}}" in value: return NativeTemplate(value, undefined=jinja2.StrictUndefined).render( diff --git a/tests/integration/api/test_event_stream.py b/tests/integration/api/test_event_stream.py index 05b48ed32..d2c2905d0 100644 --- a/tests/integration/api/test_event_stream.py +++ b/tests/integration/api/test_event_stream.py @@ -393,7 +393,7 @@ def test_create_event_stream_bad_args( result = response.data assert ( str(result["source_args"][0]) - == "The 'source_args' field must be a YAML object (dictionary)" + == "The input field must be a YAML object (dictionary)" ) diff --git a/tests/integration/api/test_rulebook.py b/tests/integration/api/test_rulebook.py index 01c63bb59..30db34dc9 100644 --- a/tests/integration/api/test_rulebook.py +++ b/tests/integration/api/test_rulebook.py @@ -404,6 +404,8 @@ def test_list_events_from_audit_rule(client: APIClient, init_db): events = response.data["results"] assert len(events) == 2 assert events[0]["received_at"] > events[1]["received_at"] + assert events[0]["payload"] == events[1]["payload"] + assert events[0]["payload"] == "key: value\n" @pytest.mark.parametrize( @@ -585,6 +587,7 @@ def init_db(): source_type="ansible.eda.range", rule_fired_at="2023-12-14T15:19:02.313122Z", received_at="2023-12-14T15:19:02.289549Z", + payload={"key": "value"}, ) audit_event_2 = models.AuditEvent.objects.create( id=str(uuid.uuid4()), @@ -592,6 +595,7 @@ def init_db(): source_type="ansible.eda.range", rule_fired_at="2023-12-14T15:19:02.323704Z", received_at="2023-12-14T15:19:02.313063Z", + payload={"key": "value"}, ) audit_event_3 = models.AuditEvent.objects.create( id=str(uuid.uuid4()), @@ -599,6 +603,7 @@ def init_db(): source_type="ansible.eda.range", rule_fired_at="2023-12-14T15:19:02.323704Z", received_at="2023-12-14T15:19:02.321472Z", + payload={"key": "value"}, ) audit_event_1.audit_actions.add(action_1) audit_event_2.audit_actions.add(action_2)