diff --git a/backend/src/apiserver/list/list.go b/backend/src/apiserver/list/list.go index e38be8f7339e..174eff961d8b 100644 --- a/backend/src/apiserver/list/list.go +++ b/backend/src/apiserver/list/list.go @@ -22,6 +22,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "math" "reflect" "strings" @@ -97,6 +98,13 @@ type Options struct { *token } +func EmptyOptions() *Options { + return &Options{ + math.MaxInt32, + &token{}, + } +} + // Matches returns trues if the sorting and filtering criteria in o matches that // of the one supplied in opts. func (o *Options) Matches(opts *Options) bool { @@ -213,9 +221,14 @@ func (o *Options) AddSortingToSelect(sqlBuilder sq.SelectBuilder) sq.SelectBuild if o.IsDesc { order = "DESC" } - sqlBuilder = sqlBuilder. - OrderBy(fmt.Sprintf("%v %v", o.SortByFieldPrefix+o.SortByFieldName, order)). - OrderBy(fmt.Sprintf("%v %v", o.KeyFieldPrefix+o.KeyFieldName, order)) + + if o.SortByFieldName != "" { + sqlBuilder = sqlBuilder.OrderBy(fmt.Sprintf("%v %v", o.SortByFieldPrefix+o.SortByFieldName, order)) + } + + if o.KeyFieldName != "" { + sqlBuilder = sqlBuilder.OrderBy(fmt.Sprintf("%v %v", o.KeyFieldPrefix+o.KeyFieldName, order)) + } return sqlBuilder } diff --git a/backend/src/apiserver/list/list_test.go b/backend/src/apiserver/list/list_test.go index 1806e158eece..e207cd900ab4 100644 --- a/backend/src/apiserver/list/list_test.go +++ b/backend/src/apiserver/list/list_test.go @@ -15,6 +15,8 @@ package list import ( + "fmt" + "math" "reflect" "strings" "testing" @@ -645,6 +647,11 @@ func TestAddPaginationAndFilterToSelect(t *testing.T) { wantSQL: "SELECT * FROM MyTable ORDER BY SortField DESC, KeyField DESC LIMIT 124", wantArgs: nil, }, + { + in: EmptyOptions(), + wantSQL: fmt.Sprintf("SELECT * FROM MyTable LIMIT %d", math.MaxInt32+1), + wantArgs: nil, + }, { in: &Options{ PageSize: 123, diff --git a/backend/src/apiserver/main.go b/backend/src/apiserver/main.go index 5430674897b5..4e302c48423f 100644 --- a/backend/src/apiserver/main.go +++ b/backend/src/apiserver/main.go @@ -106,6 +106,13 @@ func main() { } log.SetLevel(level) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) + defer cancel() + err = resourceManager.SyncSwfCrs(ctx) + if err != nil { + log.Errorf("Could not refresh the ScheduledWorkflow Kubernetes resources: %v", err) + } + go startRpcServer(resourceManager) startHttpProxy(resourceManager) diff --git a/backend/src/apiserver/resource/resource_manager.go b/backend/src/apiserver/resource/resource_manager.go index 82c13aa369fa..e915f32424bb 100644 --- a/backend/src/apiserver/resource/resource_manager.go +++ b/backend/src/apiserver/resource/resource_manager.go @@ -567,6 +567,79 @@ func (r *ResourceManager) CreateRun(ctx context.Context, run *model.Run) (*model return newRun, nil } +// SyncSwfCrs synchronizes/updates the existing ScheduledWorkflow CRs with the existing jobs. +func (r *ResourceManager) SyncSwfCrs(ctx context.Context) error { + filterContext := &model.FilterContext{ + ReferenceKey: &model.ReferenceKey{Type: model.NamespaceResourceType, ID: "kubeflow"}, + } + + opts := list.EmptyOptions() + + jobs, _, _, err := r.jobStore.ListJobs(filterContext, opts) + + if err != nil { + return util.Wrap(err, "Failed to refresh ScheduledWorkflow Kubernetes resources") + } + + for i := 0; i < len(jobs); i++ { + tmpl, _, err := r.fetchTemplateFromPipelineSpec(&jobs[i].PipelineSpec) + if err != nil { + return failedToSyncSwfCrsError(err) + } + + scheduledWorkflow, err := tmpl.ScheduledWorkflow(jobs[i]) + if err != nil { + return failedToSyncSwfCrsError(err) + } + + err = r.patchSwfCrSpec(ctx, jobs[i].Namespace, jobs[i].K8SName, scheduledWorkflow.Spec) + if err != nil { + if util.IsNotFound(errors.Cause(err)) { + continue + } + return failedToSyncSwfCrsError(err) + } + } + + return nil +} + +func failedToSyncSwfCrsError(err error) error { + return util.Wrap(err, "Failed to refresh ScheduledWorkflow Kubernetes resources") +} + +func (r *ResourceManager) patchSwfCrSpec(ctx context.Context, k8sNamespace string, crdName string, newSpec interface{}) error { + if k8sNamespace == "" { + k8sNamespace = common.GetPodNamespace() + } + if k8sNamespace == "" { + return errors.New("Namespace cannot be empty when deleting a ScheduledWorkflow Kubernetes resource.") + } + + patchPayload := map[string]interface{}{ + "spec": newSpec, + } + + patchBytes, err := json.Marshal(patchPayload) + if err != nil { + return util.NewInternalServerError(err, + "Failed to marshal patch spec") + } + + _, err = r.getScheduledWorkflowClient(k8sNamespace).Patch( + ctx, + crdName, + types.MergePatchType, + patchBytes, + ) + if err != nil { + return util.NewInternalServerError(err, + "Failed to patch ScheduledWorkflow") + } + + return nil +} + // Fetches a run with a given id. func (r *ResourceManager) GetRun(runId string) (*model.Run, error) { run, err := r.runStore.GetRun(runId)