diff --git a/lib/src/client/http2_connection.dart b/lib/src/client/http2_connection.dart index 998410a1..509bd95a 100644 --- a/lib/src/client/http2_connection.dart +++ b/lib/src/client/http2_connection.dart @@ -115,8 +115,7 @@ class Http2ClientConnection implements connection.ClientConnection { }, onPingTimeout: () => shutdown(), ); - transport.onFrameReceived - .listen((_) => keepAliveManager?.onFrameReceived()); + transport.frameReceived = keepAliveManager?.onFrameReceived; } _connectionLifeTimer ..reset() diff --git a/lib/src/server/handler.dart b/lib/src/server/handler.dart index ecbebe1e..e94d55b4 100644 --- a/lib/src/server/handler.dart +++ b/lib/src/server/handler.dart @@ -66,8 +66,8 @@ class ServerHandler extends ServiceCall { final X509Certificate? _clientCertificate; final InternetAddress? _remoteAddress; - /// Emits a ping everytime data is received - final Sink? onDataReceived; + /// Callback everytime data is received + final void Function()? onDataReceived; final Completer _isCanceledCompleter = Completer(); @@ -148,7 +148,7 @@ class ServerHandler extends ServiceCall { // -- Idle state, incoming data -- void _onDataIdle(GrpcMessage headerMessage) async { - onDataReceived?.add(null); + onDataReceived?.call(); if (headerMessage is! GrpcMetadata) { _sendError(GrpcError.unimplemented('Expected header frame')); _sinkIncoming(); @@ -289,7 +289,7 @@ class ServerHandler extends ServiceCall { return; } - onDataReceived?.add(null); + onDataReceived?.call(); final data = message; Object? request; try { diff --git a/lib/src/server/server.dart b/lib/src/server/server.dart index bc40edb0..5c8c942d 100644 --- a/lib/src/server/server.dart +++ b/lib/src/server/server.dart @@ -117,25 +117,21 @@ class ConnectionServer { required ServerTransportConnection connection, X509Certificate? clientCertificate, InternetAddress? remoteAddress, + required ServerKeepAlive serverKeepAlive, }) async { _connections.add(connection); handlers[connection] = []; // TODO(jakobr): Set active state handlers, close connection after idle // timeout. - final onDataReceivedController = StreamController(); - ServerKeepAlive( - options: _keepAliveOptions, - tooManyBadPings: () async => - await connection.terminate(ErrorCode.ENHANCE_YOUR_CALM), - pingNotifier: connection.onPingReceived, - dataNotifier: onDataReceivedController.stream, - ).handle(); + + serverKeepAlive.tooManyBadPings = + () async => await connection.terminate(ErrorCode.ENHANCE_YOUR_CALM); connection.incomingStreams.listen((stream) { final handler = serveStream_( stream: stream, clientCertificate: clientCertificate, remoteAddress: remoteAddress, - onDataReceived: onDataReceivedController.sink, + onDataReceived: serverKeepAlive.onDataReceived, ); handler.onCanceled.then((_) => handlers[connection]?.remove(handler)); handlers[connection]!.add(handler); @@ -153,7 +149,6 @@ class ConnectionServer { } _connections.remove(connection); handlers.remove(connection); - await onDataReceivedController.close(); }); } @@ -162,7 +157,7 @@ class ConnectionServer { required ServerTransportStream stream, X509Certificate? clientCertificate, InternetAddress? remoteAddress, - Sink? onDataReceived, + void Function()? onDataReceived, }) { return ServerHandler( stream: stream, @@ -279,17 +274,22 @@ class Server extends ConnectionServer { clientCertificate = socket.peerCertificate; } + final serverKeepAlive = ServerKeepAlive(options: _keepAliveOptions); final connection = ServerTransportConnection.viaSocket( socket, settings: http2ServerSettings, + pingReceived: serverKeepAlive.onPingReceived, ); + connection.pingReceived = serverKeepAlive.onPingReceived; serveConnection( connection: connection, clientCertificate: clientCertificate, remoteAddress: socket.remoteAddressOrNull, + serverKeepAlive: serverKeepAlive, ); }, onError: (error, stackTrace) { + print('error'); if (error is Error) { Zone.current.handleUncaughtError(error, stackTrace); } @@ -302,7 +302,7 @@ class Server extends ConnectionServer { required ServerTransportStream stream, X509Certificate? clientCertificate, InternetAddress? remoteAddress, - Sink? onDataReceived, + void Function()? onDataReceived, }) { return ServerHandler( stream: stream, diff --git a/lib/src/server/server_keepalive.dart b/lib/src/server/server_keepalive.dart index 890e0fe1..1e7d5c3d 100644 --- a/lib/src/server/server_keepalive.dart +++ b/lib/src/server/server_keepalive.dart @@ -40,38 +40,21 @@ class ServerKeepAliveOptions { class ServerKeepAlive { /// What to do after receiving too many bad pings, probably shut down the /// connection to not be DDoSed. - final Future Function()? tooManyBadPings; + Future Function()? tooManyBadPings; final ServerKeepAliveOptions options; - /// A stream of events for every time the server gets pinged. - final Stream pingNotifier; - - /// A stream of events for every time the server receives data. - final Stream dataNotifier; - int _badPings = 0; Stopwatch? _timeOfLastReceivedPing; ServerKeepAlive({ this.tooManyBadPings, required this.options, - required this.pingNotifier, - required this.dataNotifier, }); - void handle() { - // If we don't care about bad pings, there is not point in listening to - // events. - if (_enforcesMaxBadPings) { - pingNotifier.listen((_) => _onPingReceived()); - dataNotifier.listen((_) => _onDataReceived()); - } - } - bool get _enforcesMaxBadPings => (options.maxBadPings ?? 0) > 0; - Future _onPingReceived() async { + Future onPingReceived(int _) async { if (_enforcesMaxBadPings) { if (_timeOfLastReceivedPing == null) { _timeOfLastReceivedPing = clock.stopwatch() @@ -82,12 +65,13 @@ class ServerKeepAlive { _badPings++; } if (_badPings > options.maxBadPings!) { + // print('Call too many bad pings'); await tooManyBadPings?.call(); } } } - void _onDataReceived() { + void onDataReceived() { if (_enforcesMaxBadPings) { _badPings = 0; _timeOfLastReceivedPing = null; diff --git a/pubspec_overrides.yaml b/pubspec_overrides.yaml new file mode 100644 index 00000000..22fa6b61 --- /dev/null +++ b/pubspec_overrides.yaml @@ -0,0 +1,3 @@ +dependency_overrides: + http2: + path: ../http2 diff --git a/test/client_tests/client_keepalive_manager_test.mocks.dart b/test/client_tests/client_keepalive_manager_test.mocks.dart index 941de6a6..b1ea39f9 100644 --- a/test/client_tests/client_keepalive_manager_test.mocks.dart +++ b/test/client_tests/client_keepalive_manager_test.mocks.dart @@ -1,9 +1,7 @@ -// Mocks generated by Mockito 5.4.1 from annotations +// Mocks generated by Mockito 5.4.2 from annotations // in grpc/test/client_tests/client_keepalive_manager_test.dart. // Do not manually edit this file. -// @dart=2.19 - // ignore_for_file: no_leading_underscores_for_library_prefixes import 'package:mockito/mockito.dart' as _i1; @@ -32,6 +30,7 @@ class MockPinger extends _i1.Mock implements _i2.Pinger { ), returnValueForMissingStub: null, ); + @override void onPingTimeout() => super.noSuchMethod( Invocation.method( diff --git a/test/keepalive_test.dart b/test/keepalive_test.dart index e4ddf2d1..734a61bd 100644 --- a/test/keepalive_test.dart +++ b/test/keepalive_test.dart @@ -173,7 +173,7 @@ class FakeEchoService extends EchoServiceBase { @override Stream serverStreamingEcho( ServiceCall call, ServerStreamingEchoRequest request) { - // TODO: implement serverStreamingEcho +// TODO: implement serverStreamingEcho throw UnimplementedError(); } } diff --git a/test/server_keepalive_manager_test.dart b/test/server_keepalive_manager_test.dart index 230fc3b7..f71ec9bb 100644 --- a/test/server_keepalive_manager_test.dart +++ b/test/server_keepalive_manager_test.dart @@ -1,57 +1,45 @@ -import 'dart:async'; - import 'package:fake_async/fake_async.dart'; import 'package:grpc/src/server/server_keepalive.dart'; import 'package:test/test.dart'; void main() { - late StreamController pingStream; - late StreamController dataStream; late int maxBadPings; var goAway = false; - void initServer([ServerKeepAliveOptions? options]) => ServerKeepAlive( + ServerKeepAlive initServer([ServerKeepAliveOptions? options]) => + ServerKeepAlive( options: options ?? ServerKeepAliveOptions( maxBadPings: maxBadPings, minIntervalBetweenPingsWithoutData: Duration(milliseconds: 5), ), - pingNotifier: pingStream.stream, - dataNotifier: dataStream.stream, tooManyBadPings: () async => goAway = true, - ).handle(); + ); setUp(() { - pingStream = StreamController(); - dataStream = StreamController(); maxBadPings = 10; goAway = false; }); - tearDown(() { - pingStream.close(); - dataStream.close(); - }); - final timeAfterPing = Duration(milliseconds: 10); test('Sending too many pings without data kills connection', () async { FakeAsync().run((async) { - initServer(); + final server = initServer(); // Send good ping - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); // Send [maxBadPings] bad pings, that's still ok for (var i = 0; i < maxBadPings; i++) { - pingStream.sink.add(null); + server.onPingReceived(0); } async.elapse(timeAfterPing); expect(goAway, false); // Send another bad ping; that's one too many! - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); expect(goAway, true); }); @@ -60,17 +48,17 @@ void main() { 'Sending too many pings without data doesn`t kill connection if the server doesn`t care', () async { FakeAsync().run((async) { - initServer(ServerKeepAliveOptions( + final server = initServer(ServerKeepAliveOptions( maxBadPings: null, minIntervalBetweenPingsWithoutData: Duration(milliseconds: 5), )); // Send good ping - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); // Send a lot of bad pings, that's still ok. for (var i = 0; i < 50; i++) { - pingStream.sink.add(null); + server.onPingReceived(0); } async.elapse(timeAfterPing); expect(goAway, false); @@ -79,36 +67,36 @@ void main() { test('Sending many pings with data doesn`t kill connection', () async { FakeAsync().run((async) { - initServer(); + final server = initServer(); // Send good ping - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); // Send [maxBadPings] bad pings, that's still ok for (var i = 0; i < maxBadPings; i++) { - pingStream.sink.add(null); + server.onPingReceived(0); } async.elapse(timeAfterPing); expect(goAway, false); // Sending data resets the bad ping count - dataStream.add(null); + server.onDataReceived(); async.elapse(timeAfterPing); // Send good ping - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); // Send [maxBadPings] bad pings, that's still ok for (var i = 0; i < maxBadPings; i++) { - pingStream.sink.add(null); + server.onPingReceived(0); } async.elapse(timeAfterPing); expect(goAway, false); // Send another bad ping; that's one too many! - pingStream.sink.add(null); + server.onPingReceived(0); async.elapse(timeAfterPing); expect(goAway, true); }); diff --git a/test/src/client_utils.mocks.dart b/test/src/client_utils.mocks.dart index 98702df4..f6cefcff 100644 --- a/test/src/client_utils.mocks.dart +++ b/test/src/client_utils.mocks.dart @@ -1,9 +1,7 @@ -// Mocks generated by Mockito 5.4.1 from annotations +// Mocks generated by Mockito 5.4.2 from annotations // in grpc/test/src/client_utils.dart. // Do not manually edit this file. -// @dart=2.19 - // ignore_for_file: no_leading_underscores_for_library_prefixes import 'dart:async' as _i3; @@ -57,6 +55,25 @@ class MockClientTransportConnection extends _i1.Mock Invocation.getter(#isOpen), returnValue: false, ) as bool); + + @override + set pingReceived(dynamic Function(int)? _pingReceived) => super.noSuchMethod( + Invocation.setter( + #pingReceived, + _pingReceived, + ), + returnValueForMissingStub: null, + ); + + @override + set frameReceived(dynamic Function()? _frameReceived) => super.noSuchMethod( + Invocation.setter( + #frameReceived, + _frameReceived, + ), + returnValueForMissingStub: null, + ); + @override set onActiveStateChanged(_i2.ActiveStateHandler? callback) => super.noSuchMethod( @@ -66,21 +83,13 @@ class MockClientTransportConnection extends _i1.Mock ), returnValueForMissingStub: null, ); + @override _i3.Future get onInitialPeerSettingsReceived => (super.noSuchMethod( Invocation.getter(#onInitialPeerSettingsReceived), returnValue: _i3.Future.value(), ) as _i3.Future); - @override - _i3.Stream get onPingReceived => (super.noSuchMethod( - Invocation.getter(#onPingReceived), - returnValue: _i3.Stream.empty(), - ) as _i3.Stream); - @override - _i3.Stream get onFrameReceived => (super.noSuchMethod( - Invocation.getter(#onFrameReceived), - returnValue: _i3.Stream.empty(), - ) as _i3.Stream); + @override _i2.ClientTransportStream makeRequest( List<_i4.Header>? headers, { @@ -101,6 +110,7 @@ class MockClientTransportConnection extends _i1.Mock ), ), ) as _i2.ClientTransportStream); + @override _i3.Future ping() => (super.noSuchMethod( Invocation.method( @@ -109,6 +119,7 @@ class MockClientTransportConnection extends _i1.Mock ), returnValue: _i3.Future.value(), ) as _i3.Future); + @override _i3.Future finish() => (super.noSuchMethod( Invocation.method( @@ -117,6 +128,7 @@ class MockClientTransportConnection extends _i1.Mock ), returnValue: _i3.Future.value(), ) as _i3.Future); + @override _i3.Future terminate([int? errorCode]) => (super.noSuchMethod( Invocation.method( @@ -141,16 +153,19 @@ class MockClientTransportStream extends _i1.Mock Invocation.getter(#peerPushes), returnValue: _i3.Stream<_i2.TransportStreamPush>.empty(), ) as _i3.Stream<_i2.TransportStreamPush>); + @override int get id => (super.noSuchMethod( Invocation.getter(#id), returnValue: 0, ) as int); + @override _i3.Stream<_i2.StreamMessage> get incomingMessages => (super.noSuchMethod( Invocation.getter(#incomingMessages), returnValue: _i3.Stream<_i2.StreamMessage>.empty(), ) as _i3.Stream<_i2.StreamMessage>); + @override _i3.StreamSink<_i2.StreamMessage> get outgoingMessages => (super.noSuchMethod( Invocation.getter(#outgoingMessages), @@ -159,6 +174,7 @@ class MockClientTransportStream extends _i1.Mock Invocation.getter(#outgoingMessages), ), ) as _i3.StreamSink<_i2.StreamMessage>); + @override set onTerminated(void Function(int?)? value) => super.noSuchMethod( Invocation.setter( @@ -167,6 +183,7 @@ class MockClientTransportStream extends _i1.Mock ), returnValueForMissingStub: null, ); + @override void terminate() => super.noSuchMethod( Invocation.method( @@ -175,6 +192,7 @@ class MockClientTransportStream extends _i1.Mock ), returnValueForMissingStub: null, ); + @override void sendHeaders( List<_i4.Header>? headers, { @@ -188,6 +206,7 @@ class MockClientTransportStream extends _i1.Mock ), returnValueForMissingStub: null, ); + @override void sendData( List? bytes, {