diff --git a/examples/loopback.yaml b/examples/loopback.yaml new file mode 100644 index 00000000..bbdc9253 --- /dev/null +++ b/examples/loopback.yaml @@ -0,0 +1,13 @@ +- kind: listener + name: listener + protocol: http + port: '{{ .PORT }}' + ports: + out: + - name: loopback + port: in + +- kind: snippet + name: loopback + language: cel + code: self diff --git a/examples/system.yaml b/examples/system.yaml index 11dd85d8..a8659c29 100644 --- a/examples/system.yaml +++ b/examples/system.yaml @@ -154,7 +154,7 @@ - kind: if name: specs_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: specs_read_with_query @@ -357,7 +357,7 @@ - kind: if name: values_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: values_read_with_query @@ -557,7 +557,7 @@ - kind: if name: charts_read_or_watch - when: '!has(self.query.watch) || !self.query.watch.all(x, x == "true") || !has(self.header.Connection) || !has(self.header.Upgrade)' + when: '!has(self.header.Connection) || !has(self.header.Upgrade)' ports: out[0]: - name: charts_read_with_query diff --git a/ext/pkg/mime/encoding.go b/ext/pkg/mime/encoding.go index 7b621838..5a1b150e 100644 --- a/ext/pkg/mime/encoding.go +++ b/ext/pkg/mime/encoding.go @@ -47,30 +47,29 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er } count := 0 - var cwriter io.Writer = WriterFunc(func(p []byte) (n int, err error) { + var counter io.Writer = WriterFunc(func(p []byte) (n int, err error) { n, err = writer.Write(p) count += n return }) - w, err := Compress(cwriter, encode) + w, err := Compress(counter, encode) if err != nil { return err } - flush := func() { - if c, ok := w.(io.Closer); ok && w != cwriter { - c.Close() + defer func() { + if c, ok := w.(io.Closer); ok { + _ = c.Close() } header.Set(HeaderContentLength, strconv.Itoa(count)) - } + }() switch typ { case ApplicationJSON: if err := json.NewEncoder(w).Encode(types.InterfaceOf(value)); err != nil { return err } - flush() return nil case ApplicationFormURLEncoded: urlValues := url.Values{} @@ -80,7 +79,6 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er if _, err := w.Write([]byte(urlValues.Encode())); err != nil { return err } - flush() return nil case MultipartFormData: boundary := params["boundary"] @@ -214,22 +212,24 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er if err := mw.Close(); err != nil { return err } - flush() return nil } switch v := value.(type) { + case types.Buffer: + if _, err := io.Copy(w, v); err != nil { + return err + } + return nil case types.Binary: if _, err := w.Write(v.Bytes()); err != nil { return err } - flush() return nil case types.String: if _, err := w.Write([]byte(v.String())); err != nil { return err } - flush() return nil default: return errors.WithStack(encoding.ErrUnsupportedType) @@ -327,11 +327,7 @@ func Decode(reader io.Reader, header textproto.MIMEHeader) (types.Value, error) }) } - data, err := io.ReadAll(r) - if err != nil { - return nil, err - } - return types.NewBinary(data), nil + return types.NewBuffer(r), nil } func randomMultipartBoundary() string { diff --git a/ext/pkg/mime/encoding_test.go b/ext/pkg/mime/encoding_test.go index 10ab93a7..b2e49c44 100644 --- a/ext/pkg/mime/encoding_test.go +++ b/ext/pkg/mime/encoding_test.go @@ -171,7 +171,7 @@ func TestDecode(t *testing.T) { { whenValue: []byte("--MyBoundary\r\n" + "Content-Disposition: form-data; name=\"test\"; filename=\"test\"\r\n" + - "Content-Type: application/octet-stream\r\n" + + "Content-Type: text/plain\r\n" + "\r\n" + "test\r\n" + "--MyBoundary\r\n" + @@ -186,11 +186,11 @@ func TestDecode(t *testing.T) { ), types.NewString("files"), types.NewMap( types.NewString("test"), types.NewSlice(types.NewMap( - types.NewString("data"), types.NewBinary([]byte("test")), + types.NewString("data"), types.NewString("test"), types.NewString("filename"), types.NewString("test"), types.NewString("header"), types.NewMap( types.NewString("Content-Disposition"), types.NewSlice(types.NewString("form-data; name=\"test\"; filename=\"test\"")), - types.NewString("Content-Type"), types.NewSlice(types.NewString("application/octet-stream")), + types.NewString("Content-Type"), types.NewSlice(types.NewString("text/plain")), ), types.NewString("size"), types.NewInt64(4), )), @@ -200,7 +200,7 @@ func TestDecode(t *testing.T) { { whenValue: []byte("testtesttest"), whenType: ApplicationOctetStream, - expect: types.NewBinary([]byte("testtesttest")), + expect: types.NewBuffer(bytes.NewBuffer([]byte("testtesttest"))), }, } @@ -210,7 +210,12 @@ func TestDecode(t *testing.T) { HeaderContentType: []string{tt.whenType}, }) assert.NoError(t, err) - assert.Equal(t, tt.expect.Interface(), decode.Interface()) + + var expect any + var actual any + _ = types.Unmarshal(tt.expect, &expect) + _ = types.Unmarshal(decode, &actual) + assert.Equal(t, expect, actual) }) } } diff --git a/ext/pkg/mime/type.go b/ext/pkg/mime/type.go index 0314d857..bdb319e4 100644 --- a/ext/pkg/mime/type.go +++ b/ext/pkg/mime/type.go @@ -55,16 +55,16 @@ func DetectTypesFromBytes(value []byte) []string { // DetectTypesFromValue determines the content types based on the type of types passed. func DetectTypesFromValue(value types.Value) []string { switch value.(type) { - case types.Binary: + case types.Binary, types.Buffer: return []string{ApplicationOctetStream} case types.String: - return []string{TextPlainCharsetUTF8, ApplicationJSONCharsetUTF8} + return []string{TextPlainCharsetUTF8, ApplicationOctetStream, ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} case types.Slice: - return []string{ApplicationJSONCharsetUTF8} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded} case types.Map, types.Error: - return []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData} default: - return []string{ApplicationJSONCharsetUTF8} + return []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, TextPlainCharsetUTF8, ApplicationOctetStream} } } diff --git a/ext/pkg/mime/type_test.go b/ext/pkg/mime/type_test.go index 6dcacf2c..137a0777 100644 --- a/ext/pkg/mime/type_test.go +++ b/ext/pkg/mime/type_test.go @@ -53,21 +53,25 @@ func TestDetectTypesFromValue(t *testing.T) { when: types.NewBinary(nil), expect: []string{ApplicationOctetStream}, }, + { + when: types.NewBuffer(nil), + expect: []string{ApplicationOctetStream}, + }, { when: types.NewString(""), - expect: []string{TextPlainCharsetUTF8, ApplicationJSONCharsetUTF8}, + expect: []string{TextPlainCharsetUTF8, ApplicationOctetStream, ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, { when: types.NewSlice(), - expect: []string{ApplicationJSONCharsetUTF8}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded}, }, { when: types.NewMap(), - expect: []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, { when: types.NewError(nil), - expect: []string{ApplicationJSONCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, + expect: []string{ApplicationJSONCharsetUTF8, ApplicationXMLCharsetUTF8, ApplicationFormURLEncoded, MultipartFormData}, }, } diff --git a/ext/pkg/network/http.go b/ext/pkg/network/http.go index 9d273a3b..998d41f2 100644 --- a/ext/pkg/network/http.go +++ b/ext/pkg/network/http.go @@ -149,7 +149,10 @@ func (n *HTTPNode) action(proc *process.Process, inPck *packet.Packet) (*packet. if err != nil { return nil, packet.New(types.NewError(err)) } - defer w.Body.Close() + + proc.AddExitHook(process.ExitFunc(func(err error) { + _ = w.Body.Close() + })) body, err := mime.Decode(w.Body, textproto.MIMEHeader(w.Header)) if err != nil { diff --git a/ext/pkg/network/listener.go b/ext/pkg/network/listener.go index 9a8a4684..5d49688f 100644 --- a/ext/pkg/network/listener.go +++ b/ext/pkg/network/listener.go @@ -1,10 +1,8 @@ package network import ( - "bytes" "crypto/tls" "fmt" - "io" "net" "net/http" "net/textproto" @@ -261,6 +259,9 @@ func (n *HTTPListenNode) negotiate(req *HTTPPayload, res *HTTPPayload) { accept := req.Header.Get(mime.HeaderAccept) offers := mime.DetectTypesFromValue(res.Body) contentType := mime.Negotiate(accept, offers) + if contentType == "" && len(offers) > 0 { + contentType = offers[0] + } if contentType != "" { res.Header.Set(mime.HeaderContentType, contentType) } @@ -302,17 +303,11 @@ func (n *HTTPListenNode) write(w http.ResponseWriter, res *HTTPPayload) error { } } - buf := bytes.NewBuffer(nil) - if err := mime.Encode(buf, res.Body, textproto.MIMEHeader(h)); err != nil { - return err - } - status := res.Status if status == 0 { status = http.StatusOK } - w.WriteHeader(status) - _, err := io.Copy(w, buf) - return err + w.WriteHeader(status) + return mime.Encode(w, res.Body, textproto.MIMEHeader(h)) } diff --git a/ext/pkg/network/websocket.go b/ext/pkg/network/websocket.go index 5214dec1..d48adeea 100644 --- a/ext/pkg/network/websocket.go +++ b/ext/pkg/network/websocket.go @@ -1,7 +1,6 @@ package network import ( - "bytes" "context" "net/http" "net/textproto" @@ -202,31 +201,32 @@ func (n *WebSocketConnNode) consume(proc *process.Process) { } for inPck := range inReader.Read() { - var inPayload *WebSocketPayload - if err := types.Unmarshal(inPck.Payload(), &inPayload); err != nil { + var inPayload WebSocketPayload + _ = types.Unmarshal(inPck.Payload(), &inPayload) + if inPayload.Type == 0 && inPayload.Data == nil { inPayload.Data = inPck.Payload() - if _, ok := inPayload.Data.(types.Binary); !ok { + if kind := types.KindOf(inPayload.Data); kind != types.KindBuffer && kind != types.KindBinary { inPayload.Type = websocket.TextMessage } else { inPayload.Type = websocket.BinaryMessage } } - w := mime.WriterFunc(func(b []byte) (int, error) { - if err := conn.WriteMessage(inPayload.Type, b); err != nil { - return 0, err - } - return len(b), nil - }) - - if err := mime.Encode(w, inPayload.Data, textproto.MIMEHeader{}); err != nil { + w, err := conn.NextWriter(inPayload.Type) + if err != nil { errPck := packet.New(types.NewError(err)) - if errWriter.Write(errPck) > 0 { - <-errWriter.Receive() - } + inReader.Receive(packet.Send(errWriter, errPck)) + continue } - inReader.Receive(packet.None) + if err = mime.Encode(w, inPayload.Data, textproto.MIMEHeader{}); err != nil { + _ = w.Close() + errPck := packet.New(types.NewError(err)) + inReader.Receive(packet.Send(errWriter, errPck)) + } else { + _ = w.Close() + inReader.Receive(packet.None) + } } } @@ -237,7 +237,7 @@ func (n *WebSocketConnNode) produce(proc *process.Process) { } for { - typ, p, err := conn.ReadMessage() + typ, reader, err := conn.NextReader() if err != nil || typ == websocket.CloseMessage { outWriter := n.outPort.Open(proc) @@ -262,7 +262,7 @@ func (n *WebSocketConnNode) produce(proc *process.Process) { child := proc.Fork() outWriter := n.outPort.Open(child) - data, err := mime.Decode(bytes.NewReader(p), textproto.MIMEHeader{}) + data, err := mime.Decode(reader, textproto.MIMEHeader{}) if err != nil { data = types.NewString(err.Error()) } diff --git a/ext/pkg/network/websocket_test.go b/ext/pkg/network/websocket_test.go index a009e337..c5f081ae 100644 --- a/ext/pkg/network/websocket_test.go +++ b/ext/pkg/network/websocket_test.go @@ -95,6 +95,7 @@ func TestWebSocketNode_SendAndReceive(t *testing.T) { } outReader.Receive(packet.None) + select { case <-done: default: diff --git a/pkg/resource/store.go b/pkg/resource/store.go index 84b8bbe1..05c43297 100644 --- a/pkg/resource/store.go +++ b/pkg/resource/store.go @@ -42,8 +42,8 @@ type Stream interface { // Event represents a change event for a Resource. type Event struct { - OP EventOP // Operation type (Store, Swap, Delete) - ID uuid.UUID // ID of the changed Resource + ID uuid.UUID `json:"id" map:"id"` + OP EventOP `json:"op" map:"op"` } // EventOP represents the type of operation that triggered an Event. diff --git a/pkg/types/binary.go b/pkg/types/binary.go index 2ffe9de0..b1476c08 100644 --- a/pkg/types/binary.go +++ b/pkg/types/binary.go @@ -179,7 +179,7 @@ func newBinaryDecoder() encoding2.DecodeCompiler[Value] { } return errors.WithStack(encoding2.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding2.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Binary); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/binary_test.go b/pkg/types/binary_test.go index 014b64d6..5fa897a4 100644 --- a/pkg/types/binary_test.go +++ b/pkg/types/binary_test.go @@ -142,7 +142,6 @@ func TestBinary_Decode(t *testing.T) { d, err := base64.StdEncoding.DecodeString(decoded) assert.NoError(t, err) - assert.Equal(t, source, d) }) diff --git a/pkg/types/boolean.go b/pkg/types/boolean.go index f1925724..79360e91 100644 --- a/pkg/types/boolean.go +++ b/pkg/types/boolean.go @@ -113,7 +113,7 @@ func newBooleanDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Boolean); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/buffer.go b/pkg/types/buffer.go new file mode 100644 index 00000000..de631e08 --- /dev/null +++ b/pkg/types/buffer.go @@ -0,0 +1,161 @@ +package types + +import ( + "encoding/base64" + "io" + "reflect" + "unsafe" + + "github.com/pkg/errors" + "github.com/siyul-park/uniflow/pkg/encoding" +) + +// Buffer is a representation of a io.Reader value. +type Buffer = *_buffer + +type _buffer struct { + value io.Reader +} + +var _ Value = (Buffer)(nil) +var _ io.Reader = (Buffer)(nil) + +// NewBuffer creates a new Buffer instance. +func NewBuffer(value io.Reader) Buffer { + return &_buffer{value: value} +} + +// Read reads data from the buffer into p. +func (b Buffer) Read(p []byte) (n int, err error) { + return b.value.Read(p) +} + +// Bytes returns the raw byte slice. +func (b Buffer) Bytes() ([]byte, error) { + bytes, err := io.ReadAll(b.value) + if err != nil { + return nil, err + } + closer, ok := b.value.(io.Closer) + if ok { + if err := closer.Close(); err != nil { + return nil, err + } + } + return bytes, nil +} + +// Kind returns the kind of the buffer. +func (b Buffer) Kind() Kind { + return KindBuffer +} + +// Hash returns a hash value for the buffer. +func (b Buffer) Hash() uint64 { + return uint64(uintptr(unsafe.Pointer(b))) +} + +// Interface returns the underlying io.Reader. +func (b Buffer) Interface() any { + return b.value +} + +// Equal checks if the buffer is equal to another Value. +func (b Buffer) Equal(other Value) bool { + if o, ok := other.(Buffer); ok { + return b == o + } + return false +} + +// Compare compares the buffer with another Value. +func (b Buffer) Compare(other Value) int { + if o, ok := other.(Buffer); ok { + return compare(b.Hash(), o.Hash()) + } + return compare(b.Kind(), KindOf(other)) +} + +func newBufferEncoder() encoding.EncodeCompiler[any, Value] { + typeReader := reflect.TypeOf((*io.Reader)(nil)).Elem() + + return encoding.EncodeCompilerFunc[any, Value](func(typ reflect.Type) (encoding.Encoder[any, Value], error) { + if typ == nil { + return nil, errors.WithStack(encoding.ErrUnsupportedType) + } else if typ.ConvertibleTo(typeReader) { + return encoding.EncodeFunc(func(source any) (Value, error) { + s := source.(io.Reader) + return NewBuffer(s), nil + }), nil + } + return nil, errors.WithStack(encoding.ErrUnsupportedType) + }) +} + +func newBufferDecoder() encoding.DecodeCompiler[Value] { + typeReader := reflect.TypeOf((*io.Reader)(nil)).Elem() + + return encoding.DecodeCompilerFunc[Value](func(typ reflect.Type) (encoding.Decoder[Value, unsafe.Pointer], error) { + if typ == nil { + return nil, errors.WithStack(encoding.ErrUnsupportedType) + } else if typ.Kind() == reflect.Pointer { + if typ.Elem().ConvertibleTo(typeReader) { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + t := reflect.NewAt(typ.Elem(), target) + t.Elem().Set(reflect.ValueOf(s.Interface())) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.Slice && typ.Elem().Elem().Kind() == reflect.Uint8 { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := s.Bytes() + if err != nil { + return err + } + t := reflect.NewAt(typ.Elem(), target).Elem() + t.Set(reflect.AppendSlice(t, reflect.ValueOf(bytes).Convert(t.Type()))) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.Array && typ.Elem().Elem().Kind() == reflect.Uint8 { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := s.Bytes() + if err != nil { + return err + } + t := reflect.NewAt(typ.Elem(), target).Elem() + reflect.Copy(t, reflect.ValueOf(bytes).Convert(t.Type())) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem().Kind() == reflect.String { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + bytes, err := io.ReadAll(s) + if err != nil { + return err + } + *(*string)(target) = base64.StdEncoding.EncodeToString(bytes) + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } else if typ.Elem() == types[KindUnknown] { + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { + if s, ok := source.(Buffer); ok { + *(*any)(target) = s.Interface() + return nil + } + return errors.WithStack(encoding.ErrUnsupportedType) + }), nil + } + } + return nil, errors.WithStack(encoding.ErrUnsupportedType) + }) +} diff --git a/pkg/types/buffer_test.go b/pkg/types/buffer_test.go new file mode 100644 index 00000000..a32d6dea --- /dev/null +++ b/pkg/types/buffer_test.go @@ -0,0 +1,131 @@ +package types + +import ( + "encoding/base64" + "io" + "strings" + "testing" + + "github.com/siyul-park/uniflow/pkg/encoding" + "github.com/stretchr/testify/assert" +) + +func TestBuffer_Read(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + p := make([]byte, 4) + n, err := b.Read(p) + assert.NoError(t, err) + assert.Equal(t, 4, n) + assert.Equal(t, "test", string(p)) +} + +func TestBuffer_Kind(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + assert.Equal(t, KindBuffer, b.Kind()) +} + +func TestBuffer_Hash(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.NotEqual(t, b1.Hash(), b2.Hash()) +} + +func TestBuffer_Interface(t *testing.T) { + r := strings.NewReader("test") + b := NewBuffer(r) + assert.Equal(t, r, b.Interface()) +} + +func TestBuffer_Equal(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.True(t, b1.Equal(b1)) + assert.False(t, b1.Equal(b2)) +} + +func TestBuffer_Compare(t *testing.T) { + r1 := strings.NewReader("test1") + r2 := strings.NewReader("test2") + b1 := NewBuffer(r1) + b2 := NewBuffer(r2) + assert.Equal(t, 0, b1.Compare(b1)) + assert.NotEqual(t, 0, b1.Compare(b2)) +} + +func TestBuffer_Encode(t *testing.T) { + enc := encoding.NewEncodeAssembler[any, Value]() + enc.Add(newBufferEncoder()) + + t.Run("io.Reader", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + decoded, err := enc.Encode(source) + assert.NoError(t, err) + assert.Equal(t, v, decoded) + }) +} + +func TestBuffer_Decode(t *testing.T) { + dec := encoding.NewDecodeAssembler[Value, any]() + dec.Add(newBufferDecoder()) + + t.Run("io.Reader", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded io.Reader + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) + + t.Run("slice", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded []byte + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, []byte("test"), decoded) + }) + + t.Run("array", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded [3]byte + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.EqualValues(t, []byte("test"), decoded) + }) + + t.Run("string", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded string + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + + d, err := base64.StdEncoding.DecodeString(decoded) + assert.NoError(t, err) + assert.Equal(t, []byte("test"), d) + }) + + t.Run("any", func(t *testing.T) { + source := strings.NewReader("test") + v := NewBuffer(source) + + var decoded any + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) +} diff --git a/pkg/types/compare.go b/pkg/types/compare.go deleted file mode 100644 index 3f76411a..00000000 --- a/pkg/types/compare.go +++ /dev/null @@ -1,18 +0,0 @@ -package types - -type ordered interface { - ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64 | ~string -} - -func compare[T ordered](x, y T) int { - if x == y { - return 0 - } - if x > y { - return 1 - } - if x < y { - return -1 - } - return 0 -} diff --git a/pkg/types/encoding.go b/pkg/types/encoding.go index aabcc440..c1094e08 100644 --- a/pkg/types/encoding.go +++ b/pkg/types/encoding.go @@ -33,6 +33,7 @@ func init() { Encoder.Add(newIntegerEncoder()) Encoder.Add(newFloatEncoder()) Encoder.Add(newBooleanEncoder()) + Encoder.Add(newBufferEncoder()) Encoder.Add(newBinaryEncoder()) Encoder.Add(newStringEncoder()) Encoder.Add(newErrorEncoder()) @@ -48,6 +49,7 @@ func init() { Decoder.Add(newIntegerDecoder()) Decoder.Add(newFloatDecoder()) Decoder.Add(newBooleanDecoder()) + Decoder.Add(newBufferDecoder()) Decoder.Add(newBinaryDecoder()) Decoder.Add(newStringDecoder()) Decoder.Add(newErrorDecoder()) diff --git a/pkg/types/encoding_test.go b/pkg/types/encoding_test.go index 9585d5c4..d21bd634 100644 --- a/pkg/types/encoding_test.go +++ b/pkg/types/encoding_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "strings" "testing" "time" @@ -26,6 +27,10 @@ func TestMarshal(t *testing.T) { when: []byte{0}, expect: NewBinary([]byte{0}), }, + { + when: strings.NewReader("test"), + expect: NewBuffer(strings.NewReader("test")), + }, { when: true, expect: True, @@ -83,8 +88,12 @@ func TestUnmarshal(t *testing.T) { expect any }{ { - expect: []byte{0}, when: NewBinary([]byte{0}), + expect: []byte{0}, + }, + { + when: NewBuffer(strings.NewReader("test")), + expect: strings.NewReader("test"), }, { when: True, diff --git a/pkg/types/error.go b/pkg/types/error.go index 28289ac2..4a7556a1 100644 --- a/pkg/types/error.go +++ b/pkg/types/error.go @@ -95,7 +95,7 @@ func newErrorDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Error); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/float.go b/pkg/types/float.go index 930abfc7..e6c2a577 100644 --- a/pkg/types/float.go +++ b/pkg/types/float.go @@ -176,7 +176,7 @@ func newFloatDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Float); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/getter.go b/pkg/types/getter.go deleted file mode 100644 index ef8d3916..00000000 --- a/pkg/types/getter.go +++ /dev/null @@ -1,41 +0,0 @@ -package types - -// Get extracts a value from a nested structure using the provided paths. -func Get[T any](obj Value, paths ...any) (T, bool) { - var val T - cur := obj - for _, path := range paths { - p, err := Marshal(path) - if err != nil { - return val, false - } - - switch p := p.(type) { - case String: - if v, ok := cur.(Map); ok { - child := v.Get(p) - if child == nil { - return val, false - } - cur = child - } - case Integer: - if v, ok := cur.(Slice); ok { - if int(p.Int()) >= v.Len() { - return val, false - } - cur = v.Get(int(p.Int())) - } - default: - return val, false - } - } - - if cur == nil { - return val, false - } - if v, ok := cur.(T); ok { - return v, true - } - return val, Unmarshal(cur, &val) == nil -} diff --git a/pkg/types/integer.go b/pkg/types/integer.go index d8673bc7..04a1eebd 100644 --- a/pkg/types/integer.go +++ b/pkg/types/integer.go @@ -347,7 +347,7 @@ func newIntegerDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Integer); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/map.go b/pkg/types/map.go index 2d006eab..f30cdfca 100644 --- a/pkg/types/map.go +++ b/pkg/types/map.go @@ -288,10 +288,10 @@ func (m *immutableMap) Interface() any { } } if keyType == nil { - keyType = types[KindInvalid] + keyType = types[KindUnknown] } if valueType == nil { - valueType = types[KindInvalid] + valueType = types[KindUnknown] } if keyType.Kind() == reflect.Interface || keyType.Kind() == reflect.Map || keyType.Kind() == reflect.Slice { @@ -725,7 +725,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod } return nil }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Map); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/slice.go b/pkg/types/slice.go index 70ee997e..aa61d747 100644 --- a/pkg/types/slice.go +++ b/pkg/types/slice.go @@ -150,7 +150,7 @@ func (s Slice) Interface() any { elementType = unionType(elementType, TypeOf(KindOf(element))) } if elementType == nil { - elementType = types[KindInvalid] + elementType = types[KindUnknown] } t := reflect.MakeSlice(reflect.SliceOf(elementType), len(s.value), len(s.value)) @@ -263,7 +263,7 @@ func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Dec } return valueDecoder.Decode(source, t.Index(0).Addr().UnsafePointer()) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Slice); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/string.go b/pkg/types/string.go index 691ca99b..14dd9440 100644 --- a/pkg/types/string.go +++ b/pkg/types/string.go @@ -297,7 +297,7 @@ func newStringDecoder() encoding2.DecodeCompiler[Value] { } return errors.WithStack(encoding2.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding2.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(String); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/uinteger.go b/pkg/types/uinteger.go index 4d58fc92..665d6b0d 100644 --- a/pkg/types/uinteger.go +++ b/pkg/types/uinteger.go @@ -347,7 +347,7 @@ func newUintegerDecoder() encoding.DecodeCompiler[Value] { } return errors.WithStack(encoding.ErrUnsupportedType) }), nil - } else if typ.Elem().Kind() == reflect.Interface { + } else if typ.Elem() == types[KindUnknown] { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { if s, ok := source.(Uinteger); ok { *(*any)(target) = s.Interface() diff --git a/pkg/types/value.go b/pkg/types/value.go index e11be2b8..1b36b664 100644 --- a/pkg/types/value.go +++ b/pkg/types/value.go @@ -1,6 +1,9 @@ package types -import "reflect" +import ( + "io" + "reflect" +) // Value is an interface representing atomic data types. type Value interface { @@ -14,10 +17,15 @@ type Value interface { // Kind represents enumerated data types. type Kind byte +type ordered interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr | ~float32 | ~float64 | ~string +} + // Constants representing various data types. const ( - KindInvalid Kind = iota + KindUnknown Kind = iota KindBinary + KindBuffer KindBoolean KindError KindInt @@ -38,8 +46,9 @@ const ( ) var types = map[Kind]reflect.Type{ - KindInvalid: reflect.TypeOf((*any)(nil)).Elem(), + KindUnknown: reflect.TypeOf((*any)(nil)).Elem(), KindBinary: reflect.TypeOf([]byte(nil)), + KindBuffer: reflect.TypeOf((*io.Reader)(nil)).Elem(), KindBoolean: reflect.TypeOf(false), KindError: reflect.TypeOf((*error)(nil)).Elem(), KindInt: reflect.TypeOf(0), @@ -62,7 +71,7 @@ var types = map[Kind]reflect.Type{ // KindOf returns the kind of the provided Value. func KindOf(v Value) Kind { if v == nil { - return KindInvalid + return KindUnknown } return v.Kind() } @@ -113,6 +122,59 @@ func Compare(x, y Value) int { return x.Compare(y) } +// Get extracts a value from a nested structure using the provided paths. +func Get[T any](obj Value, paths ...any) (T, bool) { + var val T + cur := obj + for _, path := range paths { + p, err := Marshal(path) + if err != nil { + return val, false + } + + switch p := p.(type) { + case String: + if v, ok := cur.(Map); ok { + child := v.Get(p) + if child == nil { + return val, false + } + cur = child + } + case Integer: + if v, ok := cur.(Slice); ok { + if int(p.Int()) >= v.Len() { + return val, false + } + cur = v.Get(int(p.Int())) + } + default: + return val, false + } + } + + if cur == nil { + return val, false + } + if v, ok := cur.(T); ok { + return v, true + } + return val, Unmarshal(cur, &val) == nil +} + +func compare[T ordered](x, y T) int { + if x == y { + return 0 + } + if x > y { + return 1 + } + if x < y { + return -1 + } + return 0 +} + func unionType(x, y reflect.Type) reflect.Type { if x == nil { return y @@ -121,5 +183,5 @@ func unionType(x, y reflect.Type) reflect.Type { } else if x == y { return x } - return types[KindInvalid] + return types[KindUnknown] }