Skip to content

Commit

Permalink
check data integrity after getting replica
Browse files Browse the repository at this point in the history
  • Loading branch information
alogfans committed Jan 15, 2025
1 parent 683e75f commit f2a28a8
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 53 deletions.
56 changes: 39 additions & 17 deletions mooncake-p2p-store/src/example/p2p-store-example.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
deviceName string
nicPriorityMatrixPath string
fileSize int
fileSizeMB int
)

func main() {
Expand All @@ -42,10 +43,11 @@ func main() {
flag.StringVar(&localServerName, "local_server_name", "", "Local server name")
flag.StringVar(&deviceName, "device_name", "mlx5_2", "RNIC device name")
flag.StringVar(&nicPriorityMatrixPath, "nic_priority_matrix", "", "Path to NIC priority matrix file (Advanced)")
flag.IntVar(&fileSize, "file_size_mb", 2048, "File size in MB")
flag.IntVar(&fileSizeMB, "file_size_mb", 2048, "File size in MB")
flag.Parse()

fileSize = fileSize * 1024 * 1024
fileSize = fileSizeMB * 1024 * 1024

if len(localServerName) == 0 {
var err error
localServerName, err = os.Hostname()
Expand All @@ -61,7 +63,7 @@ func main() {
case "inferencer":
inferencer()
default:
fmt.Printf("Invalid command: %s\n", command)
fmt.Printf("You must specify a command, either 'trainer' or 'inferencer'\n")
os.Exit(1)
}
}
Expand All @@ -73,18 +75,30 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) {
os.Exit(1)
}

fmt.Println("After training, register new object:", name, "file size:", fileSize)
fmt.Printf("Object registration: name %s base address %x file size %d MB\n",
name,
uintptr(unsafe.Pointer(&addr[0])),
fileSizeMB)

startTimestamp := time.Now()
addrList := []uintptr{uintptr(unsafe.Pointer(&addr[0]))}
sizeList := []uint64{uint64(fileSize)}
err = store.Register(ctx, name, addrList, sizeList, 64*1024*1024, "cpu:0")

const MAX_SHARD_SIZE uint64 = 64 * 1024 * 1024
const MEMORY_LOCATION string = "cpu:0"

err = store.Register(ctx, name, addrList, sizeList, MAX_SHARD_SIZE, MEMORY_LOCATION, true)
if err != nil {
fmt.Fprintf(os.Stderr, "Register failed: %v\n", err)
fmt.Fprintf(os.Stderr, "Object registration failed: %v\n", err)
os.Exit(1)
}

phaseOneTimestamp := time.Now()
fmt.Println("Register done, duration (ms):", phaseOneTimestamp.Sub(startTimestamp).Milliseconds())
duration := phaseOneTimestamp.Sub(startTimestamp).Milliseconds()

fmt.Printf("Object registration done: duration (ms) %d throughput (GB/s) %.2f\n",
duration,
float64(fileSizeMB)/float64(duration))

checkpointInfoList, err := store.List(ctx, "foo")
if err != nil {
Expand All @@ -93,7 +107,7 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) {
}

fmt.Println(checkpointInfoList)
fmt.Println("Idle for 100 seconds")
fmt.Println("Idle for 100 seconds, now you can start another terminal to simulate inference")
time.Sleep(100 * time.Second)

err = store.Unregister(ctx, name)
Expand All @@ -109,24 +123,25 @@ func doTrainer(ctx context.Context, store *p2pstore.P2PStore, name string) {
}

func trainer() {
fmt.Println("Simulated training process started")
ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Second)
defer cancel()

store, err := p2pstore.NewP2PStore(metadataServer, localServerName, getPriorityMatrix())
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating checkpoint engine: %v\n", err)
fmt.Fprintf(os.Stderr, "P2PStore: initialization failed: %v\n", err)
os.Exit(1)
}

doTrainer(ctx, store, "foo/bar")

err = store.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "Shutdown failed: %v\n", err)
fmt.Fprintf(os.Stderr, "P2PStore: close failed: %v\n", err)
os.Exit(1)
}

fmt.Println("ALL DONE")
fmt.Println("Simulated training process stopped gracefully")
}

