Skip to content

Commit 95c7ea6

Browse files
authored
Merge pull request #55 from Azure/aaqib-m/rate-limited-client
feat: client side rate limiting with IMDS
2 parents 3c37f33 + 9cbb334 commit 95c7ea6

File tree

6 files changed

+72
-17
lines changed

6 files changed

+72
-17
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/Azure/msi-acrpull
22

3-
go 1.20
3+
go 1.21
44

55
require (
66
github.com/go-logr/logr v1.2.4

pkg/authorizer/client.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package authorizer
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
8+
"golang.org/x/time/rate"
9+
)
10+
11+
const (
12+
defaultRPS = 1
13+
defaultBurst = 5
14+
)
15+
16+
type rateLimitedClient struct {
17+
httpClient *http.Client
18+
rateLimiter *rate.Limiter
19+
}
20+
21+
func newRateLimitedClient() *rateLimitedClient {
22+
return newRateLimitedClientWithRPS(defaultRPS, defaultBurst)
23+
}
24+
25+
func newRateLimitedClientWithRPS(rps float64, burst int) *rateLimitedClient {
26+
client := &rateLimitedClient{
27+
httpClient: http.DefaultClient,
28+
rateLimiter: rate.NewLimiter(rate.Limit(rps), burst),
29+
}
30+
return client
31+
}
32+
33+
func (client *rateLimitedClient) Do(req *http.Request) (*http.Response, error) {
34+
ctx := context.Background()
35+
err := client.rateLimiter.Wait(ctx)
36+
if err != nil {
37+
return nil, fmt.Errorf("failed to wait for rate limit token: %w", err)
38+
}
39+
40+
resp, err := client.httpClient.Do(req)
41+
if err != nil {
42+
return nil, err
43+
}
44+
return resp, nil
45+
}

pkg/authorizer/token_exchanger.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ import (
1515
// TokenExchanger is an instance of ACRTokenExchanger
1616
type TokenExchanger struct {
1717
acrServerScheme string
18+
client *rateLimitedClient
1819
}
1920

2021
// NewTokenExchanger returns a new token exchanger
2122
func NewTokenExchanger() *TokenExchanger {
2223
return &TokenExchanger{
2324
acrServerScheme: "https",
25+
client: newRateLimitedClient(),
2426
}
2527
}
2628

@@ -55,11 +57,10 @@ func (te *TokenExchanger) ExchangeACRAccessToken(armToken types.AccessToken, acr
5557
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
5658
req.Header.Add("Content-Length", strconv.Itoa(len(parameters.Encode())))
5759

58-
client := &http.Client{}
5960
var resp *http.Response
6061
defer closeResponse(resp)
6162

62-
resp, err = client.Do(req)
63+
resp, err = te.client.Do(req)
6364
if err != nil {
6465
return "", fmt.Errorf("failed to send token exchange request: %w", err)
6566
}

pkg/authorizer/token_exchanger_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ var _ = Describe("Token Exchanger Tests", func() {
4747
ghttp.RespondWithJSONEncoded(200, tokenResp),
4848
))
4949

50-
te := newTestTokenExchanger()
50+
te := newTestTokenExchanger(server)
5151
token, err := te.ExchangeACRAccessToken(armToken, ul.Host)
5252

5353
Expect(err).To(BeNil())
@@ -73,7 +73,7 @@ var _ = Describe("Token Exchanger Tests", func() {
7373
ghttp.RespondWith(403, "Unauthorized"),
7474
))
7575

76-
te := newTestTokenExchanger()
76+
te := newTestTokenExchanger(server)
7777
token, err := te.ExchangeACRAccessToken(armToken, ul.Host)
7878

7979
Expect(err).NotTo(BeNil())
@@ -85,8 +85,12 @@ var _ = Describe("Token Exchanger Tests", func() {
8585
})
8686
})
8787

