Skip to content

Commit

Permalink
Merge pull request #2 from xfali/v2
Browse files Browse the repository at this point in the history
V2
  • Loading branch information
xfali authored Aug 23, 2021
2 parents 825a8b5 + 51d37ec commit 831bbb3
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 55 deletions.
6 changes: 5 additions & 1 deletion media_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ func (t *MediaType) isWildcardInnerSub() bool {
return false
}

func (t *MediaType) subEqual(o MediaType) bool {
return strings.Index(o.sub, t.sub) == 0
}

func (t *MediaType) Includes(o MediaType) bool {
if t.IsWildcard() {
return true
Expand All @@ -92,7 +96,7 @@ func (t *MediaType) Includes(o MediaType) bool {
return true
}

if t.sub == o.sub {
if t.subEqual(o) {
return true
}

Expand Down
121 changes: 77 additions & 44 deletions v2/default_restclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ import (
)

type AcceptFlag int
type ResponseBodyFlag int

const (
DefaultTimeout = 0
AcceptUserOnly AcceptFlag = 1
DefaultTimeout = 0

// 仅接受用户指定的header
AcceptUserOnly AcceptFlag = 1
// 根据Converter支持的类型,自动添加第一个支持的类型(默认)
AcceptAutoFirst AcceptFlag = 1 << 1
AcceptAutoAll AcceptFlag = 1 << 2
// 根据Converter支持的类型,自动添加所有支持的类型
AcceptAutoAll AcceptFlag = 1 << 2

// 处理所有的response body
ResponseBodyAll ResponseBodyFlag = 1
// 不处理http status 400及以上的response的body
ResponseBodyIgnoreBad ResponseBodyFlag = 1 << 1
)

var (
Expand Down Expand Up @@ -91,7 +101,8 @@ type defaultRestClient struct {
filterManager filter.FilterManager
pool buffer.Pool

autoAccept AcceptFlag
acceptFlag AcceptFlag
respFlag ResponseBodyFlag
transport http.RoundTripper
timeout time.Duration
}
Expand All @@ -104,7 +115,8 @@ func New(opts ...Opt) *defaultRestClient {
converters: defaultConverters,
pool: buffer.NewPool(),
timeout: DefaultTimeout,
autoAccept: AcceptAutoFirst,
acceptFlag: AcceptAutoFirst,
respFlag: ResponseBodyAll,
}
ret.filterManager.Add(ret.filter)
for _, opt := range opts {
Expand All @@ -114,65 +126,39 @@ func New(opts ...Opt) *defaultRestClient {
return ret
}

func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) error {
func (c *defaultRestClient) Exchange(url string, opts ...request.Opt) Error {
param := emptyParam()
for _, opt := range opts {
opt(param)
}

r, err := c.processRequest(param.reqBody, param.header)
// 序列化request body
r, err := c.encodeRequest(param.reqBody, param.header)
if r != nil {
defer r.Close()
}
if err != nil {
return err
return withErr(DefaultErrorStatus, err)
}

nilResult := reflection.IsNil(param.result)
if !nilResult {
// 根据反序列化response body的目的result类型添加header Accept
param.header = c.addAccept(param.result, param.header)
}

// 创建http.Request
req := defaultRequestCreator(param.ctx, param.method, url, r, param.header)
fm := c.filterManager
if param.filterManager.Valid() {
fm = filter.MergeFilterManager(c.filterManager, param.filterManager)
}
response, err := fm.RunFilter(req)
if err != nil {
return err
return withErr(DefaultErrorStatus, err)
}

if response.Body != nil {
defer response.Body.Close()
// need response
if param.response != nil {
copyResponse(param.response, response)
// need response's body
if param.respFlag {
// get buffer form pool
buf := buffer.NewReadWriteCloser(c.pool)
// 封装reader,在读取response body数据时写入到buffer中
reader := buffer.NewMergeReaderWriter(response.Body, buf)
// 替换response body
response.Body = ioutil.NopCloser(reader)
// 调用者response设置body为buffer
param.response.Body = buf
}
}
if nilResult {
// 如果用户没设置result,则直接读取body到discard
_, err = io.Copy(ioutil.Discard, response.Body)
return err
}

// 处理response
err = c.processResponse(response, param.result)
if err != nil {
return nil
}
}
return nil
return c.processResponse(response, param, nilResult)
}

func (c *defaultRestClient) filter(request *http.Request, fc filter.FilterChain) (*http.Response, error) {
Expand Down Expand Up @@ -209,7 +195,7 @@ func (c *defaultRestClient) newClient() *http.Client {
}
}

func (c *defaultRestClient) processRequest(requestBody interface{}, header http.Header) (io.ReadCloser, error) {
func (c *defaultRestClient) encodeRequest(requestBody interface{}, header http.Header) (io.ReadCloser, error) {
if requestBody != nil {
mtStr := getContentMediaType(header)
mediaType := ParseMediaType(mtStr)
Expand All @@ -235,7 +221,54 @@ func (c *defaultRestClient) processRequest(requestBody interface{}, header http.
return nil, nil
}

func (c *defaultRestClient) processResponse(resp *http.Response, result interface{}) error {
func (c *defaultRestClient) processResponse(response *http.Response, param *defaultParam, nilResult bool) Error {
errStatus := response.StatusCode
if response.StatusCode < http.StatusBadRequest {
errStatus = DefaultErrorStatus
} else if c.respFlag == ResponseBodyIgnoreBad {
return withStatus(response.StatusCode)
}

if response.Body != nil {
defer response.Body.Close()
// need response
if param.response != nil {
copyResponse(param.response, response)
// need response's body
if param.respFlag {
// get buffer form pool
buf := buffer.NewReadWriteCloser(c.pool)
// 封装reader,在读取response body数据时写入到buffer中
reader := buffer.NewMergeReaderWriter(response.Body, buf)
// 替换response body
response.Body = ioutil.NopCloser(reader)
// 调用者response设置body为buffer
param.response.Body = buf
}
}
if nilResult {
// 如果用户没设置result,则直接读取body到discard
_, err := io.Copy(ioutil.Discard, response.Body)
if err != nil {
return withErr(errStatus, err)
}
} else {
// 处理response
err := c.decodeResponse(response, param.result)
if err != nil {
return withErr(errStatus, err)
}
}
}

if response.StatusCode >= http.StatusBadRequest {
return withStatus(response.StatusCode)
}

return nil
}

func (c *defaultRestClient) decodeResponse(resp *http.Response, result interface{}) error {
mediaType := getResponseMediaType(resp)
t := reflect.TypeOf(result)
if t.Kind() != reflect.Func {
Expand Down Expand Up @@ -284,7 +317,7 @@ func (c *defaultRestClient) addAccept(result interface{}, header http.Header) ht
mt := ParseMediaType(userAccept)
typeMap := map[string]bool{}
var acceptList []string
if c.autoAccept != AcceptUserOnly {
if c.acceptFlag != AcceptUserOnly {
index := len(c.converters)
for index > 0 {
index--
Expand All @@ -297,14 +330,14 @@ func (c *defaultRestClient) addAccept(result interface{}, header http.Header) ht
if _, have := typeMap[mtStr]; !have {
acceptList = append(acceptList, mtStr)
typeMap[mtStr] = true
if c.autoAccept == AcceptAutoFirst {
if c.acceptFlag == AcceptAutoFirst {
break
}
}
}
}
}
if c.autoAccept == AcceptAutoFirst && len(typeMap) > 0 {
if c.acceptFlag == AcceptAutoFirst && len(typeMap) > 0 {
break
}
}
Expand Down
58 changes: 58 additions & 0 deletions v2/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (C) 2019-2021, Xiongfa Li.
// @author xiongfa.li
// @version V1.0
// Description:

package restclient

import (
"fmt"
"net/http"
)

var DefaultErrorStatus = http.StatusBadRequest

type Error interface {
error

// 获得http status code
StatusCode() int

// 获得原始error
Origin() error
}

type defaultError struct {
status int
err error
}

func withErr(status int, err error) defaultError {
return defaultError{
status: status,
err: err,
}
}

func withStatus(status int) defaultError {
return defaultError{
status: status,
err: fmt.Errorf("restclient status: [%d] %s", status, http.StatusText(status)),
}
}

func (e defaultError) Origin() error {
return e.err
}

func (e defaultError) Error() string {
if e.err != nil {
return e.err.Error()
} else {
return ""
}
}

func (e defaultError) StatusCode() int {
return e.status
}
9 changes: 8 additions & 1 deletion v2/init_opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ func SetRoundTripper(tripper http.RoundTripper) func(client *defaultRestClient)
// 配置是否自动添加accept
func SetAutoAccept(v AcceptFlag) func(client *defaultRestClient) {
return func(client *defaultRestClient) {
client.autoAccept = v
client.acceptFlag = v
}
}

// 配置是否自动添加accept
func SetResponseBodyFlag(v ResponseBodyFlag) func(client *defaultRestClient) {
return func(client *defaultRestClient) {
client.respFlag = v
}
}

Expand Down
6 changes: 5 additions & 1 deletion v2/media_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ func (t *MediaType) isWildcardInnerSub() bool {
return false
}

func (t *MediaType) subEqual(o MediaType) bool {
return strings.Index(o.sub, t.sub) == 0
}

func (t *MediaType) Includes(o MediaType) bool {
if t.IsWildcard() {
return true
Expand All @@ -92,7 +96,7 @@ func (t *MediaType) Includes(o MediaType) bool {
return true
}

if t.sub == o.sub {
if t.subEqual(o) {
return true
}

Expand Down
2 changes: 1 addition & 1 deletion v2/restclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ type RestClient interface {
// 发起请求
// url:请求路径
// params:请求参数,见ex_params.go具体定义
Exchange(url string, opts ...request.Opt) error
Exchange(url string, opts ...request.Opt) Error
}
Loading

0 comments on commit 831bbb3

Please sign in to comment.