From 1d1a1ffe5ebeb69d4b6e53b8e957089f5e83ea9e Mon Sep 17 00:00:00 2001 From: Mike Johanson Date: Wed, 5 Mar 2025 14:19:06 -0700 Subject: [PATCH] feat: enable cancelling of requests --- internal/controller/http/v1/error.go | 6 ++++++ internal/entity/dto/v1/canceled.go | 20 ++++++++++++++++++++ internal/usecase/devices/alarms.go | 10 ++++++---- internal/usecase/devices/certificates.go | 4 ++-- internal/usecase/devices/connections.go | 2 +- internal/usecase/devices/consent.go | 6 +++--- internal/usecase/devices/features.go | 7 +++++-- internal/usecase/devices/info.go | 18 ++++++++++++------ internal/usecase/devices/interfaces.go | 2 +- internal/usecase/devices/network.go | 2 +- internal/usecase/devices/power.go | 8 ++++---- internal/usecase/devices/repo.go | 1 + internal/usecase/devices/wsman/message.go | 10 ++++++++-- 13 files changed, 70 insertions(+), 26 deletions(-) create mode 100644 internal/entity/dto/v1/canceled.go diff --git a/internal/controller/http/v1/error.go b/internal/controller/http/v1/error.go index 049131cc..3bb6be4a 100644 --- a/internal/controller/http/v1/error.go +++ b/internal/controller/http/v1/error.go @@ -22,6 +22,7 @@ type response struct { func ErrorResponse(c *gin.Context, err error) { var ( validatorErr validator.ValidationErrors + cancelledError dto.CanceledError nfErr sqldb.NotFoundError notValidErr dto.NotValidError dbErr sqldb.DatabaseError @@ -35,6 +36,8 @@ func ErrorResponse(c *gin.Context, err error) { switch { case errors.As(err, &netErr): netErrorHandle(c, netErr) + case errors.As(err, &cancelledError): + cancelledErrorHandle(c, cancelledError) case errors.As(err, ¬ValidErr): notValidErrorHandle(c, notValidErr) case errors.As(err, &validatorErr): @@ -59,6 +62,9 @@ func ErrorResponse(c *gin.Context, err error) { func netErrorHandle(c *gin.Context, netErr net.Error) { c.AbortWithStatusJSON(http.StatusGatewayTimeout, response{netErr.Error()}) } +func cancelledErrorHandle(c *gin.Context, cancelError dto.CanceledError) { + c.AbortWithStatusJSON(http.StatusBadRequest, response{cancelError.Error()}) +} func notValidErrorHandle(c *gin.Context, err dto.NotValidError) { c.AbortWithStatusJSON(http.StatusBadRequest, response{err.Console.FriendlyMessage()}) diff --git a/internal/entity/dto/v1/canceled.go b/internal/entity/dto/v1/canceled.go new file mode 100644 index 00000000..11a116f0 --- /dev/null +++ b/internal/entity/dto/v1/canceled.go @@ -0,0 +1,20 @@ +package dto + +import ( + "github.com/open-amt-cloud-toolkit/console/pkg/consoleerrors" +) + +type CanceledError struct { + Console consoleerrors.InternalError +} + +func (e CanceledError) Error() string { + return e.Console.Error() +} + +func (e CanceledError) Wrap(function, call string, err error) error { + _ = e.Console.Wrap(function, call, err) + e.Console.Message = "Canceled: " + err.Error() + + return e +} diff --git a/internal/usecase/devices/alarms.go b/internal/usecase/devices/alarms.go index 527efc0e..9f862233 100644 --- a/internal/usecase/devices/alarms.go +++ b/internal/usecase/devices/alarms.go @@ -22,12 +22,14 @@ func (uc *UseCase) GetAlarmOccurrences(c context.Context, guid string) ([]dto.Al if err != nil { return nil, err } - if item == nil || item.GUID == "" { return nil, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) + if device == nil { + return nil, ErrCancelled + } alarms, err := device.GetAlarmOccurrences() if err != nil { @@ -61,7 +63,7 @@ func (uc *UseCase) CreateAlarmOccurrences(c context.Context, guid string, alarm alarm.InstanceID = alarm.ElementName - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) alarmReference, err := device.CreateAlarmOccurrences(alarm.InstanceID, alarm.StartTime, alarm.Interval, alarm.DeleteOnCompletion) if err != nil { @@ -83,7 +85,7 @@ func (uc *UseCase) DeleteAlarmOccurrences(c context.Context, guid, instanceID st return ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) err = device.DeleteAlarmOccurrences(instanceID) if err != nil { diff --git a/internal/usecase/devices/certificates.go b/internal/usecase/devices/certificates.go index 50e207bb..9179639b 100644 --- a/internal/usecase/devices/certificates.go +++ b/internal/usecase/devices/certificates.go @@ -159,7 +159,7 @@ func (uc *UseCase) GetCertificates(c context.Context, guid string) (dto.Security return dto.SecuritySettings{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.GetCertificates() if err != nil { @@ -253,7 +253,7 @@ func (uc *UseCase) GetDeviceCertificate(c context.Context, guid string) (dto.Cer return dto.Certificate{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) cert1, err := device.GetDeviceCertificate() if err != nil { diff --git a/internal/usecase/devices/connections.go b/internal/usecase/devices/connections.go index c4192b3c..1084431a 100644 --- a/internal/usecase/devices/connections.go +++ b/internal/usecase/devices/connections.go @@ -18,7 +18,7 @@ func (uc *UseCase) GetTLSSettingData(c context.Context, guid string) ([]dto.Sett return nil, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.GetTLSSettingData() if err != nil { diff --git a/internal/usecase/devices/consent.go b/internal/usecase/devices/consent.go index b5cd30e7..a25b10b3 100644 --- a/internal/usecase/devices/consent.go +++ b/internal/usecase/devices/consent.go @@ -17,7 +17,7 @@ func (uc *UseCase) CancelUserConsent(c context.Context, guid string) (dto.UserCo return dto.UserConsentMessage{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.CancelUserConsentRequest() if err != nil { @@ -37,7 +37,7 @@ func (uc *UseCase) GetUserConsentCode(c context.Context, guid string) (dto.GetUs return dto.GetUserConsentMessage{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) code, err := device.GetUserConsentCode() if err != nil { @@ -64,7 +64,7 @@ func (uc *UseCase) SendConsentCode(c context.Context, userConsent dto.UserConsen return dto.UserConsentMessage{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) consentCode, _ := strconv.Atoi(userConsent.ConsentCode) diff --git a/internal/usecase/devices/features.go b/internal/usecase/devices/features.go index 4f99a6ce..00a1667e 100644 --- a/internal/usecase/devices/features.go +++ b/internal/usecase/devices/features.go @@ -42,7 +42,10 @@ func (uc *UseCase) GetFeatures(c context.Context, guid string) (settingsResults return settingsResults, settingsResultsV2, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) + if device == nil { + return dto.Features{}, dtov2.Features{}, ErrCancelled + } // Get redirection settings from AMT err = getRedirectionService(&settingsResultsV2, device) @@ -192,7 +195,7 @@ func (uc *UseCase) SetFeatures(c context.Context, guid string, features dto.Feat return settingsResults, settingsResultsV2, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) // redirection state, listenerEnabled, err := redirectionRequestStateChange(features.EnableSOL, features.EnableIDER, &settingsResultsV2, device) diff --git a/internal/usecase/devices/info.go b/internal/usecase/devices/info.go index bdd9cccc..10419e0d 100644 --- a/internal/usecase/devices/info.go +++ b/internal/usecase/devices/info.go @@ -21,7 +21,10 @@ func (uc *UseCase) GetVersion(c context.Context, guid string) (v1 dto.Version, v return v1, v2, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) + if device == nil { + return v1, v2, ErrCancelled + } softwareIdentity, err := device.GetAMTVersion() if err != nil { @@ -69,7 +72,7 @@ func (uc *UseCase) GetHardwareInfo(c context.Context, guid string) (interface{}, return nil, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) hwInfo, err := device.GetHardwareInfo() if err != nil { @@ -89,7 +92,7 @@ func (uc *UseCase) GetDiskInfo(c context.Context, guid string) (interface{}, err return nil, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) diskInfo, err := device.GetDiskInfo() if err != nil { @@ -109,7 +112,7 @@ func (uc *UseCase) GetAuditLog(c context.Context, startIndex int, guid string) ( return dto.AuditLog{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.GetAuditLog(startIndex) if err != nil { @@ -133,7 +136,7 @@ func (uc *UseCase) GetEventLog(c context.Context, startIndex, maxReadRecords int return dto.EventLogs{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) eventLogs, err := device.GetEventLog(startIndex, maxReadRecords) if err != nil { @@ -184,7 +187,10 @@ func (uc *UseCase) GetGeneralSettings(c context.Context, guid string) (interface return nil, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) + if device == nil { + return nil, ErrCancelled + } generalSettings, err := device.GetGeneralSettings() if err != nil { diff --git a/internal/usecase/devices/interfaces.go b/internal/usecase/devices/interfaces.go index b4b07b00..aed47a9f 100644 --- a/internal/usecase/devices/interfaces.go +++ b/internal/usecase/devices/interfaces.go @@ -15,7 +15,7 @@ import ( type ( WSMAN interface { - SetupWsmanClient(device entity.Device, isRedirection, logMessages bool) wsmanAPI.Management + SetupWsmanClient(ctx context.Context, device entity.Device, isRedirection, logMessages bool) wsmanAPI.Management DestroyWsmanClient(device dto.Device) Worker() } diff --git a/internal/usecase/devices/network.go b/internal/usecase/devices/network.go index f2ca48c7..8e6e1cdd 100644 --- a/internal/usecase/devices/network.go +++ b/internal/usecase/devices/network.go @@ -20,7 +20,7 @@ func (uc *UseCase) GetNetworkSettings(c context.Context, guid string) (dto.Netwo return dto.NetworkSettings{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.GetNetworkSettings() if err != nil { diff --git a/internal/usecase/devices/power.go b/internal/usecase/devices/power.go index e26fa61a..d479473b 100644 --- a/internal/usecase/devices/power.go +++ b/internal/usecase/devices/power.go @@ -23,7 +23,7 @@ func (uc *UseCase) SendPowerAction(c context.Context, guid string, action int) ( return power.PowerActionResponse{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) response, err := device.SendPowerAction(action) if err != nil { @@ -43,7 +43,7 @@ func (uc *UseCase) GetPowerState(c context.Context, guid string) (dto.PowerState return dto.PowerState{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) state, err := device.GetPowerState() if err != nil { @@ -65,7 +65,7 @@ func (uc *UseCase) GetPowerCapabilities(c context.Context, guid string) (dto.Pow return dto.PowerCapabilities{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) version, err := device.GetAMTVersion() if err != nil { @@ -137,7 +137,7 @@ func (uc *UseCase) SetBootOptions(c context.Context, guid string, bootSetting dt return power.PowerActionResponse{}, ErrNotFound } - device := uc.device.SetupWsmanClient(*item, false, true) + device := uc.device.SetupWsmanClient(c, *item, false, true) bootData, err := device.GetBootData() if err != nil { diff --git a/internal/usecase/devices/repo.go b/internal/usecase/devices/repo.go index 76a1702e..909071fc 100644 --- a/internal/usecase/devices/repo.go +++ b/internal/usecase/devices/repo.go @@ -16,6 +16,7 @@ var ( ErrDomainsUseCase = consoleerrors.CreateConsoleError("DevicesUseCase") ErrDatabase = sqldb.DatabaseError{Console: consoleerrors.CreateConsoleError("DevicesUseCase")} ErrNotFound = sqldb.NotFoundError{Console: consoleerrors.CreateConsoleError("DevicesUseCase")} + ErrCancelled = dto.CanceledError{Console: consoleerrors.CreateConsoleError("DevicesUseCase")} ) // History - getting translate history from store. diff --git a/internal/usecase/devices/wsman/message.go b/internal/usecase/devices/wsman/message.go index d897a451..3091ee34 100644 --- a/internal/usecase/devices/wsman/message.go +++ b/internal/usecase/devices/wsman/message.go @@ -1,6 +1,7 @@ package wsman import ( + "context" gotls "crypto/tls" "sync" "time" @@ -110,7 +111,7 @@ func (g GoWSMANMessages) Worker() { } } -func (g GoWSMANMessages) SetupWsmanClient(device entity.Device, isRedirection, logAMTMessages bool) Management { +func (g GoWSMANMessages) SetupWsmanClient(ctx context.Context, device entity.Device, isRedirection, logAMTMessages bool) Management { resultChan := make(chan *ConnectionEntry) // Queue the request requestQueue <- func() { @@ -118,7 +119,12 @@ func (g GoWSMANMessages) SetupWsmanClient(device entity.Device, isRedirection, l resultChan <- g.setupWsmanClientInternal(device, isRedirection, logAMTMessages) } - return <-resultChan + select { + case entry := <-resultChan: + return entry + case <-ctx.Done(): + return nil + } } func (g GoWSMANMessages) setupWsmanClientInternal(device entity.Device, isRedirection, logAMTMessages bool) *ConnectionEntry {