88-
func newTestTokenExchanger() *TokenExchanger {
88+
func newTestTokenExchanger(server *ghttp.Server) *TokenExchanger {
89+
client := newRateLimitedClient()
90+
client.httpClient = server.HTTPTestServer.Client()
91+
8992
return &TokenExchanger{
9093
acrServerScheme: "http",
94+
client: client,
9195
}
9296
}

pkg/authorizer/token_retriever.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type TokenRetriever struct {
2626
metadataEndpoint string
2727
cache sync.Map
2828
cacheExpiration time.Duration
29+
client *rateLimitedClient
2930
}
3031

3132
type cachedToken struct {
@@ -39,6 +40,7 @@ func NewTokenRetriever() *TokenRetriever {
3940
metadataEndpoint: msiMetadataEndpoint,
4041
cache: sync.Map{},
4142
cacheExpiration: time.Duration(defaultCacheExpirationInSeconds) * time.Second,
43+
client: newRateLimitedClient(),
4244
}
4345
}
4446

@@ -98,11 +100,10 @@ func (tr *TokenRetriever) refreshToken(clientID, resourceID string) (types.Acces
98100
}
99101
req.Header.Add("Metadata", "true")
100102

101-
client := &http.Client{}
102103
var resp *http.Response
103104
defer closeResponse(resp)
104105

105-
resp, err = client.Do(req)
106+
resp, err = tr.client.Do(req)
106107
if err != nil {
107108
return "", fmt.Errorf("failed to send metadata endpoint request: %w", err)
108109
}

pkg/authorizer/token_retriever_test.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ var _ = Describe("Token Retriever Tests", func() {
3838
ghttp.RespondWithJSONEncoded(200, tokenResp),
3939
))
4040

41-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
41+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
4242
token, err := tr.AcquireARMToken("", testResourceID)
4343

4444
Expect(err).To(BeNil())
@@ -60,7 +60,7 @@ var _ = Describe("Token Retriever Tests", func() {
6060
ghttp.RespondWithJSONEncoded(200, tokenResp),
6161
))
6262

63-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
63+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
6464
token, err := tr.AcquireARMToken("", testResourceID)
6565

6666
os.Unsetenv(customARMResourceEnvVar)
@@ -82,7 +82,7 @@ var _ = Describe("Token Retriever Tests", func() {
8282
ghttp.RespondWithJSONEncoded(200, tokenResp),
8383
))
8484

85-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
85+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
8686
token, err := tr.AcquireARMToken(testClientID, "")
8787

8888
Expect(err).To(BeNil())
@@ -97,7 +97,7 @@ var _ = Describe("Token Retriever Tests", func() {
9797
ghttp.RespondWith(404, ""),
9898
))
9999

100-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
100+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
101101
token, err := tr.AcquireARMToken(testClientID, "")
102102

103103
Expect(err).NotTo(BeNil())
@@ -118,7 +118,7 @@ var _ = Describe("Token Retriever Tests", func() {
118118
ghttp.RespondWithJSONEncoded(200, tokenResp),
119119
))
120120

121-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds*1000)
121+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds*1000)
122122
token, err := tr.AcquireARMToken(testClientID, "")
123123
Expect(err).To(BeNil())
124124
Expect(token).To(Equal(armToken))
@@ -142,7 +142,7 @@ var _ = Describe("Token Retriever Tests", func() {
142142
ghttp.RespondWithJSONEncoded(200, tokenResp),
143143
))
144144

145-
tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds*1000)
145+
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds*1000)
146146
token, err := tr.AcquireARMToken("", testResourceID)
147147
Expect(err).To(BeNil())
148148
Expect(token).To(Equal(armToken))
@@ -171,7 +171,7 @@ var _ = Describe("Token Retriever Tests", func() {
171171
))
172172

173173
// set cache expire immediately
174-
tr := newTestTokenRetriever(server.URL(), 0)
174+
tr := newTestTokenRetriever(server, 0)
175175
token, err := tr.AcquireARMToken(testClientID, "")
176176
Expect(err).To(BeNil())
177177
Expect(token).To(Equal(armToken))
@@ -185,10 +185,14 @@ var _ = Describe("Token Retriever Tests", func() {
185185
})
186186
})
187187

188-
func newTestTokenRetriever(metadataEndpoint string, cacheExpirationInMilliSeconds int) *TokenRetriever {
188+
func newTestTokenRetriever(server *ghttp.Server, cacheExpirationInMilliSeconds int) *TokenRetriever {
189+
client := newRateLimitedClient()
190+
client.httpClient = server.HTTPTestServer.Client()
191+
189192
return &TokenRetriever{
190-
metadataEndpoint: metadataEndpoint,
193+
metadataEndpoint: server.URL(),
191194
cache: sync.Map{},
192195
cacheExpiration: time.Duration(cacheExpirationInMilliSeconds) * time.Millisecond,
196+
client: client,
193197
}
194198
}

0 commit comments

Comments
 (0)