Skip to content

Commit

Permalink
Merge pull request #266 from miek/clear_endpoint_halt
Browse files Browse the repository at this point in the history
Implement ClearFeature(ENDPOINT_HALT)
  • Loading branch information
miek authored Jul 5, 2024
2 parents 09eb1f6 + 4988b8d commit 13ad028
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 10 deletions.
223 changes: 223 additions & 0 deletions applets/clear_endpoint_halt_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/usr/bin/env python3
#
# This file is part of LUNA.
#
# Copyright (c) 2024 Great Scott Gadgets <info@greatscottgadgets.com>
# SPDX-License-Identifier: BSD-3-Clause

import logging
import os
import time
import usb1

from amaranth import Elaboratable, Module, Signal

from luna import top_level_cli, configure_default_logging
from luna.usb2 import USBDevice, USBStreamInEndpoint, USBStreamOutEndpoint
from luna.gateware.stream.generator import StreamSerializer
from luna.gateware.usb.request.control import ControlRequestHandler
from luna.gateware.usb.stream import USBInStreamInterface

from usb_protocol.types import USBRequestRecipient, USBRequestType
from usb_protocol.emitters import DeviceDescriptorCollection

# use pid.codes Test PID
VID = 0x1209
PID = 0x0001

BULK_ENDPOINT_NUMBER = 1
MAX_BULK_PACKET_SIZE = 512

COUNTER_MAX = 251
GET_OUT_COUNTER_VALID = 0

out_counter_valid = Signal(reset=1)

class VendorRequestHandler(ControlRequestHandler):

REQUEST_SET_LEDS = 0

def elaborate(self, platform):
m = Module()

interface = self.interface
setup = self.interface.setup

# Transmitter for small-constant-response requests
m.submodules.transmitter = transmitter = \
StreamSerializer(data_length=1, domain="usb", stream_type=USBInStreamInterface, max_length_width=1)
#
# Vendor request handlers.
with m.FSM(domain="usb"):
with m.State('IDLE'):
vendor = setup.type == USBRequestType.VENDOR
with m.If(
setup.received & \
(setup.type == USBRequestType.VENDOR) & \
(setup.recipient == USBRequestRecipient.INTERFACE) & \
(setup.index == 0)):
with m.Switch(setup.request):
with m.Case(GET_OUT_COUNTER_VALID):
m.d.comb += interface.claim.eq(1)
m.next = 'GET_OUT_COUNTER_VALID'
pass

with m.State('GET_OUT_COUNTER_VALID'):
m.d.comb += interface.claim.eq(1)
self.handle_simple_data_request(m, transmitter, out_counter_valid, length=1)

return m


class ClearHaltTestDevice(Elaboratable):


def create_descriptors(self):

descriptors = DeviceDescriptorCollection()

with descriptors.DeviceDescriptor() as d:
d.idVendor = VID
d.idProduct = PID

d.iManufacturer = "LUNA"
d.iProduct = "Clear Endpoint Halt Test"
d.iSerialNumber = "no serial"

d.bNumConfigurations = 1


with descriptors.ConfigurationDescriptor() as c:

with c.InterfaceDescriptor() as i:
i.bInterfaceNumber = 0

with i.EndpointDescriptor() as e:
e.bEndpointAddress = 0x80 | BULK_ENDPOINT_NUMBER
e.wMaxPacketSize = MAX_BULK_PACKET_SIZE

with i.EndpointDescriptor() as e:
e.bEndpointAddress = BULK_ENDPOINT_NUMBER
e.wMaxPacketSize = MAX_BULK_PACKET_SIZE


return descriptors


def elaborate(self, platform):
m = Module()

m.submodules.car = platform.clock_domain_generator()

ulpi = platform.request(platform.default_usb_connection)
m.submodules.usb = usb = USBDevice(bus=ulpi)

descriptors = self.create_descriptors()
control_ep = usb.add_standard_control_endpoint(descriptors)

control_ep.add_request_handler(VendorRequestHandler())

stream_in_ep = USBStreamInEndpoint(
endpoint_number=BULK_ENDPOINT_NUMBER,
max_packet_size=MAX_BULK_PACKET_SIZE
)
usb.add_endpoint(stream_in_ep)

stream_out_ep = USBStreamOutEndpoint(
endpoint_number=BULK_ENDPOINT_NUMBER,
max_packet_size=MAX_BULK_PACKET_SIZE
)
usb.add_endpoint(stream_out_ep)

# Generate a counter on the IN endpoint.
in_counter = Signal(8)
with m.If(stream_in_ep.stream.ready):
m.d.usb += in_counter.eq(in_counter + 1)
with m.If(in_counter == COUNTER_MAX):
m.d.usb += in_counter.eq(0)

