From 72877534298c9776021022190941afaf66bb1131 Mon Sep 17 00:00:00 2001 From: evgeny Date: Fri, 31 May 2024 17:41:00 +0100 Subject: [PATCH] [ECO-4813] fix: race condition in pending message processing If an ACK/NACK message arrives during `ConnectionManager#addPendingMessagesToQueuedMessages`, it breaks the internal pending message's `startSerial`. To avoid the race condition, a new thread-safe method, `PendingMessageQueue#popAll`, has been introduced, and all direct invocations of the `PendingMessageQueue#queue` field have been removed. --- .../ably/lib/transport/ConnectionManager.java | 40 ++++++++----------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java b/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java index dc6c753eb..8e8b7c4c7 100644 --- a/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java +++ b/lib/src/main/java/io/ably/lib/transport/ConnectionManager.java @@ -1268,18 +1268,16 @@ private synchronized List extractConnectionQueuePresenceMessages( */ private void addPendingMessagesToQueuedMessages(boolean resetMessageSerial) { synchronized (this) { - // Add messages from pending messages to front of queuedMessages in order to retry them - queuedMessages.addAll(0, pendingMessages.queue); + List allPendingMessages = pendingMessages.popAll(); if (resetMessageSerial){ // failed resume, so all new published messages start with msgSerial = 0 msgSerial = 0; //msgSerial will increase in sendImpl when messages are sent, RTN15c7 - pendingMessages.resetStartSerial(0); - } else if(!pendingMessages.queue.isEmpty()) { // pendingMessages needs to expect next msgSerial to be the earliest previously unacknowledged message - msgSerial = pendingMessages.queue.get(0).msg.msgSerial; - pendingMessages.resetStartSerial((int) (msgSerial)); + } else if (!allPendingMessages.isEmpty()) { // pendingMessages needs to expect next msgSerial to be the earliest previously unacknowledged message + msgSerial = allPendingMessages.get(0).msg.msgSerial; } - pendingMessages.queue.clear(); + // Add messages from pending messages to front of queuedMessages in order to retry them + queuedMessages.addAll(0, allPendingMessages); } } @@ -1671,9 +1669,8 @@ private void failQueuedMessages(ErrorInfo reason) { /** * A class containing a queue of messages awaiting acknowledgement */ - private class PendingMessageQueue { - private long startSerial = 0L; - private ArrayList queue = new ArrayList(); + private static class PendingMessageQueue { + private final List queue = new ArrayList<>(); public synchronized void push(QueuedMessage msg) { queue.add(msg); @@ -1682,6 +1679,8 @@ public synchronized void push(QueuedMessage msg) { public void ack(long msgSerial, int count, ErrorInfo reason) { QueuedMessage[] ackMessages = null, nackMessages = null; synchronized(this) { + if (queue.isEmpty()) return; + long startSerial = queue.get(0).msg.msgSerial; if(msgSerial < startSerial) { /* this is an error condition and shouldn't happen but * we can handle it gracefully by only processing the @@ -1734,6 +1733,8 @@ public void ack(long msgSerial, int count, ErrorInfo reason) { public synchronized void nack(long serial, int count, ErrorInfo reason) { QueuedMessage[] nackMessages = null; synchronized(this) { + if (queue.isEmpty()) return; + long startSerial = queue.get(0).msg.msgSerial; if(serial != startSerial) { /* this is an error condition and shouldn't happen but * we can handle it gracefully by only processing the @@ -1761,22 +1762,15 @@ public synchronized void nack(long serial, int count, ErrorInfo reason) { } /** - * reset the pending message queue, failing any currently pending messages. - * Used when a resume fails and we get a different connection id. - * @param oldMsgSerial the next message serial number for the old - * connection, and thus one more than the highest message serial - * in the queue. + * @return all pending queued messages and clear the queue */ - public synchronized void reset(long oldMsgSerial, ErrorInfo err) { - nack(startSerial, (int)(oldMsgSerial - startSerial), err); - startSerial = 0; - } - - public void resetStartSerial(int from) { - startSerial = from; + synchronized List popAll() { + List allPendingMessages = new ArrayList<>(queue); + queue.clear(); + return allPendingMessages; } - //fail all pending queued emssages + //fail all pending queued messages synchronized void fail(ErrorInfo reason) { for (QueuedMessage queuedMessage: queue){ if (queuedMessage.listener != null) {