From 686e0cac2d4f307b59dbb7ae7da5c26c4fdcdf53 Mon Sep 17 00:00:00 2001 From: siyual-park Date: Tue, 5 Nov 2024 13:33:42 +0900 Subject: [PATCH] fix: escape race condition --- ext/go.mod | 4 ++-- ext/go.sum | 8 ++++---- ext/pkg/control/split.go | 4 ++-- ext/pkg/io/print.go | 12 ++++++++---- ext/pkg/mime/encoding.go | 10 ++++------ ext/pkg/network/listener.go | 7 ------- ext/pkg/network/listener_test.go | 3 --- ext/pkg/network/websocket.go | 8 +++++++- ext/pkg/network/websocket_test.go | 17 ++++++++++++++--- pkg/process/local.go | 2 +- pkg/process/process_test.go | 2 +- pkg/scheme/scheme.go | 3 +++ pkg/symbol/table_test.go | 4 ++-- pkg/types/map.go | 26 ++++++++++++++++++-------- pkg/types/map_test.go | 12 ++++++++++++ pkg/types/slice.go | 16 ++++++++++++++-- pkg/types/slice_test.go | 10 ++++++++++ 17 files changed, 102 insertions(+), 46 deletions(-) diff --git a/ext/go.mod b/ext/go.mod index c33ce02d..9b1a21db 100644 --- a/ext/go.mod +++ b/ext/go.mod @@ -34,7 +34,7 @@ require ( golang.org/x/net v0.30.0 golang.org/x/sys v0.26.0 // indirect golang.org/x/text v0.19.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/protobuf v1.35.1 // indirect ) diff --git a/ext/go.sum b/ext/go.sum index 73250205..c3ecc4d4 100644 --- a/ext/go.sum +++ b/ext/go.sum @@ -82,10 +82,10 @@ golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= -google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 h1:2oV8dfuIkM1Ti7DwXc0BJfnwr9csz4TDXI9EmiI+Rbw= -google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38/go.mod h1:vuAjtvlwkDKF6L1GQ0SokiRLCGFfeBUXWr/aFFkHACc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI= -google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 h1:M0KvPgPmDZHPlbRbaNU1APr28TvwvvdUPlSv7PUvy8g= +google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:dguCy7UOdZhTvLzDyt15+rOrawrpM4q7DD9dQ1P11P4= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 h1:XVhgTWWV3kGQlwJHR3upFWZeTsei6Oks1apkZSeonIE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/ext/pkg/control/split.go b/ext/pkg/control/split.go index d3bcd9e9..e2ad4a1d 100644 --- a/ext/pkg/control/split.go +++ b/ext/pkg/control/split.go @@ -39,8 +39,8 @@ func (n *SplitNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet. switch inPayload := inPck.Payload().(type) { case types.Slice: outPcks := make([]*packet.Packet, 0, inPayload.Len()) - for i := 0; i < inPayload.Len(); i++ { - outPck := packet.New(inPayload.Get(i)) + for _, v := range inPayload.Range() { + outPck := packet.New(v) outPcks = append(outPcks, outPck) } return outPcks, nil diff --git a/ext/pkg/io/print.go b/ext/pkg/io/print.go index 97e37503..5a6ad558 100644 --- a/ext/pkg/io/print.go +++ b/ext/pkg/io/print.go @@ -92,8 +92,10 @@ func (n *PrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet.Pa if !ok { return nil, packet.New(types.NewError(encoding.ErrUnsupportedType)) } - for i := 1; i < payload.Len(); i++ { - args = append(args, types.InterfaceOf(payload.Get(i))) + for i, v := range payload.Range() { + if i > 0 { + args = append(args, types.InterfaceOf(v)) + } } } @@ -124,8 +126,10 @@ func (n *DynPrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet } var args []any - for i := 2; i < payload.Len(); i++ { - args = append(args, types.InterfaceOf(payload.Get(i))) + for i, v := range payload.Range() { + if i > 1 { + args = append(args, types.InterfaceOf(v)) + } } writer, err := n.fs.Open(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE) diff --git a/ext/pkg/mime/encoding.go b/ext/pkg/mime/encoding.go index 834ac4fe..2c598672 100644 --- a/ext/pkg/mime/encoding.go +++ b/ext/pkg/mime/encoding.go @@ -105,7 +105,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er elements = types.NewSlice(value) } - for _, element := range elements.Values() { + for _, element := range elements.Range() { h := textproto.MIMEHeader{} h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"`, quoteEscaper.Replace(key.String()))) @@ -121,7 +121,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er writeFields := func(value types.Value) error { if value, ok := value.(types.Map); ok { - for _, key := range value.Keys() { + for key := range value.Range() { if err := writeField(value, key); err != nil { return err } @@ -132,7 +132,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er writeFiles := func(value types.Value) error { if value, ok := value.(types.Map); ok { - for _, key := range value.Keys() { + for key := range value.Range() { if key, ok := key.(types.String); ok { value := value.GetOr(key, nil) @@ -195,9 +195,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er } if v, ok := value.(types.Map); ok { - for _, key := range v.Keys() { - value := v.GetOr(key, nil) - + for key, value := range v.Range() { if key.Equal(keyValues) { if err := writeFields(value); err != nil { return err diff --git a/ext/pkg/network/listener.go b/ext/pkg/network/listener.go index 78ea3975..06ac309c 100644 --- a/ext/pkg/network/listener.go +++ b/ext/pkg/network/listener.go @@ -185,13 +185,6 @@ func (n *HTTPListenNode) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer n.mu.RUnlock() proc := process.New() - ctx := r.Context() - - go func() { - <-ctx.Done() - proc.Join() - proc.Exit(ctx.Err()) - }() proc.Store(KeyHTTPResponseWriter, w) proc.Store(KeyHTTPRequest, r) diff --git a/ext/pkg/network/listener_test.go b/ext/pkg/network/listener_test.go index c038203e..93b2d774 100644 --- a/ext/pkg/network/listener_test.go +++ b/ext/pkg/network/listener_test.go @@ -339,9 +339,6 @@ func BenchmarkHTTPListenNode_ServeHTTP(b *testing.B) { n := NewHTTPListenNode("") defer n.Close() - in := port.NewOut() - in.Link(n.In(node.PortIn)) - out := port.NewIn() n.Out(node.PortOut).Link(out) diff --git a/ext/pkg/network/websocket.go b/ext/pkg/network/websocket.go index 30d223ae..7fda36c4 100644 --- a/ext/pkg/network/websocket.go +++ b/ext/pkg/network/websocket.go @@ -286,8 +286,14 @@ func (n *WebSocketConnNode) connection(proc *process.Process) (*websocket.Conn, conns := make(chan *websocket.Conn) defer close(conns) + done := make(chan struct{}) + defer close(done) + hook := process.StoreFunc(func(conn *websocket.Conn) { - conns <- conn + select { + case conns <- conn: + case <-done: + } }) for p := proc; p != nil; p = p.Parent() { diff --git a/ext/pkg/network/websocket_test.go b/ext/pkg/network/websocket_test.go index 6451975c..7fe52e02 100644 --- a/ext/pkg/network/websocket_test.go +++ b/ext/pkg/network/websocket_test.go @@ -175,7 +175,19 @@ func BenchmarkWebSocketNode_SendAndReceive(b *testing.B) { ioWriter := io.Open(proc) inWriter := in.Open(proc) - outReader := out.Open(proc) + + out.AddListener(port.ListenFunc(func(proc *process.Process) { + outReader := out.Open(proc) + + for { + _, ok := <-outReader.Read() + if !ok { + return + } + + outReader.Receive(packet.None) + } + })) var inPayload types.Value inPck := packet.New(inPayload) @@ -192,8 +204,7 @@ func BenchmarkWebSocketNode_SendAndReceive(b *testing.B) { for i := 0; i < b.N; i++ { inWriter.Write(inPck) - outPck := <-outReader.Read() - outReader.Receive(outPck) + <-inWriter.Receive() } inPayload, _ = types.Marshal(&WebSocketPayload{ diff --git a/pkg/process/local.go b/pkg/process/local.go index 992d4390..bce6a323 100644 --- a/pkg/process/local.go +++ b/pkg/process/local.go @@ -24,8 +24,8 @@ func (l *Local[T]) AddStoreHook(proc *Process, hook StoreHook[T]) bool { if _, ok := l.data[proc]; ok { l.mu.Unlock() + defer l.mu.Lock() hook.Store(l.data[proc]) - l.mu.Lock() return true } diff --git a/pkg/process/process_test.go b/pkg/process/process_test.go index b7e6f99c..7a159ba6 100644 --- a/pkg/process/process_test.go +++ b/pkg/process/process_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestNewProcess(t *testing.T) { +func TestNew(t *testing.T) { proc := New() defer proc.Exit(nil) diff --git a/pkg/scheme/scheme.go b/pkg/scheme/scheme.go index 75012b5f..a9b8100a 100644 --- a/pkg/scheme/scheme.go +++ b/pkg/scheme/scheme.go @@ -92,6 +92,9 @@ func (s *Scheme) AddCodec(kind string, codec Codec) bool { // RemoveCodec removes the codec associated with a kind. func (s *Scheme) RemoveCodec(kind string) bool { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.codecs[kind]; ok { delete(s.codecs, kind) return true diff --git a/pkg/symbol/table_test.go b/pkg/symbol/table_test.go index 88e43c53..bc31e61e 100644 --- a/pkg/symbol/table_test.go +++ b/pkg/symbol/table_test.go @@ -320,7 +320,7 @@ func BenchmarkTable_Insert(b *testing.B) { Namespace: resource.DefaultNamespace, } - sb := &Symbol{Spec: meta} + sb := &Symbol{Spec: meta, Node: node.NewOneToOneNode(nil)} _ = tb.Insert(sb) } } @@ -340,7 +340,7 @@ func BenchmarkTable_Free(b *testing.B) { Namespace: resource.DefaultNamespace, } - sb := &Symbol{Spec: meta} + sb := &Symbol{Spec: meta, Node: node.NewOneToOneNode(nil)} _ = tb.Insert(sb) b.StartTimer() diff --git a/pkg/types/map.go b/pkg/types/map.go index b1519083..0627e262 100644 --- a/pkg/types/map.go +++ b/pkg/types/map.go @@ -103,6 +103,18 @@ func (m Map) Pairs() []Value { return pairs } +// Range returns a function that iterates over all key-value pairs in the map. +func (m Map) Range() func(func(key, value Value) bool) { + return func(yield func(key Value, value Value) bool) { + for itr := m.value.Iterator(); !itr.Done(); { + k, v, _ := itr.Next() + if !yield(k, v) { + return + } + } + } +} + // Len returns the number of key-value pairs in the map. func (m Map) Len() int { return m.value.Len() @@ -234,6 +246,10 @@ func (m *mapProxy) Delete(key Value) { m.Map = m.Map.Delete(key) } +func (m *mapProxy) Close() { + m.Map = NewMap() +} + func (*comparer) Compare(x, y Value) int { return Compare(x, y) } @@ -375,14 +391,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod t.Set(reflect.MakeMapWithSize(t.Type(), proxy.Len())) } - for _, key := range proxy.Keys() { - value, ok := proxy.Get(key) - if !ok { - continue - } - - proxy.Delete(key) - + for key, value := range proxy.Range() { k := reflect.New(keyType) v := reflect.New(valueType) @@ -394,6 +403,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod t.SetMapIndex(k.Elem(), v.Elem()) } } + proxy.Close() return nil }), nil } else if typ.Elem().Kind() == reflect.Struct { diff --git a/pkg/types/map_test.go b/pkg/types/map_test.go index 02406cb6..14e457f3 100644 --- a/pkg/types/map_test.go +++ b/pkg/types/map_test.go @@ -72,6 +72,18 @@ func TestMap_Pairs(t *testing.T) { assert.Contains(t, pairs, v1) } +func TestMap_Range(t *testing.T) { + k1 := NewString(faker.UUIDHyphenated()) + v1 := NewString(faker.UUIDHyphenated()) + + o := NewMap(k1, v1) + + for k, v := range o.Range() { + assert.Equal(t, k1, k) + assert.Equal(t, v1, v) + } +} + func TestMap_Len(t *testing.T) { k1 := NewString(faker.UUIDHyphenated()) v1 := NewString(faker.UUIDHyphenated()) diff --git a/pkg/types/slice.go b/pkg/types/slice.go index 5816effd..dd20ac7e 100644 --- a/pkg/types/slice.go +++ b/pkg/types/slice.go @@ -66,6 +66,18 @@ func (s Slice) Values() []Value { return elements } +// Range returns a function that iterates over all key-value pairs of the slice. +func (s Slice) Range() func(func(key int, value Value) bool) { + return func(yield func(key int, value Value) bool) { + for itr := s.value.Iterator(); !itr.Done(); { + i, v := itr.Next() + if !yield(i, v) { + return + } + } + } +} + // Len returns the length of the slice. func (s Slice) Len() int { return s.value.Len() @@ -214,8 +226,8 @@ func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Dec return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { t := reflect.NewAt(typ.Elem(), target).Elem() if s, ok := source.(Slice); ok { - for i := 0; i < s.Len(); i++ { - if err := setElement(s.Get(i), t, i); err != nil { + for i, v := range s.Range() { + if err := setElement(v, t, i); err != nil { return err } } diff --git a/pkg/types/slice_test.go b/pkg/types/slice_test.go index 7e794b8a..64acbad3 100644 --- a/pkg/types/slice_test.go +++ b/pkg/types/slice_test.go @@ -75,6 +75,16 @@ func TestSlice_Values(t *testing.T) { assert.Equal(t, []Value{v1, v2}, o.Values()) } +func TestSlice_Range(t *testing.T) { + v1 := NewString(faker.UUIDHyphenated()) + + o := NewSlice(v1) + + for _, v := range o.Range() { + assert.Equal(t, v1, v) + } +} + func TestSlice_Len(t *testing.T) { v1 := NewString(faker.UUIDHyphenated()) v2 := NewString(faker.UUIDHyphenated())