Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 37 additions & 35 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,9 @@ def __init__(self, args, options=None, position=None):
class NodeCommands:
""" """

def __init__(self, parse_response, connection_pool, connection):
def __init__(
self, parse_response, connection_pool: ConnectionPool, connection: Connection
):
""" """
self.parse_response = parse_response
self.connection_pool = connection_pool
Expand Down Expand Up @@ -2772,13 +2774,15 @@ def _send_cluster_commands(
attempt = sorted(stack, key=lambda x: x.position)
is_default_node = False
# build a list of node objects based on node names we need to
nodes = {}
nodes: dict[str, NodeCommands] = {}
nodes_written = 0
nodes_read = 0

# as we move through each command that still needs to be processed,
# we figure out the slot number that command maps to, then from
# the slot determine the node.
for c in attempt:
while True:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I've completely lost the plot, this used to be a while True loop which always breaks on the first iteration, so this does ~nothing.

try:
# as we move through each command that still needs to be processed,
# we figure out the slot number that command maps to, then from
# the slot determine the node.
for c in attempt:
# refer to our internal node -> slot table that
# tells us where a given command should route to.
# (it might be possible we have a cached node that no longer
Expand Down Expand Up @@ -2819,50 +2823,48 @@ def _send_cluster_commands(
self._nodes_manager.initialize()
if is_default_node:
self._pipe.replace_default_node()
nodes = {}
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
redis_node.connection_pool,
connection,
)
nodes[node_name].append(c)
break

# send the commands in sequence.
# we write to all the open sockets for each node first,
# before reading anything
# this allows us to flush all the requests out across the
# network
# so that we can read them from different sockets as they come back.
# we dont' multiplex on the sockets as they come available,
# but that shouldn't make too much difference.
try:
# send the commands in sequence.
# we write to all the open sockets for each node first,
# before reading anything
# this allows us to flush all the requests out across the
# network
# so that we can read them from different sockets as they come back.
# we dont' multiplex on the sockets as they come available,
# but that shouldn't make too much difference.
node_commands = nodes.values()
for n in node_commands:
nodes_written += 1
n.write()

for n in node_commands:
n.read()
nodes_read += 1
finally:
# release all of the redis connections we allocated earlier
# release all the redis connections we allocated earlier
# back into the connection pool.
# we used to do this step as part of a try/finally block,
# but it is really dangerous to
# release connections back into the pool if for some
# reason the socket has data still left in it
# from a previous operation. The write and
# read operations already have try/catch around them for
# all known types of errors including connection
# and socket level errors.
# So if we hit an exception, something really bad
# happened and putting any oF
# these connections back into the pool is a very bad idea.
# the socket might have unread buffer still sitting in it,
# and then the next time we read from it we pass the
# buffered result back from a previous command and
# every single request after to that connection will always get
# a mismatched result.
for n in nodes.values():
# if we the connection is dirty (that is: we've written
# commands to it, but haven't read the responses), we need
# to close the connection before returning it to the pool.
# otherwise, the next caller to use this connection will
# read the response from _this_ request, not its own request.
# disconnecting discards the dirty state & forces the next
# caller to reconnect.
# NOTE: dicts have a consistent ordering; we're iterating
# through nodes.values() in the same order as we are when
# reading / writing to the connections above, which is critical
# for how we're using the nodes_written/nodes_read offsets.
for i, n in enumerate(nodes.values()):
if i < nodes_written and i >= nodes_read:
n.connection.disconnect()
n.connection_pool.release(n.connection)

# if the response isn't an exception it is a
Expand Down
79 changes: 79 additions & 0 deletions tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3463,6 +3463,85 @@ def test_pipeline_discard(self, r):
assert response[0]
assert r.get(f"{hashkey}:foo") == b"bar"

def test_connection_leak_on_non_timeout_error_during_connect(self, r):
"""
Test that connections are not leaked when a non-TimeoutError/ConnectionError
is raised during get_connection(). The bugfix ensures that if an error
occurs that isn't explicitly handled, we don't leak connections.
"""
# Ensure keys map to different nodes
assert r.keyslot("a") != r.keyslot("b")

orig_func = redis.cluster.get_connection
with patch("redis.cluster.get_connection") as get_connection:

def raise_custom_error(target_node, *args, **kwargs):
# Raise a RuntimeError (not ConnectionError or TimeoutError)
# on the second call (when getting second connection)
if get_connection.call_count == 2:
raise RuntimeError("Some unexpected error during connection")
else:
return orig_func(target_node, *args, **kwargs)

get_connection.side_effect = raise_custom_error

with pytest.raises(RuntimeError):
r.pipeline().get("a").get("b").execute()

# Verify that all connections were returned to the pool
# (not leaked) even though a non-standard error was raised
for cluster_node in r.nodes_manager.nodes_cache.values():
connection_pool = cluster_node.redis_connection.connection_pool
num_of_conns = len(connection_pool._available_connections)
assert num_of_conns == connection_pool._created_connections, (
f"Connection leaked: expected {connection_pool._created_connections} "
f"available, got {num_of_conns}"
)

def test_dirty_connection_not_reused(self, r):
"""
Test that dirty connections (with unread responses) are not reused.
A dirty connection is one where we've written commands but haven't
read all responses. If such a connection is returned to the pool,
the next caller will read responses from the previous request.
"""
# Ensure we're using multiple nodes to test the dirty connection scenario
assert r.keyslot("a") != r.keyslot("b")

# Mock the write method to raise an error after writing to only some nodes
orig_write = redis.cluster.NodeCommands.write

write_count = 0

def mock_write(self):
nonlocal write_count
write_count += 1
# Allow the first write to succeed
if write_count == 1:
return orig_write(self)
# Simulate a failure after the first write (leaving connection dirty)
else:
raise RuntimeError("Simulated write error")

with patch.object(redis.cluster.NodeCommands, "write", mock_write):
with pytest.raises(RuntimeError):
r.pipeline().get("a").get("b").execute()

# After the error, verify that no connections are in the available pool
# with dirty state (unread responses). If a connection is dirty, it should
# have been disconnected before being returned to the pool.
# We verify this by checking the connections can be reused successfully.
try:
# Try to execute a command on each connection to verify
# they're clean (not holding responses from previous requests)
result = r.ping()
assert result is True
except Exception as e:
pytest.fail(
f"Connection reuse after dirty state failed: {e}. "
f"This indicates a dirty connection was returned to the pool."
)


@pytest.mark.onlycluster
class TestReadOnlyPipeline:
Expand Down