diff --git a/lib/src/message_window.dart b/lib/src/message_window.dart index 6b2436b..0a21f4b 100644 --- a/lib/src/message_window.dart +++ b/lib/src/message_window.dart @@ -1,5 +1,4 @@ import 'dart:async'; -import 'dart:collection'; import 'dart:typed_data'; import 'package:buffer/buffer.dart'; @@ -35,76 +34,95 @@ Map _messageTypeMap = { $N: NoticeMessage.parse, }; -class MessageFramer { - final CodecContext _codecContext; - late final _reader = PgByteDataReader(codecContext: _codecContext); - final messageQueue = Queue(); +class _BytesFrame { + final int type; + final int length; + final Uint8List bytes; - MessageFramer(this._codecContext); + _BytesFrame(this.type, this.length, this.bytes); +} - int? _type; - int _expectedLength = 0; +StreamTransformer bytesToMessageParser() { + return StreamTransformer.fromHandlers( + handleData: (data, sink) {}, + ); +} - bool get _hasReadHeader => _type != null; - bool get _canReadHeader => _reader.remainingLength >= _headerByteSize; +final _emptyData = Uint8List(0); - bool get _isComplete => - _expectedLength == 0 || _expectedLength <= _reader.remainingLength; +class _BytesToFrameParser + extends StreamTransformerBase { + final CodecContext _codecContext; - Future addBytes(Uint8List bytes) async { - _reader.add(bytes); + _BytesToFrameParser(this._codecContext); - while (true) { - if (!_hasReadHeader && _canReadHeader) { - _type = _reader.readUint8(); - _expectedLength = _reader.readUint32() - 4; - } + @override + Stream<_BytesFrame> bind(Stream stream) async* { + final reader = PgByteDataReader(codecContext: _codecContext); - // special case - if (_type == SharedMessageId.copyDone) { - // unlike other messages, CopyDoneMessage only takes the length as an - // argument (must be the full length including the length bytes) - final msg = CopyDoneMessage(_expectedLength + 4); - _addMsg(msg); - continue; - } + int? type; + int expectedLength = 0; - if (_hasReadHeader && _isComplete) { - final msgMaker = _messageTypeMap[_type]; - if (msgMaker == null) { - _addMsg(UnknownMessage(_type!, _reader.read(_expectedLength))); - continue; + await for (final bytes in stream) { + reader.add(bytes); + + while (true) { + if (type == null && reader.remainingLength >= _headerByteSize) { + type = reader.readUint8(); + expectedLength = reader.readUint32() - 4; } - final targetRemainingLength = _reader.remainingLength - _expectedLength; - final msg = await msgMaker(_reader, _expectedLength); - if (_reader.remainingLength > targetRemainingLength) { - throw StateError( - 'Message parser consumed more bytes than expected. type=$_type expectedLength=$_expectedLength'); + // special case + if (type == SharedMessageId.copyDone) { + // unlike other messages, CopyDoneMessage only takes the length as an + // argument (must be the full length including the length bytes) + yield _BytesFrame(type!, expectedLength, _emptyData); + type = null; + expectedLength = 0; + continue; } - // consume the rest of the message - if (_reader.remainingLength < targetRemainingLength) { - _reader.read(targetRemainingLength - _reader.remainingLength); + + if (type != null && expectedLength <= reader.remainingLength) { + final data = reader.read(expectedLength); + yield _BytesFrame(type, expectedLength, data); + type = null; + expectedLength = 0; + continue; } - _addMsg(msg); - continue; + break; } - - break; } } +} - void _addMsg(ServerMessage msg) { - messageQueue.add(msg); - _type = null; - _expectedLength = 0; - } +class BytesToMessageParser + extends StreamTransformerBase { + final CodecContext _codecContext; + + BytesToMessageParser(this._codecContext); - bool get hasMessage => messageQueue.isNotEmpty; + @override + Stream bind(Stream stream) { + return stream + .transform(_BytesToFrameParser(_codecContext)) + .asyncMap((frame) async { + // special case + if (frame.type == SharedMessageId.copyDone) { + // unlike other messages, CopyDoneMessage only takes the length as an + // argument (must be the full length including the length bytes) + return CopyDoneMessage(frame.length + 4); + } + + final msgMaker = _messageTypeMap[frame.type]; + if (msgMaker == null) { + return UnknownMessage(frame.type, frame.bytes); + } - ServerMessage popMessage() { - return messageQueue.removeFirst(); + return await msgMaker( + PgByteDataReader(codecContext: _codecContext)..add(frame.bytes), + frame.bytes.length); + }); } } diff --git a/lib/src/v3/protocol.dart b/lib/src/v3/protocol.dart index 85a5a36..c7d406b 100644 --- a/lib/src/v3/protocol.dart +++ b/lib/src/v3/protocol.dart @@ -1,6 +1,3 @@ -import 'dart:async'; -import 'dart:typed_data'; - import 'package:async/async.dart'; import 'package:postgres/src/types/codec.dart'; import 'package:stream_channel/stream_channel.dart'; @@ -8,7 +5,6 @@ import 'package:stream_channel/stream_channel.dart'; import '../buffer.dart'; import '../message_window.dart'; import '../messages/client_messages.dart'; -import '../messages/server_messages.dart'; import '../messages/shared_messages.dart'; export '../messages/client_messages.dart'; @@ -36,7 +32,7 @@ class AggregatedClientMessage extends ClientMessage { StreamChannelTransformer> messageTransformer( CodecContext codecContext) { return StreamChannelTransformer( - _readMessages(codecContext), + BytesToMessageParser(codecContext), StreamSinkTransformer.fromHandlers( handleData: (message, out) { if (message is! ClientMessage) { @@ -52,59 +48,3 @@ StreamChannelTransformer> messageTransformer( ), ); } - -StreamTransformer _readMessages( - CodecContext codecContext) { - return StreamTransformer.fromBind((rawStream) { - return Stream.multi((listener) { - final framer = MessageFramer(codecContext); - - var paused = false; - - void emitFinishedMessages() { - while (framer.hasMessage) { - listener.addSync(framer.popMessage()); - - if (paused) break; - } - } - - Future handleChunk() async { - try { - // await framer.addBytes(bytes); - emitFinishedMessages(); - } catch (e, st) { - listener.addErrorSync(e, st); - } - } - - // Don't cancel this subscription on error! If the listener wants that, - // they'll unsubscribe in time after we forward it synchronously. - final rawSubscription = rawStream - // TODO: figure out a better way to handle multiple callbacks to framer - .asyncMap(framer.addBytes) - .listen((_) => handleChunk(), cancelOnError: false) - ..onError(listener.addErrorSync) - ..onDone(listener.closeSync); - - listener.onPause = () { - paused = true; - rawSubscription.pause(); - }; - - listener.onResume = () { - paused = false; - emitFinishedMessages(); - - if (!paused) { - rawSubscription.resume(); - } - }; - - listener.onCancel = () { - paused = true; - rawSubscription.cancel(); - }; - }); - }); -} diff --git a/test/framer_test.dart b/test/framer_test.dart index c488226..19c9963 100644 --- a/test/framer_test.dart +++ b/test/framer_test.dart @@ -1,3 +1,4 @@ +import 'dart:async'; import 'dart:typed_data'; import 'package:buffer/buffer.dart'; @@ -9,175 +10,63 @@ import 'package:postgres/src/types/codec.dart'; import 'package:test/test.dart'; void main() { - late MessageFramer framer; - setUp(() { - framer = MessageFramer(CodecContext.withDefaults()); - }); - - tearDown(() async { - await flush(framer); - }); + Future parse(Uint8List buffer, messages) async { + expect( + await Stream.fromIterable([buffer]) + .transform(BytesToMessageParser(CodecContext.withDefaults())) + .toList(), + messages, + ); + + expect( + await Stream.fromIterable(buffer.expand((b) => [ + Uint8List.fromList([b]) + ])) + .transform(BytesToMessageParser(CodecContext.withDefaults())) + .toList(), + messages, + ); + + for (var i = 1; i < buffer.length - 1; i++) { + final splitBuffers = fragmentedMessageBuffer(buffer, i); + expect( + await Stream.fromIterable(splitBuffers) + .transform(BytesToMessageParser(CodecContext.withDefaults())) + .toList(), + messages, + ); + } + } test('Perfectly sized message in one buffer', () async { - await framer.addBytes(bufferWithMessages([ - messageWithBytes([1, 2, 3], 1) - ])); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - ]); + await parse( + bufferWithMessages([ + messageWithBytes([1, 2, 3], 1), + ]), + [ + UnknownMessage(1, Uint8List.fromList([1, 2, 3])), + ]); }); test('Two perfectly sized messages in one buffer', () async { - await framer.addBytes(bufferWithMessages([ - messageWithBytes([1, 2, 3], 1), - messageWithBytes([1, 2, 3, 4], 2) - ])); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - UnknownMessage(2, Uint8List.fromList([1, 2, 3, 4])), - ]); + await parse( + bufferWithMessages([ + messageWithBytes([1, 2, 3], 1), + messageWithBytes([1, 2, 3, 4], 2), + ]), + [ + UnknownMessage(1, Uint8List.fromList([1, 2, 3])), + UnknownMessage(2, Uint8List.fromList([1, 2, 3, 4])), + ]); }); test('Header fragment', () async { - final message = messageWithBytes([1, 2, 3], 1); - final fragments = fragmentedMessageBuffer(message, 2); - await framer.addBytes(fragments.first); - expect(framer.messageQueue, isEmpty); - - await framer.addBytes(fragments.last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])) - ]); - }); - - test('Two header fragments', () async { - final message = messageWithBytes([1, 2, 3], 1); - final fragments = fragmentedMessageBuffer(message, 2); - final moreFragments = fragmentedMessageBuffer(fragments.first, 1); - - await framer.addBytes(moreFragments.first); - expect(framer.messageQueue, isEmpty); - - await framer.addBytes(moreFragments.last); - expect(framer.messageQueue, isEmpty); - - await framer.addBytes(fragments.last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - ]); - }); - - test('One message + header fragment', () async { - final message1 = messageWithBytes([1, 2, 3], 1); - final message2 = messageWithBytes([2, 2, 3], 2); - final message2Fragments = fragmentedMessageBuffer(message2, 3); - - await framer - .addBytes(bufferWithMessages([message1, message2Fragments.first])); - - expect(framer.messageQueue.length, 1); - - await framer.addBytes(message2Fragments.last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - UnknownMessage(2, Uint8List.fromList([2, 2, 3])), - ]); - }); - - test('Message + header, missing rest of buffer', () async { - final message1 = messageWithBytes([1, 2, 3], 1); - final message2 = messageWithBytes([2, 2, 3], 2); - final message2Fragments = fragmentedMessageBuffer(message2, 5); - - await framer - .addBytes(bufferWithMessages([message1, message2Fragments.first])); - - expect(framer.messageQueue.length, 1); - - await framer.addBytes(message2Fragments.last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - UnknownMessage(2, Uint8List.fromList([2, 2, 3])), - ]); - }); - - test('Message body spans two packets', () async { - final message = messageWithBytes([1, 2, 3, 4, 5, 6, 7], 1); - final fragments = fragmentedMessageBuffer(message, 8); - await framer.addBytes(fragments.first); - expect(framer.messageQueue, isEmpty); - - await framer.addBytes(fragments.last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3, 4, 5, 6, 7])), - ]); - }); - - test( - 'Message spans two packets, started in a packet that contained another message', - () async { - final earlierMessage = messageWithBytes([1, 2], 0); - final message = messageWithBytes([1, 2, 3, 4, 5, 6, 7], 1); - - await framer.addBytes(bufferWithMessages( - [earlierMessage, fragmentedMessageBuffer(message, 8).first])); - expect(framer.messageQueue, hasLength(1)); - - await framer.addBytes(fragmentedMessageBuffer(message, 8).last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(0, Uint8List.fromList([1, 2])), - UnknownMessage(1, Uint8List.fromList([1, 2, 3, 4, 5, 6, 7])) - ]); - }); - - test('Message spans three packets, only part of header in the first', - () async { - final earlierMessage = messageWithBytes([1, 2], 0); - final message = - messageWithBytes([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], 1); - - await framer.addBytes(bufferWithMessages( - [earlierMessage, fragmentedMessageBuffer(message, 3).first])); - expect(framer.messageQueue, hasLength(1)); - - await framer.addBytes( - fragmentedMessageBuffer(fragmentedMessageBuffer(message, 3).last, 6) - .first); - expect(framer.messageQueue, hasLength(1)); - - await framer.addBytes( - fragmentedMessageBuffer(fragmentedMessageBuffer(message, 3).last, 6) - .last); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(0, Uint8List.fromList([1, 2])), - UnknownMessage( - 1, Uint8List.fromList([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])), - ]); - }); - - test('Frame with no data', () async { - await framer.addBytes(bufferWithMessages([messageWithBytes([], 10)])); - - final messages = framer.messageQueue.toList(); - expect(messages, [UnknownMessage(10, Uint8List(0))]); + await parse( + bufferWithMessages([ + messageWithBytes([], 1), // frame with no data + [1], // only a header fragment + ]), + [UnknownMessage(1, Uint8List.fromList([]))]); }); test('Identify CopyDoneMessage with length equals size length (min)', @@ -188,11 +77,8 @@ void main() { SharedMessageId.copyDone, ...length, ]); - await framer.addBytes(bytes); - - final message = framer.messageQueue.toList().first; - expect(message, isA()); - expect((message as CopyDoneMessage).length, 4); + await parse( + bytes, [isA().having((m) => m.length, 'length', 4)]); }); test('Identify CopyDoneMessage when length larger than size length', @@ -202,11 +88,9 @@ void main() { SharedMessageId.copyDone, ...length, ]); - await framer.addBytes(bytes); - final message = framer.messageQueue.toList().first; - expect(message, isA()); - expect((message as CopyDoneMessage).length, 42); + await parse( + bytes, [isA().having((m) => m.length, 'length', 42)]); }); test('Adds XLogDataMessage to queue', () async { @@ -232,10 +116,12 @@ void main() { ...xlogDataMessage, ]; - await framer.addBytes(Uint8List.fromList(copyDataBytes)); - final message = framer.messageQueue.toList().first; - expect(message, isA()); - expect(message, isNot(isA())); + await parse(Uint8List.fromList(copyDataBytes), [ + allOf( + isA(), + isNot(isA()), + ), + ]); }); test('Adds XLogDataLogicalMessage with JsonMessage to queue', () async { @@ -264,10 +150,10 @@ void main() { ...xlogDataMessage, ]; - await framer.addBytes(Uint8List.fromList(copyDataMessage)); - final message = framer.messageQueue.toList().first; - expect(message, isA()); - expect((message as XLogDataLogicalMessage).message, isA()); + await parse(Uint8List.fromList(copyDataMessage), [ + isA() + .having((x) => x.message, 'message', isA()), + ]); }); test('Adds PrimaryKeepAliveMessage to queue', () async { @@ -290,9 +176,8 @@ void main() { ...xlogDataMessage, ]; - await framer.addBytes(Uint8List.fromList(copyDataMessage)); - final message = framer.messageQueue.toList().first; - expect(message, isA()); + await parse( + Uint8List.fromList(copyDataMessage), [isA()]); }); test('Adds raw CopyDataMessage for unknown stream message', () async { @@ -310,9 +195,7 @@ void main() { ...xlogDataBytes, ]; - await framer.addBytes(Uint8List.fromList(copyDataMessage)); - final message = framer.messageQueue.toList().first; - expect(message, isA()); + await parse(Uint8List.fromList(copyDataMessage), [isA()]); }); } @@ -335,15 +218,3 @@ List fragmentedMessageBuffer(List message, int pivotPoint) { Uint8List bufferWithMessages(List> messages) { return Uint8List.fromList(messages.expand((l) => l).toList()); } - -Future flush(MessageFramer framer) async { - framer.messageQueue.clear(); - await framer.addBytes(bufferWithMessages([ - messageWithBytes([1, 2, 3], 1) - ])); - - final messages = framer.messageQueue.toList(); - expect(messages, [ - UnknownMessage(1, Uint8List.fromList([1, 2, 3])), - ]); -}