Skip to content

Commit

Permalink
chore: return error if workspace config violates global constraints (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 authored Oct 21, 2024
1 parent 912f91e commit 04861dd
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 31 deletions.
9 changes: 6 additions & 3 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ func (a *apiServer) validatePoliciesAndWorkloadType(
_, priorityEnabledErr := a.m.rm.SmallerValueIsHigherPriority()
switch workloadType {
case model.ExperimentType:
return configpolicy.ValidateExperimentConfig(globalConfigPolicies, configPolicies, priorityEnabledErr)
return configpolicy.ValidateExperimentConfig(globalConfigPolicies, configPolicies,
priorityEnabledErr)
case model.NTSCType:
return configpolicy.ValidateNTSCConfig(globalConfigPolicies, configPolicies, priorityEnabledErr)
return configpolicy.ValidateNTSCConfig(globalConfigPolicies, configPolicies,
priorityEnabledErr)
default:
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(invalidWorkloadTypeErr+": %s.", workloadType))
}
Expand Down Expand Up @@ -127,7 +129,8 @@ func (a *apiServer) PutWorkspaceConfigPolicies(
return nil, err
}

err = a.validatePoliciesAndWorkloadType(globalConfigPolicies, req.WorkloadType, req.ConfigPolicies)
err = a.validatePoliciesAndWorkloadType(globalConfigPolicies, req.WorkloadType,
req.ConfigPolicies)
if err != nil {
return nil, err
}
Expand Down
5 changes: 2 additions & 3 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
) error {
q := db.Bun().NewInsert().Model(tcp)

q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime)
q = q.Set("invariant_config = ?", tcp.InvariantConfig)
q = q.Set("constraints = ?", tcp.Constraints)
q = q.Set("last_updated_by = ?, last_updated_time = ?, invariant_config = ?, constraints = ?",
tcp.LastUpdatedBy, tcp.LastUpdatedTime, tcp.InvariantConfig, tcp.Constraints)

if tcp.WorkspaceID == nil {
q = q.On("CONFLICT (workload_type) WHERE workspace_id IS NULL DO UPDATE")
Expand Down
117 changes: 92 additions & 25 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,41 @@ func ValidWorkloadType(val string) bool {
}
}

// UnmarshalConfigPolicies unmarshals optionally specified invariant config and constraint
// configurations presented as YAML or JSON strings.
func UnmarshalConfigPolicies[T any](errMsg string, constraintsStr,
configStr *string) (*model.Constraints, *T,
error,
) {
var constraints *model.Constraints
var config *T

if constraintsStr != nil {
unmarshaledConstraints, err := UnmarshalConfigPolicy[model.Constraints](
*constraintsStr,
errMsg,
)
if err != nil {
ConfigPolicyWarning(err.Error())
return nil, nil, err
}
constraints = unmarshaledConstraints
}

if configStr != nil {
unmarshaledConfig, err := UnmarshalConfigPolicy[T](
*configStr,
errMsg,
)
if err != nil {
ConfigPolicyWarning(err.Error())
return nil, nil, err
}
config = unmarshaledConfig
}
return constraints, config, nil
}

// UnmarshalConfigPolicy is a generic helper function to unmarshal both JSON and YAML strings.
func UnmarshalConfigPolicy[T any](str string, errString string) (*T, error) {
var configPolicy T
Expand Down Expand Up @@ -81,11 +116,22 @@ func ValidateExperimentConfig(
return err
}

// Warn the user when fields specified in workspace config policies overlap with global config
// policies (since these fields will be overridden by the respective fields in the global
// policies).
var globalConstraints *model.Constraints
var globalConfig *expconf.ExperimentConfig
if globalConfigPolicies != nil {
checkAgainstGlobalConfig[model.Constraints](globalConfigPolicies.Constraints, cp.Constraints, "invalid constraints")
checkAgainstGlobalConfig[expconf.ExperimentConfig](
globalConfigPolicies.InvariantConfig, cp.InvariantConfig, InvalidExperimentConfigPolicyErr,
)
globalConstraints, globalConfig, err = UnmarshalConfigPolicies[expconf.ExperimentConfig](
InvalidExperimentConfigPolicyErr,
globalConfigPolicies.Constraints,
globalConfigPolicies.InvariantConfig)
if err != nil {
return err
}

configPolicyOverlap(globalConstraints, cp.Constraints)
configPolicyOverlap(globalConfig, cp.InvariantConfig)
}

if cp.Constraints != nil {
Expand All @@ -95,10 +141,18 @@ func ValidateExperimentConfig(
if cp.InvariantConfig != nil {
if cp.InvariantConfig.RawResources != nil {
checkAgainstGlobalPriority(priorityEnabledErr, cp.InvariantConfig.RawResources.RawPriority)

// Verify the workspace invariant config doesn't conflict with workspace constraints.
if err := checkConstraintConflicts(cp.Constraints, cp.InvariantConfig.RawResources.RawMaxSlots,
cp.InvariantConfig.RawResources.RawSlotsPerTrial, cp.InvariantConfig.RawResources.RawPriority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidExperimentConfigPolicyErr+": %s.", err))
}

// Verify the workspace invariant config doesn't conflict with global constraints.
if err := checkConstraintConflicts(globalConstraints, cp.InvariantConfig.RawResources.RawMaxSlots,
cp.InvariantConfig.RawResources.RawSlotsPerTrial, cp.InvariantConfig.RawResources.RawPriority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidExperimentConfigPolicyErr+": %s.", err))
}
}
}

