diff --git a/go.mod b/go.mod index 047ac7e..8f3d416 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( github.com/gofiber/swagger v1.1.1 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 - github.com/gorilla/sessions v1.4.0 github.com/gorilla/websocket v1.5.3 github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 github.com/lib/pq v1.10.9 @@ -58,6 +57,7 @@ require ( github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect + github.com/fasthttp/websocket v1.5.3 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.12 // indirect @@ -72,10 +72,10 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/gobuffalo/pop/v6 v6.1.1 // indirect + github.com/gofiber/websocket/v2 v2.2.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/gorilla/securecookie v1.1.2 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.7 // indirect @@ -118,6 +118,7 @@ require ( github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee // indirect github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect diff --git a/go.sum b/go.sum index bcd298f..1472026 100644 --- a/go.sum +++ b/go.sum @@ -69,6 +69,8 @@ github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0o github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/evanw/esbuild v0.27.2 h1:3xBEws9y/JosfewXMM2qIyHAi+xRo8hVx475hVkJfNg= github.com/evanw/esbuild v0.27.2/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= +github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQtdek= +github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= @@ -138,6 +140,8 @@ github.com/gofiber/fiber/v2 v2.52.10 h1:jRHROi2BuNti6NYXmZ6gbNSfT3zj/8c0xy94GOU5 github.com/gofiber/fiber/v2 v2.52.10/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/gofiber/swagger v1.1.1 h1:FZVhVQQ9s1ZKLHL/O0loLh49bYB5l1HEAgxDlcTtkRA= github.com/gofiber/swagger v1.1.1/go.mod h1:vtvY/sQAMc/lGTUCg0lqmBL7Ht9O7uzChpbvJeJQINw= +github.com/gofiber/websocket/v2 v2.2.1 h1:C9cjxvloojayOp9AovmpQrk8VqvVnT8Oao3+IUygH7w= +github.com/gofiber/websocket/v2 v2.2.1/go.mod h1:Ao/+nyNnX5u/hIFPuHl28a+NIkrqK7PRimyKaj4JxVU= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -151,18 +155,12 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= -github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= -github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= -github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= -github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 h1:6UKoz5ujsI55KNpsJH3UwCq3T8kKbZwNZBNPuTTje8U= @@ -370,6 +368,8 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= +github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761 h1:0b8DF5kR0PhRoRXDiEEdzrgBc8UqVY4JWLkQJCRsLME= github.com/seatgeek/logrus-gelf-formatter v0.0.0-20210414080842-5b05eb8ff761/go.mod h1:/THDZYi7F/BsVEcYzYPqdcWFQ+1C2InkawTKfLOAnzg= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= diff --git a/lib/changeset/changeset.go b/lib/changeset/changeset.go index 20f2b0b..f387924 100644 --- a/lib/changeset/changeset.go +++ b/lib/changeset/changeset.go @@ -134,6 +134,10 @@ func MakeSplice(orig string, start int, ndel int, ins string, attribs *string, p stringAttribs: &emptyStringAttribs, } + if attribs == nil { + attribs = &emptyStringAttribs + } + var equalOps = OpsFromText("=", utils.RuneSlice(orig, 0, start), &keepArgsToUse, nil) var deletedOps = OpsFromText("-", deleted, &keepArgsToUse, nil) var insertedOps = OpsFromText("+", ins, &KeepArgs{ diff --git a/lib/cli/cli.go b/lib/cli/cli.go new file mode 100644 index 0000000..6c32a43 --- /dev/null +++ b/lib/cli/cli.go @@ -0,0 +1,626 @@ +package cli + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "unicode/utf8" + + "flag" + "io" + "time" + + "github.com/ether/etherpad-go/lib/apool" + "github.com/ether/etherpad-go/lib/changeset" + "github.com/ether/etherpad-go/lib/models/ws" + "github.com/ether/etherpad-go/lib/utils" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +type Pad struct { + logger *zap.SugaredLogger + host string + padId string + apool *apool.APool + baseRev int + atext *apool.AText + conn *websocket.Conn + connWrite sync.Mutex + poolLock sync.RWMutex + events map[string][]func(interface{}) + closeChan chan struct{} + closeOnce sync.Once + inFlight *PadChangeset + outgoing *PadChangeset +} + +func NewPad(host, padId string, conn *websocket.Conn, logger *zap.SugaredLogger) *Pad { + return &Pad{ + logger: logger, + host: host, + padId: padId, + conn: conn, + connWrite: sync.Mutex{}, + events: make(map[string][]func(interface{})), + closeChan: make(chan struct{}), + } +} + +func (p *Pad) On(event string, handler func(interface{})) { + p.events[event] = append(p.events[event], handler) +} + +func (p *Pad) emit(event string, data interface{}) { + for _, handler := range p.events[event] { + go handler(data) + } +} + +func (p *Pad) Close() { + p.closeOnce.Do(func() { + close(p.closeChan) + if p.conn != nil { + _ = p.conn.Close() + } + p.emit("disconnect", nil) + }) +} + +func (p *Pad) Append(text string) { + // Acquire lock while we read/modify shared pad state (atext, apool, baseRev) + p.poolLock.Lock() + if p.atext == nil || p.apool == nil { + p.poolLock.Unlock() + p.logger.Errorf("Pad is not initialized") + return + } + + if len(text) == 0 { + p.logger.Warnf("No text to append - changeset will not be created.") + p.poolLock.Unlock() + return + } + + if text == "\n" && strings.HasSuffix(p.atext.Text, "\n") { + p.logger.Infof("Pad already ends with newline - not appending another.") + p.poolLock.Unlock() + return + } + + if text[len(text)-1] != '\n' { + text += "\n" + } + + start := utf8.RuneCountInString(p.atext.Text) + emptyAttribs := "" + newChangeset, err := changeset.MakeSplice(p.atext.Text, start, 0, text, &emptyAttribs, p.apool) + if err != nil { + p.poolLock.Unlock() + p.logger.Errorf("Error creating changeset: %v", err) + return + } + + // Unpack and repack to ensure canonical form + unpacked, err := changeset.Unpack(newChangeset) + if err != nil { + p.poolLock.Unlock() + p.logger.Errorf("Error unpacking changeset: %v", err) + return + } + newChangeset = changeset.Pack(unpacked.OldLen, unpacked.NewLen, unpacked.Ops, unpacked.CharBank) + + // Validate generated changeset header: oldLen should equal current local text length + if unpacked.OldLen != start { + p.poolLock.Unlock() + p.logger.Errorf("Generated changeset oldLen mismatch: expected %d got %d; not sending", start, unpacked.OldLen) + // emit an error event so callers/tests can react + p.emit("append_error", map[string]interface{}{"error": "oldLen_mismatch", "expected": start, "got": unpacked.OldLen}) + return + } + + newAText, err := changeset.ApplyToAText(newChangeset, *p.atext, *p.apool) + if err != nil { + p.poolLock.Unlock() + p.logger.Errorf("Error applying changeset: %v", err) + return + } + + p.atext = newAText + baseRev := p.baseRev + p.poolLock.Unlock() + + // Queue the changeset for sending + pc := &PadChangeset{changeset: newChangeset, baseRev: baseRev} + p.sendMessage(pc) +} + +type PadState struct { + Host string + Path string + PadId string +} + +func Connect(host string, logger *zap.SugaredLogger) *Pad { + return connect(host, logger) +} + +func connect(host string, logger *zap.SugaredLogger) *Pad { + + padState := PadState{} + + if host == "" { + padState.Host = "http://127.0.0.1:9001" + padState.Path = "/p/test" + padState.PadId = utils.RandomString(10) + } else { + parsedUrl, err := url.Parse(host) + if err != nil { + logger.Warnf("Invalid host URL: %v", err) + os.Exit(1) + } + padState.Host = fmt.Sprintf("%s://%s", parsedUrl.Scheme, parsedUrl.Host) + const padIdParam = "/p/" + indexOfPadId := strings.Index(parsedUrl.Path, padIdParam) + if indexOfPadId == -1 { + padState.Path = "" + padState.PadId = utils.RandomString(10) + } else { + padState.Path = parsedUrl.Path[0:indexOfPadId] + padState.PadId = parsedUrl.Path[indexOfPadId+len(padIdParam):] + } + } + + httpClient := &http.Client{} + fullUrl := fmt.Sprintf("%s%s/p/%s", padState.Host, padState.Path, padState.PadId) + logger.Infof("Getting Pad at %s", fullUrl) + resp, err := httpClient.Get(fullUrl) + if err != nil { + logger.Errorf("Failed to connect to pad at %s: %v", fullUrl, err) + os.Exit(1) + } + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + logger.Errorf("Failed to connect to pad at %s, status: %s, body: %s", fullUrl, resp.Status, string(body)) + os.Exit(1) + } + defer func() { + _ = resp.Body.Close() + }() + + wsUrl := fmt.Sprintf("%s/%ssocket.io", strings.Replace(padState.Host, "http", "ws", 1), padState.Path) + logger.Infof("Connecting to WebSocket at %s", wsUrl) + connection, resp, err := websocket.DefaultDialer.Dial(wsUrl, nil) + if err != nil { + logger.Errorf("WebSocket connection failed: %v", err) + if resp != nil { + logger.Warnf("Response Status: %s", resp.Status) + } + os.Exit(1) + } + + var authorToken = "t." + utils.RandomString(20) + + pad := NewPad(padState.Host, padState.PadId, connection, logger) + + go func() { + // Recover to avoid crashing the whole process on unexpected panics in the reader loop + defer func() { + if r := recover(); r != nil { + logger.Errorf("panic in recv goroutine: %v", r) + pad.emit("disconnect", r) + _ = connection.Close() + } + pad.Close() + }() + + var ( + newline = []byte{'\n'} + space = []byte{' '} + ) + for { + select { + case <-pad.closeChan: + return + default: + _, message, err := connection.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + logger.Errorf("error: %v", err) + } + pad.emit("disconnect", err) + return + } + message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + logger.Debugf("Received: %s", message) + + var arr []interface{} + isArrayFormat := false + if err := json.Unmarshal(message, &arr); err == nil && len(arr) == 2 { + isArrayFormat = true + } + + if isArrayFormat { + msgType, _ := arr[0].(string) + if msgType != "message" { + continue + + } + msgObj := arr[1] + msgMap, ok := msgObj.(map[string]interface{}) + if !ok { + logger.Errorf("Fehler beim Casten der Nachricht zu map[string]interface{} (Array-Format)") + continue + } + typeStr, _ := msgMap["type"].(string) + switch typeStr { + case "CLIENT_VARS": + data, ok := msgMap["data"].(map[string]interface{}) + if !ok { + logger.Errorf("CLIENT_VARS: Data fehlt oder hat falschen Typ (Array-Format)") + continue + } + collabVars, ok := data["collab_client_vars"].(map[string]interface{}) + if !ok { + logger.Errorf("CLIENT_VARS: collab_client_vars fehlt oder hat falschen Typ (Array-Format)") + continue + } + initText, _ := collabVars["initialAttributedText"].(map[string]interface{}) + atext := apool.AText{ + Text: initText["text"].(string), + Attribs: initText["attribs"].(string), + } + pad.emit("numConnectedUsers", collabVars["numConnectedUsers"]) + apoolMap, _ := collabVars["apool"].(map[string]interface{}) + pool := apool.NewAPool() + if numToAttrib, ok := apoolMap["numToAttrib"].(map[string]interface{}); ok { + for k, v := range numToAttrib { + idx, err := strconv.Atoi(k) + if err != nil { + continue + } + if arr, ok := v.([]interface{}); ok && len(arr) == 2 { + attr := apool.Attribute{ + Key: arr[0].(string), + Value: arr[1].(string), + } + pool.NumToAttrib[idx] = attr + } + } + } + if nextNum, ok := apoolMap["nextNum"].(float64); ok { + pool.NextNum = int(nextNum) + } + // protect setting shared fields + pad.poolLock.Lock() + pad.apool = &pool + if rev, ok := collabVars["rev"].(float64); ok { + pad.baseRev = int(rev) + } + pad.atext = &atext + pad.poolLock.Unlock() + pad.emit("connected", nil) + case "COLLABROOM": + data, ok := msgMap["data"].(map[string]interface{}) + if !ok { + continue + } + if data["type"] == "NEW_CHANGES" { + // Ensure we have initial state + pad.poolLock.RLock() + havePool := pad.apool != nil + haveAText := pad.atext != nil + pad.poolLock.RUnlock() + if !havePool || !haveAText { + logger.Errorf("received NEW_CHANGES but pad.apool or pad.atext is nil - skipping") + continue + } + if newRev, ok := data["newRev"].(float64); ok && int(newRev) <= pad.baseRev { + continue + } + if newRev, ok := data["newRev"].(float64); ok { + if int(newRev)-1 != pad.baseRev { + logger.Errorf("wrong incoming revision :%v/%v", int(newRev), pad.baseRev) + continue + } + } + wireApool := apool.NewAPool() + if apoolMap, ok := data["apool"].(map[string]interface{}); ok { + if numToAttrib, ok := apoolMap["numToAttrib"].(map[string]interface{}); ok { + for k, v := range numToAttrib { + idx, err := strconv.Atoi(k) + if err != nil { + continue + } + if arr, ok := v.([]interface{}); ok && len(arr) == 2 { + attr := apool.Attribute{ + Key: arr[0].(string), + Value: arr[1].(string), + } + wireApool.NumToAttrib[idx] = attr + } + } + } + if nextNum, ok := apoolMap["nextNum"].(float64); ok { + wireApool.NextNum = int(nextNum) + } + } + changesetStr, _ := data["changeset"].(string) + // Re-read pad.apool and pad.atext under read lock to ensure stability + pad.poolLock.RLock() + localPool := pad.apool + localAText := pad.atext + pad.poolLock.RUnlock() + if localPool == nil || localAText == nil { + logger.Errorf("pad.apool or pad.atext became nil while processing NEW_CHANGES - skipping") + continue + } + serverChangeset := changeset.MoveOpsToNewPool(changesetStr, &wireApool, localPool) + server := &PadChangeset{changeset: serverChangeset} + // Validate server changeset header before attempting to apply it + if unpacked, err := changeset.Unpack(server.changeset); err != nil { + logger.Errorf("cannot unpack server changeset: %v - skipping", err) + continue + } else if utf8.RuneCountInString(localAText.Text) != unpacked.OldLen { + logger.Errorf("server changeset oldLen %d does not match local text length %d - skipping", unpacked.OldLen, utf8.RuneCountInString(localAText.Text)) + continue + } + if pad.inFlight != nil { + transformX(pad.inFlight, server, localPool) + } + if pad.outgoing != nil { + transformX(pad.outgoing, server, localPool) + if newRev, ok := data["newRev"].(float64); ok { + pad.outgoing.baseRev = int(newRev) + } + } + atext, err := changeset.ApplyToAText(server.changeset, *localAText, *localPool) + if err != nil { + logger.Errorf("Fehler beim Anwenden des Changesets: %v", err) + continue + } + // write back updated atext under lock + pad.poolLock.Lock() + pad.atext = atext + pad.poolLock.Unlock() + if newRev, ok := data["newRev"].(float64); ok { + pad.baseRev = int(newRev) + } + pad.emit("newContents", atext) + } + if data["type"] == "ACCEPT_COMMIT" { + if newRev, ok := data["newRev"].(float64); ok && int(newRev) <= pad.baseRev { + continue + } + if newRev, ok := data["newRev"].(float64); ok { + if int(newRev)-1 != pad.baseRev { + logger.Errorf("wrong incoming revision :%v/%v", int(newRev), pad.baseRev) + continue + } + pad.baseRev = int(newRev) + pad.inFlight = nil + if pad.outgoing != nil { + pad.outgoing.baseRev = int(newRev) + } + pad.sendMessage(nil) + } + } + } + } + var obj map[string]interface{} + if err := json.Unmarshal(message, &obj); err == nil { + event, _ := obj["event"].(string) + if event == "message" { + data, ok := obj["data"].(map[string]interface{}) + if ok { + typeStr, _ := data["type"].(string) + if typeStr == "CLIENT_READY" { + pad.emit("connected", nil) + } + pad.emit("message", data) + } + } + } + } + } + }() + + if err := connection.WriteJSON(ws.ClientReady{ + Event: "message", + Data: ws.ClientReadyData{ + Component: "pad", + Type: "CLIENT_READY", + PadID: padState.PadId, + Token: authorToken, + UserInfo: ws.ClientReadyUserInfo{ + ColorId: nil, + Name: nil, + }, + }, + }); err != nil { + logger.Errorf("Fehler beim Senden von CLIENT_READY: %v", err) + } + + return pad +} + +func transformX(client, server *PadChangeset, pool *apool.APool) { + if cs, err := changeset.Follow(server.changeset, client.changeset, false, pool); err == nil && cs != nil { + client.changeset = *cs + } + if cs, err := changeset.Follow(client.changeset, server.changeset, true, pool); err == nil && cs != nil { + server.changeset = *cs + } +} + +type PadChangeset struct { + changeset string + baseRev int +} + +func (p *Pad) sendMessage(optMsg *PadChangeset) { + if optMsg != nil { + if p.outgoing != nil { + if optMsg.baseRev != p.outgoing.baseRev { + p.logger.Warnf("Dropping outgoing changeset due to baseRev mismatch") + return + } + tempStr, err := changeset.Compose(p.outgoing.changeset, optMsg.changeset, p.apool) + if err != nil { + p.logger.Errorf("Error composing outgoing changesets: %v", err) + return + } + p.outgoing.changeset = *tempStr + } else { + p.outgoing = optMsg + } + } + if p.inFlight == nil && p.outgoing != nil { + p.inFlight = p.outgoing + p.outgoing = nil + apoolCreated := apool.NewAPool() + changeset.MoveOpsToNewPool(p.inFlight.changeset, p.apool, &apoolCreated) + wirePool := apoolCreated.ToJsonable() + p.logger.Debugf("Sending changeset: %s", p.inFlight.changeset) + msg := ws.UserChange{ + Event: "message", + Data: ws.UserChangeData{ + Type: "COLLABROOM", + Component: "pad", + Data: ws.UserChangeDataData{ + Type: "USER_CHANGES", + BaseRev: p.inFlight.baseRev, + Changeset: p.inFlight.changeset, + Apool: ws.UserChangeDataDataApool{ + NumToAttrib: wirePool.NumToAttribRaw, + NextNum: wirePool.NextNum, + }, + }, + }, + } + p.connWrite.Lock() + defer p.connWrite.Unlock() + _ = p.conn.WriteJSON(msg) + } +} + +func (p *Pad) OnConnected(callback func(padState *Pad)) { + p.On("connected", func(data interface{}) { + callback(p) + }) +} + +func (p *Pad) OnNumConnectedUsers(callback func(count int)) { + p.On("numConnectedUsers", func(data interface{}) { + if count, ok := data.(float64); ok { + callback(int(count)) + } + }) +} + +func (p *Pad) OnDisconnect(callback func(err interface{})) { + p.On("disconnect", func(data interface{}) { + callback(data) + }) +} + +func (p *Pad) OnMessage(callback func(msg map[string]interface{})) { + p.On("message", func(data interface{}) { + if msg, ok := data.(map[string]interface{}); ok { + callback(msg) + } + }) +} + +func (p *Pad) OnNewContents(callback func(atext apool.AText)) { + p.On("newContents", func(data interface{}) { + if atext, ok := data.(*apool.AText); ok { + if atext != nil { + callback(*atext) + } + } else { + println("OnNewContents: invalid data type received") + } + }) +} + +func RunFromCLI(logger *zap.SugaredLogger, args []string) { + host, appendStr, err := parseCLIArgs(args) + if err != nil { + return + } + + if host == "" { + logger.Warnf("No host specified..") + return + } + + if appendStr != "" { + pad := connect(host, logger) + pad.OnConnected(func(_ *Pad) { + logger.Infof("CLI Connected, appending...") + pad.Append(appendStr) + logger.Infof("Appended %q to %s", appendStr, host) + if os.Getenv("GO_TEST_MODE") == "true" { + pad.emit("append_done", nil) + } else { + os.Exit(0) + } + }) + if os.Getenv("GO_TEST_MODE") == "true" { + done := make(chan struct{}) + pad.On("append_done", func(_ interface{}) { + close(done) + }) + select { + case <-done: + pad.Close() + return + case <-time.After(10 * time.Second): + logger.Warnf("Append timeout") + pad.Close() + return + } + } else { + select {} + } + } else { + pad := connect(host, logger) + pad.OnConnected(func(padState *Pad) { + logger.Infof("Connected to %s with padId %s", padState.host, padState.padId) + logger.Debugf("Pad Contents: \n%s", padState.atext.Text) + }) + pad.OnNewContents(func(atext apool.AText) { + logger.Debugf("Pad Contents: \n%s", atext.Text) + }) + + done := make(chan struct{}) + pad.On("disconnect", func(_ interface{}) { + close(done) + }) + <-done + } + + logger.Infof("Stopping CLI") +} + +func parseCLIArgs(args []string) (string, string, error) { + fs := flag.NewFlagSet("cli", flag.ContinueOnError) + host := fs.String("host", "", "The host of the pad (e.g. http://127.0.0.1:9001/p/test)") + appendStr := fs.String("append", "", "Append contents to pad") + fs.StringVar(appendStr, "a", "", "Append contents to pad (shorthand)") + + if len(args) > 0 && !strings.HasPrefix(args[0], "-") { + *host = args[0] + args = args[1:] + } + + err := fs.Parse(args) + return *host, *appendStr, err +} diff --git a/lib/cli/cli_test.go b/lib/cli/cli_test.go new file mode 100644 index 0000000..05adaf3 --- /dev/null +++ b/lib/cli/cli_test.go @@ -0,0 +1,55 @@ +package cli + +import ( + "testing" +) + +func TestParseCLIArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantHost string + wantAppend string + }{ + { + name: "no arguments", + args: []string{}, + wantHost: "", + wantAppend: "", + }, + { + name: "positional host", + args: []string{"http://test.com"}, + wantHost: "http://test.com", + wantAppend: "", + }, + { + name: "explicit flags", + args: []string{"-host", "http://test.com", "-append", "hello"}, + wantHost: "http://test.com", + wantAppend: "hello", + }, + { + name: "shorthand append", + args: []string{"http://test.com", "-a", "world"}, + wantHost: "http://test.com", + wantAppend: "world", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, appendStr, err := parseCLIArgs(tt.args) + if err != nil { + t.Errorf("parseCLIArgs() error = %v", err) + return + } + if host != tt.wantHost { + t.Errorf("host = %v, want %v", host, tt.wantHost) + } + if appendStr != tt.wantAppend { + t.Errorf("appendStr = %v, want %v", appendStr, tt.wantAppend) + } + }) + } +} diff --git a/lib/db/MySQLDB.go b/lib/db/MySQLDB.go index 907c0df..08ad602 100644 --- a/lib/db/MySQLDB.go +++ b/lib/db/MySQLDB.go @@ -726,10 +726,14 @@ func (d MysqlDB) SetAuthorByToken(token, authorId string) error { * @param {String} author The id of the author */ func (d MysqlDB) GetAuthor(author string) (*db.AuthorDB, error) { - - var resultedSQL, args, err = mysql.Select("*"). + var resultedSQL, args, err = mysql.Select("globalAuthor.*, padRev.id"). From("globalAuthor"). - Where(sq.Eq{"id": author}).ToSql() + LeftJoin("padRev ON globalAuthor.id = padRev.authorId"). + Where(sq.Eq{"globalAuthor.id": author}).ToSql() + + if err != nil { + return nil, err + } query, err := d.sqlDB.Query(resultedSQL, args...) if err != nil { @@ -738,14 +742,37 @@ func (d MysqlDB) GetAuthor(author string) (*db.AuthorDB, error) { defer query.Close() var authorDB *db.AuthorDB + for query.Next() { - var authorCopy db.AuthorDB - query.Scan(&authorCopy.ID, &authorCopy.ColorId, &authorCopy.Name, &authorCopy.Timestamp) - authorDB = &authorCopy - return authorDB, nil + var padID sql.NullString + + if authorDB == nil { + authorDB = &db.AuthorDB{ + PadIDs: make(map[string]struct{}), + } + err = query.Scan(&authorDB.ID, &authorDB.ColorId, &authorDB.Name, + &authorDB.Timestamp, &padID) + if err != nil { + return nil, err + } + } else { + var dummy1, dummy2, dummy3, dummy4 interface{} + err = query.Scan(&dummy1, &dummy2, &dummy3, &dummy4, &padID) + if err != nil { + return nil, err + } + } + + if padID.Valid { + authorDB.PadIDs[padID.String] = struct{}{} + } } - return nil, errors.New(AuthorNotFoundError) + if authorDB == nil { + return nil, errors.New(AuthorNotFoundError) + } + + return authorDB, nil } func (d MysqlDB) GetAuthorByToken(token string) (*string, error) { diff --git a/lib/db/PostgresDB.go b/lib/db/PostgresDB.go index b90daab..ac7a495 100644 --- a/lib/db/PostgresDB.go +++ b/lib/db/PostgresDB.go @@ -733,10 +733,14 @@ func (d PostgresDB) SetAuthorByToken(token, authorId string) error { * @param {String} author The id of the author */ func (d PostgresDB) GetAuthor(author string) (*db.AuthorDB, error) { - - var resultedSQL, args, err = psql.Select("*"). + var resultedSQL, args, err = psql.Select("globalAuthor.*, padRev.id"). From("globalAuthor"). - Where(sq.Eq{"id": author}).ToSql() + LeftJoin("padRev ON globalAuthor.id = padRev.authorId"). + Where(sq.Eq{"globalAuthor.id": author}).ToSql() + + if err != nil { + return nil, err + } query, err := d.sqlDB.Query(resultedSQL, args...) if err != nil { @@ -745,14 +749,37 @@ func (d PostgresDB) GetAuthor(author string) (*db.AuthorDB, error) { defer query.Close() var authorDB *db.AuthorDB + for query.Next() { - var authorCopy db.AuthorDB - query.Scan(&authorCopy.ID, &authorCopy.ColorId, &authorCopy.Name, &authorCopy.Timestamp) - authorDB = &authorCopy - return authorDB, nil + var padID sql.NullString + + if authorDB == nil { + authorDB = &db.AuthorDB{ + PadIDs: make(map[string]struct{}), + } + err = query.Scan(&authorDB.ID, &authorDB.ColorId, &authorDB.Name, + &authorDB.Timestamp, &padID) + if err != nil { + return nil, err + } + } else { + var dummy1, dummy2, dummy3, dummy4 interface{} + err = query.Scan(&dummy1, &dummy2, &dummy3, &dummy4, &padID) + if err != nil { + return nil, err + } + } + + if padID.Valid { + authorDB.PadIDs[padID.String] = struct{}{} + } } - return nil, errors.New(AuthorNotFoundError) + if authorDB == nil { + return nil, errors.New(AuthorNotFoundError) + } + + return authorDB, nil } func (d PostgresDB) GetAuthorByToken(token string) (*string, error) { diff --git a/lib/db/SQLiteDB.go b/lib/db/SQLiteDB.go index 80a6c54..6eb33cf 100644 --- a/lib/db/SQLiteDB.go +++ b/lib/db/SQLiteDB.go @@ -730,25 +730,53 @@ func (d SQLiteDB) SetAuthorByToken(token, authorId string) error { * @param {String} author The id of the author */ func (d SQLiteDB) GetAuthor(author string) (*db.AuthorDB, error) { - - var resultedSQL, args, err = sq.Select("*"). + var resultedSQL, args, err = sq.Select("globalAuthor.*, padRev.id"). From("globalAuthor"). - Where(sq.Eq{"id": author}).ToSql() + LeftJoin("padRev ON globalAuthor.id = padRev.authorId"). + Where(sq.Eq{"globalAuthor.id": author}).ToSql() + + if err != nil { + return nil, err + } query, err := d.sqlDB.Query(resultedSQL, args...) if err != nil { return nil, err } defer query.Close() + + var authorDB *db.AuthorDB + for query.Next() { - var authorDB *db.AuthorDB - var authorCopy db.AuthorDB - query.Scan(&authorCopy.ID, &authorCopy.ColorId, &authorCopy.Name, &authorCopy.Timestamp) - authorDB = &authorCopy - return authorDB, nil + var padID sql.NullString + + if authorDB == nil { + authorDB = &db.AuthorDB{ + PadIDs: make(map[string]struct{}), + } + err = query.Scan(&authorDB.ID, &authorDB.ColorId, &authorDB.Name, + &authorDB.Timestamp, &padID) + if err != nil { + return nil, err + } + } else { + var dummy1, dummy2, dummy3, dummy4 interface{} + err = query.Scan(&dummy1, &dummy2, &dummy3, &dummy4, &padID) + if err != nil { + return nil, err + } + } + + if padID.Valid { + authorDB.PadIDs[padID.String] = struct{}{} + } } - return nil, errors.New(AuthorNotFoundError) + if authorDB == nil { + return nil, errors.New(AuthorNotFoundError) + } + + return authorDB, nil } func (d SQLiteDB) GetAuthorByToken(token string) (*string, error) { diff --git a/lib/loadtest/app.go b/lib/loadtest/app.go new file mode 100644 index 0000000..aa6e65e --- /dev/null +++ b/lib/loadtest/app.go @@ -0,0 +1,296 @@ +package loadtest + +import ( + "fmt" + "math/rand" + "net/url" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "flag" + + "github.com/ether/etherpad-go/lib/apool" + "github.com/ether/etherpad-go/lib/cli" + "github.com/ether/etherpad-go/lib/utils" + "go.uber.org/zap" +) + +func RunFromCLI(logger *zap.SugaredLogger, args []string) { + host, authors, lurkers, duration, untilFail, err := parseRunArgs(args) + if err != nil { + return + } + StartLoadTest(logger, host, authors, lurkers, duration, untilFail) +} + +func parseRunArgs(args []string) (string, int, int, int, bool, error) { + fs := flag.NewFlagSet("loadtest", flag.ContinueOnError) + host := fs.String("host", "http://127.0.0.1:9001", "The host to test") + authors := fs.Int("authors", 0, "Number of authors") + lurkers := fs.Int("lurkers", 0, "Number of lurkers") + duration := fs.Int("duration", 0, "Duration of the test in seconds") + untilFail := fs.Bool("loadUntilFail", false, "Load until the server fails") + + if len(args) > 0 && !strings.HasPrefix(args[0], "-") { + *host = args[0] + args = args[1:] + } + + err := fs.Parse(args) + return *host, *authors, *lurkers, *duration, *untilFail, err +} + +func RunMultiFromCLI(logger *zap.SugaredLogger, args []string) { + host, maxPads, err := parseMultiRunArgs(args) + if err != nil { + return + } + StartMultiLoadTest(logger, host, maxPads) +} + +func parseMultiRunArgs(args []string) (string, int, error) { + fs := flag.NewFlagSet("multiload", flag.ContinueOnError) + host := fs.String("host", "http://127.0.0.1:9001", "The host to test") + maxPads := fs.Int("maxPads", 10, "Maximum number of pads") + + if len(args) > 0 && !strings.HasPrefix(args[0], "-") { + *host = args[0] + args = args[1:] + } + + err := fs.Parse(args) + return *host, *maxPads, err +} + +type Metrics struct { + ClientsConnected int64 + AuthorsConnected int64 + LurkersConnected int64 + AppendSent int64 + ErrorCount int64 + AcceptedCommit int64 + ChangeFromServer int64 + NumConnectedUsers int64 // From server + StartTime time.Time +} + +var stats Metrics +var maxPS float64 +var statsLock sync.Mutex + +func randomPadName() string { + const chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + const strLen = 10 + var b strings.Builder + for i := 0; i < strLen; i++ { + b.WriteByte(chars[rand.Intn(len(chars))]) + } + return b.String() +} + +func updateMetricsUI(host string) { + if os.Getenv("SILENT_METRICS") == "true" { + return + } + statsLock.Lock() + defer statsLock.Unlock() + + testDuration := time.Since(stats.StartTime) + + // Clear screen and move cursor to top-left + fmt.Print("\033[2J\033[0;0H") + fmt.Printf("Load Test Metrics -- Target Pad %s\n\n", host) + + if atomic.LoadInt64(&stats.NumConnectedUsers) > 0 { + fmt.Printf("Total Clients Connected: %d\n", atomic.LoadInt64(&stats.NumConnectedUsers)) + } + fmt.Printf("Local Clients Connected: %d\n", atomic.LoadInt64(&stats.ClientsConnected)) + fmt.Printf("Authors Connected: %d\n", atomic.LoadInt64(&stats.AuthorsConnected)) + fmt.Printf("Lurkers Connected: %d\n", atomic.LoadInt64(&stats.LurkersConnected)) + fmt.Printf("Sent Append messages: %d\n", atomic.LoadInt64(&stats.AppendSent)) + fmt.Printf("Errors: %d\n", atomic.LoadInt64(&stats.ErrorCount)) + fmt.Printf("Commits accepted by server: %d\n", atomic.LoadInt64(&stats.AcceptedCommit)) + + changesFromServer := atomic.LoadInt64(&stats.ChangeFromServer) + fmt.Printf("Commits sent from Server to Client: %d\n", changesFromServer) + + durationSec := testDuration.Seconds() + if durationSec > 0 { + currentRate := float64(changesFromServer) / durationSec // This is mean rate actually in this simple impl + fmt.Printf("Current rate per second of Commits sent from Server to Client: %.0f\n", currentRate) + fmt.Printf("Mean(per second) of # of Commits sent from Server to Client: %.0f\n", currentRate) + + if currentRate > maxPS { + maxPS = currentRate + } + fmt.Printf("Max(per second) of # of Messages (SocketIO has cap of 10k): %.0f\n", maxPS) + } + + diff := atomic.LoadInt64(&stats.AppendSent) - atomic.LoadInt64(&stats.AcceptedCommit) + if diff > 5 { + fmt.Printf("Number of commits not yet replied as ACCEPT_COMMIT from server: %d\n", diff) + } + + fmt.Printf("Seconds test has been running for: %d\n", int(durationSec)) +} + +func newAuthor(host string, logger *zap.SugaredLogger) { + pad := cli.Connect(host, logger) + + pad.OnDisconnect(func(err interface{}) { + fmt.Printf("connection error connecting to pad: %v\n", err) + os.Exit(1) + }) + + pad.OnConnected(func(p *cli.Pad) { + atomic.AddInt64(&stats.ClientsConnected, 1) + atomic.AddInt64(&stats.AuthorsConnected, 1) + updateMetricsUI(host) + + ticker := time.NewTicker(400 * time.Millisecond) + go func() { + for range ticker.C { + atomic.AddInt64(&stats.AppendSent, 1) + updateMetricsUI(host) + p.Append(utils.RandomString(10)) + } + }() + }) + + pad.OnNumConnectedUsers(func(count int) { + atomic.StoreInt64(&stats.NumConnectedUsers, int64(count)) + updateMetricsUI(host) + }) + + pad.OnMessage(func(msg map[string]interface{}) { + if msg["type"] == "COLLABROOM" { + if data, ok := msg["data"].(map[string]interface{}); ok { + if data["type"] == "ACCEPT_COMMIT" { + atomic.AddInt64(&stats.AcceptedCommit, 1) + } + } + } + }) + + pad.OnNewContents(func(atext apool.AText) { + atomic.AddInt64(&stats.ChangeFromServer, 1) + }) +} + +func newLurker(host string, logger *zap.SugaredLogger) { + pad := cli.Connect(host, logger) + + pad.OnDisconnect(func(err interface{}) { + fmt.Printf("connection error connecting to pad: %v\n", err) + os.Exit(1) + }) + + pad.OnConnected(func(p *cli.Pad) { + atomic.AddInt64(&stats.ClientsConnected, 1) + atomic.AddInt64(&stats.LurkersConnected, 1) + updateMetricsUI(host) + }) + + pad.OnNumConnectedUsers(func(count int) { + atomic.StoreInt64(&stats.NumConnectedUsers, int64(count)) + updateMetricsUI(host) + }) + + pad.OnNewContents(func(atext apool.AText) { + atomic.AddInt64(&stats.ChangeFromServer, 1) + }) +} + +func StartLoadTest(logger *zap.SugaredLogger, host string, numAuthors, numLurkers int, duration int, loadUntilFail bool) { + stats.StartTime = time.Now() + + if host == "" { + host = "http://127.0.0.1:9001" + } + + if !strings.Contains(host, "/p/") { + host = fmt.Sprintf("%s/p/%s", strings.TrimSuffix(host, "/"), randomPadName()) + } else { + // Ensure it's a valid URL + _, err := url.Parse(host) + if err != nil { + fmt.Printf("Invalid host: %v\n", err) + os.Exit(1) + } + } + + var endTime time.Time + if duration > 0 { + endTime = stats.StartTime.Add(time.Duration(duration) * time.Second) + } + + if numAuthors > 0 || numLurkers > 0 { + var users []string + for i := 0; i < numLurkers; i++ { + users = append(users, "l") + } + for i := 0; i < numAuthors; i++ { + users = append(users, "a") + } + + go func() { + for _, t := range users { + if t == "l" { + newLurker(host, logger) + } else { + newAuthor(host, logger) + } + time.Sleep(200 * time.Millisecond / time.Duration(len(users))) + } + }() + } else { + if duration > 0 { + fmt.Printf("Creating load for %d seconds\n", duration) + } else { + fmt.Println("Creating load until the pad server stops responding in a timely fashion") + } + + go func() { + // Loads at ratio of 3(lurkers):1(author), every 1 second it adds more. + users := []string{"a", "l", "l", "l"} + ticker := time.NewTicker(1 * time.Second) + for range ticker.C { + for _, t := range users { + if t == "l" { + newLurker(host, logger) + } else { + newAuthor(host, logger) + } + time.Sleep(200 * time.Millisecond / time.Duration(len(users))) + } + } + }() + } + + ticker := time.NewTicker(100 * time.Millisecond) + for range ticker.C { + if !endTime.IsZero() && time.Now().After(endTime) { + fmt.Println("Test duration complete and Load Tests PASS") + // Print final stats + fmt.Printf("%+v\n", stats) + if os.Getenv("GO_TEST_MODE") == "true" { + return + } + os.Exit(0) + } + + if loadUntilFail { + diff := atomic.LoadInt64(&stats.AppendSent) - atomic.LoadInt64(&stats.AcceptedCommit) + if diff > 100 { + fmt.Printf("Load test failed: too many pending commits (%d)\n", diff) + if os.Getenv("GO_TEST_MODE") == "true" { + return + } + os.Exit(1) + } + } + } +} diff --git a/lib/loadtest/app_test.go b/lib/loadtest/app_test.go new file mode 100644 index 0000000..fe7cbf1 --- /dev/null +++ b/lib/loadtest/app_test.go @@ -0,0 +1,120 @@ +package loadtest + +import ( + "testing" +) + +func TestParseRunArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantHost string + wantAuthors int + wantLurkers int + wantDuration int + wantUntilFail bool + }{ + { + name: "default values", + args: []string{}, + wantHost: "http://127.0.0.1:9001", + wantAuthors: 0, + wantLurkers: 0, + wantDuration: 0, + }, + { + name: "positional host", + args: []string{"http://test.com"}, + wantHost: "http://test.com", + wantAuthors: 0, + wantLurkers: 0, + wantDuration: 0, + }, + { + name: "explicit flags", + args: []string{"-host", "http://test.com", "-authors", "5", "-lurkers", "10", "-duration", "60", "-loadUntilFail"}, + wantHost: "http://test.com", + wantAuthors: 5, + wantLurkers: 10, + wantDuration: 60, + wantUntilFail: true, + }, + { + name: "positional host and flags", + args: []string{"http://pos.com", "-authors", "3"}, + wantHost: "http://pos.com", + wantAuthors: 3, + wantLurkers: 0, + wantDuration: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, authors, lurkers, duration, untilFail, err := parseRunArgs(tt.args) + if err != nil { + t.Errorf("parseRunArgs() error = %v", err) + return + } + if host != tt.wantHost { + t.Errorf("host = %v, want %v", host, tt.wantHost) + } + if authors != tt.wantAuthors { + t.Errorf("authors = %v, want %v", authors, tt.wantAuthors) + } + if lurkers != tt.wantLurkers { + t.Errorf("lurkers = %v, want %v", lurkers, tt.wantLurkers) + } + if duration != tt.wantDuration { + t.Errorf("duration = %v, want %v", duration, tt.wantDuration) + } + if untilFail != tt.wantUntilFail { + t.Errorf("untilFail = %v, want %v", untilFail, tt.wantUntilFail) + } + }) + } +} + +func TestParseMultiRunArgs(t *testing.T) { + tests := []struct { + name string + args []string + wantHost string + wantMaxPads int + }{ + { + name: "default values", + args: []string{}, + wantHost: "http://127.0.0.1:9001", + wantMaxPads: 10, + }, + { + name: "explicit flags", + args: []string{"-host", "http://test.com", "-maxPads", "20"}, + wantHost: "http://test.com", + wantMaxPads: 20, + }, + { + name: "positional host", + args: []string{"http://pos.com", "-maxPads", "5"}, + wantHost: "http://pos.com", + wantMaxPads: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + host, maxPads, err := parseMultiRunArgs(tt.args) + if err != nil { + t.Errorf("parseMultiRunArgs() error = %v", err) + return + } + if host != tt.wantHost { + t.Errorf("host = %v, want %v", host, tt.wantHost) + } + if maxPads != tt.wantMaxPads { + t.Errorf("maxPads = %v, want %v", maxPads, tt.wantMaxPads) + } + }) + } +} diff --git a/lib/loadtest/multi.go b/lib/loadtest/multi.go new file mode 100644 index 0000000..daea556 --- /dev/null +++ b/lib/loadtest/multi.go @@ -0,0 +1,51 @@ +package loadtest + +import ( + "fmt" + "os" + "os/exec" + "sync" + "time" + + "go.uber.org/zap" +) + +func StartMultiLoadTest(logger *zap.SugaredLogger, host string, maxPads int) { + if maxPads <= 0 { + maxPads = 10 + } + + fmt.Printf("Starting multi-pad load test: %d pads for 30 seconds each\n", maxPads) + + executable, err := os.Executable() + if err != nil { + logger.Errorf("Failed to get executable path: %v", err) + os.Exit(1) + } + + var wg sync.WaitGroup + messageCount := 0 + + for i := 0; i < maxPads; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + cmd := exec.Command(executable, "loadtest", "-host", host, "-authors", "3", "-duration", "30") + cmd.Env = append(os.Environ(), "SILENT_METRICS=true") + + output, err := cmd.CombinedOutput() + if err != nil { + fmt.Printf("Child process %d exited with error: %v\n", id, err) + fmt.Printf("Output: %s\n", string(output)) + fmt.Println("total pads made:", id) // Approximation + fmt.Println("total messages", messageCount) + os.Exit(1) + } + }(i) + + time.Sleep(100 * time.Millisecond) + } + + wg.Wait() + fmt.Println("Multi-pad load test completed successfully") +} diff --git a/lib/locales/locales.go b/lib/locales/locales.go index cf922e1..17574e5 100644 --- a/lib/locales/locales.go +++ b/lib/locales/locales.go @@ -22,10 +22,17 @@ func Init(initStore *lib.InitStore) { } fileName := file.Name() Locales[strings.Replace(fileName, ".json", "", -1)] = `locales/` + fileName - content, _ := fs.ReadFile(initStore.UiAssets, "./assets/locales/en.json") + content, err := fs.ReadFile(initStore.UiAssets, "assets/locales/en.json") + if err != nil { + initStore.Logger.Warnf("Could not read en.json: %v", err) + continue + } var enMap = make(map[string]string) - json.Unmarshal(content, &enMap) + if err := json.Unmarshal(content, &enMap); err != nil { + initStore.Logger.Warnf("Could not unmarshal en.json: %v", err) + continue + } Locales["en"] = enMap } } diff --git a/lib/models/ws/clientReady.go b/lib/models/ws/clientReady.go index 4ca57d7..47337bd 100644 --- a/lib/models/ws/clientReady.go +++ b/lib/models/ws/clientReady.go @@ -1,17 +1,21 @@ package ws type ClientReady struct { - Event string `json:"event"` - Data struct { - Component string `json:"component"` - Type string `json:"type"` - PadID string `json:"padId"` - Token string `json:"token"` - UserInfo struct { - ColorId *string `json:"colorId"` - Name *string `json:"name"` - } `json:"userInfo"` - Reconnect *bool `json:"reconnect"` - ClientRev *int `json:"client_rev"` - } `json:"data"` + Event string `json:"event"` + Data ClientReadyData `json:"data"` +} + +type ClientReadyData struct { + Component string `json:"component"` + Type string `json:"type"` + PadID string `json:"padId"` + Token string `json:"token"` + UserInfo ClientReadyUserInfo `json:"userInfo"` + Reconnect *bool `json:"reconnect"` + ClientRev *int `json:"client_rev"` +} + +type ClientReadyUserInfo struct { + ColorId *string `json:"colorId"` + Name *string `json:"name"` } diff --git a/lib/models/ws/userChange.go b/lib/models/ws/userChange.go index 95c1bb9..f4f84ca 100644 --- a/lib/models/ws/userChange.go +++ b/lib/models/ws/userChange.go @@ -1,17 +1,24 @@ package ws type UserChange struct { - Event string `json:"event"` - Data struct { - Component string `json:"component"` - Data struct { - Apool struct { - NumToAttrib map[int][]string `json:"numToAttrib"` - NextNum int `json:"nextNum"` - } `json:"apool"` - BaseRev int `json:"baseRev"` - Changeset string `json:"changeset"` - } `json:"data"` - Type string `json:"type"` - } `json:"data"` + Event string `json:"event"` + Data UserChangeData `json:"data"` +} + +type UserChangeData struct { + Component string `json:"component"` + Data UserChangeDataData `json:"data"` + Type string `json:"type"` +} + +type UserChangeDataDataApool struct { + NumToAttrib map[int][]string `json:"numToAttrib"` + NextNum int `json:"nextNum"` +} + +type UserChangeDataData struct { + Type string `json:"type"` + Apool UserChangeDataDataApool `json:"apool"` + BaseRev int `json:"baseRev"` + Changeset string `json:"changeset"` } diff --git a/lib/session/SessionDatabase.go b/lib/session/SessionDatabase.go index 7113802..22eb730 100644 --- a/lib/session/SessionDatabase.go +++ b/lib/session/SessionDatabase.go @@ -1,20 +1,19 @@ package session import ( - "github.com/ether/etherpad-go/lib/db" "time" + + "github.com/ether/etherpad-go/lib/db" ) type Database struct { } func (s Database) Get(key string) ([]byte, error) { - println(key) return nil, nil } func (s Database) Set(key string, val []byte, exp time.Duration) error { - println(key, val, exp) //TODO implement me return nil } diff --git a/lib/test/api/author/author_test.go b/lib/test/api/author/author_test.go index 830edd6..7a1f978 100644 --- a/lib/test/api/author/author_test.go +++ b/lib/test/api/author/author_test.go @@ -9,7 +9,6 @@ import ( "github.com/ether/etherpad-go/lib/api/author" "github.com/ether/etherpad-go/lib/test/testutils" - "github.com/gofiber/fiber/v2" "github.com/stretchr/testify/assert" ) @@ -108,15 +107,18 @@ func testGetExistingAuthor(t *testing.T, tsStore testutils.TestDataStore) { } func testGetAuthorPadIDS(t *testing.T, tsStore testutils.TestDataStore) { - t.Skip() - // Skip because we cannot yet map pads to authors - app := fiber.New() author.Init(tsStore.ToInitStore()) dbAuthorToSave := testutils.GenerateDBAuthor() assert.NoError(t, tsStore.DS.SaveAuthor(dbAuthorToSave)) + padText := "Hallo123\n" + _, err := tsStore.PadManager.GetPad("pad123", &padText, &dbAuthorToSave.ID) + assert.NoError(t, err) req := httptest.NewRequest("GET", "/author/"+dbAuthorToSave.ID+"/pads", nil) - resp, _ := app.Test(req, 10) + resp, err := tsStore.App.Test(req, 100) + if err != nil { + t.Errorf("error getting author pads: %v", err) + } if resp.StatusCode != 200 { t.Errorf("should return 200 for existing author pads, got %d", resp.StatusCode) } diff --git a/lib/test/testutils/general/stringUtils.go b/lib/test/testutils/general/stringUtils.go index e1c5299..c45badf 100644 --- a/lib/test/testutils/general/stringUtils.go +++ b/lib/test/testutils/general/stringUtils.go @@ -29,7 +29,7 @@ func RandomMultiline(approxMaxLines, approxMaxCols int) string { } func RandomInlineString(length int) string { - const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 !@#$%^&*()_+-=[]{}|;:,.<>?" + const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" var result strings.Builder result.Grow(length) diff --git a/lib/test/ws/pad_message_handler_test.go b/lib/test/ws/pad_message_handler_test.go index 509e0b3..32d60ba 100644 --- a/lib/test/ws/pad_message_handler_test.go +++ b/lib/test/ws/pad_message_handler_test.go @@ -1620,32 +1620,11 @@ func testHandleMessageUserChangeReadonly(t *testing.T, ds testutils.TestDataStor // Create USER_CHANGES message userChange := ws.UserChange{ Event: "message", - Data: struct { - Component string `json:"component"` - Data struct { - Apool struct { - NumToAttrib map[int][]string `json:"numToAttrib"` - NextNum int `json:"nextNum"` - } `json:"apool"` - BaseRev int `json:"baseRev"` - Changeset string `json:"changeset"` - } `json:"data"` - Type string `json:"type"` - }{ + Data: ws.UserChangeData{ Component: "pad", Type: "USER_CHANGES", - Data: struct { - Apool struct { - NumToAttrib map[int][]string `json:"numToAttrib"` - NextNum int `json:"nextNum"` - } `json:"apool"` - BaseRev int `json:"baseRev"` - Changeset string `json:"changeset"` - }{ - Apool: struct { - NumToAttrib map[int][]string `json:"numToAttrib"` - NextNum int `json:"nextNum"` - }{ + Data: ws.UserChangeDataData{ + Apool: ws.UserChangeDataDataApool{ NumToAttrib: map[int][]string{}, NextNum: 0, }, diff --git a/lib/utils/logger.go b/lib/utils/logger.go index d8f1376..99b0762 100644 --- a/lib/utils/logger.go +++ b/lib/utils/logger.go @@ -1,10 +1,14 @@ package utils -import "go.uber.org/zap" +import ( + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) func SetupLogger() *zap.SugaredLogger { - logger, _ := zap.NewProduction() - logger = zap.Must(zap.NewDevelopment()) + cfg := zap.NewDevelopmentConfig() + cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel) + logger := zap.Must(cfg.Build()) sugar := logger.Sugar() return sugar diff --git a/lib/utils/stringUtils.go b/lib/utils/stringUtils.go index ac3f533..cef4a90 100644 --- a/lib/utils/stringUtils.go +++ b/lib/utils/stringUtils.go @@ -1,17 +1,15 @@ package utils import ( - randc "crypto/rand" - "encoding/hex" "math/big" "strconv" "strings" + + "github.com/ether/etherpad-go/lib/test/testutils/general" ) func RandomString(length int) string { - bytes := make([]byte, length) - randc.Read(bytes) - return hex.EncodeToString(bytes) + return general.RandomInlineString(length) } func NumToString(num int) string { diff --git a/lib/ws/AdminMessageHandler.go b/lib/ws/AdminMessageHandler.go index 3dd3293..33fc6e1 100644 --- a/lib/ws/AdminMessageHandler.go +++ b/lib/ws/AdminMessageHandler.go @@ -188,9 +188,11 @@ func (h AdminMessageHandler) HandleMessage(message admin.EventMessage, retrieved return } + h.hub.ClientsRWMutex.RLock() for key := range h.hub.Clients { key.SafeSend(responseBytes) } + h.hub.ClientsRWMutex.RUnlock() } case "deletePad": diff --git a/lib/ws/PadMessageHandler.go b/lib/ws/PadMessageHandler.go index 7cccd25..a07950e 100644 --- a/lib/ws/PadMessageHandler.go +++ b/lib/ws/PadMessageHandler.go @@ -8,6 +8,7 @@ import ( "regexp" "slices" "strconv" + "sync" "time" "unicode/utf8" @@ -56,6 +57,7 @@ type Task struct { type ChannelOperator struct { channels map[string]chan Task handler *PadMessageHandler + mu sync.Mutex } func NewChannelOperator(p *PadMessageHandler) ChannelOperator { @@ -66,19 +68,21 @@ func NewChannelOperator(p *PadMessageHandler) ChannelOperator { } func (c *ChannelOperator) AddToQueue(ch string, t Task) { - var _, ok = c.channels[ch] - + c.mu.Lock() + chChan, ok := c.channels[ch] if !ok { - c.channels[ch] = make(chan Task) - go func() { - for { - var incomingTask = <-c.channels[ch] + // small buffer to decouple producer from goroutine scheduling + chChan = make(chan Task, 1) + c.channels[ch] = chChan + go func(localCh chan Task) { + for incomingTask := range localCh { c.handler.handleUserChanges(incomingTask) } - }() + }(chChan) } + c.mu.Unlock() - c.channels[ch] <- t + chChan <- t } type PadMessageHandler struct { @@ -199,7 +203,7 @@ func (p *PadMessageHandler) handleUserChanges(task Task) { // and can be applied after "c". optRebasedChangeset, err := changeset.Follow(revisionPad.Changeset, rebasedChangeset, false, &retrievedPad.Pool) if err != nil { - p.Logger.Warnf("Error rebasing changeset at rev %d: %v", r, err) + p.Logger.Warnf("Error rebasing changeset at rev %d: %v for %s", r, err, retrievedPad.Id) return } rebasedChangeset = *optRebasedChangeset @@ -218,7 +222,6 @@ func (p *PadMessageHandler) handleUserChanges(task Task) { if *oldLen != utf8.RuneCountInString(prevText) { p.Logger.Warnf("Can't apply changeset to pad text: oldLen=%d, prevTextLen=%d, baseRev=%d, headRev=%d", *oldLen, utf8.RuneCountInString(prevText), r, retrievedPad.Head) - // Don't panic - just return and let the client retry or reconnect return } @@ -1347,16 +1350,19 @@ func (p *PadMessageHandler) UpdatePadClients(pad *pad2.Pad) { func (p *PadMessageHandler) GetRoomSockets(padID string) []Client { var sockets = make([]Client, 0) + p.hub.ClientsRWMutex.RLock() for k := range p.hub.Clients { sessId := p.SessionStore.getSession(k.SessionId) if sessId != nil && sessId.PadId == padID { sockets = append(sockets, *k) } } + p.hub.ClientsRWMutex.RUnlock() return sockets } func (p *PadMessageHandler) KickSessionsFromPad(padID string) { + p.hub.ClientsRWMutex.RLock() for k := range p.hub.Clients { if k == nil || k.SessionId == "" { continue @@ -1370,4 +1376,5 @@ func (p *PadMessageHandler) KickSessionsFromPad(padID string) { k.SendPadDelete() } } + p.hub.ClientsRWMutex.RUnlock() } diff --git a/lib/ws/client.go b/lib/ws/client.go index fe754ea..5fa5b3e 100644 --- a/lib/ws/client.go +++ b/lib/ws/client.go @@ -26,6 +26,9 @@ import ( var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, } var ( @@ -114,6 +117,7 @@ func (c *Client) writePump() { // ensures that there is at most one reader on a connection by executing all // reads from this goroutine. func (c *Client) readPump(retrievedSettings *settings.Settings, logger *zap.SugaredLogger) { + c.Hub.Register <- c defer func() { c.Hub.Unregister <- c c.Conn.Close() diff --git a/lib/ws/hub.go b/lib/ws/hub.go index df63ef4..0e84bdd 100644 --- a/lib/ws/hub.go +++ b/lib/ws/hub.go @@ -1,10 +1,13 @@ package ws +import "sync" + // Hub maintains the set of active Clients and broadcasts messages to the // Clients. type Hub struct { // Registered Clients. - Clients map[*Client]bool + Clients map[*Client]bool + ClientsRWMutex sync.RWMutex // Inbound messages from the Clients. Broadcast chan []byte @@ -29,16 +32,21 @@ func (h *Hub) Run() { for { select { case client := <-h.Register: + h.ClientsRWMutex.Lock() h.Clients[client] = true + h.ClientsRWMutex.Unlock() case client := <-h.Unregister: if client == nil { continue } + h.ClientsRWMutex.Lock() if _, ok := h.Clients[client]; ok { delete(h.Clients, client) close(client.Send) } + h.ClientsRWMutex.Unlock() case message := <-h.Broadcast: + h.ClientsRWMutex.RLock() for client := range h.Clients { if client == nil { continue @@ -50,6 +58,7 @@ func (h *Hub) Run() { delete(h.Clients, client) } } + h.ClientsRWMutex.RUnlock() } } } diff --git a/main.go b/main.go index fb64432..f27785a 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,9 @@ import ( "github.com/ether/etherpad-go/lib" api2 "github.com/ether/etherpad-go/lib/api" "github.com/ether/etherpad-go/lib/author" + "github.com/ether/etherpad-go/lib/cli" "github.com/ether/etherpad-go/lib/hooks" + "github.com/ether/etherpad-go/lib/loadtest" "github.com/ether/etherpad-go/lib/pad" "github.com/ether/etherpad-go/lib/plugins" session2 "github.com/ether/etherpad-go/lib/session" @@ -41,6 +43,29 @@ var uiAssets embed.FS func main() { setupLogger := utils.SetupLogger() defer setupLogger.Sync() + + if len(os.Args) > 1 { + switch os.Args[1] { + case "cli": + cli.RunFromCLI(setupLogger, os.Args[2:]) + return + case "loadtest": + loadtest.RunFromCLI(setupLogger, os.Args[2:]) + return + case "multiload": + loadtest.RunMultiFromCLI(setupLogger, os.Args[2:]) + return + case "-h", "--help", "help": + fmt.Println("Usage: etherpad [command] [options]") + fmt.Println("Commands:") + fmt.Println(" cli Interactive CLI for pads") + fmt.Println(" loadtest Run a load test on a single pad") + fmt.Println(" multiload Run a multi-pad load test") + fmt.Println(" (none) Start the Etherpad server") + return + } + } + settings2.InitSettings(setupLogger) var settings = settings2.Displayed