diff --git a/blockchain/binance/client.go b/blockchain/binance/client.go index 5cc327e..d0babf9 100644 --- a/blockchain/binance/client.go +++ b/blockchain/binance/client.go @@ -13,8 +13,8 @@ type Client struct { req client.Request } -func InitClient(url, apiKey string) Client { - request := client.InitJSONClient(url) +func InitClient(url, apiKey string, errorHandler client.HttpErrorHandler) Client { + request := client.InitJSONClient(url, errorHandler) request.AddHeader("apikey", apiKey) return Client{ req: request, diff --git a/client/client.go b/client/client.go index 2e79aea..ce5fab0 100644 --- a/client/client.go +++ b/client/client.go @@ -15,36 +15,37 @@ import ( ) type Request struct { - BaseUrl string - Headers map[string]string - HttpClient *http.Client - ErrorHandler func(res *http.Response, uri string) error + BaseUrl string + Headers map[string]string + HttpClient *http.Client + HttpErrorHandler HttpErrorHandler } +type HttpErrorHandler func(res *http.Response, uri string) error + func (r *Request) SetTimeout(seconds time.Duration) { r.HttpClient.Timeout = time.Second * seconds } -func InitClient(baseUrl string) Request { +func InitClient(baseUrl string, errorHandler HttpErrorHandler) Request { + if errorHandler == nil { + errorHandler = DefaultErrorHandler + } return Request{ - Headers: make(map[string]string), - HttpClient: DefaultClient, - ErrorHandler: DefaultErrorHandler, - BaseUrl: baseUrl, + Headers: make(map[string]string), + HttpClient: DefaultClient, + HttpErrorHandler: errorHandler, + BaseUrl: baseUrl, } } -func InitJSONClient(baseUrl string) Request { - headers := map[string]string{ +func InitJSONClient(baseUrl string, errorHandler HttpErrorHandler) Request { + client := InitClient(baseUrl, errorHandler) + client.Headers = map[string]string{ "Content-Type": "application/json", "Accept": "application/json", } - return Request{ - Headers: headers, - HttpClient: DefaultClient, - ErrorHandler: DefaultErrorHandler, - BaseUrl: baseUrl, - } + return client } var DefaultClient = &http.Client{ @@ -112,7 +113,7 @@ func (r *Request) Execute(method string, url string, body io.Reader, result inte return err } - err = r.ErrorHandler(res, url) + err = r.HttpErrorHandler(res, url) if err != nil { return err } diff --git a/client/client_test.go b/client/client_test.go index d390ade..0d24d9c 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -57,7 +57,7 @@ func TestRequest_GetBase(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := InitClient(tt.fields.baseUrl) + r := InitClient(tt.fields.baseUrl, nil) if got := r.GetBase(tt.path); got != tt.want { t.Errorf("Request.GetBase() = %v, want %v", got, tt.want) } diff --git a/mock/mock_test.go b/mock/mock_test.go index 4336a71..98c668d 100644 --- a/mock/mock_test.go +++ b/mock/mock_test.go @@ -26,7 +26,7 @@ func TestCreateMockedAPI(t *testing.T) { server := httptest.NewServer(CreateMockedAPI(data)) defer server.Close() - client := client.InitClient(server.URL) + client := client.InitClient(server.URL, nil) var resp response err := client.Get(&resp, "1", nil)