func getPriorityMatrix() string {
Expand All @@ -149,18 +164,22 @@ func doInferencer(ctx context.Context, store *p2pstore.P2PStore, name string) {
os.Exit(1)
}

fmt.Println("Expecting to retrieve from object", name)
fmt.Println("Object retrival started: name", name)
startTimestamp := time.Now()
addrList := []uintptr{uintptr(unsafe.Pointer(&addr[0]))}
sizeList := []uint64{uint64(fileSize)}
err = store.GetReplica(ctx, name, addrList, sizeList)
if err != nil {
fmt.Fprintf(os.Stderr, "GetLocalCheckpoint failed: %v\n", err)
fmt.Fprintf(os.Stderr, "Object retrival failed: %v\n", err)
os.Exit(1)
}

phaseOneTimestamp := time.Now()
fmt.Println("GetReplica done, duration (ms):", phaseOneTimestamp.Sub(startTimestamp).Milliseconds())
duration := phaseOneTimestamp.Sub(startTimestamp).Milliseconds()

fmt.Printf("Object retrival done: duration (ms) %d throughput (GB/s) %.2f\n",
duration,
float64(fileSizeMB)/float64(duration))

err = store.DeleteReplica(ctx, name)
if err != nil {
Expand All @@ -175,21 +194,24 @@ func doInferencer(ctx context.Context, store *p2pstore.P2PStore, name string) {
}

func inferencer() {
fmt.Println("Simulated inference process started")

ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Second)
defer cancel()

store, err := p2pstore.NewP2PStore(metadataServer, localServerName, getPriorityMatrix())
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating checkpoint engine: %v\n", err)
fmt.Fprintf(os.Stderr, "P2PStore: initialization failed: %v\n", err)
os.Exit(1)
}

doInferencer(ctx, store, "foo/bar")

err = store.Close()
if err != nil {
fmt.Fprintf(os.Stderr, "Shutdown failed: %v\n", err)
fmt.Fprintf(os.Stderr, "P2PStore: close failed: %v\n", err)
os.Exit(1)
}

fmt.Println("ALL DONE")
fmt.Println("Simulated inference process stopped gracefully")
}
123 changes: 88 additions & 35 deletions mooncake-p2p-store/src/p2pstore/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ package p2pstore
import (
"context"
"log"
"sync"
"net"
"strconv"
"sync"
)

// When the data size larger than MAX_CHUNK_SIZE bytes, we split them into multiple buffers and registered seperately.
Expand All @@ -32,7 +32,7 @@ const METADATA_KEY_PREFIX string = "mooncake/checkpoint/"

