Skip to content

Commit

Permalink
fix: escape race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
siyul-park committed Nov 5, 2024
1 parent 84fda34 commit a439458
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 22 deletions.
4 changes: 2 additions & 2 deletions ext/pkg/control/split.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func (n *SplitNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet.
switch inPayload := inPck.Payload().(type) {
case types.Slice:
outPcks := make([]*packet.Packet, 0, inPayload.Len())
for i := 0; i < inPayload.Len(); i++ {
outPck := packet.New(inPayload.Get(i))
for _, v := range inPayload.Range() {
outPck := packet.New(v)
outPcks = append(outPcks, outPck)
}
return outPcks, nil
Expand Down
12 changes: 8 additions & 4 deletions ext/pkg/io/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ func (n *PrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet.Pa
if !ok {
return nil, packet.New(types.NewError(encoding.ErrUnsupportedType))
}
for i := 1; i < payload.Len(); i++ {
args = append(args, types.InterfaceOf(payload.Get(i)))
for i, v := range payload.Range() {
if i > 0 {
args = append(args, types.InterfaceOf(v))
}
}
}

Expand Down Expand Up @@ -124,8 +126,10 @@ func (n *DynPrintNode) action(_ *process.Process, inPck *packet.Packet) (*packet
}

var args []any
for i := 2; i < payload.Len(); i++ {
args = append(args, types.InterfaceOf(payload.Get(i)))
for i, v := range payload.Range() {
if i > 1 {
args = append(args, types.InterfaceOf(v))
}
}

writer, err := n.fs.Open(filename, os.O_WRONLY|os.O_APPEND|os.O_CREATE)
Expand Down
10 changes: 4 additions & 6 deletions ext/pkg/mime/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er
elements = types.NewSlice(value)
}

for _, element := range elements.Values() {
for _, element := range elements.Range() {
h := textproto.MIMEHeader{}
h.Set(HeaderContentDisposition, fmt.Sprintf(`form-data; name="%s"`, quoteEscaper.Replace(key.String())))

Expand All @@ -121,7 +121,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er

writeFields := func(value types.Value) error {
if value, ok := value.(types.Map); ok {
for _, key := range value.Keys() {
for key := range value.Range() {
if err := writeField(value, key); err != nil {
return err
}
Expand All @@ -132,7 +132,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er

writeFiles := func(value types.Value) error {
if value, ok := value.(types.Map); ok {
for _, key := range value.Keys() {
for key := range value.Range() {
if key, ok := key.(types.String); ok {
value := value.GetOr(key, nil)

Expand Down Expand Up @@ -195,9 +195,7 @@ func Encode(writer io.Writer, value types.Value, header textproto.MIMEHeader) er
}

if v, ok := value.(types.Map); ok {
for _, key := range v.Keys() {
value := v.GetOr(key, nil)

for key, value := range v.Range() {
if key.Equal(keyValues) {
if err := writeFields(value); err != nil {
return err
Expand Down
3 changes: 3 additions & 0 deletions pkg/scheme/scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ func (s *Scheme) AddCodec(kind string, codec Codec) bool {

// RemoveCodec removes the codec associated with a kind.
func (s *Scheme) RemoveCodec(kind string) bool {
s.mu.Lock()
defer s.mu.Unlock()

if _, ok := s.codecs[kind]; ok {
delete(s.codecs, kind)
return true
Expand Down
26 changes: 18 additions & 8 deletions pkg/types/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ func (m Map) Pairs() []Value {
return pairs
}

// Range returns a function that iterates over all key-value pairs in the map.
func (m Map) Range() func(func(key, value Value) bool) {
return func(yield func(key Value, value Value) bool) {
for itr := m.value.Iterator(); !itr.Done(); {
k, v, _ := itr.Next()
if !yield(k, v) {
return
}
}
}
}

// Len returns the number of key-value pairs in the map.
func (m Map) Len() int {
return m.value.Len()
Expand Down Expand Up @@ -234,6 +246,10 @@ func (m *mapProxy) Delete(key Value) {
m.Map = m.Map.Delete(key)
}

func (m *mapProxy) Close() {
m.Map = NewMap()
}

func (*comparer) Compare(x, y Value) int {
return Compare(x, y)
}
Expand Down Expand Up @@ -375,14 +391,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod
t.Set(reflect.MakeMapWithSize(t.Type(), proxy.Len()))
}

for _, key := range proxy.Keys() {
value, ok := proxy.Get(key)
if !ok {
continue
}

proxy.Delete(key)

for key, value := range proxy.Range() {
k := reflect.New(keyType)
v := reflect.New(valueType)

Expand All @@ -394,6 +403,7 @@ func newMapDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Decod
t.SetMapIndex(k.Elem(), v.Elem())
}
}
proxy.Close()
return nil
}), nil
} else if typ.Elem().Kind() == reflect.Struct {
Expand Down
12 changes: 12 additions & 0 deletions pkg/types/map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ func TestMap_Pairs(t *testing.T) {
assert.Contains(t, pairs, v1)
}

func TestMap_Range(t *testing.T) {
k1 := NewString(faker.UUIDHyphenated())
v1 := NewString(faker.UUIDHyphenated())

o := NewMap(k1, v1)

for k, v := range o.Range() {
assert.Equal(t, k1, k)
assert.Equal(t, v1, v)
}
}

func TestMap_Len(t *testing.T) {
k1 := NewString(faker.UUIDHyphenated())
v1 := NewString(faker.UUIDHyphenated())
Expand Down
16 changes: 14 additions & 2 deletions pkg/types/slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ func (s Slice) Values() []Value {
return elements
}

// Range returns a function that iterates over all key-value pairs of the slice.
func (s Slice) Range() func(func(key int, value Value) bool) {
return func(yield func(key int, value Value) bool) {
for itr := s.value.Iterator(); !itr.Done(); {
i, v := itr.Next()
if !yield(i, v) {
return
}
}
}
}

// Len returns the length of the slice.
func (s Slice) Len() int {
return s.value.Len()
Expand Down Expand Up @@ -214,8 +226,8 @@ func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.Dec
return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error {
t := reflect.NewAt(typ.Elem(), target).Elem()
if s, ok := source.(Slice); ok {
for i := 0; i < s.Len(); i++ {
if err := setElement(s.Get(i), t, i); err != nil {
for i, v := range s.Range() {
if err := setElement(v, t, i); err != nil {
return err
}
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/types/slice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ func TestSlice_Values(t *testing.T) {
assert.Equal(t, []Value{v1, v2}, o.Values())
}

func TestSlice_Range(t *testing.T) {
v1 := NewString(faker.UUIDHyphenated())

o := NewSlice(v1)

for _, v := range o.Range() {
assert.Equal(t, v1, v)
}
}

func TestSlice_Len(t *testing.T) {
v1 := NewString(faker.UUIDHyphenated())
v2 := NewString(faker.UUIDHyphenated())
Expand Down

0 comments on commit a439458

Please sign in to comment.