From 5a075bef0394f480b2d0cb02569fd9b6ed7bc714 Mon Sep 17 00:00:00 2001 From: siyul-park Date: Sat, 28 Dec 2024 09:06:32 +0900 Subject: [PATCH] fix: more minimal lock --- ext/pkg/io/sql.go | 4 +- ext/pkg/language/javascript/compiler.go | 4 +- ext/pkg/network/listener.go | 3 - ext/pkg/network/router.go | 6 +- pkg/node/manytoone.go | 25 ++-- pkg/node/onetomany.go | 24 ++-- pkg/node/onetoone.go | 15 ++- pkg/packet/reader.go | 64 ++++----- pkg/packet/writer.go | 75 +++++------ pkg/port/inport.go | 9 +- pkg/port/outport.go | 9 +- pkg/process/local.go | 90 +++++++++---- pkg/process/process.go | 6 +- pkg/symbol/cluster.go | 2 +- pkg/types/map.go | 166 +++++++++++++---------- pkg/types/map_test.go | 155 ++++++++++++++++++---- pkg/types/slice.go | 58 ++++---- pkg/types/slice_test.go | 169 +++++++++++++++++------- 18 files changed, 562 insertions(+), 322 deletions(-) diff --git a/ext/pkg/io/sql.go b/ext/pkg/io/sql.go index f0b37cee..dd7db272 100644 --- a/ext/pkg/io/sql.go +++ b/ext/pkg/io/sql.go @@ -109,9 +109,9 @@ func (n *SQLNode) action(proc *process.Process, inPck *packet.Packet) (*packet.P proc.AddExitHook(process.ExitFunc(func(err error) { if err != nil { - tx.Rollback() + _ = tx.Rollback() } else { - tx.Commit() + _ = tx.Commit() } })) diff --git a/ext/pkg/language/javascript/compiler.go b/ext/pkg/language/javascript/compiler.go index 3d5e6bc4..e9cc1c48 100644 --- a/ext/pkg/language/javascript/compiler.go +++ b/ext/pkg/language/javascript/compiler.go @@ -68,7 +68,9 @@ func NewCompiler(options ...api.TransformOptions) language.Compiler { }, } - return language.RunFunc(func(ctx context.Context, args []any) ([]any, error) { + return language.RunFunc(func(ctx context.Context, args []any) (_ []any, err error) { + defer func() { err, _ = recover().(error) }() + vm := vms.Get().(*goja.Runtime) defer vms.Put(vm) diff --git a/ext/pkg/network/listener.go b/ext/pkg/network/listener.go index 61f6f483..9a8a4684 100644 --- a/ext/pkg/network/listener.go +++ b/ext/pkg/network/listener.go @@ -181,9 +181,6 @@ func (n *HTTPListenNode) Shutdown() error { // ServeHTTP handles HTTP requests. func (n *HTTPListenNode) ServeHTTP(w http.ResponseWriter, r *http.Request) { - n.mu.RLock() - defer n.mu.RUnlock() - proc := process.New() proc.Store(KeyHTTPResponseWriter, w) diff --git a/ext/pkg/network/router.go b/ext/pkg/network/router.go index 71e45184..04e58ef4 100644 --- a/ext/pkg/network/router.go +++ b/ext/pkg/network/router.go @@ -170,7 +170,7 @@ func (n *RouteNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet. method, _ := types.Get[string](inPayload, "method") path, _ := types.Get[string](inPayload, "path") - route, paramValues := n.find(method, path) + route, values := n.find(method, path) if route == nil { outPayload, _ := types.Marshal(NewHTTPPayload(http.StatusNotFound)) return nil, packet.New(outPayload) @@ -189,9 +189,9 @@ func (n *RouteNode) action(_ *process.Process, inPck *packet.Packet) ([]*packet. return nil, packet.New(outPayload) } - params := make([]types.Value, 0, len(paramValues)*2) + params := make([]types.Value, 0, len(values)*2) for i, name := range route.paramNames { - params = append(params, types.NewString(name), types.NewString(paramValues[i])) + params = append(params, types.NewString(name), types.NewString(values[i])) } outPayload := inPayload.Set(types.NewString("params"), types.NewMap(params...)) diff --git a/pkg/node/manytoone.go b/pkg/node/manytoone.go index 2b5bbd33..7ef8ab53 100644 --- a/pkg/node/manytoone.go +++ b/pkg/node/manytoone.go @@ -91,15 +91,17 @@ func (n *ManyToOneNode) Close() error { } func (n *ManyToOneNode) forward(index int) port.Listener { - return port.ListenFunc(func(proc *process.Process) { - n.mu.RLock() - defer n.mu.RUnlock() + inPort := n.inPorts[index] - inReader := n.inPorts[index].Open(proc) + return port.ListenFunc(func(proc *process.Process) { + inReader := inPort.Open(proc) var outWriter *packet.Writer var errWriter *packet.Writer readGroup, _ := n.readGroups.LoadOrStore(proc, func() (*packet.ReadGroup, error) { + n.mu.RLock() + defer n.mu.RUnlock() + inReaders := make([]*packet.Reader, len(n.inPorts)) for i, inPort := range n.inPorts { inReaders[i] = inPort.Open(proc) @@ -110,19 +112,20 @@ func (n *ManyToOneNode) forward(index int) port.Listener { for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if outWriter == nil { - outWriter = n.outPort.Open(proc) - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if inPcks := readGroup.Read(inReader, inPck); len(inPcks) < len(n.inPorts) { n.tracer.Reduce(inPck) } else if outPck, errPck := n.action(proc, inPcks); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) + } + n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else if outPck != nil { + if outWriter == nil { + outWriter = n.outPort.Open(proc) + } + n.tracer.Transform(inPck, outPck) n.tracer.Write(outWriter, outPck) } else { diff --git a/pkg/node/onetomany.go b/pkg/node/onetomany.go index 8e4d1cd5..b8a7b27f 100644 --- a/pkg/node/onetomany.go +++ b/pkg/node/onetomany.go @@ -94,22 +94,17 @@ func (n *OneToManyNode) forward(proc *process.Process) { defer n.mu.RUnlock() inReader := n.inPort.Open(proc) - outWriters := make([]*packet.Writer, 0, len(n.outPorts)) + outWriters := make([]*packet.Writer, len(n.outPorts)) var errWriter *packet.Writer for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if len(outWriters) == 0 { - for _, outPort := range n.outPorts { - outWriters = append(outWriters, outPort.Open(proc)) + if outPcks, errPck := n.action(proc, inPck); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) } - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if outPcks, errPck := n.action(proc, inPck); errPck != nil { n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else { @@ -122,6 +117,10 @@ func (n *OneToManyNode) forward(proc *process.Process) { count := 0 for i, outPck := range outPcks { if i < len(outWriters) && outPck != nil { + if outWriters[i] == nil { + outWriters[i] = n.outPorts[i].Open(proc) + } + n.tracer.Write(outWriters[i], outPck) count++ } @@ -135,12 +134,9 @@ func (n *OneToManyNode) forward(proc *process.Process) { } func (n *OneToManyNode) backward(index int) port.Listener { - return port.ListenFunc(func(proc *process.Process) { - n.mu.RLock() - defer n.mu.RUnlock() - - outPort := n.outPorts[index] + outPort := n.outPorts[index] + return port.ListenFunc(func(proc *process.Process) { outWriter := outPort.Open(proc) for backPck := range outWriter.Receive() { diff --git a/pkg/node/onetoone.go b/pkg/node/onetoone.go index a78f933b..5736e74c 100644 --- a/pkg/node/onetoone.go +++ b/pkg/node/onetoone.go @@ -75,17 +75,18 @@ func (n *OneToOneNode) forward(proc *process.Process) { for inPck := range inReader.Read() { n.tracer.Read(inReader, inPck) - if outWriter == nil { - outWriter = n.outPort.Open(proc) - } - if errWriter == nil { - errWriter = n.errPort.Open(proc) - } - if outPck, errPck := n.action(proc, inPck); errPck != nil { + if errWriter == nil { + errWriter = n.errPort.Open(proc) + } + n.tracer.Transform(inPck, errPck) n.tracer.Write(errWriter, errPck) } else { + if outWriter == nil { + outWriter = n.outPort.Open(proc) + } + n.tracer.Transform(inPck, outPck) n.tracer.Write(outWriter, outPck) } diff --git a/pkg/packet/reader.go b/pkg/packet/reader.go index a1a3129a..d316b662 100644 --- a/pkg/packet/reader.go +++ b/pkg/packet/reader.go @@ -11,7 +11,7 @@ type Reader struct { writers []*Writer in chan *Packet out chan *Packet - done chan struct{} + done bool inbounds Hooks outbounds Hooks mu sync.Mutex @@ -20,21 +20,18 @@ type Reader struct { // NewReader creates a new Reader instance and starts its processing loop. func NewReader() *Reader { r := &Reader{ - in: make(chan *Packet), - out: make(chan *Packet), - done: make(chan struct{}), + in: make(chan *Packet), + out: make(chan *Packet), } go func() { defer close(r.out) - defer close(r.in) buffer := make([]*Packet, 0, 2) for { var pck *Packet - select { - case pck = <-r.in: - case <-r.done: + var ok bool + if pck, ok = <-r.in; !ok { return } @@ -42,10 +39,12 @@ func NewReader() *Reader { case r.out <- pck: default: buffer = append(buffer, pck) - for len(buffer) > 0 { select { - case pck = <-r.in: + case pck, ok = <-r.in: + if !ok { + return + } buffer = append(buffer, pck) case r.out <- buffer[0]: buffer = buffer[1:] @@ -63,18 +62,17 @@ func (r *Reader) AddInboundHook(hook Hook) bool { r.mu.Lock() defer r.mu.Unlock() - select { - case <-r.done: + if r.done { return false - default: - for _, h := range r.inbounds { - if h == hook { - return false - } + } + + for _, h := range r.inbounds { + if h == hook { + return false } - r.inbounds = append(r.inbounds, hook) - return true } + r.inbounds = append(r.inbounds, hook) + return true } // AddOutboundHook adds a handler to process outbound packets. @@ -82,18 +80,17 @@ func (r *Reader) AddOutboundHook(hook Hook) bool { r.mu.Lock() defer r.mu.Unlock() - select { - case <-r.done: + if r.done { return false - default: - for _, h := range r.outbounds { - if h == hook { - return false - } + } + + for _, h := range r.outbounds { + if h == hook { + return false } - r.outbounds = append(r.outbounds, hook) - return true } + r.outbounds = append(r.outbounds, hook) + return true } // Read returns the channel for reading packets from the reader. @@ -125,10 +122,8 @@ func (r *Reader) Close() { r.mu.Lock() defer r.mu.Unlock() - select { - case <-r.done: + if r.done { return - default: } pck := New(types.NewError(ErrDroppedPacket)) @@ -137,8 +132,9 @@ func (r *Reader) Close() { go w.receive(pck, r) } - close(r.done) + close(r.in) + r.done = true r.writers = nil r.inbounds = nil r.outbounds = nil @@ -148,10 +144,8 @@ func (r *Reader) write(pck *Packet, writer *Writer) bool { r.mu.Lock() defer r.mu.Unlock() - select { - case <-r.done: + if r.done { return false - default: } r.writers = append(r.writers, writer) diff --git a/pkg/packet/writer.go b/pkg/packet/writer.go index 43b27329..e4ec199a 100644 --- a/pkg/packet/writer.go +++ b/pkg/packet/writer.go @@ -12,7 +12,7 @@ type Writer struct { receives [][]*Packet in chan *Packet out chan *Packet - done chan struct{} + done bool inbounds Hooks outbounds Hooks mu sync.Mutex @@ -34,21 +34,18 @@ func SendOrFallback(writer *Writer, outPck *Packet, backPck *Packet) *Packet { // NewWriter creates a new Writer instance and starts its processing loop. func NewWriter() *Writer { w := &Writer{ - in: make(chan *Packet), - out: make(chan *Packet), - done: make(chan struct{}), + in: make(chan *Packet), + out: make(chan *Packet), } go func() { defer close(w.out) - defer close(w.in) buffer := make([]*Packet, 0, 2) for { var pck *Packet - select { - case pck = <-w.in: - case <-w.done: + var ok bool + if pck, ok = <-w.in; !ok { return } @@ -56,10 +53,12 @@ func NewWriter() *Writer { case w.out <- pck: default: buffer = append(buffer, pck) - for len(buffer) > 0 { select { - case pck = <-w.in: + case pck, ok = <-w.in: + if !ok { + return + } buffer = append(buffer, pck) case w.out <- buffer[0]: buffer = buffer[1:] @@ -77,18 +76,17 @@ func (w *Writer) AddInboundHook(hook Hook) bool { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return false - default: - for _, h := range w.inbounds { - if h == hook { - return false - } + } + + for _, h := range w.inbounds { + if h == hook { + return false } - w.inbounds = append(w.inbounds, hook) - return true } + w.inbounds = append(w.inbounds, hook) + return true } // AddOutboundHook adds a handler to process outbound packets. @@ -96,18 +94,17 @@ func (w *Writer) AddOutboundHook(hook Hook) bool { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return false - default: - for _, h := range w.outbounds { - if h == hook { - return false - } + } + + for _, h := range w.outbounds { + if h == hook { + return false } - w.outbounds = append(w.outbounds, hook) - return true } + w.outbounds = append(w.outbounds, hook) + return true } // Link connects a reader to the writer. @@ -115,12 +112,11 @@ func (w *Writer) Link(reader *Reader) { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return - default: - w.readers = append(w.readers, reader) } + + w.readers = append(w.readers, reader) } // Write writes a packet to all linked readers and returns the count of successful writes. @@ -128,10 +124,8 @@ func (w *Writer) Write(pck *Packet) int { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return 0 - default: } if len(w.readers) == 0 { @@ -171,10 +165,8 @@ func (w *Writer) Close() { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return - default: } pck := New(types.NewError(ErrDroppedPacket)) @@ -183,8 +175,9 @@ func (w *Writer) Close() { w.in <- pck } - close(w.done) + close(w.in) + w.done = true w.readers = nil w.receives = nil w.inbounds = nil @@ -195,10 +188,8 @@ func (w *Writer) receive(pck *Packet, reader *Reader) bool { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: + if w.done { return false - default: } index := w.indexOfReader(reader) diff --git a/pkg/port/inport.go b/pkg/port/inport.go index 4cd6a4ff..86e4fa47 100644 --- a/pkg/port/inport.go +++ b/pkg/port/inport.go @@ -95,9 +95,16 @@ func (p *InPort) AddListener(listener Listener) bool { // Open prepares the input port for a given process and returns a reader. func (p *InPort) Open(proc *process.Process) *packet.Reader { + p.mu.RLock() + reader, ok := p.readers[proc] + p.mu.RUnlock() + if ok { + return reader + } + p.mu.Lock() - reader, ok := p.readers[proc] + reader, ok = p.readers[proc] if ok { p.mu.Unlock() return reader diff --git a/pkg/port/outport.go b/pkg/port/outport.go index b195079c..f7d814a0 100644 --- a/pkg/port/outport.go +++ b/pkg/port/outport.go @@ -129,9 +129,16 @@ func (p *OutPort) Unlink(in *InPort) { // Open opens the output port for the given process and returns a writer. func (p *OutPort) Open(proc *process.Process) *packet.Writer { + p.mu.RLock() + writer, ok := p.writers[proc] + p.mu.RUnlock() + if ok { + return writer + } + p.mu.Lock() - writer, ok := p.writers[proc] + writer, ok = p.writers[proc] if ok { p.mu.Unlock() return writer diff --git a/pkg/process/local.go b/pkg/process/local.go index 186b0306..124b6d40 100644 --- a/pkg/process/local.go +++ b/pkg/process/local.go @@ -1,18 +1,31 @@ package process -import "sync" +import ( + "sync" + "sync/atomic" +) -// Local provides a concurrent cache for process-specific data. +// Local provides a concurrent cache for process-specific eager. type Local[T any] struct { - data map[*Process]T + eager map[*Process]T + lazy map[*Process]*lazy[T] storeHooks map[*Process]StoreHooks[T] mu sync.RWMutex } +type lazy[T any] struct { + fn func() (T, error) + value T + error error + done atomic.Uint32 + mu sync.Mutex +} + // NewLocal creates a new Local cache instance. func NewLocal[T any]() *Local[T] { return &Local[T]{ - data: make(map[*Process]T), + eager: make(map[*Process]T), + lazy: make(map[*Process]*lazy[T]), storeHooks: make(map[*Process]StoreHooks[T]), } } @@ -22,7 +35,7 @@ func (l *Local[T]) AddStoreHook(proc *Process, hook StoreHook[T]) bool { l.mu.Lock() defer l.mu.Unlock() - if val, ok := l.data[proc]; ok { + if val, ok := l.eager[proc]; ok { l.mu.Unlock() defer l.mu.Lock() @@ -63,8 +76,8 @@ func (l *Local[T]) Keys() []*Process { l.mu.RLock() defer l.mu.RUnlock() - keys := make([]*Process, 0, len(l.data)) - for proc := range l.data { + keys := make([]*Process, 0, len(l.eager)) + for proc := range l.eager { keys = append(keys, proc) } return keys @@ -75,7 +88,7 @@ func (l *Local[T]) Load(proc *Process) (T, bool) { l.mu.RLock() defer l.mu.RUnlock() - val, ok := l.data[proc] + val, ok := l.eager[proc] return val, ok } @@ -83,9 +96,9 @@ func (l *Local[T]) Load(proc *Process) (T, bool) { func (l *Local[T]) Store(proc *Process, val T) { l.mu.Lock() - _, ok := l.data[proc] + _, ok := l.eager[proc] - l.data[proc] = val + l.eager[proc] = val if !ok { proc.AddExitHook(ExitFunc(func(err error) { l.Delete(proc) @@ -100,14 +113,14 @@ func (l *Local[T]) Store(proc *Process, val T) { storeHooks.Store(val) } -// Delete removes the process and its data from the cache. +// Delete removes the process and its eager from the cache. func (l *Local[T]) Delete(proc *Process) bool { l.mu.Lock() defer l.mu.Unlock() - _, ok := l.data[proc] + _, ok := l.eager[proc] - delete(l.data, proc) + delete(l.eager, proc) delete(l.storeHooks, proc) return ok @@ -115,40 +128,69 @@ func (l *Local[T]) Delete(proc *Process) bool { // LoadOrStore retrieves or stores a value for the given process. func (l *Local[T]) LoadOrStore(proc *Process, val func() (T, error)) (T, error) { + l.mu.RLock() + v, ok := l.eager[proc] + l.mu.RUnlock() + if ok { + return v, nil + } + l.mu.Lock() - defer l.mu.Unlock() - if v, ok := l.data[proc]; ok { + if v, ok := l.eager[proc]; ok { + l.mu.Unlock() return v, nil } - v, err := val() + fn, ok := l.lazy[proc] + if !ok { + fn = &lazy[T]{fn: val} + l.lazy[proc] = fn + } + + l.mu.Unlock() + + v, err := fn.Do() if err != nil { return v, err } - l.data[proc] = v - proc.AddExitHook(ExitFunc(func(err error) { - l.Delete(proc) - })) + l.mu.Lock() + + l.eager[proc] = v + delete(l.lazy, proc) storeHooks := l.storeHooks[proc] delete(l.storeHooks, proc) l.mu.Unlock() - storeHooks.Store(v) + proc.AddExitHook(ExitFunc(func(err error) { + l.Delete(proc) + })) - l.mu.Lock() + storeHooks.Store(v) return v, nil } -// Close clears all cached data and hooks. +// Close clears all cached eager and hooks. func (l *Local[T]) Close() { l.mu.Lock() defer l.mu.Unlock() - l.data = make(map[*Process]T) + l.eager = make(map[*Process]T) + l.lazy = make(map[*Process]*lazy[T]) l.storeHooks = make(map[*Process]StoreHooks[T]) } + +func (o *lazy[T]) Do() (T, error) { + o.mu.Lock() + defer o.mu.Unlock() + + if o.done.Load() == 0 { + defer o.done.Store(1) + o.value, o.error = o.fn() + } + return o.value, o.error +} diff --git a/pkg/process/process.go b/pkg/process/process.go index b5fcfa5d..f29825c8 100644 --- a/pkg/process/process.go +++ b/pkg/process/process.go @@ -7,7 +7,7 @@ import ( "github.com/gofrs/uuid" ) -// Process represents a unit of execution with data, status, and lifecycle management. +// Process represents a unit of execution with eager, status, and lifecycle management. type Process struct { parent *Process id uuid.UUID @@ -44,7 +44,7 @@ func (p *Process) ID() uuid.UUID { return p.id } -// Keys returns all data keys in the process. +// Keys returns all eager keys in the process. func (p *Process) Keys() []string { p.mu.RLock() defer p.mu.RUnlock() @@ -137,7 +137,7 @@ func (p *Process) Join() { p.wait.Wait() } -// Fork creates a child process with inherited data and context. +// Fork creates a child process with inherited eager and context. func (p *Process) Fork() *Process { p.wait.Add(1) diff --git a/pkg/symbol/cluster.go b/pkg/symbol/cluster.go index b0fb9c8c..5c35b871 100644 --- a/pkg/symbol/cluster.go +++ b/pkg/symbol/cluster.go @@ -1,10 +1,10 @@ package symbol import ( - "github.com/siyul-park/uniflow/pkg/packet" "sync" "github.com/siyul-park/uniflow/pkg/node" + "github.com/siyul-park/uniflow/pkg/packet" "github.com/siyul-park/uniflow/pkg/port" "github.com/siyul-park/uniflow/pkg/process" "github.com/siyul-park/uniflow/pkg/spec" diff --git a/pkg/types/map.go b/pkg/types/map.go index 4c74f5b2..cee31f23 100644 --- a/pkg/types/map.go +++ b/pkg/types/map.go @@ -18,6 +18,8 @@ import ( type Map interface { Value + // Has checks if the map contains the specified key. + Has(key Value) bool // Get retrieves the value associated with the given key. Get(key Value) Value // Set adds or updates a key-value pair in the map. @@ -72,41 +74,46 @@ var _ Map = (*mutableMap)(nil) // NewMap creates a new Map with key-value pairs. func NewMap(pairs ...Value) Map { - value := make(map[uint64][][2]Value, len(pairs)/2) + m := &mutableMap{value: make(map[uint64][][2]Value, len(pairs)/2)} for i := 0; i < len(pairs)/2; i++ { k, v := pairs[i*2], pairs[i*2+1] - - hash := HashOf(k) - exists := false - if elements, ok := value[hash]; ok { - for _, pair := range elements { - if Equal(pair[0], k) { - pair[1] = v - exists = true - break - } + m.Set(k, v) + } + return m.Immutable() +} + +// Has checks if the map contains the specified key. +func (m *immutableMap) Has(key Value) bool { + if bucket, ok := m.value[HashOf(key)]; ok { + low, high := 0, len(bucket)-1 + for low <= high { + mid := low + (high-low)/2 + diff := Compare(bucket[mid][0], key) + if diff == 0 { + return true + } else if diff < 0 { + low = mid + 1 + } else { + high = mid - 1 } } - if !exists { - value[hash] = append(value[hash], [2]Value{k, v}) - } } - - for _, elements := range value { - sort.Slice(elements, func(i, j int) bool { - return Compare(elements[i][0], elements[j][0]) < 0 - }) - } - - return &immutableMap{value: value} + return false } // Get retrieves the value associated with the given key. func (m *immutableMap) Get(key Value) Value { - if elements, ok := m.value[HashOf(key)]; ok { - for _, pair := range elements { - if Equal(pair[0], key) { - return pair[1] + if bucket, ok := m.value[HashOf(key)]; ok { + low, high := 0, len(bucket)-1 + for low <= high { + mid := low + (high-low)/2 + diff := Compare(bucket[mid][0], key) + if diff == 0 { + return bucket[mid][1] + } else if diff < 0 { + low = mid + 1 + } else { + high = mid - 1 } } } @@ -115,11 +122,17 @@ func (m *immutableMap) Get(key Value) Value { // Set adds or updates a key-value pair in the map. func (m *immutableMap) Set(key, val Value) Map { + if Equal(m.Get(key), val) { + return m + } return m.Mutable().Set(key, val).Immutable() } // Delete removes a key-value pair from the map by key. func (m *immutableMap) Delete(key Value) Map { + if !m.Has(key) { + return m + } return m.Mutable().Delete(key).Immutable() } @@ -131,8 +144,8 @@ func (m *immutableMap) Clear() Map { // Keys returns all keys in the map. func (m *immutableMap) Keys() []Value { keys := make([]Value, 0, len(m.value)) - for _, elements := range m.value { - for _, pair := range elements { + for _, bucket := range m.value { + for _, pair := range bucket { keys = append(keys, pair[0]) } } @@ -142,8 +155,8 @@ func (m *immutableMap) Keys() []Value { // Values returns all values in the map. func (m *immutableMap) Values() []Value { values := make([]Value, 0, len(m.value)) - for _, elements := range m.value { - for _, pair := range elements { + for _, bucket := range m.value { + for _, pair := range bucket { values = append(values, pair[1]) } } @@ -153,8 +166,8 @@ func (m *immutableMap) Values() []Value { // Pairs returns all key-value pairs in the map. func (m *immutableMap) Pairs() []Value { pairs := make([]Value, 0, len(m.value)*2) - for _, elements := range m.value { - for _, pair := range elements { + for _, bucket := range m.value { + for _, pair := range bucket { pairs = append(pairs, pair[0], pair[1]) } } @@ -164,8 +177,8 @@ func (m *immutableMap) Pairs() []Value { // Len returns the number of key-value pairs in the map. func (m *immutableMap) Len() int { length := 0 - for _, elements := range m.value { - length += len(elements) + for _, bucket := range m.value { + length += len(bucket) } return length } @@ -200,8 +213,8 @@ func (m *immutableMap) Immutable() Map { // Mutable returns a mutable version of the map. func (m *immutableMap) Mutable() Map { value := make(map[uint64][][2]Value, len(m.value)) - for hash, elements := range m.value { - value[hash] = elements + for hash, bucket := range m.value { + value[hash] = bucket } return &mutableMap{value: value} } @@ -213,8 +226,8 @@ func (m *immutableMap) Map() map[any]any { } values := make(map[any]any, len(m.value)) - for _, elements := range m.value { - for _, pair := range elements { + for _, bucket := range m.value { + for _, pair := range bucket { k, v := pair[0], pair[1] values[InterfaceOf(k)] = InterfaceOf(v) } @@ -269,8 +282,8 @@ func (m *immutableMap) Interface() any { values := make([]any, 0, len(m.value)) hashable := true - for _, elements := range m.value { - for _, pair := range elements { + for _, bucket := range m.value { + for _, pair := range bucket { k, v := pair[0], pair[1] keys = append(keys, InterfaceOf(k)) @@ -366,6 +379,11 @@ func (m *immutableMap) Compare(other Value) int { return compare(m.Kind(), KindOf(other)) } +// Has checks if the map contains the specified key. +func (m *mutableMap) Has(key Value) bool { + return m.Immutable().Has(key) +} + // Get retrieves the value associated with the given key. func (m *mutableMap) Get(key Value) Value { return m.Immutable().Get(key) @@ -374,44 +392,57 @@ func (m *mutableMap) Get(key Value) Value { // Set adds or updates a key-value pair in the map. func (m *mutableMap) Set(key, val Value) Map { hash := HashOf(key) - exists := false - if elements, ok := m.value[hash]; ok { - modify := make([][2]Value, len(elements)) - copy(modify, elements) - - for _, pair := range modify { - if Equal(pair[0], key) { - pair[1] = val - exists = true - break - } + bucket := m.value[hash] + + modify := make([][2]Value, len(bucket)) + copy(modify, bucket) + + ok := false + + low, high := 0, len(modify)-1 + for low <= high { + mid := low + (high-low)/2 + diff := Compare(modify[mid][0], key) + if diff == 0 { + modify[mid][1] = val + ok = true + break + } else if diff < 0 { + low = mid + 1 + } else { + high = mid - 1 } - - m.value[hash] = modify - } - if !exists { - m.value[hash] = append(m.value[hash], [2]Value{key, val}) } - elements := m.value[hash] - sort.Slice(elements, func(i, j int) bool { - return Compare(elements[i][0], elements[j][0]) < 0 - }) + if !ok { + idx := low + modify = append(modify, [2]Value{}) + copy(modify[idx+1:], modify[idx:]) + modify[idx] = [2]Value{key, val} + } + m.value[hash] = modify return m } // Delete removes a key-value pair from the map by key. func (m *mutableMap) Delete(key Value) Map { hash := HashOf(key) - if elements, ok := m.value[hash]; ok { - modify := make([][2]Value, len(elements)) - copy(modify, elements) - - for i, pair := range modify { - if Equal(pair[0], key) { - modify = append(modify[:i], modify[i+1:]...) + if bucket, ok := m.value[hash]; ok { + modify := make([][2]Value, len(bucket)) + copy(modify, bucket) + + low, high := 0, len(modify)-1 + for low <= high { + mid := low + (high-low)/2 + diff := Compare(modify[mid][0], key) + if diff == 0 { + modify = append(modify[:mid], modify[mid+1:]...) break + } else if diff < 0 { + low = mid + 1 + } else { + high = mid - 1 } } @@ -421,7 +452,6 @@ func (m *mutableMap) Delete(key Value) Map { delete(m.value, hash) } } - return m } diff --git a/pkg/types/map_test.go b/pkg/types/map_test.go index 280ce13b..eabd1aa5 100644 --- a/pkg/types/map_test.go +++ b/pkg/types/map_test.go @@ -20,20 +20,47 @@ func TestNewMap(t *testing.T) { assert.Equal(t, map[any]any{k1.Interface(): v1.Interface()}, o.Map()) } -func TestMap_GetAndSetAndDelete(t *testing.T) { +func TestMap_Has(t *testing.T) { + k1 := NewString(faker.UUIDHyphenated()) + v1 := NewString(faker.UUIDHyphenated()) + + o := NewMap(k1, v1) + + ok := o.Has(k1) + assert.True(t, ok) +} + +func TestMap_Get(t *testing.T) { + k1 := NewString(faker.UUIDHyphenated()) + v1 := NewString(faker.UUIDHyphenated()) + + o := NewMap(k1, v1) + + r := o.Get(k1) + assert.Equal(t, v1, r) +} + +func TestMap_Set(t *testing.T) { k1 := NewString(faker.UUIDHyphenated()) v1 := NewString(faker.UUIDHyphenated()) o := NewMap() o = o.Set(k1, v1) - r1 := o.Get(k1) - assert.Equal(t, v1, r1) + r := o.Get(k1) + assert.Equal(t, v1, r) +} + +func TestMap_Delete(t *testing.T) { + k1 := NewString(faker.UUIDHyphenated()) + v1 := NewString(faker.UUIDHyphenated()) + + o := NewMap(k1, v1) o = o.Delete(k1) - r2 := o.Get(k1) - assert.Nil(t, r2) + ok := o.Has(k1) + assert.False(t, ok) } func TestMap_Keys(t *testing.T) { @@ -244,23 +271,62 @@ func TestMap_Decode(t *testing.T) { }) } +func BenchmarkMap_Has(b *testing.B) { + key := NewString(faker.UUIDHyphenated()) + value := NewString(faker.UUIDHyphenated()) + + m := NewMap(key, value) + for i := 0; i < 100; i++ { + m = m.Set(NewString(faker.UUIDHyphenated()), NewString(faker.UUIDHyphenated())) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Has(key) + } +} + func BenchmarkMap_Set(b *testing.B) { m := NewMap() + for i := 0; i < 100; i++ { + m = m.Set(NewString(faker.UUIDHyphenated()), NewString(faker.UUIDHyphenated())) + } + + key := NewString(faker.UUIDHyphenated()) + value := NewString(faker.UUIDHyphenated()) for i := 0; i < b.N; i++ { - m = m.Set(NewString(faker.UUIDHyphenated()), NewString(faker.UUIDHyphenated())) + m.Set(key, value) } } func BenchmarkMap_Get(b *testing.B) { - m := NewMap() - for i := 0; i < 1000; i++ { + key := NewString(faker.UUIDHyphenated()) + value := NewString(faker.UUIDHyphenated()) + + m := NewMap(key, value) + for i := 0; i < 100; i++ { m = m.Set(NewString(faker.UUIDHyphenated()), NewString(faker.UUIDHyphenated())) } b.ResetTimer() for i := 0; i < b.N; i++ { - _ = m.Get(NewString(faker.UUIDHyphenated())) + m.Get(key) + } +} + +func BenchmarkMap_Delete(b *testing.B) { + key := NewString(faker.UUIDHyphenated()) + value := NewString(faker.UUIDHyphenated()) + + m := NewMap(key, value) + for i := 0; i < 100; i++ { + m = m.Set(NewString(faker.UUIDHyphenated()), NewString(faker.UUIDHyphenated())) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + m.Delete(key) } } @@ -337,30 +403,61 @@ func BenchmarkMap_Decode(b *testing.B) { dec.Add(newStringDecoder()) dec.Add(newMapDecoder(dec)) - b.Run("map", func(b *testing.B) { - v := NewMap( - NewString("foo"), NewString("foo"), - NewString("bar"), NewString("bar"), - ) + b.Run("static", func(b *testing.B) { + b.Run("map", func(b *testing.B) { + v := NewMap( + NewString("foo"), NewString("foo"), + NewString("bar"), NewString("bar"), + ) - for i := 0; i < b.N; i++ { - var decoded map[string]string - _ = dec.Decode(v, &decoded) - } + for i := 0; i < b.N; i++ { + var decoded map[string]string + _ = dec.Decode(v, &decoded) + } + }) + + b.Run("struct", func(b *testing.B) { + v := NewMap( + NewString("foo"), NewString("foo"), + NewString("bar"), NewString("bar"), + ) + + for i := 0; i < b.N; i++ { + var decoded struct { + Foo string `map:"foo"` + Bar string `map:"bar"` + } + _ = dec.Decode(v, &decoded) + } + }) }) - b.Run("struct", func(b *testing.B) { - v := NewMap( - NewString("foo"), NewString("foo"), - NewString("bar"), NewString("bar"), - ) + b.Run("dynamic", func(b *testing.B) { + b.Run("map", func(b *testing.B) { + v := NewMap( + NewString("foo"), NewString("foo"), + NewString("bar"), NewString("bar"), + ) - for i := 0; i < b.N; i++ { - var decoded struct { - Foo string `map:"foo"` - Bar string `map:"bar"` + for i := 0; i < b.N; i++ { + var decoded map[any]any + _ = dec.Decode(v, &decoded) } - _ = dec.Decode(v, &decoded) - } + }) + + b.Run("struct", func(b *testing.B) { + v := NewMap( + NewString("foo"), NewString("foo"), + NewString("bar"), NewString("bar"), + ) + + for i := 0; i < b.N; i++ { + var decoded struct { + Foo any `map:"foo"` + Bar any `map:"bar"` + } + _ = dec.Decode(v, &decoded) + } + }) }) } diff --git a/pkg/types/slice.go b/pkg/types/slice.go index 1cf80620..6316af92 100644 --- a/pkg/types/slice.go +++ b/pkg/types/slice.go @@ -28,17 +28,17 @@ func NewSlice(elements ...Value) Slice { } // Prepend adds a value to the beginning of the slice. -func (s Slice) Prepend(value Value) Slice { - return &_slice{value: append([]Value{value}, s.value...)} +func (s Slice) Prepend(elements ...Value) Slice { + return &_slice{value: append(elements, s.value...)} } // Append adds a value to the end of the slice. -func (s Slice) Append(value Value) Slice { - elements := make([]Value, len(s.value), len(s.value)+1) - copy(elements, s.value) - elements = append(elements, value) +func (s Slice) Append(elements ...Value) Slice { + value := make([]Value, len(s.value), len(s.value)+len(elements)) + copy(value, s.value) + value = append(value, elements...) - return &_slice{value: elements} + return &_slice{value: value} } // Sub returns a new slice that is a sub-slice of the original slice. @@ -210,7 +210,6 @@ func newSliceEncoder(encoder *encoding.EncodeAssembler[any, Value]) encoding.Enc values := make([]Value, 0, s.Len()) for i := 0; i < s.Len(); i++ { v := s.Index(i) - if value, err := valueEncoder.Encode(v.Interface()); err != nil { return nil, err } else { @@ -225,38 +224,41 @@ func newSliceEncoder(encoder *encoding.EncodeAssembler[any, Value]) encoding.Enc } func newSliceDecoder(decoder *encoding.DecodeAssembler[Value, any]) encoding.DecodeCompiler[Value] { - setElement := func(source Value, target reflect.Value, i int) error { - v := reflect.New(target.Type().Elem()) - if err := decoder.Decode(source, v.Interface()); err != nil { - return err - } - - if target.Len() < i+1 { - if target.Kind() != reflect.Slice { - return errors.WithStack(encoding.ErrUnsupportedValue) - } else { - target.Set(reflect.Append(target, v.Elem()).Convert(target.Type())) - } - } else { - target.Index(i).Set(v.Elem().Convert(target.Type().Elem())) - } - return nil - } - return encoding.DecodeCompilerFunc[Value](func(typ reflect.Type) (encoding.Decoder[Value, unsafe.Pointer], error) { if typ != nil && typ.Kind() == reflect.Pointer { if typ.Elem().Kind() == reflect.Array || typ.Elem().Kind() == reflect.Slice { + valueDecoder, err := decoder.Compile(reflect.PointerTo(typ.Elem().Elem())) + if err != nil { + return nil, err + } + return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { t := reflect.NewAt(typ.Elem(), target).Elem() if s, ok := source.(Slice); ok { + for t.Len() < s.Len() { + if t.Kind() != reflect.Slice { + return errors.WithStack(encoding.ErrUnsupportedValue) + } else { + t.Set(reflect.Append(t, reflect.Zero(t.Type().Elem()))) + } + } + for i, v := range s.Range() { - if err := setElement(v, t, i); err != nil { + if err := valueDecoder.Decode(v, t.Index(i).Addr().UnsafePointer()); err != nil { return err } } return nil } - return setElement(source, t, 0) + + if t.Len() == 0 { + if t.Kind() != reflect.Slice { + return errors.WithStack(encoding.ErrUnsupportedValue) + } else { + t.Set(reflect.Append(t, reflect.Zero(t.Type().Elem()))) + } + } + return valueDecoder.Decode(source, t.Index(0).Addr().UnsafePointer()) }), nil } else if typ.Elem().Kind() == reflect.Interface { return encoding.DecodeFunc(func(source Value, target unsafe.Pointer) error { diff --git a/pkg/types/slice_test.go b/pkg/types/slice_test.go index 64acbad3..fe6ae96c 100644 --- a/pkg/types/slice_test.go +++ b/pkg/types/slice_test.go @@ -20,22 +20,28 @@ func TestNewSlice(t *testing.T) { assert.Equal(t, []any{v1.Interface()}, o.Slice()) } -func TestSlice_GetAndSet(t *testing.T) { +func TestSlice_Get(t *testing.T) { v1 := NewString(faker.UUIDHyphenated()) - v2 := NewString(faker.UUIDHyphenated()) o := NewSlice(v1) - r1 := o.Get(0) - assert.Equal(t, v1, r1) + r := o.Get(0) + assert.Equal(t, v1, r) + + r = o.Get(1) + assert.Nil(t, r) +} - r2 := o.Get(1) - assert.Nil(t, r2) +func TestSlice_Set(t *testing.T) { + v1 := NewString(faker.UUIDHyphenated()) + v2 := NewString(faker.UUIDHyphenated()) + + o := NewSlice(v1) o = o.Set(0, v2) - r3 := o.Get(0) - assert.Equal(t, v2, r3) + r := o.Get(0) + assert.Equal(t, v2, r) } func TestSlice_Prepend(t *testing.T) { @@ -164,34 +170,68 @@ func TestSlice_Decode(t *testing.T) { dec.Add(newStringDecoder()) dec.Add(newSliceDecoder(dec)) - t.Run("slice", func(t *testing.T) { - source := []string{"foo", "bar"} - v := NewSlice(NewString("foo"), NewString("bar")) + t.Run("static", func(t *testing.T) { + t.Run("slice", func(t *testing.T) { + source := []string{"foo", "bar"} + v := NewSlice(NewString("foo"), NewString("bar")) - var decoded []string - err := dec.Decode(v, &decoded) - assert.NoError(t, err) - assert.Equal(t, source, decoded) - }) + var decoded []string + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) - t.Run("array", func(t *testing.T) { - source := []string{"foo", "bar"} - v := NewSlice(NewString("foo"), NewString("bar")) + t.Run("array", func(t *testing.T) { + source := []string{"foo", "bar"} + v := NewSlice(NewString("foo"), NewString("bar")) - var decoded [2]string - err := dec.Decode(v, &decoded) - assert.NoError(t, err) - assert.EqualValues(t, source, decoded) + var decoded [2]string + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.EqualValues(t, source, decoded) + }) + + t.Run("element", func(t *testing.T) { + source := []string{"foo"} + v := NewString("foo") + + var decoded []string + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) }) - t.Run("element", func(t *testing.T) { - source := []string{"foo"} - v := NewString("foo") + t.Run("dynamic", func(t *testing.T) { + t.Run("slice", func(t *testing.T) { + source := []any{"foo", "bar"} + v := NewSlice(NewString("foo"), NewString("bar")) + + var decoded []any + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) + + t.Run("array", func(t *testing.T) { + source := []any{"foo", "bar"} + v := NewSlice(NewString("foo"), NewString("bar")) + + var decoded [2]any + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.EqualValues(t, source, decoded) + }) + + t.Run("element", func(t *testing.T) { + source := []any{"foo"} + v := NewString("foo") - var decoded []string - err := dec.Decode(v, &decoded) - assert.NoError(t, err) - assert.Equal(t, source, decoded) + var decoded []any + err := dec.Decode(v, &decoded) + assert.NoError(t, err) + assert.Equal(t, source, decoded) + }) }) } @@ -285,30 +325,61 @@ func BenchmarkSlice_Decode(b *testing.B) { dec.Add(newStringDecoder()) dec.Add(newSliceDecoder(dec)) - b.Run("slice", func(b *testing.B) { - v := NewSlice(NewString("foo"), NewString("bar")) + b.Run("static", func(b *testing.B) { + b.Run("slice", func(b *testing.B) { + v := NewSlice(NewString("foo"), NewString("bar")) - for i := 0; i < b.N; i++ { - var decoded []string - _ = dec.Decode(v, &decoded) - } - }) + for i := 0; i < b.N; i++ { + var decoded []string + _ = dec.Decode(v, &decoded) + } + }) - b.Run("array", func(b *testing.B) { - v := NewSlice(NewString("foo"), NewString("bar")) + b.Run("array", func(b *testing.B) { + v := NewSlice(NewString("foo"), NewString("bar")) - for i := 0; i < b.N; i++ { - var decoded [2]string - _ = dec.Decode(v, &decoded) - } + for i := 0; i < b.N; i++ { + var decoded [2]string + _ = dec.Decode(v, &decoded) + } + }) + + b.Run("element", func(b *testing.B) { + v := NewString("foo") + + for i := 0; i < b.N; i++ { + var decoded []string + _ = dec.Decode(v, &decoded) + } + }) }) - b.Run("element", func(b *testing.B) { - v := NewString("foo") + b.Run("dynamic", func(b *testing.B) { + b.Run("slice", func(b *testing.B) { + v := NewSlice(NewString("foo"), NewString("bar")) - for i := 0; i < b.N; i++ { - var decoded []string - _ = dec.Decode(v, &decoded) - } + for i := 0; i < b.N; i++ { + var decoded []any + _ = dec.Decode(v, &decoded) + } + }) + + b.Run("array", func(b *testing.B) { + v := NewSlice(NewString("foo"), NewString("bar")) + + for i := 0; i < b.N; i++ { + var decoded [2]any + _ = dec.Decode(v, &decoded) + } + }) + + b.Run("element", func(b *testing.B) { + v := NewString("foo") + + for i := 0; i < b.N; i++ { + var decoded []any + _ = dec.Decode(v, &decoded) + } + }) }) }