diff --git a/queue_manager.go b/queue_manager.go index 61d5d99..2f78513 100644 --- a/queue_manager.go +++ b/queue_manager.go @@ -5,7 +5,6 @@ import ( "encoding/xml" "errors" "io/ioutil" - "net/http" "strings" "time" @@ -262,7 +261,7 @@ func (qm *QueueManager) Delete(ctx context.Context, name string) error { res, err := qm.entityManager.Delete(ctx, "/"+name) defer closeRes(ctx, res) - return err + return checkForError(ctx, err, res) } // Put creates or updates a Service Bus Queue @@ -309,8 +308,7 @@ func (qm *QueueManager) Put(ctx context.Context, name string, opts ...QueueManag res, err := qm.entityManager.Put(ctx, "/"+name, reqBytes, mw...) defer closeRes(ctx, res) - if err != nil { - tab.For(ctx).Error(err) + if err := checkForError(ctx, err, res); err != nil { return nil, err } @@ -346,8 +344,7 @@ func (qm *QueueManager) List(ctx context.Context, options ...ListQueuesOption) ( res, err := qm.entityManager.Get(ctx, basePath) defer closeRes(ctx, res) - if err != nil { - tab.For(ctx).Error(err) + if err := checkForError(ctx, err, res); err != nil { return nil, err } @@ -378,15 +375,10 @@ func (qm *QueueManager) Get(ctx context.Context, name string) (*QueueEntity, err res, err := qm.entityManager.Get(ctx, name) defer closeRes(ctx, res) - if err != nil { - tab.For(ctx).Error(err) + if err := checkForError(ctx, err, res); err != nil { return nil, err } - if res.StatusCode == http.StatusNotFound { - return nil, ErrNotFound{EntityPath: res.Request.URL.Path} - } - b, err := ioutil.ReadAll(res.Body) if err != nil { tab.For(ctx).Error(err) diff --git a/subscription_manager.go b/subscription_manager.go index 2a293a6..60d0382 100644 --- a/subscription_manager.go +++ b/subscription_manager.go @@ -213,7 +213,7 @@ func (sm *SubscriptionManager) Delete(ctx context.Context, name string) error { res, err := sm.entityManager.Delete(ctx, sm.getResourceURI(name)) defer closeRes(ctx, res) - return err + return checkForError(ctx, err, res) } // Put creates or updates a Service Bus Topic @@ -260,7 +260,7 @@ func (sm *SubscriptionManager) Put(ctx context.Context, name string, opts ...Sub res, err := sm.entityManager.Put(ctx, sm.getResourceURI(name), reqBytes, mw...) defer closeRes(ctx, res) - if err != nil { + if err := checkForError(ctx, err, res); err != nil { return nil, err } @@ -295,7 +295,7 @@ func (sm *SubscriptionManager) List(ctx context.Context, options ...ListSubscrip res, err := sm.entityManager.Get(ctx, basePath) defer closeRes(ctx, res) - if err != nil { + if err := checkForError(ctx, err, res); err != nil { return nil, err } @@ -325,14 +325,10 @@ func (sm *SubscriptionManager) Get(ctx context.Context, name string) (*Subscript res, err := sm.entityManager.Get(ctx, sm.getResourceURI(name)) defer closeRes(ctx, res) - if err != nil { + if err := checkForError(ctx, err, res); err != nil { return nil, err } - if res.StatusCode == http.StatusNotFound { - return nil, ErrNotFound{EntityPath: res.Request.URL.Path} - } - b, err := ioutil.ReadAll(res.Body) if err != nil { return nil, err @@ -361,14 +357,10 @@ func (sm *SubscriptionManager) ListRules(ctx context.Context, subscriptionName s res, err := sm.entityManager.Get(ctx, sm.getRulesResourceURI(subscriptionName)) defer closeRes(ctx, res) - if err != nil { + if err := checkForError(ctx, err, res); err != nil { return nil, err } - if res.StatusCode == http.StatusNotFound { - return nil, ErrNotFound{EntityPath: res.Request.URL.Path} - } - b, err := ioutil.ReadAll(res.Body) if err != nil { return nil, err @@ -473,7 +465,7 @@ func (sm *SubscriptionManager) DeleteRule(ctx context.Context, subscriptionName, res, err := sm.entityManager.Delete(ctx, sm.getRuleResourceURI(subscriptionName, ruleName)) defer closeRes(ctx, res) - return err + return checkForError(ctx, err, res) } func ruleEntryToEntity(entry *ruleEntry) *RuleEntity { diff --git a/subscription_test.go b/subscription_test.go index 055be91..5932715 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -26,6 +26,7 @@ import ( "context" "encoding/xml" "fmt" + "os" "strings" "sync" "testing" @@ -33,6 +34,7 @@ import ( "github.com/Azure/azure-amqp-common-go/v3/uuid" "github.com/Azure/azure-sdk-for-go/services/servicebus/mgmt/2015-08-01/servicebus" + "github.com/joho/godotenv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -458,7 +460,12 @@ func (suite *serviceBusSuite) testSubscriptionManager(tests map[string]func(cont defer func(sName string) { ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) defer cancel() - if !suite.NoError(sm.Delete(ctx, sName)) { + + err = sm.Delete(ctx, sName) + + if !IsErrNotFound(err) && !suite.NoError(err) { + // not all tests actually create a subscription (some of these tests are + // basically unittests) suite.Fail(err.Error()) } }(subName) @@ -785,3 +792,130 @@ func checkZeroSubscriptionMessages(ctx context.Context, t *testing.T, topic *Top assert.Fail(t, "message count never reached zero") } + +func TestErrorMessagesWithMissingPrivileges(t *testing.T) { + _ = godotenv.Load() + + // we're obscuring the HTTP errors coming back from the service and + // we shouldn't. Just testing some of the common scenarios with a + // connection string that lacks Manage privileges. + lowPrivCS := os.Getenv("SERVICEBUS_CONNECTION_STRING_NO_MANAGE") + normalCS := os.Getenv("SERVICEBUS_CONNECTION_STRING") + + if lowPrivCS == "" || normalCS == "" { + t.Skip("Need both SERVICEBUS_CONNECTION_STRING_NO_MANAGE and SERVICEBUS_CONNECTION_STRING") + } + + nanoSeconds := time.Now().UnixNano() + + topicName := fmt.Sprintf("topic-%d", nanoSeconds) + queueName := fmt.Sprintf("queue-%d", nanoSeconds) + subName := "subscription1" + ruleName := "rule" + + // create some entities that we need (there's a diff between something not being + // found and something failing because of lack of authorization) + cleanup := func() func() { + ns, err := NewNamespace(NamespaceWithConnectionString(normalCS)) + qm := ns.NewQueueManager() + + _, err = qm.Put(context.Background(), queueName) + require.NoError(t, err) + + tm := ns.NewTopicManager() + _, err = tm.Put(context.Background(), topicName) + require.NoError(t, err) + + sm, err := ns.NewSubscriptionManager(topicName) + require.NoError(t, err) + + _, err = sm.Put(context.Background(), subName) + require.NoError(t, err) + + _, err = sm.PutRule(context.Background(), subName, ruleName, TrueFilter{}) + require.NoError(t, err) + + return func() { + require.NoError(t, tm.Delete(context.Background(), topicName)) // should delete the subscription + require.NoError(t, qm.Delete(context.Background(), queueName)) + } + }() + defer cleanup() + + ns, err := NewNamespace(NamespaceWithConnectionString(lowPrivCS)) + require.NoError(t, err) + + ctx := context.Background() + wg := sync.WaitGroup{} + wg.Add(3) + + go func() { + defer wg.Done() + + qm := ns.NewQueueManager() + + _, err = qm.Get(ctx, "not-found-queue") + require.True(t, IsErrNotFound(err)) + + _, err = qm.Get(ctx, queueName) + require.EqualError(t, err, "request failed: 401 Unauthorized") + + _, err = qm.List(ctx) + require.EqualError(t, err, "request failed: 401 Unauthorized") + + _, err = qm.Put(ctx, "canneverbecreated") + require.EqualError(t, err, "request failed: 401 Unauthorized") + + err = qm.Delete(ctx, queueName) + require.EqualError(t, err, "request failed: 401 Unauthorized") + }() + + go func() { + defer wg.Done() + + tm := ns.NewTopicManager() + + _, err = tm.Get(ctx, "not-found-topic") + require.True(t, IsErrNotFound(err)) + + _, err = tm.Get(ctx, topicName) + require.EqualError(t, err, "request failed: 401 Unauthorized") + + _, err = tm.Put(ctx, "canneverbecreated") + require.Contains(t, err.Error(), "error code: 401, Details: Authorization failed for specified action") + + _, err = tm.List(ctx) + require.Contains(t, err.Error(), "error code: 401, Details: Manage,EntityRead claims required for this operation") + + err = tm.Delete(ctx, topicName) + require.Contains(t, err.Error(), "request failed: 401 Unauthorized") + }() + + go func() { + defer wg.Done() + + sm, err := ns.NewSubscriptionManager(topicName) + require.NoError(t, err) + + _, err = sm.Get(ctx, "not-found-subscription") + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + _, err = sm.Get(ctx, subName) + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + _, err = sm.Put(ctx, subName) + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + err = sm.Delete(ctx, subName) + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + _, err = sm.List(ctx) + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + _, err = sm.ListRules(ctx, subName) + require.Contains(t, err.Error(), "request failed: 401 SubCode=40100: Unauthorized") + + }() + + wg.Wait() +} diff --git a/topic_manager.go b/topic_manager.go index da86f0f..567edde 100644 --- a/topic_manager.go +++ b/topic_manager.go @@ -4,6 +4,7 @@ import ( "context" "encoding/xml" "errors" + "fmt" "io/ioutil" "net/http" "time" @@ -93,7 +94,7 @@ func (tm *TopicManager) Delete(ctx context.Context, name string) error { res, err := tm.entityManager.Delete(ctx, "/"+name) defer closeRes(ctx, res) - return err + return checkForError(ctx, err, res) } // Put creates or updates a Service Bus Topic @@ -200,15 +201,10 @@ func (tm *TopicManager) Get(ctx context.Context, name string) (*TopicEntity, err res, err := tm.entityManager.Get(ctx, name) defer closeRes(ctx, res) - if err != nil { - tab.For(ctx).Error(err) + if err := checkForError(ctx, err, res); err != nil { return nil, err } - if res.StatusCode == http.StatusNotFound { - return nil, ErrNotFound{EntityPath: res.Request.URL.Path} - } - b, err := ioutil.ReadAll(res.Body) if err != nil { tab.For(ctx).Error(err) @@ -322,3 +318,25 @@ func TopicWithMessageTimeToLive(window *time.Duration) TopicManagementOption { return nil } } + +func checkForError(ctxForLogging context.Context, err error, res *http.Response) error { + if err != nil { + tab.For(ctxForLogging).Error(err) + return err + } + + // check the response as well + if res.StatusCode == http.StatusNotFound { + err := ErrNotFound{EntityPath: res.Request.URL.Path} + tab.For(ctxForLogging).Error(err) + return err + } + + if res.StatusCode >= 400 { + err := fmt.Errorf("request failed: %s", res.Status) + tab.For(ctxForLogging).Error(err) + return err + } + + return nil +}