Skip to content

Commit

Permalink
Merge pull request #40 from google/gbg/gatt-mtu
Browse files Browse the repository at this point in the history
maintain the att mtu only at the connection level
  • Loading branch information
barbibulle authored Oct 5, 2022
2 parents 0edd6b7 + 20dedbd commit 073757d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 45 deletions.
13 changes: 7 additions & 6 deletions bumble/gatt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def from_client(cls, client):
class Client:
def __init__(self, connection):
self.connection = connection
self.mtu = ATT_DEFAULT_MTU
self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
Expand Down Expand Up @@ -217,7 +216,7 @@ async def request_mtu(self, mtu):

# We can only send one request per connection
if self.mtu_exchange_done:
return
return self.connection.att_mtu

# Send the request
self.mtu_exchange_done = True
Expand All @@ -230,8 +229,10 @@ async def request_mtu(self, mtu):
response
)

self.mtu = max(ATT_DEFAULT_MTU, response.server_rx_mtu)
return self.mtu
# Compute the final MTU
self.connection.att_mtu = min(mtu, response.server_rx_mtu)

return self.connection.att_mtu

def get_services_by_uuid(self, uuid):
return [service for service in self.services if service.uuid == uuid]
Expand Down Expand Up @@ -629,7 +630,7 @@ async def read_value(self, attribute, no_long_read=False):
# If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that
attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1:
if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value)
while True:
Expand All @@ -651,7 +652,7 @@ async def read_value(self, attribute, no_long_read=False):
part = response.part_attribute_value
attribute_value += part

if len(part) < self.mtu - 1:
if len(part) < self.connection.att_mtu - 1:
break

offset += len(part)
Expand Down
59 changes: 29 additions & 30 deletions bumble/gatt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
GATT_SERVER_DEFAULT_MAX_MTU = 517


# -----------------------------------------------------------------------------
# GATT Server
# -----------------------------------------------------------------------------
Expand All @@ -49,9 +55,8 @@ def __init__(self, device):
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = 23 # FIXME: 517 # The max MTU we're willing to negotiate
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.mtus = {} # Map of ATT MTU values by connection handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)

Expand Down Expand Up @@ -188,9 +193,8 @@ async def notify_subscriber(self, connection, attribute, value=None, force=False
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)

# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]

# Notify
notification = ATT_Handle_Value_Notification(
Expand Down Expand Up @@ -219,9 +223,8 @@ async def indicate_subscriber(self, connection, attribute, value=None, force=Fal
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)

# Truncate if needed
mtu = self.get_mtu(connection)
if len(value) > mtu - 3:
value = value[:mtu - 3]
if len(value) > connection.att_mtu - 3:
value = value[:connection.att_mtu - 3]

# Indicate
indication = ATT_Handle_Value_Indication(
Expand Down Expand Up @@ -272,8 +275,6 @@ async def indicate_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(True, attribute, value, force)

def on_disconnection(self, connection):
if connection.handle in self.mtus:
del self.mtus[connection.handle]
if connection.handle in self.subscribers:
del self.subscribers[connection.handle]
if connection.handle in self.indication_semaphores:
Expand Down Expand Up @@ -314,9 +315,6 @@ def on_gatt_pdu(self, connection, att_pdu):
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')

def get_mtu(self, connection):
return self.mtus.get(connection.handle, ATT_DEFAULT_MTU)

#######################################################
# ATT handlers
#######################################################
Expand All @@ -336,12 +334,16 @@ def on_att_exchange_mtu_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
mtu = max(ATT_DEFAULT_MTU, min(self.max_mtu, request.client_rx_mtu))
self.mtus[connection.handle] = mtu
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = mtu))
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))

# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
mtu = min(self.max_mtu, request.client_rx_mtu)

# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
# Notify the device
self.device.on_connection_att_mtu_update(connection.handle, mtu)
else:
logger.warning('invalid client_rx_mtu received, MTU not changed')

def on_att_find_information_request(self, connection, request):
'''
Expand All @@ -358,7 +360,7 @@ def on_att_find_information_request(self, connection, request):
return

# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
uuid_size = 0
for attribute in (
Expand Down Expand Up @@ -409,7 +411,7 @@ def on_att_find_by_type_value_request(self, connection, request):
'''