# Expect a counter on the OUT endpoint, and verify that it is contiguous.
prev_out_counter = Signal(8, reset=COUNTER_MAX)
with m.If(stream_out_ep.stream.valid):
out_counter = stream_out_ep.stream.payload
counter_increase = out_counter == (prev_out_counter + 1)
counter_wrap = (out_counter == 0) & (prev_out_counter == COUNTER_MAX)
with m.If(~counter_increase & ~counter_wrap):
m.d.usb += out_counter_valid.eq(0)

m.d.usb += prev_out_counter.eq(out_counter)

m.d.comb += [
stream_in_ep.stream.valid .eq(1),
stream_in_ep.stream.payload .eq(in_counter),

stream_out_ep.stream.ready .eq(1),
]

# Connect our device as a high speed device by default.
m.d.comb += [
usb.connect .eq(1),
usb.full_speed_only .eq(1 if os.getenv('LUNA_FULL_ONLY') else 0),
]

return m

def test_clear_halt():
with usb1.USBContext() as context:
device = context.openByVendorIDAndProductID(VID, PID)

# Read the first packet which should have a DATA0 PID, next we expect DATA1.
packet = device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE)
# Send clear halt, this resets both sides to DATA0.
device.clearHalt(usb1.ENDPOINT_IN | BULK_ENDPOINT_NUMBER)
# Read another packet. If the PID doesn't match what we epxect,
# then the host will assume it was a retransmission of the last one and drop it.
packet += device.bulkRead(BULK_ENDPOINT_NUMBER, MAX_BULK_PACKET_SIZE)

# Check that the counter is contiguous across all received data, making sure we didn't drop a packet.
for i in range(1, len(packet)):
if packet[i] == packet[i-1] + 1:
pass
elif packet[i] == 0 and packet[i-1] == COUNTER_MAX:
pass
else:
print(f"IN test fail {i} {packet[i]} {packet[i-1]}")
return

print("IN OK")

# Generate three packets worth of counter data, the gateware will verify that it is contiguous.
data = bytes(i % (COUNTER_MAX+1) for i in range(MAX_BULK_PACKET_SIZE*3))
# Send DATA0, device should expect DATA1 next.
device.bulkWrite(BULK_ENDPOINT_NUMBER, data[:MAX_BULK_PACKET_SIZE])
# Reset both sides to DATA0.
device.clearHalt(usb1.ENDPOINT_OUT | BULK_ENDPOINT_NUMBER)
# Send two packets. If the first packet doesn't match,
# it'll be dropped and another is required to let the gateware check the counter.
device.bulkWrite(BULK_ENDPOINT_NUMBER, data[MAX_BULK_PACKET_SIZE:])

# Read back the out_counter_valid register to check for success.
request_type = usb1.REQUEST_TYPE_VENDOR | usb1.RECIPIENT_INTERFACE | usb1.ENDPOINT_IN
if device.controlRead(request_type, GET_OUT_COUNTER_VALID, 0, 0, 1)[0] == 1:
print("OUT OK")
else:
print("OUT FAIL")


if __name__ == "__main__":
configure_default_logging()

# If our environment is suggesting we rerun tests without rebuilding, do so.
if os.getenv('LUNA_RERUN_TEST'):
logging.info("Running speed test without rebuilding...")

# Otherwise, rebuild.
else:
device = top_level_cli(ClearHaltTestDevice)

# Give the device a moment to connect.
if device is not None:
logging.info("Giving the device time to connect...")
time.sleep(5)

test_clear_halt()
28 changes: 27 additions & 1 deletion luna/gateware/usb/request/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from amaranth import *
from amaranth.hdl.ast import Value, Const
from usb_protocol.types import USBStandardRequests, USBRequestType
from usb_protocol.types import USBStandardFeatures, USBStandardRequests, USBRequestRecipient, USBRequestType
from usb_protocol.emitters import DeviceDescriptorCollection

from ..usb2.request import RequestHandlerInterface, USBRequestHandler
Expand Down Expand Up @@ -139,6 +139,8 @@ def elaborate(self, platform):

with m.Case(USBStandardRequests.GET_STATUS):
m.next = 'GET_STATUS'
with m.Case(USBStandardRequests.CLEAR_FEATURE):
m.next = 'CLEAR_FEATURE'
with m.Case(USBStandardRequests.SET_ADDRESS):
m.next = 'SET_ADDRESS'
with m.Case(USBStandardRequests.SET_CONFIGURATION):
Expand All @@ -158,6 +160,30 @@ def elaborate(self, platform):
# TODO: copy the remote wakeup and bus-powered attributes from bmAttributes of the relevant descriptor?
self.handle_simple_data_request(m, transmitter, 0, length=2)

with m.State('CLEAR_FEATURE'):
# Provide an response to the STATUS stage.
with m.If(interface.status_requested):

