From f2a28a81cfb5b7110abe6f6df7e995eec3c3f1e9 Mon Sep 17 00:00:00 2001 From: Feng Ren Date: Wed, 15 Jan 2025 08:15:07 +0000 Subject: [PATCH] check data integrity after getting replica --- .../src/example/p2p-store-example.go | 56 +++++--- mooncake-p2p-store/src/p2pstore/core.go | 123 +++++++++++++----- mooncake-p2p-store/src/p2pstore/metadata.go | 6 +- 3 files changed, 132 insertions(+), 53 deletions(-) diff --git a/mooncake-p2p-store/src/example/p2p-store-example.go b/mooncake-p2p-store/src/example/p2p-store-example.go index eeb5cdb..122ad8f 100644 --- a/mooncake-p2p-store/src/example/p2p-store-example.go +++ b/mooncake-p2p-store/src/example/p2p-store-example.go @@ -34,6 +34,7 @@ var ( deviceName string nicPriorityMatrixPath string fileSize int + fileSizeMB int ) func main() { @@ -42,10 +43,11 @@ func main() { flag.StringVar(&localServerName, "local_server_name", "", "Local server name") flag.StringVar(&deviceName, "device_name", "mlx5_2", "RNIC device name") flag.StringVar(&nicPriorityMatrixPath, "nic_priority_matrix", "", "Path to NIC priority matrix file (Advanced)") - flag.IntVar(&fileSize, "file_size_mb", 2048, "File size in MB") + flag.IntVar(&fileSizeMB, "file_size_mb", 2048, "File size in MB") flag.Parse() - fileSize = fileSize * 1024 * 1024 + fileSize = fileSizeMB * 1024 * 1024 + if len(localServerName) == 0 { var err error localServerName, err = os.Hostname() @@ -61,7 +63,7 @@ func main() { case "inferencer": inferencer() default: - fmt.Printf("Invalid command: %s\n", command) + fmt.Printf("You must specify a command, either 'trainer' or 'inferencer'\n") os.Exit(1) } } @@ -73,18 +75,30 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) { os.Exit(1) } - fmt.Println("After training, register new object:", name, "file size:", fileSize) + fmt.Printf("Object registration: name %s base address %x file size %d MB\n", + name, + uintptr(unsafe.Pointer(&addr[0])), + fileSizeMB) + startTimestamp := time.Now() addrList := []uintptr{uintptr(unsafe.Pointer(&addr[0]))} sizeList := []uint64{uint64(fileSize)} - err = store.Register(ctx, name, addrList, sizeList, 64*1024*1024, "cpu:0") + + const MAX_SHARD_SIZE uint64 = 64 * 1024 * 1024 + const MEMORY_LOCATION string = "cpu:0" + + err = store.Register(ctx, name, addrList, sizeList, MAX_SHARD_SIZE, MEMORY_LOCATION, true) if err != nil { - fmt.Fprintf(os.Stderr, "Register failed: %v\n", err) + fmt.Fprintf(os.Stderr, "Object registration failed: %v\n", err) os.Exit(1) } phaseOneTimestamp := time.Now() - fmt.Println("Register done, duration (ms):", phaseOneTimestamp.Sub(startTimestamp).Milliseconds()) + duration := phaseOneTimestamp.Sub(startTimestamp).Milliseconds() + + fmt.Printf("Object registration done: duration (ms) %d throughput (GB/s) %.2f\n", + duration, + float64(fileSizeMB)/float64(duration)) checkpointInfoList, err := store.List(ctx, "foo") if err != nil { @@ -93,7 +107,7 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) { } fmt.Println(checkpointInfoList) - fmt.Println("Idle for 100 seconds") + fmt.Println("Idle for 100 seconds, now you can start another terminal to simulate inference") time.Sleep(100 * time.Second) err = store.Unregister(ctx, name) @@ -109,12 +123,13 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) { } func trainer() { + fmt.Println("Simulated training process started") ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Second) defer cancel() store, err := p2pstore.NewP2PStore(metadataServer, localServerName, getPriorityMatrix()) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating checkpoint engine: %v\n", err) + fmt.Fprintf(os.Stderr, "P2PStore: initialization failed: %v\n", err) os.Exit(1) } @@ -122,11 +137,11 @@ func trainer() { err = store.Close() if err != nil { - fmt.Fprintf(os.Stderr, "Shutdown failed: %v\n", err) + fmt.Fprintf(os.Stderr, "P2PStore: close failed: %v\n", err) os.Exit(1) } - fmt.Println("ALL DONE") + fmt.Println("Simulated training process stopped gracefully") } func getPriorityMatrix() string { @@ -149,18 +164,22 @@ func doInferencer(ctx context.Context, store *p2pstore.P2PStore, name string) { os.Exit(1) } - fmt.Println("Expecting to retrieve from object", name) + fmt.Println("Object retrival started: name", name) startTimestamp := time.Now() addrList := []uintptr{uintptr(unsafe.Pointer(&addr[0]))} sizeList := []uint64{uint64(fileSize)} err = store.GetReplica(ctx, name, addrList, sizeList) if err != nil { - fmt.Fprintf(os.Stderr, "GetLocalCheckpoint failed: %v\n", err) + fmt.Fprintf(os.Stderr, "Object retrival failed: %v\n", err) os.Exit(1) } phaseOneTimestamp := time.Now() - fmt.Println("GetReplica done, duration (ms):", phaseOneTimestamp.Sub(startTimestamp).Milliseconds()) + duration := phaseOneTimestamp.Sub(startTimestamp).Milliseconds() + + fmt.Printf("Object retrival done: duration (ms) %d throughput (GB/s) %.2f\n", + duration, + float64(fileSizeMB)/float64(duration)) err = store.DeleteReplica(ctx, name) if err != nil { @@ -175,21 +194,24 @@ func doInferencer(ctx context.Context, store *p2pstore.P2PStore, name string) { } func inferencer() { + fmt.Println("Simulated inference process started") + ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Second) defer cancel() store, err := p2pstore.NewP2PStore(metadataServer, localServerName, getPriorityMatrix()) if err != nil { - fmt.Fprintf(os.Stderr, "Error creating checkpoint engine: %v\n", err) + fmt.Fprintf(os.Stderr, "P2PStore: initialization failed: %v\n", err) os.Exit(1) } doInferencer(ctx, store, "foo/bar") + err = store.Close() if err != nil { - fmt.Fprintf(os.Stderr, "Shutdown failed: %v\n", err) + fmt.Fprintf(os.Stderr, "P2PStore: close failed: %v\n", err) os.Exit(1) } - fmt.Println("ALL DONE") + fmt.Println("Simulated inference process stopped gracefully") } diff --git a/mooncake-p2p-store/src/p2pstore/core.go b/mooncake-p2p-store/src/p2pstore/core.go index fae267e..e1ddf7e 100644 --- a/mooncake-p2p-store/src/p2pstore/core.go +++ b/mooncake-p2p-store/src/p2pstore/core.go @@ -17,9 +17,9 @@ package p2pstore import ( "context" "log" - "sync" "net" "strconv" + "sync" ) // When the data size larger than MAX_CHUNK_SIZE bytes, we split them into multiple buffers and registered seperately. @@ -32,7 +32,7 @@ const METADATA_KEY_PREFIX string = "mooncake/checkpoint/" type P2PStore struct { metadataConnString string - localServerName string + localServerName string catalog *Catalog memory *RegisteredMemory metadata *Metadata @@ -68,9 +68,9 @@ func NewP2PStore(metadataConnString string, localServerName string, nicPriorityM } if len(nicPriorityMatrix) == 0 { - err = transfer.installTransport("tcp", nicPriorityMatrix); + err = transfer.installTransport("tcp", nicPriorityMatrix) } else { - err = transfer.installTransport("rdma", nicPriorityMatrix); + err = transfer.installTransport("rdma", nicPriorityMatrix) } if err != nil { metadata.Close() @@ -79,7 +79,7 @@ func NewP2PStore(metadataConnString string, localServerName string, nicPriorityM store := &P2PStore{ metadataConnString: metadataConnString, - localServerName: localServerName, + localServerName: localServerName, catalog: NewCatalog(), memory: NewRegisteredMemory(transfer, MAX_CHUNK_SIZE), metadata: metadata, @@ -112,7 +112,13 @@ func (store *P2PStore) unregisterBuffers(bufferList []Buffer, maxShardSize uint6 } } -func (store *P2PStore) Register(ctx context.Context, name string, addrList []uintptr, sizeList []uint64, maxShardSize uint64, location string) error { +func (store *P2PStore) Register(ctx context.Context, + name string, + addrList []uintptr, + sizeList []uint64, + maxShardSize uint64, + location string, + forceCreate bool) error { if len(addrList) != len(sizeList) || len(addrList) == 0 { return ErrInvalidArgument } @@ -155,7 +161,13 @@ func (store *P2PStore) Register(ctx context.Context, name string, addrList []uin } } - err := store.metadata.Put(ctx, name, &payload) + var err error + if forceCreate { + err = store.metadata.Put(ctx, name, &payload) + } else { + err = store.metadata.Create(ctx, name, &payload) + } + if err != nil { store.unregisterBuffers(bufferList, maxShardSize) return err @@ -210,10 +222,10 @@ func (store *P2PStore) Unregister(ctx context.Context, name string) error { } type PayloadInfo struct { - Name string // Full name of checkpoint file - MaxShardSize uint64 // - TotalSize uint64 // - SizeList []uint64 // + Name string + MaxShardSize uint64 + TotalSize uint64 + SizeList []uint64 } func (store *P2PStore) List(ctx context.Context, namePrefix string) ([]PayloadInfo, error) { @@ -234,25 +246,7 @@ func (store *P2PStore) List(ctx context.Context, namePrefix string) ([]PayloadIn return result, nil } -// Get replica for same name multiple times in one P2P store will return ErrPayloadOpened -func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []uintptr, sizeList []uint64) error { - if len(addrList) != len(sizeList) || len(addrList) == 0 { - return ErrInvalidArgument - } - - if store.catalog.Contains(name) { - return ErrPayloadOpened - } - - payload, revision, err := store.metadata.Get(ctx, name) - if err != nil { - return err - } - - if payload == nil { - return ErrPayloadNotFound - } - +func (store *P2PStore) doGetReplica(payload *Payload, addrList []uintptr, sizeList []uint64) error { var wg sync.WaitGroup errChan := make(chan error, 1) @@ -260,8 +254,6 @@ func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []u taskID := 0 maxShardSize := payload.MaxShardSize - _ = store.transfer.syncSegmentCache() - for i := 0; i < len(addrList); i++ { addr, size := addrList[i], sizeList[i] err := store.memory.Add(addr, size, maxShardSize, "cpu:0") @@ -295,15 +287,76 @@ func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []u } default: } + return nil +} +func contains(slice []Location, value Location) bool { + for _, item := range slice { + if item == value { + return true + } + } + return false +} + +func isSubsetOf(old *Payload, new *Payload) bool { + if len(old.Shards) != len(new.Shards) { + return false + } + for i := 0; i < len(old.Shards); i += 1 { + for _, value := range old.Shards[i].Gold { + if !contains(new.Shards[i].Gold, value) { + return false + } + } + for _, value := range old.Shards[i].ReplicaList { + if !contains(new.Shards[i].ReplicaList, value) { + return false + } + } + } + return true +} + +func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []uintptr, sizeList []uint64) error { + if len(addrList) != len(sizeList) || len(addrList) == 0 { + return ErrInvalidArgument + } + + if store.catalog.Contains(name) { + return ErrPayloadOpened + } + + payload, revision, err := store.metadata.Get(ctx, name) + if err != nil { + return err + } + if payload == nil { + return ErrPayloadNotFound + } + for { + _ = store.transfer.syncSegmentCache() + err = store.doGetReplica(payload, addrList, sizeList) + if err != nil { + return err + } + newPayload, recheckRevision, err := store.metadata.Get(ctx, name) + if err != nil { + return err + } + if revision == recheckRevision { + break + } + if isSubsetOf(payload, newPayload) { + break + } + } return store.updatePayloadMetadata(ctx, name, addrList, sizeList, payload, revision) } func (store *P2PStore) performTransfer(source uintptr, shard Shard) error { - const MAX_RETRY_COUNT int = 8 retryCount := 0 - - for retryCount < MAX_RETRY_COUNT { + for retryCount < shard.Count() { batchID, err := store.transfer.allocateBatchID(1) if err != nil { return err diff --git a/mooncake-p2p-store/src/p2pstore/metadata.go b/mooncake-p2p-store/src/p2pstore/metadata.go index 6095664..2805773 100644 --- a/mooncake-p2p-store/src/p2pstore/metadata.go +++ b/mooncake-p2p-store/src/p2pstore/metadata.go @@ -21,7 +21,7 @@ import ( "math/rand" "time" - "go.etcd.io/etcd/client/v3" + clientv3 "go.etcd.io/etcd/client/v3" ) // key = payload_name @@ -70,6 +70,10 @@ func (s *Shard) GetLocation(retryTimes int) *Location { } } +func (s *Shard) Count() int { + return len(s.ReplicaList) + len(s.Gold) +} + func (s *Shard) getRandomLocation() *Location { r := rand.New(rand.NewSource(time.Now().UnixNano())) if len(s.ReplicaList) > 0 {