From 3762ca917de73972ff42448be2db84125163225c Mon Sep 17 00:00:00 2001 From: siyul-park Date: Sat, 4 Jan 2025 10:43:07 +0900 Subject: [PATCH] feat: support buffer type --- examples/httpproxy.yaml | 4 +- examples/loopback.yaml | 13 +++ examples/system.yaml | 6 +- ext/README.md | 1 - ext/README_kr.md | 1 - ext/docs/proxy_node.md | 52 --------- ext/docs/proxy_node_kr.md | 52 --------- ext/pkg/mime/compression.go | 81 ++++++++++++-- ext/pkg/mime/encoding.go | 130 +++++++++++++---------- ext/pkg/mime/encoding_test.go | 15 ++- ext/pkg/mime/type.go | 10 +- ext/pkg/mime/type_test.go | 12 ++- ext/pkg/mime/writer.go | 11 -- ext/pkg/network/builder.go | 1 - ext/pkg/network/builder_test.go | 2 +- ext/pkg/network/http.go | 76 +++++++++----- ext/pkg/network/http_test.go | 112 +++++++++++++++----- ext/pkg/network/listener.go | 15 +-- ext/pkg/network/proxy.go | 133 ----------------------- ext/pkg/network/proxy_test.go | 128 ---------------------- ext/pkg/network/websocket.go | 34 +++--- ext/pkg/network/websocket_test.go | 1 + pkg/resource/store.go | 4 +- pkg/types/binary.go | 2 +- pkg/types/binary_test.go | 1 - pkg/types/boolean.go | 2 +- pkg/types/buffer.go | 169 ++++++++++++++++++++++++++++++ pkg/types/buffer_test.go | 145 +++++++++++++++++++++++++ pkg/types/compare.go | 18 ---- pkg/types/encoding.go | 2 + pkg/types/encoding_test.go | 11 +- pkg/types/error.go | 2 +- pkg/types/float.go | 2 +- pkg/types/getter.go | 41 -------- pkg/types/integer.go | 2 +- pkg/types/map.go | 6 +- pkg/types/slice.go | 4 +- pkg/types/string.go | 2 +- pkg/types/uinteger.go | 2 +- pkg/types/value.go | 72 ++++++++++++- 40 files changed, 759 insertions(+), 618 deletions(-) create mode 100644 examples/loopback.yaml delete mode 100644 ext/docs/proxy_node.md delete mode 100644 ext/docs/proxy_node_kr.md delete mode 100644 ext/pkg/mime/writer.go delete mode 100644 ext/pkg/network/proxy.go delete mode 100644 ext/pkg/network/proxy_test.go create mode 100644 pkg/types/buffer.go create mode 100644 pkg/types/buffer_test.go delete mode 100644 pkg/types/compare.go delete mode 100644 pkg/types/getter.go diff --git a/examples/httpproxy.yaml b/examples/httpproxy.yaml index af3ab61a..c1f323d3 100644 --- a/examples/httpproxy.yaml +++ b/examples/httpproxy.yaml @@ -7,6 +7,6 @@ - name: proxy port: in -- kind: proxy +- kind: http name: proxy - urls: [https://www.google.com/] + url: https://echo.free.beeceptor.com/ diff --git a/examples/loopback.yaml b/examples/loopback.yaml new file mode 100644 index 00000000..bbdc9253 --- /dev/null +++ b/examples/loopback.yaml @@ -0,0 +1,13 @@ +- kind: listener + name: listener + protocol: http + port: '{{ .PORT }}' + ports: + out: + - name: loopback + port: in + +- kind: snippet + name: loopback + language: cel + code: self diff --git a/examples/system.yaml b/examples/system.yaml index 11dd85d8..a8659c29 100644 --- a/examples/system.yaml +++ b/examples/system.yaml @@ -154,7 +154,7 @@ - kind: if name: specs_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: specs_read_with_query @@ -357,7 +357,7 @@ - kind: if name: values_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: values_read_with_query @@ -557,7 +557,7 @@ - kind: if name: charts_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: charts_read_with_query diff --git a/ext/README.md b/ext/README.md index e00cd276..301242d5 100644 --- a/ext/README.md +++ b/ext/README.md @@ -45,7 +45,6 @@ Facilitates smooth execution of network-related tasks across various protocols. - **[WebSocket Node](./docs/websocket_node.md)**: Establishes WebSocket connections and handles message sending and receiving. - **[Gateway Node](./docs/gateway_node.md)**: Upgrades HTTP connections to WebSocket for real-time data communication. - **[Listener Node](./docs/listener_node.md)**: Receives network requests on specified protocols and ports. -- **[Proxy Node](./docs/proxy_node.md)**: Proxies HTTP requests to other servers and returns their responses. - **[Router Node](./docs/router_node.md)**: Routes input packets to multiple output ports based on conditions. ### **System** diff --git a/ext/README_kr.md b/ext/README_kr.md index c8bee290..1711bf36 100644 --- a/ext/README_kr.md +++ b/ext/README_kr.md @@ -42,7 +42,6 @@ - **[WebSocket 노드](./docs/websocket_node_kr.md)**: WebSocket 연결을 설정하고 메시지를 송수신합니다. - **[Gateway 노드](./docs/gateway_node_kr.md)**: HTTP 연결을 WebSocket으로 업그레이드하여 실시간 데이터 통신을 지원합니다. - **[Listener 노드](./docs/listener_node_kr.md)**: 지정된 프로토콜과 포트에서 네트워크 요청을 수신합니다. -- **[Proxy 노드](./docs/proxy_node_kr.md)**: HTTP 요청을 다른 서버로 프록시하여 응답을 반환합니다. - **[Router 노드](./docs/router_node_kr.md)**: 입력 패킷을 조건에 따라 여러 출력 포트로 라우팅합니다. ### **System** diff --git a/ext/docs/proxy_node.md b/ext/docs/proxy_node.md deleted file mode 100644 index f682882a..00000000 --- a/ext/docs/proxy_node.md +++ /dev/null @@ -1,52 +0,0 @@ -# Proxy Node - -**The Proxy Node** acts as an HTTP proxy, forwarding requests to other servers and returning their responses. It can be used for load balancing, API gateway functionality, or caching. - -## Specification - -- **urls**: A list of target server URLs to proxy requests to. Requests are forwarded to these URLs using a round-robin approach. (Required) - -## Ports - -- **in**: The port that receives HTTP requests. The following fields are included: - - **method**: The HTTP method (e.g., `GET`, `POST`) - - **scheme**: The URL scheme (e.g., `http`, `https`) - - **host**: The request's host (e.g., `example.com`) - - **path**: The request's path (e.g., `/api/v1/resource`) - - **query**: URL query string parameters (e.g., `?key=value`) - - **protocol**: The HTTP protocol version (e.g., `HTTP/1.1`) - - **header**: HTTP headers (e.g., `Content-Type: application/json`) - - **body**: The request body (e.g., JSON, XML, text) - - **status**: The HTTP status code - -- **out**: The port that returns the response from the proxied server. The following fields are included: - - **method**: The HTTP method (e.g., `GET`, `POST`) - - **scheme**: The URL scheme (e.g., `http`, `https`) - - **host**: The request's host - - **path**: The request's path - - **query**: URL query string parameters - - **protocol**: The HTTP protocol version (e.g., `HTTP/1.1`) - - **header**: The HTTP headers of the response - - **body**: The response body - - **status**: The HTTP status code of the response - -- **error**: The port that returns any errors encountered during the request (e.g., network failure, invalid URL). - -## Example - -```yaml -- kind: listener - name: listener - protocol: http - port: 8000 - ports: - out: - - name: proxy - port: in - -- kind: proxy - name: proxy - urls: - - https://backend1.com/ - - https://backend2.com/ -``` diff --git a/ext/docs/proxy_node_kr.md b/ext/docs/proxy_node_kr.md deleted file mode 100644 index e962465e..00000000 --- a/ext/docs/proxy_node_kr.md +++ /dev/null @@ -1,52 +0,0 @@ -# Proxy 노드 - -**Proxy 노드**는 HTTP 요청을 다른 서버로 프록시하여 중계하고, 그 응답을 반환하는 기능을 제공합니다. - -## 명세 - -- **urls**: 프록시할 대상 서버의 URL 목록을 지정합니다. 요청은 이 목록에서 라운드 로빈 방식으로 선택된 서버로 전달됩니다. - -## 포트 - -- **in**: HTTP 요청을 수신하는 포트입니다. 다음 필드를 포함합니다: - - **method**: HTTP 메서드 (예: `GET`, `POST`) - - **scheme**: URL의 스킴 (예: `http`, `https`) - - **host**: 요청의 호스트 - - **path**: 요청의 경로 - - **query**: URL 쿼리 문자열 파라미터 - - **protocol**: HTTP 프로토콜 버전 (예: `HTTP/1.1`) - - **header**: HTTP 헤더 - - **body**: 요청 본문 - - **status**: HTTP 상태 코드 - -- **out**: 프록시된 서버의 응답을 반환하는 포트입니다. 다음 필드를 포함합니다: - - **method**: HTTP 메서드 (예: `GET`, `POST`) - - **scheme**: URL의 스킴 (예: `http`, `https`) - - **host**: 요청의 호스트 - - **path**: 요청의 경로 - - **query**: URL 쿼리 문자열 파라미터 - - **protocol**: HTTP 프로토콜 버전 (예: `HTTP/1.1`) - - **header**: HTTP 헤더 - - **body**: 요청 본문 - - **status**: HTTP 상태 코드 - -- **error**: 오류가 발생했을 때 에러를 반환하는 포트입니다. (예: 네트워크 장애, 잘못된 URL) - -## 예시 - -```yaml -- kind: listener - name: listener - protocol: http - port: 8000 - ports: - out: - - name: proxy - port: in - -- kind: proxy - name: proxy - urls: - - https://backend1.com/ - - https://backend2.com/ -``` diff --git a/ext/pkg/mime/compression.go b/ext/pkg/mime/compression.go index 7a169cd1..08f91fa0 100644 --- a/ext/pkg/mime/compression.go +++ b/ext/pkg/mime/compression.go @@ -8,6 +8,14 @@ import ( "github.com/andybalholm/brotli" ) +type multiWriter struct { + pipe []io.Writer +} + +type multiReader struct { + pipe []io.Reader +} + const ( EncodingGzip = "gzip" EncodingDeflate = "deflate" @@ -15,15 +23,20 @@ const ( EncodingIdentity = "identity" ) +var _ io.Writer = (*multiWriter)(nil) +var _ io.Closer = (*multiWriter)(nil) +var _ io.Reader = (*multiReader)(nil) +var _ io.Closer = (*multiReader)(nil) + // Compress compresses input data using the specified encoding, returns original if unsupported. func Compress(writer io.Writer, encoding string) (io.Writer, error) { switch encoding { case EncodingGzip: - return gzip.NewWriter(writer), nil + return newMultiWriter(writer, gzip.NewWriter(writer)), nil case EncodingDeflate: - return zlib.NewWriter(writer), nil + return newMultiWriter(writer, zlib.NewWriter(writer)), nil case EncodingBr: - return brotli.NewWriter(writer), nil + return newMultiWriter(writer, brotli.NewWriter(writer)), nil default: return writer, nil } @@ -33,12 +46,68 @@ func Compress(writer io.Writer, encoding string) (io.Writer, error) { func Decompress(reader io.Reader, encoding string) (io.Reader, error) { switch encoding { case EncodingGzip: - return gzip.NewReader(reader) + r, err := gzip.NewReader(reader) + if err != nil { + return nil, err + } + return newMultiReader(reader, r), nil case EncodingDeflate: - return zlib.NewReader(reader) + r, err := zlib.NewReader(reader) + if err != nil { + return nil, err + } + return newMultiReader(reader, r), nil case EncodingBr: - return brotli.NewReader(reader), nil + return newMultiReader(reader, brotli.NewReader(reader)), nil default: return reader, nil } } + +// newMultiWriter creates a writer that writes to multiple writers in sequence. +func newMultiWriter(pipe ...io.Writer) io.Writer { + return &multiWriter{pipe: pipe} +} + +// newMultiReader creates a reader that reads from multiple readers in sequence. +func newMultiReader(pipe ...io.Reader) io.Reader { + return &multiReader{pipe: pipe} +} + +func (w *multiWriter) Write(p []byte) (n int, err error) { + if len(w.pipe) == 0 { + return 0, io.ErrClosedPipe + } + return w.pipe[len(w.pipe)-1].Write(p) +} + +func (w *multiWriter) Close() error { + for i := len(w.pipe) - 1; i >= 0; i-- { + if c, ok := w.pipe[i].(io.Closer); ok { + if err := c.Close(); err != nil { + return err + } + } + } + w.pipe = nil + return nil +} + +func (r *multiReader) Read(p []byte) (n int, err error) { + if len(r.pipe) == 0 { + return 0, io.ErrClosedPipe + } + return r.pipe[len(r.pipe)-1].Read(p) +} + +func (r *multiReader) Close() error { + for i := len(r.pipe) - 1; i >= 0; i-- { + if c, ok := r.pipe[i].(io.Closer); ok { + if err := c.Close(); err != nil { + return err + } + } + } + r.pipe = nil + return nil +} diff --git a/ext/pkg/mime/encoding.go b/ext/pkg/mime/encoding.go index 7b621838..75a3caf9 100644 --- a/ext/pkg/mime/encoding.go +++ b/ext/pkg/mime/encoding.go @@ -18,12 +18,20 @@ import ( "github.com/siyul-park/uniflow/pkg/types" ) +type byteCounter struct { + io.Writer + count int +} + var ( keyValues = types.NewString("values") keyFiles = types.NewString("files") ) -var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") +var escapeQuotes = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +var _ io.Writer = (*byteCounter)(nil) +var _ io.Closer = (*byteCounter)(nil) // Encode encodes the given types into the writer with the specified MIME headers. func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) error { @@ -41,36 +49,39 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er } } - typ, params, err := mime.ParseMediaType(typ) + counter := &byteCounter{Writer: writer} + defer header.Set(HeaderContentLength, strconv.Itoa(counter.Count())) + + w, err := Compress(counter, encode) if err != nil { return err } + if c, ok := w.(io.Closer); ok { + defer c.Close() + } - count := 0 - var cwriter io.Writer = WriterFunc(func(p []byte) (n int, err error) { - n, err = writer.Write(p) - count += n - return - }) + if v, ok := value.(types.Buffer); ok { + if _, err := io.Copy(w, v); err != nil { + return err + } + return nil + } else if v, ok := value.(types.Binary); ok { + if _, err := w.Write(v.Bytes()); err != nil { + return err + } + return nil + } - w, err := Compress(cwriter, encode) + typ, params, err := mime.ParseMediaType(typ) if err != nil { return err } - flush := func() { - if c, ok := w.(io.Closer); ok && w != cwriter { - c.Close() - } - header.Set(HeaderContentLength, strconv.Itoa(count)) - } - switch typ { case ApplicationJSON: if err := json.NewEncoder(w).Encode(types.InterfaceOf(value)); err != nil { return err } - flush() return nil case ApplicationFormURLEncoded: urlValues := url.Values{} @@ -80,7 +91,14 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er if _, err := w.Write([]byte(urlValues.Encode())); err != nil { return err } - flush() + return nil + case TextPlain: + if v, ok := value.(types.String); ok { + if _, err := w.Write([]byte(v.String())); err != nil { + return err + } + return nil + } return nil case MultipartFormData: boundary := params["boundary"] @@ -95,7 +113,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er return err } - writeField := func(obj types.Map, key types.Value) error { + writeFormField := func(obj types.Map, key types.Value) error { if key, ok := key.(types.String); ok { value := obj.Get(key) @@ -108,7 +126,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er for _, element := range elements.Range() { h := textproto.MIMEHeader{} - h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"`, quoteEscaper.Replace(key.String()))) + h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"`, escapeQuotes.Replace(key.String()))) if w, err := mw.CreatePart(h); err != nil { return err @@ -120,10 +138,10 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er return nil } - writeFields := func(value types.Value) error { + writeFormFields := func(value types.Value) error { if value, ok := value.(types.Map); ok { for key := range value.Range() { - if err := writeField(value, key); err != nil { + if err := writeFormField(value, key); err != nil { return err } } @@ -131,7 +149,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er return nil } - writeFiles := func(value types.Value) error { + writeFormFiles := func(value types.Value) error { if value, ok := value.(types.Map); ok { for key := range value.Range() { if key, ok := key.(types.String); ok { @@ -181,7 +199,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er } } - h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"; filename="%s"`, quoteEscaper.Replace(key.String()), quoteEscaper.Replace(filename))) + h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes.Replace(key.String()), escapeQuotes.Replace(filename))) if writer, err := mw.CreatePart(h); err != nil { return err @@ -198,14 +216,14 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er if v, ok := value.(types.Map); ok { for key, value := range v.Range() { if key.Equal(keyValues) { - if err := writeFields(value); err != nil { + if err := writeFormFields(value); err != nil { return err } } else if key.Equal(keyFiles) { - if err := writeFiles(value); err != nil { + if err := writeFormFiles(value); err != nil { return err } - } else if err := writeField(v, key); err != nil { + } else if err := writeFormField(v, key); err != nil { return err } } @@ -214,29 +232,13 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er if err := mw.Close(); err != nil { return err } - flush() return nil } - switch v := value.(type) { - case types.Binary: - if _, err := w.Write(v.Bytes()); err != nil { - return err - } - flush() - return nil - case types.String: - if _, err := w.Write([]byte(v.String())); err != nil { - return err - } - flush() - return nil - default: - return errors.WithStack(encoding.ErrUnsupportedType) - } + return errors.WithStack(encoding.ErrUnsupportedType) } -// Decode decodes the given reader with the specified MIME headers into an types. +// Decode decodes the given reader with the specified MIME headers into types. func Decode(reader io.Reader, header textproto.MIMEHeader) (types.Value, error) { typ := header.Get(HeaderContentType) encode := header.Get(HeaderContentEncoding) @@ -245,9 +247,12 @@ func Decode(reader io.Reader, header textproto.MIMEHeader) (types.Value, error) if err != nil { return nil, err } - if c, ok := r.(io.Closer); ok && r != reader { - defer c.Close() - } + + defer func() { + if c, ok := r.(io.Closer); ok { + _ = c.Close() + } + }() if typ == "" { data, err := io.ReadAll(r) @@ -290,9 +295,9 @@ func Decode(reader io.Reader, header textproto.MIMEHeader) (types.Value, error) } return types.NewString(string(data)), nil case MultipartFormData: - reader := multipart.NewReader(r, params["boundary"]) + parts := multipart.NewReader(r, params["boundary"]) - form, err := reader.ReadForm(0) + form, err := parts.ReadForm(0) if err != nil { return nil, err } @@ -327,11 +332,9 @@ func Decode(reader io.Reader, header textproto.MIMEHeader) (types.Value, error) }) } - data, err := io.ReadAll(r) - if err != nil { - return nil, err - } - return types.NewBinary(data), nil + buf := types.NewBuffer(r) + r = io.NopCloser(r) + return buf, nil } func randomMultipartBoundary() string { @@ -342,3 +345,20 @@ func randomMultipartBoundary() string { } return fmt.Sprintf("%x", buf[:]) } + +func (w *byteCounter) Write(p []byte) (n int, err error) { + n, err = w.Writer.Write(p) + w.count += n + return n, err +} + +func (w *byteCounter) Close() error { + if c, ok := w.Writer.(io.Closer); ok { + return c.Close() + } + return nil +} + +func (w *byteCounter) Count() int { + return w.count +} diff --git a/ext/pkg/mime/encoding_test.go b/ext/pkg/mime/encoding_test.go index 10ab93a7..b2e49c44 100644 --- a/ext/pkg/mime/encoding_test.go +++ b/ext/pkg/mime/encoding_test.go @@ -171,7 +171,7 @@ func TestDecode(t *testing.T) { { whenValue: []byte("--MyBoundary\r\n" + "Content-Disposition: form-data; name=\"test\"; filename=\"test\"\r\n" + - "Content-Type: application/octet-stream\r\n" + + "Content-Type: text/plain\r\n" + "\r\n" + "test\r\n" + "--MyBoundary\r\n" + @@ -186,11 +186,11 @@ func TestDecode(t *testing.T) { ), types.NewString("files"), types.NewMap( types.NewString("test"), types.NewSlice(types.NewMap( - types.NewString("data"), types.NewBinary([]byte("test")), + types.NewString("data"), types.NewString("test"), types.NewString("filename"), types.NewString("test"), types.NewString("header"), types.NewMap( types.NewString("Content-Disposition"), types.NewSlice(types.NewString("form-data; name=\"test\"; filename=\"test\"")), - types.NewString("Content-Type"), types.NewSlice(types.NewString("application/octet-stream")), + types.NewString("Content-Type"), types.NewSlice(types.NewString("text/plain")), ), types.NewString("size"), types.NewInt64(4), )), @@ -200,7 +200,7 @@ func TestDecode(t *testing.T) { { whenValue: []byte("testtesttest"), whenType: ApplicationOctetStream, - expect: types.NewBinary([]byte("testtesttest")), + expect: types.NewBuffer(bytes.NewBuffer([]byte("testtesttest"))), }, } @@ -210,7 +210,12 @@ func TestDecode(t *testing.T) { HeaderContentType: []string{tt.whenType}, }) assert.NoError(t, err) - assert.Equal(t, tt.expect.Interface(), decode.Interface()) + + var expect any + var actual any + _ = types.Unmarshal(tt.expect, &expect) + _ = types.Unmarshal(decode, &actual) + assert.Equal(t, expect, actual) }) } } diff --git a/ext/pkg/mime/type.go b/ext/pkg/mime/type.go index 0314d857..bdb319e4 100644 --- a/ext/pkg/mime/type.go +++ b/ext/pkg/mime/type.go @@ -55,16 +55,16 @@ func DetectTypesFromBytes(value []byte) []string { // DetectTypesFromValue determines the content types based on the type of types passed. func DetectTypesFromValue(value types.Value) []string { switch value.(type) { - case types.Binary: + case types.Binary, types.Buffer: return []string{ApplicationOctetStream} case types.String: - return []string{TextPlainCharsetUTF8, ApplicationJSONCharsetUTF8} + return []string{TextPlainCharsetUTF8, ApplicationOctetStream, ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} case types.Slice: - return []string{ApplicationJSONCharsetUTF8} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded} case types.Map, types.Error: - return []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} default: - return []string{ApplicationJSONCharsetUTF8} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, TextPlainCharsetUTF8, ApplicationOctetStream} } } diff --git a/ext/pkg/mime/type_test.go b/ext/pkg/mime/type_test.go index 6dcacf2c..137a0777 100644 --- a/ext/pkg/mime/type_test.go +++ b/ext/pkg/mime/type_test.go @@ -53,21 +53,25 @@ func TestDetectTypesFromValue(t *testing.T) { when: types.NewBinary(nil), expect: []string{ApplicationOctetStream}, }, + { + when: types.NewBuffer(nil), + expect: []string{ApplicationOctetStream}, + }, { when: types.NewString(""), - expect: []string{TextPlainCharsetUTF8, ApplicationJSONCharsetUTF8}, + expect: []string{TextPlainCharsetUTF8, ApplicationOctetStream, ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, { when: types.NewSlice(), - expect: []string{ApplicationJSONCharsetUTF8}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded}, }, { when: types.NewMap(), - expect: []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, { when: types.NewError(nil), - expect: []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, } diff --git a/ext/pkg/mime/writer.go b/ext/pkg/mime/writer.go deleted file mode 100644 index e9ccc82c..00000000 --- a/ext/pkg/mime/writer.go +++ /dev/null @@ -1,11 +0,0 @@ -package mime - -import "io" - -type WriterFunc func([]byte) (int, error) - -var _ io.Writer = WriterFunc(nil) - -func (f WriterFunc) Write(p []byte) (int, error) { - return f(p) -} diff --git a/ext/pkg/network/builder.go b/ext/pkg/network/builder.go index 328594ea..83052ca2 100644 --- a/ext/pkg/network/builder.go +++ b/ext/pkg/network/builder.go @@ -39,7 +39,6 @@ func AddToScheme() scheme.Register { }{ {KindHTTP, NewHTTPNodeCodec(), &HTTPNodeSpec{}}, {KindListener, NewListenNodeCodec(), &ListenNodeSpec{}}, - {KindProxy, NewProxyNodeCodec(), &ProxyNodeSpec{}}, {KindRouter, NewRouteNodeCodec(), &RouteNodeSpec{}}, {KindWebSocket, NewWebSocketNodeCodec(), &WebSocketNodeSpec{}}, {KindGateway, NewGatewayNodeCodec(), &GatewayNodeSpec{}}, diff --git a/ext/pkg/network/builder_test.go b/ext/pkg/network/builder_test.go index ff77d6dc..dc3be7e7 100644 --- a/ext/pkg/network/builder_test.go +++ b/ext/pkg/network/builder_test.go @@ -42,7 +42,7 @@ func TestAddToScheme(t *testing.T) { err := AddToScheme().AddToScheme(s) assert.NoError(t, err) - tests := []string{KindHTTP, KindListener, KindProxy, KindRouter, KindWebSocket, KindGateway} + tests := []string{KindHTTP, KindListener, KindRouter, KindWebSocket, KindGateway} for _, tt := range tests { t.Run(tt, func(t *testing.T) { diff --git a/ext/pkg/network/http.go b/ext/pkg/network/http.go index 9d273a3b..d0f3c612 100644 --- a/ext/pkg/network/http.go +++ b/ext/pkg/network/http.go @@ -1,8 +1,8 @@ package network import ( - "bytes" "context" + "io" "net/http" "net/textproto" "net/url" @@ -58,26 +58,37 @@ func NewHTTPNodeCodec() scheme.Codec { return nil, err } - n := NewHTTPNode(parse) + transport := &http.Transport{} + if err := http2.ConfigureTransport(transport); err != nil { + return nil, err + } + client := &http.Client{Transport: transport} + + n := NewHTTPNode(client) + n.SetURL(parse) n.SetTimeout(spec.Timeout) return n, nil }) } // NewHTTPNode creates a new HTTPNode instance. -func NewHTTPNode(url *url.URL) *HTTPNode { - transport := &http.Transport{} - _ = http2.ConfigureTransport(transport) - - client := &http.Client{ - Transport: transport, +func NewHTTPNode(client *http.Client) *HTTPNode { + if client == nil { + client = http.DefaultClient } - - n := &HTTPNode{client: client, url: url} + n := &HTTPNode{client: client, url: &url.URL{}} n.OneToOneNode = node.NewOneToOneNode(n.action) return n } +// SetURL sets the URL for the HTTP request. +func (n *HTTPNode) SetURL(url *url.URL) { + n.mu.Lock() + defer n.mu.Unlock() + + n.url = url +} + // SetTimeout sets the timeout duration for the HTTP request. func (n *HTTPNode) SetTimeout(timeout time.Duration) { n.mu.Lock() @@ -127,35 +138,52 @@ func (n *HTTPNode) action(proc *process.Process, inPck *packet.Packet) (*packet. } } - buf := bytes.NewBuffer(nil) - if err := mime.Encode(buf, req.Body, textproto.MIMEHeader(req.Header)); err != nil { - return nil, packet.New(types.NewError(err)) + header := textproto.MIMEHeader{} + for k, v := range req.Header { + header[k] = v } - u := &url.URL{ - Scheme: req.Scheme, - Host: req.Host, - Path: req.Path, - RawQuery: req.Query.Encode(), + pr, pw := io.Pipe() + + errors := make(chan error) + go func() { + defer close(errors) + if err := mime.Encode(pw, req.Body, header); err != nil { + errors <- err + } + }() + + r := &http.Request{ + Method: req.Method, + URL: &url.URL{ + Scheme: req.Scheme, + Host: req.Host, + Path: req.Path, + RawQuery: req.Query.Encode(), + }, + Header: req.Header, + Body: pr, } - r, err := http.NewRequest(req.Method, u.String(), buf) + w, err := n.client.Do(r.WithContext(ctx)) if err != nil { return nil, packet.New(types.NewError(err)) } - r = r.WithContext(ctx) - - w, err := n.client.Do(r) - if err != nil { + if err := <-errors; err != nil { return nil, packet.New(types.NewError(err)) } - defer w.Body.Close() body, err := mime.Decode(w.Body, textproto.MIMEHeader(w.Header)) if err != nil { return nil, packet.New(types.NewError(err)) } + if b, ok := body.(types.Buffer); ok { + proc.AddExitHook(process.ExitFunc(func(err error) { + _ = b.Close() + })) + } + res := &HTTPPayload{ Method: req.Method, Scheme: req.Scheme, diff --git a/ext/pkg/network/http_test.go b/ext/pkg/network/http_test.go index 45ad1f58..b46aa42c 100644 --- a/ext/pkg/network/http_test.go +++ b/ext/pkg/network/http_test.go @@ -2,12 +2,15 @@ package network import ( "context" + "crypto/tls" "net/http" "net/http/httptest" "net/url" "testing" "time" + "golang.org/x/net/http2" + "github.com/siyul-park/uniflow/pkg/node" "github.com/siyul-park/uniflow/pkg/packet" "github.com/siyul-park/uniflow/pkg/port" @@ -34,46 +37,102 @@ func TestHTTPNodeCodec_Compile(t *testing.T) { } func TestNewHTTPNode(t *testing.T) { - n := NewHTTPNode(&url.URL{}) + n := NewHTTPNode(nil) assert.NotNil(t, n) assert.NoError(t, n.Close()) } func TestHTTPNode_SendAndReceive(t *testing.T) { - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) - defer cancel() + t.Run("HTTP/1.1", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() - s := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - })) - defer s.Close() + s := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + })) + defer s.Close() - u, _ := url.Parse(s.URL) + u, _ := url.Parse(s.URL) - n := NewHTTPNode(u) - defer n.Close() + n := NewHTTPNode(nil) + defer n.Close() - n.SetTimeout(time.Second) + n.SetURL(u) + n.SetTimeout(time.Second) - in := port.NewOut() - in.Link(n.In(node.PortIn)) + in := port.NewOut() + in.Link(n.In(node.PortIn)) - proc := process.New() - defer proc.Exit(nil) + proc := process.New() + defer proc.Exit(nil) - inWriter := in.Open(proc) + inWriter := in.Open(proc) - var inPayload types.Value - inPck := packet.New(inPayload) + var inPayload types.Value + inPck := packet.New(inPayload) - inWriter.Write(inPck) + inWriter.Write(inPck) - select { - case outPck := <-inWriter.Receive(): - _, ok := outPck.Payload().(types.Error) - assert.False(t, ok) - case <-ctx.Done(): - assert.Fail(t, ctx.Err().Error()) - } + select { + case outPck := <-inWriter.Receive(): + _, ok := outPck.Payload().(types.Error) + assert.False(t, ok) + case <-ctx.Done(): + assert.Fail(t, ctx.Err().Error()) + } + }) + + t.Run("HTTP/2", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + defer cancel() + + s := httptest.NewUnstartedServer(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { + assert.Equal(t, "HTTP/2.0", req.Proto) + })) + _ = http2.ConfigureServer(s.Config, nil) + + s.TLS = &tls.Config{ + NextProtos: []string{"h2"}, + } + + s.StartTLS() + defer s.Close() + + client := &http.Client{ + Transport: &http2.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } + u, _ := url.Parse(s.URL) + + n := NewHTTPNode(client) + defer n.Close() + + n.SetURL(u) + n.SetTimeout(time.Second) + + in := port.NewOut() + in.Link(n.In(node.PortIn)) + + proc := process.New() + defer proc.Exit(nil) + + inWriter := in.Open(proc) + + var inPayload types.Value + inPck := packet.New(inPayload) + + inWriter.Write(inPck) + + select { + case outPck := <-inWriter.Receive(): + _, ok := outPck.Payload().(types.Error) + assert.False(t, ok) + case <-ctx.Done(): + assert.Fail(t, ctx.Err().Error()) + } + }) } func BenchmarkHTTPNode_SendAndReceive(b *testing.B) { @@ -83,9 +142,10 @@ func BenchmarkHTTPNode_SendAndReceive(b *testing.B) { u, _ := url.Parse(s.URL) - n := NewHTTPNode(u) + n := NewHTTPNode(nil) defer n.Close() + n.SetURL(u) n.SetTimeout(time.Second) in := port.NewOut() diff --git a/ext/pkg/network/listener.go b/ext/pkg/network/listener.go index 9a8a4684..5d49688f 100644 --- a/ext/pkg/network/listener.go +++ b/ext/pkg/network/listener.go @@ -1,10 +1,8 @@ package network import ( - "bytes" "crypto/tls" "fmt" - "io" "net" "net/http" "net/textproto" @@ -261,6 +259,9 @@ func (n *HTTPListenNode) negotiate(req *HTTPPayload, res *HTTPPayload) { accept := req.Header.Get(mime.HeaderAccept) offers := mime.DetectTypesFromValue(res.Body) contentType := mime.Negotiate(accept, offers) + if contentType == "" && len(offers) > 0 { + contentType = offers[0] + } if contentType != "" { res.Header.Set(mime.HeaderContentType, contentType) } @@ -302,17 +303,11 @@ func (n *HTTPListenNode) write(w http.ResponseWriter, res *HTTPPayload) error { } } - buf := bytes.NewBuffer(nil) - if err := mime.Encode(buf, res.Body, textproto.MIMEHeader(h)); err != nil { - return err - } - status := res.Status if status == 0 { status = http.StatusOK } - w.WriteHeader(status) - _, err := io.Copy(w, buf) - return err + w.WriteHeader(status) + return mime.Encode(w, res.Body, textproto.MIMEHeader(h)) } diff --git a/ext/pkg/network/proxy.go b/ext/pkg/network/proxy.go deleted file mode 100644 index b10d8712..00000000 --- a/ext/pkg/network/proxy.go +++ /dev/null @@ -1,133 +0,0 @@ -package network - -import ( - "bytes" - "io" - "net/http" - "net/http/httptest" - "net/http/httputil" - "net/textproto" - "net/url" - "sync" - - "github.com/pkg/errors" - "github.com/siyul-park/uniflow/ext/pkg/mime" - "github.com/siyul-park/uniflow/pkg/encoding" - "github.com/siyul-park/uniflow/pkg/node" - "github.com/siyul-park/uniflow/pkg/packet" - "github.com/siyul-park/uniflow/pkg/process" - "github.com/siyul-park/uniflow/pkg/scheme" - "github.com/siyul-park/uniflow/pkg/spec" - "github.com/siyul-park/uniflow/pkg/types" - "golang.org/x/net/http2" -) - -// ProxyNodeSpec defines the specifications for creating a ProxyNode. -type ProxyNodeSpec struct { - spec.Meta `map:",inline"` - URLs []string `map:"urls" validate:"required,dive,url"` -} - -// ProxyNode represents a Node for handling HTTP proxy. -type ProxyNode struct { - *node.OneToOneNode - proxy *httputil.ReverseProxy -} - -const KindProxy = "proxy" - -// NewProxyNodeCodec creates a new codec for ProxyNode. -func NewProxyNodeCodec() scheme.Codec { - return scheme.CodecWithType(func(spec *ProxyNodeSpec) (node.Node, error) { - urls := make([]*url.URL, 0, len(spec.URLs)) - for _, u := range spec.URLs { - parsed, err := url.Parse(u) - if err != nil { - return nil, err - } - urls = append(urls, parsed) - } - if len(urls) == 0 { - return nil, errors.WithStack(encoding.ErrUnsupportedValue) - } - return NewProxyNode(urls), nil - }) -} - -// NewProxyNode creates a new ProxyNode instance. -func NewProxyNode(urls []*url.URL) *ProxyNode { - var index int - var mu sync.Mutex - - transport := &http.Transport{} - http2.ConfigureTransport(transport) - - proxy := &httputil.ReverseProxy{ - Transport: transport, - Rewrite: func(r *httputil.ProxyRequest) { - mu.Lock() - defer mu.Unlock() - - index = (index + 1) % len(urls) - - r.SetURL(urls[index]) - r.SetXForwarded() - }, - } - - n := &ProxyNode{proxy: proxy} - n.OneToOneNode = node.NewOneToOneNode(n.action) - return n -} - -// action handles the HTTP proxy request and response. -func (n *ProxyNode) action(proc *process.Process, inPck *packet.Packet) (*packet.Packet, *packet.Packet) { - req := &HTTPPayload{} - if err := types.Unmarshal(inPck.Payload(), req); err != nil { - return nil, packet.New(types.NewError(err)) - } - - buf := bytes.NewBuffer(nil) - if err := mime.Encode(buf, req.Body, textproto.MIMEHeader(req.Header)); err != nil { - return nil, packet.New(types.NewError(err)) - } - - r := &http.Request{ - Method: req.Method, - URL: &url.URL{ - Scheme: req.Scheme, - Host: req.Host, - Path: req.Path, - RawQuery: req.Query.Encode(), - }, - Proto: req.Protocol, - Header: req.Header, - Body: io.NopCloser(buf), - } - w := httptest.NewRecorder() - - n.proxy.ServeHTTP(w, r) - - body, err := mime.Decode(w.Body, textproto.MIMEHeader(w.Header())) - if err != nil { - return nil, packet.New(types.NewError(err)) - } - - res := &HTTPPayload{ - Method: req.Method, - Scheme: req.Scheme, - Host: req.Host, - Path: req.Path, - Query: req.Query, - Protocol: req.Protocol, - Header: w.Header(), - Body: body, - Status: w.Code, - } - - outPayload, err := types.Marshal(res) - if err != nil { - return nil, packet.New(types.NewError(err)) - } - return packet.New(outPayload), nil -} diff --git a/ext/pkg/network/proxy_test.go b/ext/pkg/network/proxy_test.go deleted file mode 100644 index 64e9ab76..00000000 --- a/ext/pkg/network/proxy_test.go +++ /dev/null @@ -1,128 +0,0 @@ -package network - -import ( - "context" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" - - "github.com/siyul-park/uniflow/pkg/node" - "github.com/siyul-park/uniflow/pkg/packet" - "github.com/siyul-park/uniflow/pkg/port" - "github.com/siyul-park/uniflow/pkg/process" - "github.com/siyul-park/uniflow/pkg/types" - "github.com/stretchr/testify/assert" -) - -func TestProxyNodeCodec_Compile(t *testing.T) { - codec := NewProxyNodeCodec() - - spec := &ProxyNodeSpec{ - URLs: []string{"http://localhost"}, - } - - n, err := codec.Compile(spec) - assert.NoError(t, err) - assert.NotNil(t, n) - assert.NoError(t, n.Close()) -} - -func TestNewProxyNode(t *testing.T) { - u, _ := url.Parse("http://localhost") - n := NewProxyNode([]*url.URL{u}) - assert.NotNil(t, n) - assert.NoError(t, n.Close()) -} - -func TestProxyNode_SendAndReceive(t *testing.T) { - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) - defer cancel() - - s1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - writer.WriteHeader(http.StatusOK) - writer.Write([]byte("Backend 1")) - })) - defer s1.Close() - - s2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - writer.WriteHeader(http.StatusOK) - writer.Write([]byte("Backend 2")) - })) - defer s2.Close() - - u1, _ := url.Parse(s1.URL) - u2, _ := url.Parse(s2.URL) - - n := NewProxyNode([]*url.URL{u1, u2}) - defer n.Close() - - in := port.NewOut() - in.Link(n.In(node.PortIn)) - - proc := process.New() - defer proc.Exit(nil) - - inWriter := in.Open(proc) - - inPayload := types.NewMap( - types.NewString("method"), types.NewString(http.MethodGet), - types.NewString("scheme"), types.NewString("http"), - types.NewString("host"), types.NewString("test"), - types.NewString("path"), types.NewString("/"), - types.NewString("protocol"), types.NewString("HTTP/1.1"), - types.NewString("status"), types.NewInt(0), - ) - inPck := packet.New(inPayload) - - inWriter.Write(inPck) - - select { - case outPck := <-inWriter.Receive(): - payload := &HTTPPayload{} - err := types.Unmarshal(outPck.Payload(), payload) - assert.NoError(t, err) - assert.Contains(t, payload.Body.Interface(), "Backend") - case <-ctx.Done(): - assert.Fail(t, ctx.Err().Error()) - } -} - -func BenchmarkProxyNode_SendAndReceive(b *testing.B) { - s := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - writer.WriteHeader(http.StatusOK) - writer.Write([]byte("OK")) - })) - defer s.Close() - - u, _ := url.Parse(s.URL) - - n := NewProxyNode([]*url.URL{u}) - defer n.Close() - - in := port.NewOut() - in.Link(n.In(node.PortIn)) - - proc := process.New() - defer proc.Exit(nil) - - inWriter := in.Open(proc) - - inPayload := types.NewMap( - types.NewString("method"), types.NewString(http.MethodGet), - types.NewString("scheme"), types.NewString("http"), - types.NewString("host"), types.NewString("test"), - types.NewString("path"), types.NewString("/"), - types.NewString("protocol"), types.NewString("HTTP/1.1"), - types.NewString("status"), types.NewInt(0), - ) - inPck := packet.New(inPayload) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - inWriter.Write(inPck) - <-inWriter.Receive() - } -} diff --git a/ext/pkg/network/websocket.go b/ext/pkg/network/websocket.go index 5214dec1..a0a89b85 100644 --- a/ext/pkg/network/websocket.go +++ b/ext/pkg/network/websocket.go @@ -1,7 +1,6 @@ package network import ( - "bytes" "context" "net/http" "net/textproto" @@ -202,31 +201,30 @@ func (n *WebSocketConnNode) consume(proc *process.Process) { } for inPck := range inReader.Read() { - var inPayload *WebSocketPayload - if err := types.Unmarshal(inPck.Payload(), &inPayload); err != nil { + var inPayload WebSocketPayload + _ = types.Unmarshal(inPck.Payload(), &inPayload) + if inPayload.Type == 0 && inPayload.Data == nil { inPayload.Data = inPck.Payload() - if _, ok := inPayload.Data.(types.Binary); !ok { + if kind := types.KindOf(inPayload.Data); kind != types.KindBuffer && kind != types.KindBinary { inPayload.Type = websocket.TextMessage } else { inPayload.Type = websocket.BinaryMessage } } - w := mime.WriterFunc(func(b []byte) (int, error) { - if err := conn.WriteMessage(inPayload.Type, b); err != nil { - return 0, err - } - return len(b), nil - }) - - if err := mime.Encode(w, inPayload.Data, textproto.MIMEHeader{}); err != nil { + w, err := conn.NextWriter(inPayload.Type) + if err != nil { errPck := packet.New(types.NewError(err)) - if errWriter.Write(errPck) > 0 { - <-errWriter.Receive() - } + inReader.Receive(packet.Send(errWriter, errPck)) + continue } - inReader.Receive(packet.None) + if err = mime.Encode(w, inPayload.Data, textproto.MIMEHeader{}); err != nil { + errPck := packet.New(types.NewError(err)) + inReader.Receive(packet.Send(errWriter, errPck)) + } else { + inReader.Receive(packet.None) + } } } @@ -237,7 +235,7 @@ func (n *WebSocketConnNode) produce(proc *process.Process) { } for { - typ, p, err := conn.ReadMessage() + typ, reader, err := conn.NextReader() if err != nil || typ == websocket.CloseMessage { outWriter := n.outPort.Open(proc) @@ -262,7 +260,7 @@ func (n *WebSocketConnNode) produce(proc *process.Process) { child := proc.Fork() outWriter := n.outPort.Open(child) - data, err := mime.Decode(bytes.NewReader(p), textproto.MIMEHeader{}) + data, err := mime.Decode(reader, textproto.MIMEHeader{}) if err != nil { data = types.NewString(err.Error()) } diff --git a/ext/pkg/network/websocket_test.go b/ext/pkg/network/websocket_test.go index a009e337..c5f081ae 100644 --- a/ext/pkg/network/websocket_test.go +++ b/ext/pkg/network/websocket_test.go @@ -95,6 +95,7 @@ func TestWebSocketNode_SendAndReceive(t *testing.T) { } outReader.Receive(packet.None) + select { case <-done: default: diff --git a/pkg/resource/store.go b/pkg/resource/store.go index 84b8bbe1..05c43297 100644 --- a/pkg/resource/store.go +++ b/pkg/resource/store.go @@ -42,8 +42,8 @@ type Stream interface { // Event represents a change event for a Resource. type Event struct { - OP EventOP // Operation type (Store, Swap, Delete) - ID uuid.UUID // ID of the changed Resource + ID uuid.UUID `json:"id" map:"id"` + OP EventOP `json:"op" map:"op"` } // EventOP represents the type of operation that triggered an Event. diff --git a/pkg/types/binary.go b/pkg/types/binary.go index 2ffe9de0..b1476c08 100644 --- a/pkg/types/binary.go +++ b/pkg/types/binary.go @@ -179,7 +179,7 @@ func newBinaryDecoder() encoding2.DecodeCompiler[Value] { } return errors.WithStack(encoding2.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding2.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Binary); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/binary_test.go b/pkg/types/binary_test.go index 014b64d6..5fa897a4 100644 --- a/pkg/types/binary_test.go +++ b/pkg/types/binary_test.go @@ -142,7 +142,6 @@ func TestBinary_Decode(t *testing.T) { d, err := base64.StdEncoding.DecodeString(decoded) assert.NoError(t, err) - assert.Equal(t, source, d) }) diff --git a/pkg/types/boolean.go b/pkg/types/boolean.go index f1925724..79360e91 100644 --- a/pkg/types/boolean.go +++ b/pkg/types/boolean.go @@ -113,7 +113,7 @@ func newBooleanDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Boolean); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/buffer.go b/pkg/types/buffer.go new file mode 100644 index 00000000..421c14b1 --- /dev/null +++ b/pkg/types/buffer.go @@ -0,0 +1,169 @@ +package types + +import ( + "encoding/base64" + "io" + "reflect" + "unsafe" + + "github.com/pkg/errors" + "github.com/siyul-park/uniflow/pkg/encoding" +) + +// Buffer is a representation of a io.Reader value. +type Buffer = *_buffer + +type _buffer struct { + value io.Reader +} + +var _ Value = (Buffer)(nil) +var _ io.Reader = (Buffer)(nil) + +// NewBuffer creates a new Buffer instance. +func NewBuffer(value io.Reader) Buffer { + return &_buffer{value: value} +} + +// Read reads data from the buffer into p. +func (b Buffer) Read(p []byte) (n int, err error) { + return b.value.Read(p) +} + +// Bytes returns the raw byte slice. +func (b Buffer) Bytes() ([]byte, error) { + bytes, err := io.ReadAll(b.value) + if err != nil { + return nil, err + } + closer, ok := b.value.(io.Closer) + if ok { + if err := closer.Close(); err != nil { + return nil, err + } + } + return bytes, nil +} + +// Close closes the buffer. +func (b Buffer) Close() error { + if closer, ok := b.value.(io.Closer); ok { + return closer.Close() + } + return nil +} + +// Kind returns the kind of the buffer. +func (b Buffer) Kind() Kind { + return KindBuffer +} + +// Hash returns a hash value for the buffer. +func (b Buffer) Hash() uint64 { + return uint64(uintptr(unsafe.Pointer(b))) +} + +// Interface returns the underlying io.Reader. +func (b Buffer) Interface() any { + return b.value +} + +// Equal checks if the buffer is equal to another Value. +func (b Buffer) Equal(other Value) bool { + if o, ok := other.(Buffer); ok { + return b == o + } + return false +} + +// Compare compares the buffer with another Value. +func (b Buffer) Compare(other Value) int { + if o, ok := other.(Buffer); ok { + return compare(b.Hash(), o.Hash()) + } + return compare(b.Kind(), KindOf(other)) +} + +func newBufferEncoder() encoding.EncodeCompiler[any, Value] { + typeReader := reflect.TypeOf((*io.Reader)(nil)).Elem() + + return encoding.EncodeCompilerFunc[any, Value](func(typ reflect.Type) (encoding.Encoder[any, Value], error) { + if typ == nil { + return nil, errors.WithStack(encoding.ErrUnsupportedType) + } else if typ.ConvertibleTo(typeReader) { + return encoding.EncodeFunc(func(source any) (Value, error) { + s := source.(io.Reader) + return NewBuffer(s), nil + }), nil + } + return nil, errors.WithStack(encoding.ErrUnsupportedType) + }) +} + +func newBufferDecoder() encoding.DecodeCompiler[Value] { + typeReader := reflect.TypeOf((*io.Reader)(nil)).Elem() + + return encoding.DecodeCompilerFunc[Value](func(typ reflect.Type) (encoding.Decoder[Value, unsafe.Pointer], error) { + if typ == nil { + return nil, errors.WithStack(encoding.ErrUnsupportedType) + } else if typ.Kind() == reflect.Pointer { + if typ.Elem().ConvertibleTo(typeReader) { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + t := reflect.NewAt(typ.Elem(), target) + t.Elem().Set(reflect.ValueOf(s.Interface())) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.Slice && typ.Elem().Elem().Kind() == reflect.Uint8 { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := s.Bytes() + if err != nil { + return err + } + t := reflect.NewAt(typ.Elem(), target).Elem() + t.Set(reflect.AppendSlice(t, reflect.ValueOf(bytes).Convert(t.Type()))) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.Array && typ.Elem().Elem().Kind() == reflect.Uint8 { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := s.Bytes() + if err != nil { + return err + } + t := reflect.NewAt(typ.Elem(), target).Elem() + reflect.Copy(t, reflect.ValueOf(bytes).Convert(t.Type())) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.String { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := io.ReadAll(s) + if err != nil { + return err + } + *(*string)(target) = base64.StdEncoding.EncodeToString(bytes) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem() == types[KindUnknown] { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + *(*any)(target) = s.Interface() + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } + } + return nil, errors.WithStack(encoding.ErrUnsupportedType) + }) +} diff --git a/pkg/types/buffer_test.go b/pkg/types/buffer_test.go new file mode 100644 index 00000000..346a9879 --- /dev/null +++ b/pkg/types/buffer_test.go @@ -0,0 +1,145 @@ +package types + +import ( + "encoding/base64" + "io" + "strings" + "testing" + + "github.com/siyul-park/uniflow/pkg/encoding" + "github.com/stretchr/testify/assert" +) + +func TestBuffer_Read(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + p := make([]byte, 4) + n, err := b.Read(p) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "test", string(p)) +} + +func TestBuffer_Bytes(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + p, err := b.Bytes() + assert.NoError(t, err) + assert.Equal(t, "test", string(p)) +} + +func TestBuffer_Close(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + assert.NoError(t, b.Close()) +} + +func TestBuffer_Kind(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + assert.Equal(t, KindBuffer, b.Kind()) +} + +func TestBuffer_Hash(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.NotEqual(t, b1.Hash(), b2.Hash()) +} + +func TestBuffer_Interface(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + assert.Equal(t, r, b.Interface()) +} + +func TestBuffer_Equal(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.True(t, b1.Equal(b1)) + assert.False(t, b1.Equal(b2)) +} + +func TestBuffer_Compare(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.Equal(t, 0, b1.Compare(b1)) + assert.NotEqual(t, 0, b1.Compare(b2)) +} + +func TestBuffer_Encode(t *testing.T) { + enc := encoding.NewEncodeAssembler[any, Value]() + enc.Add(newBufferEncoder()) + + t.Run("io.Reader", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + decoded, err := enc.Encode(source) + assert.NoError(t, err) + assert.Equal(t, v, decoded) + }) +} + +func TestBuffer_Decode(t *testing.T) { + dec := encoding.NewDecodeAssembler[Value, any]() + dec.Add(newBufferDecoder()) + + t.Run("io.Reader", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded io.Reader + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) + + t.Run("slice", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded []byte + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, []byte("test"), decoded) + }) + + t.Run("array", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded [3]byte + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.EqualValues(t, []byte("test"), decoded) + }) + + t.Run("string", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded string + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + + d, err := base64.StdEncoding.DecodeString(decoded) + assert.NoError(t, err) + assert.Equal(t, []byte("test"), d) + }) + + t.Run("any", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded any + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) +} diff --git a/pkg/types/compare.go b/pkg/types/compare.go deleted file mode 100644 index 3f76411a..00000000 --- a/pkg/types/compare.go +++ /dev/null @@ -1,18 +0,0 @@ -package types - -type ordered interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64 | ~string -} - -func compare[T ordered](x, y T) int { - if x == y { - return 0 - } - if x > y { - return 1 - } - if x < y { - return -1 - } - return 0 -} diff --git a/pkg/types/encoding.go b/pkg/types/encoding.go index aabcc440..c1094e08 100644 --- a/pkg/types/encoding.go +++ b/pkg/types/encoding.go @@ -33,6 +33,7 @@ func init() { Encoder.Add(newIntegerEncoder()) Encoder.Add(newFloatEncoder()) Encoder.Add(newBooleanEncoder()) + Encoder.Add(newBufferEncoder()) Encoder.Add(newBinaryEncoder()) Encoder.Add(newStringEncoder()) Encoder.Add(newErrorEncoder()) @@ -48,6 +49,7 @@ func init() { Decoder.Add(newIntegerDecoder()) Decoder.Add(newFloatDecoder()) Decoder.Add(newBooleanDecoder()) + Decoder.Add(newBufferDecoder()) Decoder.Add(newBinaryDecoder()) Decoder.Add(newStringDecoder()) Decoder.Add(newErrorDecoder()) diff --git a/pkg/types/encoding_test.go b/pkg/types/encoding_test.go index 9585d5c4..d21bd634 100644 --- a/pkg/types/encoding_test.go +++ b/pkg/types/encoding_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strings" "testing" "time" @@ -26,6 +27,10 @@ func TestMarshal(t *testing.T) { when: []byte{0}, expect: NewBinary([]byte{0}), }, + { + when: strings.NewReader("test"), + expect: NewBuffer(strings.NewReader("test")), + }, { when: true, expect: True, @@ -83,8 +88,12 @@ func TestUnmarshal(t *testing.T) { expect any }{ { - expect: []byte{0}, when: NewBinary([]byte{0}), + expect: []byte{0}, + }, + { + when: NewBuffer(strings.NewReader("test")), + expect: strings.NewReader("test"), }, { when: True, diff --git a/pkg/types/error.go b/pkg/types/error.go index 28289ac2..4a7556a1 100644 --- a/pkg/types/error.go +++ b/pkg/types/error.go @@ -95,7 +95,7 @@ func newErrorDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Error); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/float.go b/pkg/types/float.go index 930abfc7..e6c2a577 100644 --- a/pkg/types/float.go +++ b/pkg/types/float.go @@ -176,7 +176,7 @@ func newFloatDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Float); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/getter.go b/pkg/types/getter.go deleted file mode 100644 index ef8d3916..00000000 --- a/pkg/types/getter.go +++ /dev/null @@ -1,41 +0,0 @@ -package types - -// Get extracts a value from a nested structure using the provided paths. -func Get[T any](obj Value, paths ...any) (T, bool) { - var val T - cur := obj - for _, path := range paths { - p, err := Marshal(path) - if err != nil { - return val, false - } - - switch p := p.(type) { - case String: - if v, ok := cur.(Map); ok { - child := v.Get(p) - if child == nil { - return val, false - } - cur = child - } - case Integer: - if v, ok := cur.(Slice); ok { - if int(p.Int()) >= v.Len() { - return val, false - } - cur = v.Get(int(p.Int())) - } - default: - return val, false - } - } - - if cur == nil { - return val, false - } - if v, ok := cur.(T); ok { - return v, true - } - return val, Unmarshal(cur, &val) == nil -} diff --git a/pkg/types/integer.go b/pkg/types/integer.go index d8673bc7..04a1eebd 100644 --- a/pkg/types/integer.go +++ b/pkg/types/integer.go @@ -347,7 +347,7 @@ func newIntegerDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Integer); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/map.go b/pkg/types/map.go index 2d006eab..f30cdfca 100644 --- a/pkg/types/map.go +++ b/pkg/types/map.go @@ -288,10 +288,10 @@ func (m *immutableMap) Interface() any { } } if keyType == nil { - keyType = types[KindInvalid] + keyType = types[KindUnknown] } if valueType == nil { - valueType = types[KindInvalid] + valueType = types[KindUnknown] } if keyType.Kind() == reflect.Interface || keyType.Kind() == reflect.Map || keyType.Kind() == reflect.Slice { @@ -725,7 +725,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod } return nil }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Map); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/slice.go b/pkg/types/slice.go index 70ee997e..aa61d747 100644 --- a/pkg/types/slice.go +++ b/pkg/types/slice.go @@ -150,7 +150,7 @@ func (s Slice) Interface() any { elementType = unionType(elementType, TypeOf(KindOf(element))) } if elementType == nil { - elementType = types[KindInvalid] + elementType = types[KindUnknown] } t := reflect.MakeSlice(reflect.SliceOf(elementType), len(s.value), len(s.value)) @@ -263,7 +263,7 @@ func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Dec } return valueDecoder.Decode(source, t.Index(0).Addr().UnsafePointer()) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Slice); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/string.go b/pkg/types/string.go index 691ca99b..14dd9440 100644 --- a/pkg/types/string.go +++ b/pkg/types/string.go @@ -297,7 +297,7 @@ func newStringDecoder() encoding2.DecodeCompiler[Value] { } return errors.WithStack(encoding2.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding2.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(String); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/uinteger.go b/pkg/types/uinteger.go index 4d58fc92..665d6b0d 100644 --- a/pkg/types/uinteger.go +++ b/pkg/types/uinteger.go @@ -347,7 +347,7 @@ func newUintegerDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Uinteger); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/value.go b/pkg/types/value.go index e11be2b8..1b36b664 100644 --- a/pkg/types/value.go +++ b/pkg/types/value.go @@ -1,6 +1,9 @@ package types -import "reflect" +import ( + "io" + "reflect" +) // Value is an interface representing atomic data types. type Value interface { @@ -14,10 +17,15 @@ type Value interface { // Kind represents enumerated data types. type Kind byte +type ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64 | ~string +} + // Constants representing various data types. const ( - KindInvalid Kind = iota + KindUnknown Kind = iota KindBinary + KindBuffer KindBoolean KindError KindInt @@ -38,8 +46,9 @@ const ( ) var types = map[Kind]reflect.Type{ - KindInvalid: reflect.TypeOf((*any)(nil)).Elem(), + KindUnknown: reflect.TypeOf((*any)(nil)).Elem(), KindBinary: reflect.TypeOf([]byte(nil)), + KindBuffer: reflect.TypeOf((*io.Reader)(nil)).Elem(), KindBoolean: reflect.TypeOf(false), KindError: reflect.TypeOf((*error)(nil)).Elem(), KindInt: reflect.TypeOf(0), @@ -62,7 +71,7 @@ var types = map[Kind]reflect.Type{ // KindOf returns the kind of the provided Value. func KindOf(v Value) Kind { if v == nil { - return KindInvalid + return KindUnknown } return v.Kind() } @@ -113,6 +122,59 @@ func Compare(x, y Value) int { return x.Compare(y) } +// Get extracts a value from a nested structure using the provided paths. +func Get[T any](obj Value, paths ...any) (T, bool) { + var val T + cur := obj + for _, path := range paths { + p, err := Marshal(path) + if err != nil { + return val, false + } + + switch p := p.(type) { + case String: + if v, ok := cur.(Map); ok { + child := v.Get(p) + if child == nil { + return val, false + } + cur = child + } + case Integer: + if v, ok := cur.(Slice); ok { + if int(p.Int()) >= v.Len() { + return val, false + } + cur = v.Get(int(p.Int())) + } + default: + return val, false + } + } + + if cur == nil { + return val, false + } + if v, ok := cur.(T); ok { + return v, true + } + return val, Unmarshal(cur, &val) == nil +} + +func compare[T ordered](x, y T) int { + if x == y { + return 0 + } + if x > y { + return 1 + } + if x < y { + return -1 + } + return 0 +} + func unionType(x, y reflect.Type) reflect.Type { if x == nil { return y @@ -121,5 +183,5 @@ func unionType(x, y reflect.Type) reflect.Type { } else if x == y { return x } - return types[KindInvalid] + return types[KindUnknown] }