Expand All @@ -120,11 +174,25 @@ func ValidateNTSCConfig(
please remove "invariant_config" section and try again`
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(NotSupportedConfigPolicyErr+": %s.", msg))
}

// Warn the user when fields specified in workspace config policies overlap with global config
// policies (since these fields will be overridden by the respective fields in the global
// policies).
var globalConstraints *model.Constraints
var globalConfig *model.CommandConfig
if globalConfigPolicies != nil {
checkAgainstGlobalConfig[model.Constraints](globalConfigPolicies.Constraints, cp.Constraints, "invalid constraints")
checkAgainstGlobalConfig[model.CommandConfig](
globalConfigPolicies.InvariantConfig, cp.InvariantConfig, InvalidNTSCConfigPolicyErr,
)
if globalConfigPolicies.Constraints != nil {
globalConstraints, globalConfig, err = UnmarshalConfigPolicies[model.CommandConfig](
InvalidNTSCConfigPolicyErr,
globalConfigPolicies.Constraints,
globalConfigPolicies.InvariantConfig)
if err != nil {
return err
}
}

configPolicyOverlap(globalConstraints, cp.Constraints)
configPolicyOverlap(globalConfig, cp.InvariantConfig)
}

if cp.Constraints != nil {
Expand All @@ -141,10 +209,18 @@ func ValidateNTSCConfig(
slots = &cp.InvariantConfig.Resources.Slots
}

// Verify the workspace invariant config doesn't conflict with workspace constraints.
if err := checkConstraintConflicts(cp.Constraints, cp.InvariantConfig.Resources.MaxSlots,
slots, cp.InvariantConfig.Resources.Priority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidNTSCConfigPolicyErr+": %s.", err))
}

// Verify the workspace invariant config conflict with global constraints.
if err := checkConstraintConflicts(globalConstraints,
cp.InvariantConfig.Resources.MaxSlots, slots,
cp.InvariantConfig.Resources.Priority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidNTSCConfigPolicyErr+": %s.", err))
}
}

return err
Expand Down Expand Up @@ -180,25 +256,16 @@ func checkConstraintConflicts(constraints *model.Constraints, maxSlots, slots, p
return nil
}

// checkAgainstGlobalConfig is a generic to check constraints & invariant configs against the global config.
func checkAgainstGlobalConfig[T any](
globalConfigPolicies *string,
config *T,
errorMsg string,
) {
if globalConfigPolicies != nil && config != nil {
global, err := UnmarshalConfigPolicy[T](*globalConfigPolicies, errorMsg)
if err != nil {
ConfigPolicyWarning(err.Error())
return
}
configPolicyConflict(global, config)
// configPolicyOverlap compares two different configurations and warns the user when both
// configurations define the same field.
func configPolicyOverlap(config1, config2 interface{}) {
if reflect.ValueOf(config1).Type() != reflect.ValueOf(config2).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.Constraints{}).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.CommandConfig{}).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&expconf.ExperimentConfig{}).Type() {
return
}
}

// configPolicyConflict compares two different configurations and
// returns an error if both try to define the same field.
func configPolicyConflict(config1, config2 interface{}) {
v1 := reflect.ValueOf(config1)
v2 := reflect.ValueOf(config2)

Expand Down
Loading

0 comments on commit 04861dd

Please sign in to comment.