From 112aca392283473ec0b135f1db422a6e9c26d659 Mon Sep 17 00:00:00 2001 From: xfali Date: Mon, 23 Aug 2021 23:45:32 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- media_type.go | 6 +- v2/default_restclient.go | 5 +- v2/media_type.go | 6 +- v2/test/restclient_test.go | 119 ++++++++++++++++++++++++++++++++++--- 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/media_type.go b/media_type.go index 18554c4..acc8c42 100644 --- a/media_type.go +++ b/media_type.go @@ -83,6 +83,10 @@ func (t *MediaType) isWildcardInnerSub() bool { return false } +func (t *MediaType) subEqual(o MediaType) bool { + return strings.Index(o.sub, t.sub) == 0 +} + func (t *MediaType) Includes(o MediaType) bool { if t.IsWildcard() { return true @@ -92,7 +96,7 @@ func (t *MediaType) Includes(o MediaType) bool { return true } - if t.sub == o.sub { + if t.subEqual(o) { return true } diff --git a/v2/default_restclient.go b/v2/default_restclient.go index 6e3b529..41e4ad2 100644 --- a/v2/default_restclient.go +++ b/v2/default_restclient.go @@ -167,10 +167,7 @@ func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) error { } // 处理response - err = c.processResponse(response, param.result) - if err != nil { - return nil - } + return c.processResponse(response, param.result) } return nil } diff --git a/v2/media_type.go b/v2/media_type.go index 18554c4..acc8c42 100644 --- a/v2/media_type.go +++ b/v2/media_type.go @@ -83,6 +83,10 @@ func (t *MediaType) isWildcardInnerSub() bool { return false } +func (t *MediaType) subEqual(o MediaType) bool { + return strings.Index(o.sub, t.sub) == 0 +} + func (t *MediaType) Includes(o MediaType) bool { if t.IsWildcard() { return true @@ -92,7 +96,7 @@ func (t *MediaType) Includes(o MediaType) bool { return true } - if t.sub == o.sub { + if t.subEqual(o) { return true } diff --git a/v2/test/restclient_test.go b/v2/test/restclient_test.go index ce21d99..ae6a94b 100644 --- a/v2/test/restclient_test.go +++ b/v2/test/restclient_test.go @@ -7,6 +7,7 @@ package test import ( "context" + "encoding/json" "fmt" "github.com/xfali/restclient/v2" "github.com/xfali/restclient/v2/filter" @@ -24,6 +25,13 @@ func init() { time.Sleep(time.Second) } +type testStruct struct { + Id int64 + Name string + Value float64 + CreateTime time.Time +} + func startHttpServer(shutdown time.Duration) { http.HandleFunc("/auth", func(writer http.ResponseWriter, request *http.Request) { v := request.Header.Get(restutil.HeaderAuthorization) @@ -93,6 +101,33 @@ func startHttpServer(shutdown time.Duration) { } }) + http.HandleFunc("/struct", func(writer http.ResponseWriter, request *http.Request) { + switch request.Method { + case http.MethodGet: + body := request.Body + if body != nil { + defer body.Close() + } + d, _ := json.Marshal(testStruct{ + Id: 1, + Name: "test", + Value: 3.1415926, + CreateTime: time.Now(), + }) + writer.Header().Set(restutil.HeaderContentType, restclient.MediaTypeJson) + writer.Write(d) + break + case http.MethodPost: + body := request.Body + writer.Header().Set(restutil.HeaderContentType, restclient.MediaTypeJson) + if body != nil { + defer body.Close() + io.Copy(writer, body) + } + break + } + }) + server := &http.Server{Addr: ":8080", Handler: nil} go server.ListenAndServe() @@ -134,7 +169,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -153,7 +188,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -171,7 +206,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -190,7 +225,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -208,7 +243,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -226,7 +261,7 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200") } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) } t.Log(ret) }) @@ -242,7 +277,77 @@ func TestRequest(t *testing.T) { if resp.StatusCode != http.StatusOK { t.Fatal("not 200 ", resp.StatusCode) } else { - t.Log( resp.StatusCode) + t.Log(resp.StatusCode) + } + t.Log(ret) + }) +} + +func TestStruct(t *testing.T) { + client := restclient.New(restclient.AddIFilter(filter.NewLog(xlog.GetLogger(), ""))) + t.Run("Get", func(t *testing.T) { + ret := testStruct{} + resp := new(http.Response) + err := client.Exchange("http://localhost:8080/struct", + request.WithResult(&ret), + request.WithResponse(resp, false)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(resp.StatusCode) + } + if ret.Id != 1 { + t.Fatal("expect id 1 but get ", ret.Id) + } + t.Log(ret) + }) + + t.Run("Get func", func(t *testing.T) { + resp := new(http.Response) + err := client.Exchange("http://localhost:8080/struct", + request.WithResult(func(ret testStruct) { + if ret.Id != 1 { + t.Fatal("expect id 1 but get ", ret.Id) + } + t.Log(ret) + }), + request.WithResponse(resp, false)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(resp.StatusCode) + } + }) + + t.Run("Post", func(t *testing.T) { + ret := testStruct{ + Id: 2, + Name: "test2", + Value: 1.0, + CreateTime: time.Now(), + } + resp := new(http.Response) + err := client.Exchange("http://localhost:8080/struct", + request.MethodPost(), + request.WithResult(&ret), + request.WithResponse(resp, false), + request.WithRequestBody(ret)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(resp.StatusCode) + } + if ret.Id != 2 { + t.Fatal("expect id 2 but get ", ret.Id) } t.Log(ret) }) From 51d37ec9dedbc48fdb283bb4d7eb2910d9350535 Mon Sep 17 00:00:00 2001 From: xfali Date: Mon, 23 Aug 2021 23:46:10 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E8=BF=94=E5=9B=9EError=EF=BC=8C=E5=8C=85?= =?UTF-8?q?=E5=90=ABhttp=20status=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- v2/default_restclient.go | 118 +++++++++++++++++++++++------------- v2/errors.go | 58 ++++++++++++++++++ v2/init_opts.go | 9 ++- v2/restclient.go | 2 +- v2/test/restclient_test.go | 119 +++++++++++++++++++++++++++++++++++++ 5 files changed, 263 insertions(+), 43 deletions(-) create mode 100644 v2/errors.go diff --git a/v2/default_restclient.go b/v2/default_restclient.go index 41e4ad2..d534257 100644 --- a/v2/default_restclient.go +++ b/v2/default_restclient.go @@ -23,12 +23,22 @@ import ( ) type AcceptFlag int +type ResponseBodyFlag int const ( - DefaultTimeout = 0 - AcceptUserOnly AcceptFlag = 1 + DefaultTimeout = 0 + + // 仅接受用户指定的header + AcceptUserOnly AcceptFlag = 1 + // 根据Converter支持的类型,自动添加第一个支持的类型(默认) AcceptAutoFirst AcceptFlag = 1 << 1 - AcceptAutoAll AcceptFlag = 1 << 2 + // 根据Converter支持的类型,自动添加所有支持的类型 + AcceptAutoAll AcceptFlag = 1 << 2 + + // 处理所有的response body + ResponseBodyAll ResponseBodyFlag = 1 + // 不处理http status 400及以上的response的body + ResponseBodyIgnoreBad ResponseBodyFlag = 1 << 1 ) var ( @@ -91,7 +101,8 @@ type defaultRestClient struct { filterManager filter.FilterManager pool buffer.Pool - autoAccept AcceptFlag + acceptFlag AcceptFlag + respFlag ResponseBodyFlag transport http.RoundTripper timeout time.Duration } @@ -104,7 +115,8 @@ func New(opts ...Opt) *defaultRestClient { converters: defaultConverters, pool: buffer.NewPool(), timeout: DefaultTimeout, - autoAccept: AcceptAutoFirst, + acceptFlag: AcceptAutoFirst, + respFlag: ResponseBodyAll, } ret.filterManager.Add(ret.filter) for _, opt := range opts { @@ -114,25 +126,28 @@ func New(opts ...Opt) *defaultRestClient { return ret } -func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) error { +func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) Error { param := emptyParam() for _, opt := range opts { opt(param) } - r, err := c.processRequest(param.reqBody, param.header) + // 序列化request body + r, err := c.encodeRequest(param.reqBody, param.header) if r != nil { defer r.Close() } if err != nil { - return err + return withErr(DefaultErrorStatus, err) } nilResult := reflection.IsNil(param.result) if !nilResult { + // 根据反序列化response body的目的result类型添加header Accept param.header = c.addAccept(param.result, param.header) } + // 创建http.Request req := defaultRequestCreator(param.ctx, param.method, url, r, param.header) fm := c.filterManager if param.filterManager.Valid() { @@ -140,36 +155,10 @@ func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) error { } response, err := fm.RunFilter(req) if err != nil { - return err + return withErr(DefaultErrorStatus, err) } - if response.Body != nil { - defer response.Body.Close() - // need response - if param.response != nil { - copyResponse(param.response, response) - // need response's body - if param.respFlag { - // get buffer form pool - buf := buffer.NewReadWriteCloser(c.pool) - // 封装reader,在读取response body数据时写入到buffer中 - reader := buffer.NewMergeReaderWriter(response.Body, buf) - // 替换response body - response.Body = ioutil.NopCloser(reader) - // 调用者response设置body为buffer - param.response.Body = buf - } - } - if nilResult { - // 如果用户没设置result,则直接读取body到discard - _, err = io.Copy(ioutil.Discard, response.Body) - return err - } - - // 处理response - return c.processResponse(response, param.result) - } - return nil + return c.processResponse(response, param, nilResult) } func (c *defaultRestClient) filter(request *http.Request, fc filter.FilterChain) (*http.Response, error) { @@ -206,7 +195,7 @@ func (c *defaultRestClient) newClient() *http.Client { } } -func (c *defaultRestClient) processRequest(requestBody interface{}, header http.Header) (io.ReadCloser, error) { +func (c *defaultRestClient) encodeRequest(requestBody interface{}, header http.Header) (io.ReadCloser, error) { if requestBody != nil { mtStr := getContentMediaType(header) mediaType := ParseMediaType(mtStr) @@ -232,7 +221,54 @@ func (c *defaultRestClient) processRequest(requestBody interface{}, header http. return nil, nil } -func (c *defaultRestClient) processResponse(resp *http.Response, result interface{}) error { +func (c *defaultRestClient) processResponse(response *http.Response, param *defaultParam, nilResult bool) Error { + errStatus := response.StatusCode + if response.StatusCode < http.StatusBadRequest { + errStatus = DefaultErrorStatus + } else if c.respFlag == ResponseBodyIgnoreBad { + return withStatus(response.StatusCode) + } + + if response.Body != nil { + defer response.Body.Close() + // need response + if param.response != nil { + copyResponse(param.response, response) + // need response's body + if param.respFlag { + // get buffer form pool + buf := buffer.NewReadWriteCloser(c.pool) + // 封装reader,在读取response body数据时写入到buffer中 + reader := buffer.NewMergeReaderWriter(response.Body, buf) + // 替换response body + response.Body = ioutil.NopCloser(reader) + // 调用者response设置body为buffer + param.response.Body = buf + } + } + if nilResult { + // 如果用户没设置result,则直接读取body到discard + _, err := io.Copy(ioutil.Discard, response.Body) + if err != nil { + return withErr(errStatus, err) + } + } else { + // 处理response + err := c.decodeResponse(response, param.result) + if err != nil { + return withErr(errStatus, err) + } + } + } + + if response.StatusCode >= http.StatusBadRequest { + return withStatus(response.StatusCode) + } + + return nil +} + +func (c *defaultRestClient) decodeResponse(resp *http.Response, result interface{}) error { mediaType := getResponseMediaType(resp) t := reflect.TypeOf(result) if t.Kind() != reflect.Func { @@ -281,7 +317,7 @@ func (c *defaultRestClient) addAccept(result interface{}, header http.Header) ht mt := ParseMediaType(userAccept) typeMap := map[string]bool{} var acceptList []string - if c.autoAccept != AcceptUserOnly { + if c.acceptFlag != AcceptUserOnly { index := len(c.converters) for index > 0 { index-- @@ -294,14 +330,14 @@ func (c *defaultRestClient) addAccept(result interface{}, header http.Header) ht if _, have := typeMap[mtStr]; !have { acceptList = append(acceptList, mtStr) typeMap[mtStr] = true - if c.autoAccept == AcceptAutoFirst { + if c.acceptFlag == AcceptAutoFirst { break } } } } } - if c.autoAccept == AcceptAutoFirst && len(typeMap) > 0 { + if c.acceptFlag == AcceptAutoFirst && len(typeMap) > 0 { break } } diff --git a/v2/errors.go b/v2/errors.go new file mode 100644 index 0000000..5a7a43c --- /dev/null +++ b/v2/errors.go @@ -0,0 +1,58 @@ +// Copyright (C) 2019-2021, Xiongfa Li. +// @author xiongfa.li +// @version V1.0 +// Description: + +package restclient + +import ( + "fmt" + "net/http" +) + +var DefaultErrorStatus = http.StatusBadRequest + +type Error interface { + error + + // 获得http status code + StatusCode() int + + // 获得原始error + Origin() error +} + +type defaultError struct { + status int + err error +} + +func withErr(status int, err error) defaultError { + return defaultError{ + status: status, + err: err, + } +} + +func withStatus(status int) defaultError { + return defaultError{ + status: status, + err: fmt.Errorf("restclient status: [%d] %s", status, http.StatusText(status)), + } +} + +func (e defaultError) Origin() error { + return e.err +} + +func (e defaultError) Error() string { + if e.err != nil { + return e.err.Error() + } else { + return "" + } +} + +func (e defaultError) StatusCode() int { + return e.status +} diff --git a/v2/init_opts.go b/v2/init_opts.go index 5151324..b68ab4e 100644 --- a/v2/init_opts.go +++ b/v2/init_opts.go @@ -43,7 +43,14 @@ func SetRoundTripper(tripper http.RoundTripper) func(client *defaultRestClient) // 配置是否自动添加accept func SetAutoAccept(v AcceptFlag) func(client *defaultRestClient) { return func(client *defaultRestClient) { - client.autoAccept = v + client.acceptFlag = v + } +} + +// 配置是否自动添加accept +func SetResponseBodyFlag(v ResponseBodyFlag) func(client *defaultRestClient) { + return func(client *defaultRestClient) { + client.respFlag = v } } diff --git a/v2/restclient.go b/v2/restclient.go index fce209b..d93a5c6 100644 --- a/v2/restclient.go +++ b/v2/restclient.go @@ -11,5 +11,5 @@ type RestClient interface { // 发起请求 // url:请求路径 // params:请求参数,见ex_params.go具体定义 - Exchange(url string, opts ...request.Opt) error + Exchange(url string, opts ...request.Opt) Error } diff --git a/v2/test/restclient_test.go b/v2/test/restclient_test.go index ae6a94b..ad01c5f 100644 --- a/v2/test/restclient_test.go +++ b/v2/test/restclient_test.go @@ -128,6 +128,35 @@ func startHttpServer(shutdown time.Duration) { } }) + http.HandleFunc("/error", func(writer http.ResponseWriter, request *http.Request) { + switch request.Method { + case http.MethodGet: + body := request.Body + if body != nil { + defer body.Close() + } + d, _ := json.Marshal(testStruct{ + Id: 1, + Name: "test", + Value: 3.1415926, + CreateTime: time.Now(), + }) + writer.Header().Set(restutil.HeaderContentType, restclient.MediaTypeJson) + writer.WriteHeader(http.StatusBadRequest) + writer.Write(d) + break + case http.MethodPost: + body := request.Body + writer.Header().Set(restutil.HeaderContentType, restclient.MediaTypeJson) + writer.WriteHeader(http.StatusBadRequest) + if body != nil { + defer body.Close() + io.Copy(writer, body) + } + break + } + }) + server := &http.Server{Addr: ":8080", Handler: nil} go server.ListenAndServe() @@ -352,3 +381,93 @@ func TestStruct(t *testing.T) { t.Log(ret) }) } + +func TestErrorStruct(t *testing.T) { + client := restclient.New(restclient.AddIFilter(filter.NewLog(xlog.GetLogger(), ""))) + t.Run("Get", func(t *testing.T) { + ret := testStruct{} + err := client.Exchange("http://localhost:8080/error", + request.WithResult(&ret)) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + if err.StatusCode() == http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(err.StatusCode()) + } + if ret.Id != 1 { + t.Fatal("expect id 1 but get ", ret.Id) + } + t.Log(ret) + }) + + t.Run("Get not found", func(t *testing.T) { + ret := testStruct{} + err := client.Exchange("http://localhost:8080/404", + request.WithResult(&ret)) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + if err.StatusCode() != http.StatusNotFound { + t.Fatal("not 404") + } else { + t.Log(err.StatusCode()) + } + if ret.Id == 1 { + t.Fatal("expect id 0 but get ", ret.Id) + } + t.Log(ret) + }) + + t.Run("Get func", func(t *testing.T) { + err := client.Exchange("http://localhost:8080/error", + request.WithResult(func(ret testStruct) { + if ret.Id != 1 { + t.Fatal("expect id 1 but get ", ret.Id) + } + t.Log(ret) + })) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + if err.StatusCode() == http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(err.StatusCode()) + } + }) + + t.Run("Post", func(t *testing.T) { + ret := testStruct{ + Id: 2, + Name: "test2", + Value: 1.0, + CreateTime: time.Now(), + } + err := client.Exchange("http://localhost:8080/error", + request.MethodPost(), + request.WithResult(&ret), + request.WithRequestBody(ret)) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + if err.StatusCode() == http.StatusOK { + t.Fatal("not 200") + } else { + t.Log(err.StatusCode()) + } + if ret.Id != 2 { + t.Fatal("expect id 2 but get ", ret.Id) + } + t.Log(ret) + }) +}