# If our stall condition is met, stall; otherwise, send a ZLP [USB 8.5.3].
# For now, we only implement clearing ENDPOINT_HALT.
stall_condition = \
(setup.recipient != USBRequestRecipient.ENDPOINT) | \
(setup.value != USBStandardFeatures.ENDPOINT_HALT)
with m.If(stall_condition):
m.d.comb += handshake_generator.stall.eq(1)
with m.Else():
m.d.comb += self.send_zlp()

# Accept the relevant value after the packet is ACK'd...
with m.If(interface.handshakes_in.ack):
m.d.comb += [
interface.clear_endpoint_halt.enable .eq(1),
interface.clear_endpoint_halt.direction.eq(setup.index[7]),
interface.clear_endpoint_halt.number .eq(setup.index[0:4]),
]

# ... and then return to idle.
m.next = 'IDLE'

# SET_ADDRESS -- The host is trying to assign us an address.
with m.State('SET_ADDRESS'):
Expand Down
2 changes: 2 additions & 0 deletions luna/gateware/usb/usb2/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def elaborate(self, platform):
interface.address_changed .eq(request_handler.address_changed),
interface.new_address .eq(request_handler.new_address),

interface.clear_endpoint_halt_out .eq(request_handler.clear_endpoint_halt),

request_handler.active_config .eq(interface.active_config),
interface.config_changed .eq(request_handler.config_changed),
interface.new_config .eq(request_handler.new_config),
Expand Down
10 changes: 10 additions & 0 deletions luna/gateware/usb/usb2/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .packet import DataCRCInterface, InterpacketTimerInterface, TokenDetectorInterface
from .packet import HandshakeExchangeInterface
from .request import ClearEndpointHaltInterface
from ..stream import USBInStreamInterface, USBOutStreamInterface
from ...utils.bus import OneHotMultiplexer

Expand Down Expand Up @@ -90,6 +91,9 @@ def __init__(self):
self.config_changed = Signal()
self.new_config = Signal(8)

self.clear_endpoint_halt_out = Signal(ClearEndpointHaltInterface)
self.clear_endpoint_halt_in = Signal(ClearEndpointHaltInterface)

self.rx = USBOutStreamInterface()
self.rx_complete = Signal()
self.rx_ready_for_response = Signal()
Expand Down Expand Up @@ -213,6 +217,8 @@ def elaborate(self, platform):
shared.handshakes_in .connect(interface.handshakes_in),
shared.tokenizer .connect(interface.tokenizer),

interface.clear_endpoint_halt_in .eq(shared.clear_endpoint_halt_out),

# Rx interface.
shared.rx .connect(interface.rx),
interface.rx_complete .eq(shared.rx_complete),
Expand Down Expand Up @@ -259,6 +265,10 @@ def elaborate(self, platform):
# ... and our timer start signals.
self.or_join_interface_signals(m, lambda interface : interface.timer.start)

self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.enable)
self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.direction)
self.or_join_interface_signals(m, lambda interface : interface.clear_endpoint_halt_out.number)

# Finally, connect up our transmit PID select.
conditional = m.If

Expand Down
18 changes: 18 additions & 0 deletions luna/gateware/usb/usb2/endpoints/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def elaborate(self, platform):
# Create our transfer manager, which will be used to sequence packet transfers for our stream.
m.submodules.tx_manager = tx_manager = USBInTransferManager(self._max_packet_size)

# Check there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint.
clear_endpoint_halt = \
interface.clear_endpoint_halt_in.enable & \
interface.clear_endpoint_halt_in.direction & \
(interface.clear_endpoint_halt_in.number == self._endpoint_number)
m.d.comb += [

# Always generate ZLPs; in order to pass along when stream packets terminate.
Expand All @@ -94,6 +99,9 @@ def elaborate(self, platform):
tx_manager.flush .eq(self.flush),
tx_manager.discard .eq(self.discard),

# ... and data-toggle reset on clear endpoint halt...
tx_manager.reset_sequence .eq(clear_endpoint_halt),

# ... and our output stream...
interface.tx .stream_eq(tx_manager.packet_stream),
interface.tx_pid_toggle .eq(tx_manager.data_pid),
Expand Down Expand Up @@ -414,6 +422,16 @@ def elaborate(self, platform):
with m.If(data_response_requested & data_accepted):
m.d.usb += expected_data_toggle.eq(~expected_data_toggle)

# If there has been a ClearFeature(ENDPOINT_HALT) request address to this endpoint...
clear_endpoint_halt = \
self.interface.clear_endpoint_halt_in.enable & \
~self.interface.clear_endpoint_halt_in.direction & \
(self.interface.clear_endpoint_halt_in.number == self._endpoint_number)

with m.If(clear_endpoint_halt):
# ... reset the expected data toggle.
m.d.usb += expected_data_toggle.eq(0)


return m

Loading

0 comments on commit 13ad028

Please sign in to comment.