type P2PStore struct {
metadataConnString string
localServerName string
localServerName string
catalog *Catalog
memory *RegisteredMemory
metadata *Metadata
Expand Down Expand Up @@ -68,9 +68,9 @@ func NewP2PStore(metadataConnString string, localServerName string, nicPriorityM
}

if len(nicPriorityMatrix) == 0 {
err = transfer.installTransport("tcp", nicPriorityMatrix);
err = transfer.installTransport("tcp", nicPriorityMatrix)
} else {
err = transfer.installTransport("rdma", nicPriorityMatrix);
err = transfer.installTransport("rdma", nicPriorityMatrix)
}
if err != nil {
metadata.Close()
Expand All @@ -79,7 +79,7 @@ func NewP2PStore(metadataConnString string, localServerName string, nicPriorityM

store := &P2PStore{
metadataConnString: metadataConnString,
localServerName: localServerName,
localServerName: localServerName,
catalog: NewCatalog(),
memory: NewRegisteredMemory(transfer, MAX_CHUNK_SIZE),
metadata: metadata,
Expand Down Expand Up @@ -112,7 +112,13 @@ func (store *P2PStore) unregisterBuffers(bufferList []Buffer, maxShardSize uint6
}
}

func (store *P2PStore) Register(ctx context.Context, name string, addrList []uintptr, sizeList []uint64, maxShardSize uint64, location string) error {
func (store *P2PStore) Register(ctx context.Context,
name string,
addrList []uintptr,
sizeList []uint64,
maxShardSize uint64,
location string,
forceCreate bool) error {
if len(addrList) != len(sizeList) || len(addrList) == 0 {
return ErrInvalidArgument
}
Expand Down Expand Up @@ -155,7 +161,13 @@ func (store *P2PStore) Register(ctx context.Context, name string, addrList []uin
}
}

err := store.metadata.Put(ctx, name, &payload)
var err error
if forceCreate {
err = store.metadata.Put(ctx, name, &payload)
} else {
err = store.metadata.Create(ctx, name, &payload)
}

if err != nil {
store.unregisterBuffers(bufferList, maxShardSize)
return err
Expand Down Expand Up @@ -210,10 +222,10 @@ func (store *P2PStore) Unregister(ctx context.Context, name string) error {
}

type PayloadInfo struct {
Name string // Full name of checkpoint file
MaxShardSize uint64 //
TotalSize uint64 //
SizeList []uint64 //
Name string
MaxShardSize uint64
TotalSize uint64
SizeList []uint64
}

func (store *P2PStore) List(ctx context.Context, namePrefix string) ([]PayloadInfo, error) {
Expand All @@ -234,34 +246,14 @@ func (store *P2PStore) List(ctx context.Context, namePrefix string) ([]PayloadIn
return result, nil
}

// Get replica for same name multiple times in one P2P store will return ErrPayloadOpened
func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []uintptr, sizeList []uint64) error {
if len(addrList) != len(sizeList) || len(addrList) == 0 {
return ErrInvalidArgument
}

if store.catalog.Contains(name) {
return ErrPayloadOpened
}

payload, revision, err := store.metadata.Get(ctx, name)
if err != nil {
return err
}

if payload == nil {
return ErrPayloadNotFound
}

func (store *P2PStore) doGetReplica(payload *Payload, addrList []uintptr, sizeList []uint64) error {
var wg sync.WaitGroup
errChan := make(chan error, 1)

var offset uint64 = 0
taskID := 0
maxShardSize := payload.MaxShardSize

_ = store.transfer.syncSegmentCache()

for i := 0; i < len(addrList); i++ {
addr, size := addrList[i], sizeList[i]
err := store.memory.Add(addr, size, maxShardSize, "cpu:0")
Expand Down Expand Up @@ -295,15 +287,76 @@ func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []u
}
default:
}
return nil
}

func contains(slice []Location, value Location) bool {
for _, item := range slice {
if item == value {
return true
}
}
return false
}

func isSubsetOf(old *Payload, new *Payload) bool {
if len(old.Shards) != len(new.Shards) {
return false
}
for i := 0; i < len(old.Shards); i += 1 {
for _, value := range old.Shards[i].Gold {
if !contains(new.Shards[i].Gold, value) {
return false
}
}
for _, value := range old.Shards[i].ReplicaList {
if !contains(new.Shards[i].ReplicaList, value) {
return false
}
}
}
return true
}

func (store *P2PStore) GetReplica(ctx context.Context, name string, addrList []uintptr, sizeList []uint64) error {
if len(addrList) != len(sizeList) || len(addrList) == 0 {
return ErrInvalidArgument
}

if store.catalog.Contains(name) {
return ErrPayloadOpened
}

payload, revision, err := store.metadata.Get(ctx, name)
if err != nil {
return err
}
if payload == nil {
return ErrPayloadNotFound
}
for {
_ = store.transfer.syncSegmentCache()
err = store.doGetReplica(payload, addrList, sizeList)
if err != nil {
return err
}
newPayload, recheckRevision, err := store.metadata.Get(ctx, name)
if err != nil {
return err
}
if revision == recheckRevision {
break
}
if isSubsetOf(payload, newPayload) {
break
}
}
return store.updatePayloadMetadata(ctx, name, addrList, sizeList, payload, revision)
}

func (store *P2PStore) performTransfer(source uintptr, shard Shard) error {
const MAX_RETRY_COUNT int = 8
retryCount := 0

for retryCount < MAX_RETRY_COUNT {
for retryCount < shard.Count() {
batchID, err := store.transfer.allocateBatchID(1)
if err != nil {
return err
Expand Down
6 changes: 5 additions & 1 deletion mooncake-p2p-store/src/p2pstore/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"math/rand"
"time"

"go.etcd.io/etcd/client/v3"
clientv3 "go.etcd.io/etcd/client/v3"
)

// key = payload_name
Expand Down Expand Up @@ -70,6 +70,10 @@ func (s *Shard) GetLocation(retryTimes int) *Location {
}
}

func (s *Shard) Count() int {
return len(s.ReplicaList) + len(s.Gold)
}

func (s *Shard) getRandomLocation() *Location {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
if len(s.ReplicaList) > 0 {
Expand Down

0 comments on commit f2a28a8

Please sign in to comment.