diff --git a/pkg/server/server.go b/pkg/server/server.go index e7ef3e2cf..bfa8023d1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -91,3 +91,12 @@ func WriteResponse( ) { c.JSON(statusCode, responseBody) } + +//AbortWithError aborts the request and returns the error in the response body +func AbortWithError(c *gin.Context, + statusCode int, + err error) { + c.Error(err) + c.JSON(statusCode, err.Error()) + c.Abort() +} diff --git a/presidio-api/cmd/presidio-api/actions.go b/presidio-api/cmd/presidio-api/actions.go index 1e8534847..5df18855e 100644 --- a/presidio-api/cmd/presidio-api/actions.go +++ b/presidio-api/cmd/presidio-api/actions.go @@ -29,7 +29,7 @@ func (api *API) analyze(c *gin.Context) { res, err := api.Services.AnalyzeItem(c, analyzeAPIRequest.Text, analyzeTemplate) if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) + server.AbortWithError(c, http.StatusInternalServerError, err) return } if res == nil { @@ -57,19 +57,17 @@ func (api *API) anonymize(c *gin.Context) { analyzeRes, err := api.Services.AnalyzeItem(c, anonymizeAPIRequest.Text, analyzeTemplate) if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) + server.AbortWithError(c, http.StatusInternalServerError, err) return - } - if analyzeRes == nil { + } else if analyzeRes == nil { return } anonymizeRes, err := api.Services.AnonymizeItem(c, analyzeRes, anonymizeAPIRequest.Text, anonymizeTemplate) if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) + server.AbortWithError(c, http.StatusInternalServerError, err) return - } - if anonymizeRes == nil { + } else if anonymizeRes == nil { return } server.WriteResponse(c, http.StatusOK, anonymizeRes) @@ -82,6 +80,9 @@ func (api *API) scheduleScannerCronJob(c *gin.Context) { if c.Bind(&cronAPIJobRequest) == nil { project := c.Param("project") scannerCronJobRequest := api.getScannerCronJobRequest(&cronAPIJobRequest, project, c) + if scannerCronJobRequest == nil { + return + } scheulderResponse := api.invokeScannerCronJobScheduler(scannerCronJobRequest, c) if scheulderResponse == nil { return @@ -95,7 +96,7 @@ func (api *API) invokeScannerCronJobScheduler(scannerCronJobRequest *types.Scann res, err := api.Services.ApplyScan(c, scannerCronJobRequest) if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) + server.AbortWithError(c, http.StatusInternalServerError, err) return nil } return res @@ -137,7 +138,7 @@ func (api *API) getScannerCronJobRequest(cronJobAPIRequest *types.ScannerCronJob trigger = cronJobAPIRequest.ScannerCronJobRequest.GetTrigger() name = cronJobAPIRequest.ScannerCronJobRequest.GetName() } else { - c.AbortWithError(http.StatusBadRequest, fmt.Errorf("ScannerCronJobTemplateId or ScannerCronJobRequest must be supplied")) + server.AbortWithError(c, http.StatusBadRequest, fmt.Errorf("ScannerCronJobTemplateId or ScannerCronJobRequest must be supplied")) return nil } @@ -154,6 +155,9 @@ func (api *API) scheduleStreamsJob(c *gin.Context) { if c.Bind(&streamsJobRequest) == nil { project := c.Param("project") streamsJobRequest := api.getStreamsJobRequest(&streamsJobRequest, project, c) + if streamsJobRequest == nil { + return + } scheulderResponse := api.invokeStreamsJobScheduler(streamsJobRequest, c) if scheulderResponse == nil { return @@ -166,7 +170,7 @@ func (api *API) scheduleStreamsJob(c *gin.Context) { func (api *API) invokeStreamsJobScheduler(streamsJobRequest *types.StreamsJobRequest, c *gin.Context) *types.StreamsJobResponse { res, err := api.Services.ApplyStream(c, streamsJobRequest) if err != nil { - c.AbortWithError(http.StatusInternalServerError, err) + server.AbortWithError(c, http.StatusInternalServerError, err) return nil } return res @@ -179,6 +183,9 @@ func (api *API) getStreamsJobRequest(jobAPIRequest *types.StreamsJobApiRequest, jobTemplate := &types.StreamsJobTemplate{} api.getTemplate(project, scheduleStreamsJob, jobAPIRequest.StreamsJobTemplateId, jobTemplate, c) + if jobTemplate == nil { + return nil + } streamID := jobTemplate.GetStreamsTemplateId() streamTemplate := &types.StreamTemplate{} api.getTemplate(project, stream, streamID, streamTemplate, c) @@ -193,6 +200,10 @@ func (api *API) getStreamsJobRequest(jobAPIRequest *types.StreamsJobApiRequest, if jobTemplate.AnonymizeTemplateId != "" { api.getTemplate(project, anonymize, jobTemplate.GetAnonymizeTemplateId(), anonymizeTemplate, c) } + + if streamTemplate == nil || datasinkTemplate == nil || analyzeTemplate == nil || anonymizeTemplate == nil { + return nil + } streamsJobRequest = &types.StreamsJobRequest{ Name: streamTemplate.GetName(), StreamsRequest: &types.StreamRequest{ @@ -205,7 +216,7 @@ func (api *API) getStreamsJobRequest(jobAPIRequest *types.StreamsJobApiRequest, } else if jobAPIRequest.GetStreamsJobRequest() != nil { streamsJobRequest = jobAPIRequest.GetStreamsJobRequest() } else { - c.AbortWithError(http.StatusBadRequest, fmt.Errorf("StreamsJobTemplateId or StreamsRequest must be supplied")) + server.AbortWithError(c, http.StatusBadRequest, fmt.Errorf("StreamsJobTemplateId or StreamsRequest must be supplied")) return nil } @@ -215,11 +226,12 @@ func (api *API) getStreamsJobRequest(jobAPIRequest *types.StreamsJobApiRequest, func (api *API) getTemplate(project string, action string, id string, obj interface{}, c *gin.Context) { template, err := api.Templates.GetTemplate(project, action, id) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) + return } err = presidio.ConvertJSONToInterface(template, obj) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) } } @@ -230,7 +242,7 @@ func (api *API) getAnalyzeTemplate(analyzeTemplateID string, analyzeTemplate *ty analyzeTemplate = &types.AnalyzeTemplate{} api.getTemplate(project, analyze, analyzeTemplateID, analyzeTemplate, c) } else if analyzeTemplate == nil { - c.AbortWithError(http.StatusBadRequest, fmt.Errorf("AnalyzeTemplate or AnalyzeTemplateId must be supplied")) + server.AbortWithError(c, http.StatusBadRequest, fmt.Errorf("AnalyzeTemplate or AnalyzeTemplateId must be supplied")) return nil } @@ -244,7 +256,7 @@ func (api *API) getAnonymizeTemplate(anonymizeTemplateID string, anonymizeTempla anonymizeTemplate = &types.AnonymizeTemplate{} api.getTemplate(project, anonymize, anonymizeTemplateID, anonymizeTemplate, c) } else if anonymizeTemplate == nil { - c.AbortWithError(http.StatusBadRequest, fmt.Errorf("AnalyzeTemplate or AnalyzeTemplateId must be supplied")) + server.AbortWithError(c, http.StatusBadRequest, fmt.Errorf("AnalyzeTemplate or AnalyzeTemplateId must be supplied")) return nil } diff --git a/presidio-api/cmd/presidio-api/templates.go b/presidio-api/cmd/presidio-api/templates.go index b5fc37cf5..7582d94f8 100644 --- a/presidio-api/cmd/presidio-api/templates.go +++ b/presidio-api/cmd/presidio-api/templates.go @@ -25,7 +25,7 @@ func (api *API) getActionTemplate(c *gin.Context) { id := c.Param("id") result, err := api.Templates.GetTemplate(project, action, id) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) return } server.WriteResponse(c, http.StatusOK, result) @@ -37,12 +37,12 @@ func (api *API) postActionTemplate(c *gin.Context) { id := c.Param("id") value, err := validateTemplate(action, c) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) return } err = api.Templates.InsertTemplate(project, action, id, value) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) return } server.WriteResponse(c, http.StatusCreated, "Template added successfully ") @@ -54,12 +54,13 @@ func (api *API) putActionTemplate(c *gin.Context) { id := c.Param("id") value, err := validateTemplate(action, c) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) return } err = api.Templates.UpdateTemplate(project, action, id, value) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) + return } server.WriteResponse(c, http.StatusOK, "Template updated successfully") @@ -71,7 +72,7 @@ func (api *API) deleteActionTemplate(c *gin.Context) { id := c.Param("id") err := api.Templates.DeleteTemplate(project, action, id) if err != nil { - c.AbortWithError(http.StatusBadRequest, err) + server.AbortWithError(c, http.StatusBadRequest, err) return } server.WriteResponse(c, http.StatusNoContent, "")