diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index 884ea0562c..54c38b83be 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -390,7 +390,9 @@ func (g *GenericMulticastProtocolState) MakeAllNonMemberLocked() { switch g.mode { case protocolModeV2: v2ReportBuilder = g.opts.Protocol.NewReportV2Builder() - handler = func(groupAddress tcpip.Address, _ *multicastGroupState) { + handler = func(groupAddress tcpip.Address, info *multicastGroupState) { + info.cancelDelayedReportJob() + // Send a report immediately to announce us leaving the group. v2ReportBuilder.AddRecord( MulticastGroupProtocolV2ReportRecordChangeToIncludeMode, diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 8f1a2e446f..6e1ff80b3a 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -1392,6 +1392,107 @@ func TestGroupStateNonMember(t *testing.T) { } } +// TestMakeAllNonMemberCancelsDelayedReportJob tests that the delayed report job +// is cancelled on MakeAllNonMember, otherwise the job will panic if the endpoint +// is disabled. +func TestMakeAllNonMemberCancelsDelayedReportJob(t *testing.T) { + const maxRespCode = 1 + + tests := []struct { + name string + v1 bool + v1Compatibility bool + checkFields func(tcpip.Address, bool) checkFields + }{ + { + name: "V1", + v1: true, + v1Compatibility: false, + checkFields: func(addr tcpip.Address, leave bool) checkFields { + if leave { + return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} + } + return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} + }, + }, + { + name: "V1 Compatibility", + v1: false, + v1Compatibility: true, + checkFields: func(addr tcpip.Address, leave bool) checkFields { + if leave { + return checkFields{sendLeaveGroupAddresses: []tcpip.Address{addr}} + } + return checkFields{sendReportGroupAddresses: []tcpip.Address{addr}} + }, + }, + { + name: "V2", + v1: false, + v1Compatibility: false, + checkFields: func(addr tcpip.Address, leave bool) checkFields { + recordType := ip.MulticastGroupProtocolV2ReportRecordChangeToExcludeMode + if leave { + recordType = ip.MulticastGroupProtocolV2ReportRecordChangeToIncludeMode + } + return checkFields{sentV2Reports: []mockReportV2{{records: []mockReportV2Record{mockReportV2Record{ + recordType: recordType, + groupAddress: addr, + }}}}} + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + mgp := mockMulticastGroupProtocol{t: t} + clock := faketime.NewManualClock() + + mgp.init(ip.GenericMulticastProtocolOptions{ + Rand: rand.New(rand.NewSource(3)), + Clock: clock, + MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, + }, test.v1) + + if test.v1Compatibility { + // V1 query targeting an unjoined group should drop us into V1 + // compatibility mode without sending any packets, affecting tests. + mgp.handleQuery(addr3, 0) + } + + mgp.joinGroup(addr1) + if diff := mgp.check(test.checkFields(addr1, false /* leave */)); diff != "" { + t.Fatalf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + // Handle a query so that the delayed report job is scheduled when operating + // in V2 mode. + mgp.handleQueryV2(addr1, maxRespCode, header.MakeAddressIterator(addr1.Len(), bytes.NewBuffer(nil)), 0, 0) + + mgp.makeAllNonMember() + if diff := mgp.check(test.checkFields(addr1, true /* leave */)); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + + mgp.setEnabled(false) + + // Generic multicast protocol timers are expected to take the job mutex. + // + // Advance the clock to after the delayed report job is supposed to fire. + // If the delayed report job isn't cancelled by the MakeAllNonMember call, + // it will panic due to the expectation that the protocol is enabled. + if test.v1 || test.v1Compatibility { + clock.Advance(mgp.V2QueryMaxRespCodeToV1Delay(maxRespCode)) + } else { + clock.Advance(mgp.V2QueryMaxRespCodeToV2Delay(maxRespCode)) + } + if diff := mgp.check(checkFields{}); diff != "" { + t.Errorf("mockMulticastGroupProtocol mismatch (-want +got):\n%s", diff) + } + }) + } +} + func TestQueuedPackets(t *testing.T) { tests := []struct { name string