From dca33f0f797bc5540fca377ac25f4f651781afb5 Mon Sep 17 00:00:00 2001 From: BruceAko Date: Thu, 15 Aug 2024 08:47:12 +0800 Subject: [PATCH] fix: add unit tests for AnnouncePeers Signed-off-by: BruceAko --- scheduler/service/service_v2.go | 5 +- scheduler/service/service_v2_test.go | 263 ++++++++++++++++++++++++++- 2 files changed, 266 insertions(+), 2 deletions(-) diff --git a/scheduler/service/service_v2.go b/scheduler/service/service_v2.go index 93158aded72..77a63402e20 100644 --- a/scheduler/service/service_v2.go +++ b/scheduler/service/service_v2.go @@ -1490,10 +1490,13 @@ func (v *V2) handleAnnouncePeersRequest(ctx context.Context, request *schedulerv for _, p := range request.Peers { hostID := p.GetHost().GetId() peerTask := p.GetTask() + if peerTask == nil { + return nil, status.Error(codes.InvalidArgument, "request is invalid and doesn't contain a task") + } taskID := peerTask.GetId() peerID := p.GetId() download := &commonv2.Download{ - PieceLength: peerTask.GetPieceLength(), + PieceLength: &peerTask.PieceLength, Digest: peerTask.Digest, Url: peerTask.GetUrl(), Tag: peerTask.Tag, diff --git a/scheduler/service/service_v2_test.go b/scheduler/service/service_v2_test.go index eeef14cc30a..6327f0026c2 100644 --- a/scheduler/service/service_v2_test.go +++ b/scheduler/service/service_v2_test.go @@ -44,6 +44,7 @@ import ( schedulerv2mocks "d7y.io/api/v2/pkg/apis/scheduler/v2/mocks" managertypes "d7y.io/dragonfly/v2/manager/types" + "d7y.io/dragonfly/v2/pkg/idgen" nethttp "d7y.io/dragonfly/v2/pkg/net/http" pkgtypes "d7y.io/dragonfly/v2/pkg/types" "d7y.io/dragonfly/v2/scheduler/config" @@ -3098,7 +3099,7 @@ func TestServiceV2_handleResource(t *testing.T) { assert.Equal(peer.Priority, download.Priority) assert.Equal(peer.Range.Start, int64(download.Range.Start)) assert.Equal(peer.Range.Length, int64(download.Range.Length)) - assert.NotNil(peer.AnnouncePeerStream) + assert.NotNil(peer.AnnouncePeerStream.Load()) assert.EqualValues(peer.Host, mockHost) assert.EqualValues(peer.Task, mockTask) }, @@ -3407,3 +3408,263 @@ func TestServiceV2_downloadTaskBySeedPeer(t *testing.T) { }) } } + +func TestServiceV2_handleAnnouncePeersRequest(t *testing.T) { + tests := []struct { + name string + request *schedulerv2.AnnouncePeersRequest + run func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) + }{ + { + name: "task and host exist in scheduler, peer does not", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + ) + + assert := assert.New(t) + peers, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.NoError(err) + peer := peers[0] + assert.Equal(peer.ID, mockPeer.ID) + assert.Nil(peer.AnnouncePeerStream.Load()) + assert.True(peer.FSM.Is(resource.PeerStateSucceeded)) + assert.True(peer.Task.FSM.Is(resource.PeerStateSucceeded)) + assert.EqualValues(peer.Host, mockHost) + assert.EqualValues(peer.Task, mockTask) + piece, _ := peer.Pieces.Load(mockPiece.Number) + assert.EqualValues(piece.(*resource.Piece).Digest, mockPiece.Digest) + }, + }, + { + name: "invalid request with no task", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Error(codes.InvalidArgument, "request is invalid and doesn't contain a task")) + }, + }, + { + name: "host does not exist in scheduler", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(nil, false).Times(1), + ) + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Errorf(codes.NotFound, "host %s not found", mockHost.ID)) + }, + }, + { + name: "task dag size exceeds the limit", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String(), + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Delete(gomock.Eq(mockPeer.ID)).Return().Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + ) + for i := 0; i < resource.PeerCountLimitForTask+1; i++ { + peer := resource.NewPeer(idgen.PeerIDV1("127.0.0.1"), mockResourceConfig, mockTask, mockHost) + mockTask.StorePeer(peer) + } + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.NoError(err) + _, loaded := peerManager.Load(mockPeer.ID) + assert.Equal(loaded, false) + }, + }, + { + name: "construct piece fails due to invalid digest", + request: &schedulerv2.AnnouncePeersRequest{ + Peers: []*commonv2.Peer{ + { + Id: mockPeerID, + Pieces: []*commonv2.Piece{ + { + Number: uint32(mockPiece.Number), + ParentId: &mockPiece.ParentID, + Offset: mockPiece.Offset, + Length: mockPiece.Length, + Digest: mockPiece.Digest.String() + ":", + TrafficType: &mockPiece.TrafficType, + Cost: durationpb.New(mockPiece.Cost), + CreatedAt: timestamppb.New(mockPiece.CreatedAt), + }, + }, + Task: &commonv2.Task{ + Id: mockTaskID, + PieceLength: uint64(mockTaskPieceLength), + ContentLength: uint64(1024), + }, + Host: &commonv2.Host{ + Id: mockHostID, + }, + }, + }, + }, + run: func(t *testing.T, svc *V2, request *schedulerv2.AnnouncePeersRequest, mockHost *resource.Host, mockTask *resource.Task, mockPeer *resource.Peer, + hostManager resource.HostManager, taskManager resource.TaskManager, peerManager resource.PeerManager, mr *resource.MockResourceMockRecorder, mh *resource.MockHostManagerMockRecorder, + mt *resource.MockTaskManagerMockRecorder, mp *resource.MockPeerManagerMockRecorder) { + gomock.InOrder( + mr.HostManager().Return(hostManager).Times(1), + mh.Load(gomock.Eq(mockHost.ID)).Return(mockHost, true).Times(1), + mr.TaskManager().Return(taskManager).Times(1), + mt.Load(gomock.Eq(mockTask.ID)).Return(mockTask, true).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Load(gomock.Eq(mockPeer.ID)).Return(nil, false).Times(1), + mr.PeerManager().Return(peerManager).Times(1), + mp.Store(gomock.Any()).Return().Times(1), + ) + + assert := assert.New(t) + _, err := svc.handleAnnouncePeersRequest(context.Background(), request) + assert.ErrorIs(err, status.Errorf(codes.InvalidArgument, "invalid digest")) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctl := gomock.NewController(t) + defer ctl.Finish() + scheduling := schedulingmocks.NewMockScheduling(ctl) + res := resource.NewMockResource(ctl) + dynconfig := configmocks.NewMockDynconfigInterface(ctl) + storage := storagemocks.NewMockStorage(ctl) + networkTopology := networktopologymocks.NewMockNetworkTopology(ctl) + hostManager := resource.NewMockHostManager(ctl) + taskManager := resource.NewMockTaskManager(ctl) + peerManager := resource.NewMockPeerManager(ctl) + + mockHost := resource.NewHost( + mockRawHost.ID, mockRawHost.IP, mockRawHost.Hostname, + mockRawHost.Port, mockRawHost.DownloadPort, mockRawHost.Type) + mockTask := resource.NewTask(mockTaskID, mockTaskURL, mockTaskTag, mockTaskApplication, commonv2.TaskType_DFDAEMON, mockTaskFilteredQueryParams, mockTaskHeader, mockTaskBackToSourceLimit, resource.WithDigest(mockTaskDigest), resource.WithPieceLength(mockTaskPieceLength)) + mockPeer := resource.NewPeer(mockPeerID, mockResourceConfig, mockTask, mockHost) + svc := NewV2(&config.Config{Scheduler: mockSchedulerConfig}, res, scheduling, dynconfig, storage, networkTopology) + + tc.run(t, svc, tc.request, mockHost, mockTask, mockPeer, hostManager, taskManager, peerManager, res.EXPECT(), hostManager.EXPECT(), taskManager.EXPECT(), peerManager.EXPECT()) + }) + } +}