Skip to content

Commit

Permalink
feat: add control implmentations and implemented requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
qduanmu committed Jan 22, 2025
1 parent e5c75ea commit 3fe61d6
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 8 deletions.
313 changes: 308 additions & 5 deletions trestlebot/tasks/sync_cac_content_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,32 @@
import logging
import os
import pathlib
from typing import List
import re
from typing import Dict, List, Optional, Pattern, Set

# from ssg.products import get_all
from ssg.profiles import get_profiles_from_products
from trestle.common.list_utils import none_if_empty
from ssg.controls import Control, ControlsManager
from ssg.products import load_product_yaml, product_yaml_path
from ssg.profiles import _load_yaml_profile_file, get_profiles_from_products
from trestle.common.common_types import TypeWithParts, TypeWithProps
from trestle.common.const import TRESTLE_HREF_HEADING
from trestle.common.list_utils import as_list, none_if_empty
from trestle.common.model_utils import ModelUtils
from trestle.core.catalog.catalog_interface import CatalogInterface
from trestle.core.control_interface import ControlInterface
from trestle.core.generators import generate_sample_model
from trestle.core.models.file_content_type import FileContentType
from trestle.core.profile_resolver import ProfileResolver
from trestle.oscal.catalog import Catalog
from trestle.oscal.common import Property
from trestle.oscal.component import ComponentDefinition, DefinedComponent
from trestle.oscal.component import (
ComponentDefinition,
ControlImplementation,
DefinedComponent,
ImplementedRequirement,
SetParameter,
Statement,
)

from trestlebot import const
from trestlebot.tasks.base_task import TaskBase
Expand All @@ -29,6 +45,65 @@

logger = logging.getLogger(__name__)

SECTION_PATTERN = r"Section ([a-z]):"


class OSCALProfileHelper:
"""Helper class to handle OSCAL profile."""

def __init__(self, trestle_root: pathlib.Path) -> None:
"""Initialize."""
self._root = trestle_root
self.profile_controls: Set[str] = set()
self.controls_by_label: Dict[str, str] = dict()

def load(self, profile_path: str) -> None:
"""Load the profile catalog."""
profile_resolver = ProfileResolver()
resolved_catalog: Catalog = profile_resolver.get_resolved_profile_catalog(
self._root,
profile_path,
block_params=False,
params_format="[.]",
show_value_warnings=True,
)

for control in CatalogInterface(resolved_catalog).get_all_controls_from_dict():
self.profile_controls.add(control.id)
label = ControlInterface.get_label(control)
if label:
self.controls_by_label[label] = control.id
self._handle_parts(control)

def _handle_parts(
self,
control: TypeWithParts,
) -> None:
"""Handle parts of a control."""
if control.parts:
for part in control.parts:
if not part.id:
continue
self.profile_controls.add(part.id)
label = ControlInterface.get_label(part)
# Avoiding key collision here. The higher level control object will take
# precedence.
if label and label not in self.controls_by_label.keys():
self.controls_by_label[label] = part.id
self._handle_parts(part)

def validate(self, control_id: str) -> Optional[str]:
"""Validate that the control id exists in the catalog and return the id"""
if control_id in self.controls_by_label.keys():
logger.debug(f"Found control {control_id} in control labels")
return self.controls_by_label.get(control_id)
elif control_id in self.profile_controls:
logger.debug(f"Found control {control_id} in profile control ids")
return control_id

logger.debug(f"Control {control_id} does not exist in the profile")
return None


class SyncCacContentTask(TaskBase):
"""Sync CaC content to OSCAL component definition task."""
Expand All @@ -47,7 +122,12 @@ def __init__(
self.cac_profile: str = cac_profile
self.cac_content_root: str = cac_content_root
self.compdef_type: str = compdef_type
self.oscal_profile: str = oscal_profile
self.rules: List[str] = []
self.controls: List[Control] = list()
self.profile_href: str = ""
self.profile_path: str = ""
self._rules_by_id: Dict[str, RuleInfo] = dict()

super().__init__(working_dir, None)

Expand All @@ -68,7 +148,8 @@ def _get_rules_properties(self) -> List[Property]:
self.cac_profile,
)
rules_transformer.add_rules(self.rules)
rules: List[RuleInfo] = rules_transformer.get_all_rules()
self.rules_by_id: Dict[str, RuleInfo] = rules_transformer.get_all_rule_objs()
rules: List[RuleInfo] = list(self.rules_by_id.values())
all_rule_properties: List[Property] = rules_transformer.transform(rules)
return all_rule_properties