# Build list of returned attributes
pdu_space_available = self.get_mtu(connection) - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
Expand Down Expand Up @@ -457,8 +459,7 @@ def on_att_read_by_type_request(self, connection, request):
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''

mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
Expand All @@ -471,7 +472,7 @@ def on_att_read_by_type_request(self, connection, request):

# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 4, 253)
max_attribute_size = min(connection.att_mtu - 4, 253)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
Expand Down Expand Up @@ -511,7 +512,7 @@ def on_att_read_request(self, connection, request):
if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(self.get_mtu(connection) - 1, len(value))
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
Expand All @@ -530,22 +531,21 @@ def on_att_read_blob_request(self, connection, request):

if attribute := self.get_attribute(request.attribute_handle):
# TODO: check permissions
mtu = self.get_mtu(connection)
value = attribute.read_value(connection)
if request.value_offset > len(value):
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
)
elif len(value) <= mtu - 1:
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error = request.op_code,
attribute_handle_in_error = request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
)
else:
part_size = min(mtu - 1, len(value) - request.value_offset)
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
)
Expand Down Expand Up @@ -574,8 +574,7 @@ def on_att_read_by_group_type_request(self, connection, request):
self.send_response(connection, response)
return

mtu = self.get_mtu(connection)
pdu_space_available = mtu - 2
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
Expand All @@ -586,7 +585,7 @@ def on_att_read_by_group_type_request(self, connection, request):
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
max_attribute_size = min(mtu - 6, 251)
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
# We need to truncate
attribute_value = attribute_value[:max_attribute_size]
Expand Down
60 changes: 51 additions & 9 deletions tests/gatt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def encode_value(self, value):
def decode_value(self, value_bytes):
return value_bytes[0]

[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]

characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Expand Down Expand Up @@ -306,27 +306,32 @@ def test_CharacteristicValue():


# -----------------------------------------------------------------------------
class TwoDevices:
class LinkedDevices:
def __init__(self):
self.connections = [None, None]
self.connections = [None, None, None]

self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
Controller('C2', link = self.link),
Controller('C3', link = self.link)
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
address = 'F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
),
Device(
address = 'F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2]))
)
]

self.paired = [None, None]
self.paired = [None, None, None]


# -----------------------------------------------------------------------------
Expand All @@ -339,7 +344,7 @@ async def async_barrier():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]

characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Expand Down Expand Up @@ -416,7 +421,7 @@ def on_characteristic2_write(connection, value):
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_read_write2():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]

v = bytes([0x11, 0x22, 0x33, 0x44])
characteristic1 = Characteristic(
Expand Down Expand Up @@ -466,7 +471,7 @@ async def test_read_write2():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_subscribe_notify():
[client, server] = TwoDevices().devices
[client, server] = LinkedDevices().devices[:2]

characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Expand Down Expand Up @@ -631,12 +636,49 @@ def on_c3_update_2(value):
assert not c3._called_2


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_mtu_exchange():
[d1, d2, d3] = LinkedDevices().devices[:3]

d3.gatt_server.max_mtu = 100

d3_connections = []
@d3.on('connection')
def on_d3_connection(connection):
d3_connections.append(connection)

await d1.power_on()
await d2.power_on()
await d3.power_on()

d1_connection = await d1.connect(d3.random_address)
assert len(d3_connections) == 1
assert d3_connections[0] is not None

d2_connection = await d2.connect(d3.random_address)
assert len(d3_connections) == 2
assert d3_connections[1] is not None

d1_peer = Peer(d1_connection)
d2_peer = Peer(d2_connection)

d1_client_mtu = await d1_peer.request_mtu(220)
assert d1_client_mtu == 100
assert d1_connection.att_mtu == 100

d2_client_mtu = await d2_peer.request_mtu(50)
assert d2_client_mtu == 50
assert d2_connection.att_mtu == 50


# -----------------------------------------------------------------------------
async def async_main():
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_characteristic_encoding()
await test_mtu_exchange()


# -----------------------------------------------------------------------------
Expand Down

0 comments on commit 073757d

Please sign in to comment.