diff --git a/scheduler/service/service_v2.go b/scheduler/service/service_v2.go index 93158aded72..2a2f098de3d 100644 --- a/scheduler/service/service_v2.go +++ b/scheduler/service/service_v2.go @@ -93,7 +93,7 @@ func (v *V2) AnnouncePeer(stream schedulerv2.Scheduler_AnnouncePeerServer) error for { select { case <-ctx.Done(): - logger.Info("context was done") + logger.Info("announce peer context was done") return ctx.Err() default: } @@ -141,7 +141,7 @@ func (v *V2) AnnouncePeer(stream schedulerv2.Scheduler_AnnouncePeerServer) error case *schedulerv2.AnnouncePeerRequest_DownloadPeerFinishedRequest: downloadPeerFinishedRequest := announcePeerRequest.DownloadPeerFinishedRequest log.Infof("receive DownloadPeerFinishedRequest, content length: %d, piece count: %d", downloadPeerFinishedRequest.GetContentLength(), downloadPeerFinishedRequest.GetPieceCount()) - // Notice: Handler uses context.Background() to avoid stream cancel by dfdameon. + // Notice: Handler uses context.Background() to avoid stream cancel by dfdaemon. if err := v.handleDownloadPeerFinishedRequest(context.Background(), req.GetPeerId()); err != nil { log.Error(err) return err @@ -149,21 +149,21 @@ func (v *V2) AnnouncePeer(stream schedulerv2.Scheduler_AnnouncePeerServer) error case *schedulerv2.AnnouncePeerRequest_DownloadPeerBackToSourceFinishedRequest: downloadPeerBackToSourceFinishedRequest := announcePeerRequest.DownloadPeerBackToSourceFinishedRequest log.Infof("receive DownloadPeerBackToSourceFinishedRequest, content length: %d, piece count: %d", downloadPeerBackToSourceFinishedRequest.GetContentLength(), downloadPeerBackToSourceFinishedRequest.GetPieceCount()) - // Notice: Handler uses context.Background() to avoid stream cancel by dfdameon. + // Notice: Handler uses context.Background() to avoid stream cancel by dfdaemon. if err := v.handleDownloadPeerBackToSourceFinishedRequest(context.Background(), req.GetPeerId(), downloadPeerBackToSourceFinishedRequest); err != nil { log.Error(err) return err } case *schedulerv2.AnnouncePeerRequest_DownloadPeerFailedRequest: log.Infof("receive DownloadPeerFailedRequest, description: %s", announcePeerRequest.DownloadPeerFailedRequest.GetDescription()) - // Notice: Handler uses context.Background() to avoid stream cancel by dfdameon. + // Notice: Handler uses context.Background() to avoid stream cancel by dfdaemon. if err := v.handleDownloadPeerFailedRequest(context.Background(), req.GetPeerId()); err != nil { log.Error(err) return err } case *schedulerv2.AnnouncePeerRequest_DownloadPeerBackToSourceFailedRequest: log.Infof("receive DownloadPeerBackToSourceFailedRequest, description: %s", announcePeerRequest.DownloadPeerBackToSourceFailedRequest.GetDescription()) - // Notice: Handler uses context.Background() to avoid stream cancel by dfdameon. + // Notice: Handler uses context.Background() to avoid stream cancel by dfdaemon. if err := v.handleDownloadPeerBackToSourceFailedRequest(context.Background(), req.GetPeerId()); err != nil { log.Error(err) return err @@ -867,10 +867,12 @@ func (v *V2) AnnouncePeers(stream schedulerv2.Scheduler_AnnouncePeersServer) err ctx, cancel := context.WithCancel(stream.Context()) defer cancel() + announcePeersCount := 0 + for { select { case <-ctx.Done(): - logger.Info("context was done") + logger.Info("announce peers context was done") return ctx.Err() default: } @@ -878,6 +880,7 @@ func (v *V2) AnnouncePeers(stream schedulerv2.Scheduler_AnnouncePeersServer) err request, err := stream.Recv() if err != nil { if err == io.EOF { + logger.Infof("announce %d peers", announcePeersCount) return nil } @@ -885,10 +888,12 @@ func (v *V2) AnnouncePeers(stream schedulerv2.Scheduler_AnnouncePeersServer) err return err } - if _, err := v.handleAnnouncePeersRequest(ctx, request); err != nil { + peers, err := v.handleAnnouncePeersRequest(ctx, request) + if err != nil { logger.Error(err) return err } + announcePeersCount += len(peers) } } @@ -1484,16 +1489,17 @@ func (v *V2) downloadTaskBySeedPeer(ctx context.Context, taskID string, download } // handleAnnouncePeersRequest handles AnnouncePeersRequest. -func (v *V2) handleAnnouncePeersRequest(ctx context.Context, request *schedulerv2.AnnouncePeersRequest) ([]*resource.Peer, error) { - var peers []*resource.Peer - +func (v *V2) handleAnnouncePeersRequest(ctx context.Context, request *schedulerv2.AnnouncePeersRequest) (peers []*resource.Peer, err error) { for _, p := range request.Peers { hostID := p.GetHost().GetId() peerTask := p.GetTask() + if peerTask == nil { + return nil, status.Error(codes.InvalidArgument, "request is invalid and doesn't contain a task") + } taskID := peerTask.GetId() peerID := p.GetId() download := &commonv2.Download{ - PieceLength: peerTask.GetPieceLength(), + PieceLength: &peerTask.PieceLength, Digest: peerTask.Digest, Url: peerTask.GetUrl(), Tag: peerTask.Tag, @@ -1528,11 +1534,7 @@ func (v *V2) handleAnnouncePeersRequest(ctx context.Context, request *schedulerv // advance the task state to TaskStateSucceeded. if !task.FSM.Is(resource.TaskStateSucceeded) { if task.FSM.Can(resource.TaskEventDownload) { - if err := task.FSM.Event(ctx, resource.TaskEventDownload); err != nil { - msg := fmt.Sprintf("task fsm event failed: %s", err.Error()) - peer.Log.Error(msg) - return nil, status.Error(codes.Internal, err.Error()) - } + task.FSM.SetState(resource.TaskEventDownload) } // Construct piece. @@ -1571,9 +1573,7 @@ func (v *V2) handleAnnouncePeersRequest(ctx context.Context, request *schedulerv if peer.Range == nil && !peer.Task.FSM.Is(resource.TaskStateSucceeded) { peer.Task.ContentLength.Store(int64(len(p.Pieces))) peer.Task.TotalPieceCount.Store(int32(task.ContentLength.Load())) - if err := peer.Task.FSM.Event(ctx, resource.TaskEventDownloadSucceeded); err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } + peer.Task.FSM.SetState(resource.TaskEventDownloadSucceeded) } } diff --git a/scheduler/service/service_v2_test.go b/scheduler/service/service_v2_test.go index eeef14cc30a..6574c19c249 100644 --- a/scheduler/service/service_v2_test.go +++ b/scheduler/service/service_v2_test.go @@ -44,6 +44,7 @@ import ( schedulerv2mocks "d7y.io/api/v2/pkg/apis/scheduler/v2/mocks" managertypes "d7y.io/dragonfly/v2/manager/types" + "d7y.io/dragonfly/v2/pkg/idgen" nethttp "d7y.io/dragonfly/v2/pkg/net/http" pkgtypes "d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/scheduler/config" @@ -3098,7 +3099,7 @@ func TestServiceV2_handleResource(t *testing.T) { assert.Equal(peer.Priority, download.Priority) assert.Equal(peer.Range.Start, int64(download.Range.Start)) assert.Equal(peer.Range.Length, int64(download.Range.Length)) - assert.NotNil(peer.AnnouncePeerStream) + assert.NotNil(peer.AnnouncePeerStream.Load()) assert.EqualValues(peer.Host, mockHost) assert.EqualValues(peer.Task, mockTask) }, @@ -3407,3 +3408,263 @@ func TestServiceV2_downloadTaskBySeedPeer(t *testing.T) { }) } } + +func TestServiceV2_handleAnnouncePeersRequest(t *testing.T) { + tests := []struct { + name string + request *schedulerv2.AnnouncePeersRequest + run func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) + }{ + { + name: "task and host exist in scheduler, peer does not", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + ) + + assert := assert.New(t) + peers, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.NoError(err) + peer := peers[0] + assert.Equal(peer.ID, mockPeer.ID) + assert.Nil(peer.AnnouncePeerStream.Load()) + assert.True(peer.FSM.Is(resource.PeerStateSucceeded)) + assert.True(peer.Task.FSM.Is(resource.PeerStateSucceeded)) + assert.EqualValues(peer.Host, mockHost) + assert.EqualValues(peer.Task, mockTask) + piece, _ := peer.Pieces.Load(mockPiece.Number) + assert.EqualValues(piece.(*resource.Piece).Digest, mockPiece.Digest) + }, + }, + { + name: "invalid request with no task", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Error(codes.InvalidArgument, "request is invalid and doesn't contain a task")) + }, + }, + { + name: "host does not exist in scheduler", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(nil, false).Times(1), + ) + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Errorf(codes.NotFound, "host %s not found", mockHost.ID)) + }, + }, + { + name: "task dag size exceeds PeerCountLimitForTask", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Delete(gomock.Eq(mockPeer.ID)).Return().Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + ) + for i := 0; i < resource.PeerCountLimitForTask+1; i++ { + peer := resource.NewPeer(idgen.PeerIDV1("127.0.0.1"), mockResourceConfig, mockTask, mockHost) + mockTask.StorePeer(peer) + } + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.NoError(err) + _, loaded := peerManager.Load(mockPeer.ID) + assert.Equal(loaded, false) + }, + }, + { + name: "construct piece fails due to invalid digest", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String() + ":", + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + ) + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Errorf(codes.InvalidArgument, "invalid digest")) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctl := gomock.NewController(t) + defer ctl.Finish() + scheduling := schedulingmocks.NewMockScheduling(ctl) + res := resource.NewMockResource(ctl) + dynconfig := configmocks.NewMockDynconfigInterface(ctl) + storage := storagemocks.NewMockStorage(ctl) + networkTopology := networktopologymocks.NewMockNetworkTopology(ctl) + hostManager := resource.NewMockHostManager(ctl) + taskManager := resource.NewMockTaskManager(ctl) + peerManager := resource.NewMockPeerManager(ctl) + + mockHost := resource.NewHost( + mockRawHost.ID, mockRawHost.IP, mockRawHost.Hostname, + mockRawHost.Port, mockRawHost.DownloadPort, mockRawHost.Type) + mockTask := resource.NewTask(mockTaskID, mockTaskURL, mockTaskTag, mockTaskApplication, commonv2.TaskType_DFDAEMON, mockTaskFilteredQueryParams, mockTaskHeader, mockTaskBackToSourceLimit, resource.WithDigest(mockTaskDigest), resource.WithPieceLength(mockTaskPieceLength)) + mockPeer := resource.NewPeer(mockPeerID, mockResourceConfig, mockTask, mockHost) + svc := NewV2(&config.Config{Scheduler: mockSchedulerConfig}, res, scheduling, dynconfig, storage, networkTopology) + + tc.run(t, svc, tc.request, mockHost, mockTask, mockPeer, hostManager, taskManager, peerManager, res.EXPECT(), hostManager.EXPECT(), taskManager.EXPECT(), peerManager.EXPECT()) + }) + } +}