Expand All @@ -91,6 +172,217 @@ def _add_props(self, oscal_component: DefinedComponent) -> DefinedComponent:
oscal_component.props = props
return oscal_component

def _get_source(self, profile_name_or_href: str) -> None:
"""Get the href and source of the profile."""
profile_in_trestle_dir = "://" not in profile_name_or_href
self.profile_href = profile_name_or_href
if profile_in_trestle_dir:
local_path = f"profiles/{profile_name_or_href}/profile.json"
self.profile_href = TRESTLE_HREF_HEADING + local_path
self.profile_path = os.path.join(self.working_dir, local_path)
else:
self.profile_path = self.profile_href

def _load_controls_manager(self) -> ControlsManager:
"""
Loads and initializes a ControlsManager instance.
"""
product_yml_path = product_yaml_path(self.cac_content_root, self.product)
product_yaml = load_product_yaml(product_yml_path)
controls_dir = os.path.join(self.cac_content_root, "controls")
control_mgr = ControlsManager(controls_dir, product_yaml)
control_mgr.load()
return control_mgr

def _get_controls(self) -> None:
controls_manager = self._load_controls_manager()
policies = controls_manager.policies
profile_yaml = _load_yaml_profile_file(self.cac_profile)
selections = profile_yaml.get("selections", [])
for selected in selections:
if ":" in selected:
parts = selected.split(":")
if len(parts) == 3:
policy_id, level = parts[0], parts[2]
else:
policy_id, level = parts[0], "all"
policy = policies.get(policy_id)
if policy is not None:
self.controls.extend(
controls_manager.get_all_controls_of_level(policy_id, level)
)

@staticmethod
def _build_sections_dict(
control_response: str,
section_pattern: Pattern[str],
) -> Dict[str, List[str]]:
"""Find all sections in the control response and build a dictionary of them."""
lines = control_response.split("\n")

sections_dict: Dict[str, List[str]] = dict()
current_section_label = None

for line in lines:
match = section_pattern.match(line)

if match:
current_section_label = match.group(1)
sections_dict[current_section_label] = [line]
elif current_section_label is not None:
sections_dict[current_section_label].append(line)

return sections_dict

def _create_statement(self, statement_id: str, description: str = "") -> Statement:
"""Create a statement."""
statement = generate_sample_model(Statement)
statement.statement_id = statement_id
if description:
statement.description = description
return statement

def _handle_response(
self,
implemented_req: ImplementedRequirement,
control: Control,
profile: OSCALProfileHelper,
) -> None:
"""
Break down the response into parts.
Args:
implemented_req: The implemented requirement to add the response and statements to.
control_response: The control response to add to the implemented requirement.
"""
# If control notes is unavailable, consider to use other input as replacement
# or a generic information.
control_response = control.notes
pattern = re.compile(SECTION_PATTERN, re.IGNORECASE)

sections_dict = self._build_sections_dict(control_response, pattern)
# oscal_status = OscalStatus.from_string(control.status)

if sections_dict:
# self._add_response_by_status(implemented_req, oscal_status, REPLACE_ME)
implemented_req.statements = list()
for section_label, section_content in sections_dict.items():
statement_id = profile.validate(
f"{implemented_req.control_id}_smt.{section_label}"
)
if statement_id is None:
continue

section_content_str = "\n".join(section_content)
section_content_str = pattern.sub("", section_content_str)
statement = self._create_statement(
statement_id, section_content_str.strip()
)
implemented_req.statements.append(statement)
# else:
# self._add_response_by_status(
# implemented_req, oscal_status, control_response.strip()
# )

def _process_rule_ids(self, rule_ids: List[str]) -> List[str]:
"""
Process rule ids.
Notes: Rule ids with an "=" are parameters and should not be included
# when searching for rules.
"""
processed_rule_ids: List[str] = list()
for rule_id in rule_ids:
parts = rule_id.split("=")
if len(parts) == 1:
processed_rule_ids.append(rule_id)
return processed_rule_ids

def _attach_rules(
self,
type_with_props: TypeWithProps,
rule_ids: List[str],
) -> None:
"""Add rules to a type with props."""
all_props: List[Property] = as_list(type_with_props.props)
# Get a subset from self.rules_by_id according to rule_ids
rules_by_id = {k: v for k, v in self.rules_by_id.items() if k in rule_ids}
rules: List[RuleInfo] = list(rules_by_id.values())
rules_transformer = RulesTransformer(
self.cac_content_root,
self.product,
self.cac_profile,
)
rule_properties: List[Property] = rules_transformer.transform(rules)
all_props.extend(rule_properties)
type_with_props.props = none_if_empty(all_props)

def _add_set_parameters(
self, control_implementation: ControlImplementation
) -> None:
"""Add set parameters to a control implementation."""
rules: List[RuleInfo] = list(self.rules_by_id.values())
params = []
for rule in rules:
params.extend(rule._parameters)
param_selections = {param.id: param.selected_value for param in params}

if param_selections:
all_set_params: List[SetParameter] = as_list(
control_implementation.set_parameters
)
for param_id, value in param_selections.items():
set_param = generate_sample_model(SetParameter)
set_param.param_id = param_id
set_param.values = [value]
all_set_params.append(set_param)
control_implementation.set_parameters = none_if_empty(all_set_params)

def _create_implemented_requirement(
self, control: Control
) -> Optional[ImplementedRequirement]:
"""Create implemented requirement from a control object"""

logger.info(f"Creating implemented requirement for {control.id}")
profile = OSCALProfileHelper(pathlib.Path(self.working_dir))
profile.load(self.profile_path)

control_id = profile.validate(control.id)
if control_id:
implemented_req = generate_sample_model(ImplementedRequirement)
implemented_req.control_id = control_id
self._handle_response(implemented_req, control, profile)
rule_ids = self._process_rule_ids(control.rules)
self._attach_rules(implemented_req, rule_ids)
return implemented_req
return None

def _create_control_implementation(self) -> ControlImplementation:
"""Create control implementation for a component."""
ci = generate_sample_model(ControlImplementation)
ci.source = self.profile_href
all_implement_reqs = list()
self._get_controls()

# Get all profile related controls here:
for control in self.controls:
implemented_req = self._create_implemented_requirement(control)
if implemented_req:
all_implement_reqs.append(implemented_req)
ci.implemented_requirements = all_implement_reqs
self._add_set_parameters(ci)
return ci

def _add_control_implementations(
self, oscal_component: DefinedComponent
) -> DefinedComponent:
"""Add control implementations to OSCAL component."""
self._get_source(self.oscal_profile)
control_implementation: ControlImplementation = (
self._create_control_implementation()
)
oscal_component.control_implementations = [control_implementation]
return oscal_component

def _update_compdef(
self, cd_json: pathlib.Path, oscal_component: DefinedComponent
) -> None:
Expand All @@ -106,6 +398,16 @@ def _update_compdef(
logger.info(f"Start to update props of {component.title}")
compdef.components[index].props = oscal_component.props
updated = True
if (
component.control_implementations
!= oscal_component.control_implementations
):
logger.info(f"Start to update props of {component.title}")
compdef.components[index].control_implementations = (
oscal_component.control_implementations
)
updated = True
if updated:
compdef.oscal_write(cd_json)
break

Expand Down Expand Up @@ -140,6 +442,7 @@ def _create_or_update_compdef(self, compdef_type: str = "service") -> None:
"""Create or update component definition for specified CaC profile."""
oscal_component = generate_sample_model(DefinedComponent)
oscal_component = self._add_props(oscal_component)
oscal_component = self._add_control_implementations(oscal_component)

repo_path = pathlib.Path(self.working_dir)
cd_json: pathlib.Path = ModelUtils.get_model_path_for_name_and_class(
Expand Down
5 changes: 2 additions & 3 deletions trestlebot/transformers/cac_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,8 @@ def _get_rule_properties(self, ruleset: str, rule_obj: RuleInfo) -> List[Propert

return rule_properties

def get_all_rules(self) -> List[RuleInfo]:
"""Get all rules that have been loaded"""
return list(self._rules_by_id.values())
def get_all_rule_objs(self) -> Dict[str, RuleInfo]:
return self._rules_by_id

def transform(self, rule_objs: List[RuleInfo]) -> List[Property]:
"""Get the rules properties for a set of rule ids."""
Expand Down

0 comments on commit 3fe61d6

Please sign in to comment.