diff --git a/img.png b/img.png new file mode 100644 index 0000000..2776fc9 Binary files /dev/null and b/img.png differ diff --git a/requirements.txt b/requirements.txt index 27eb11f..8862eb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,7 @@ python-can==4.3.1 pywin32==306 typing_extensions==4.10.0 wrapt==1.16.0 + +typing~=3.7.4.3 +matplotlib~=3.8.3 +scipy=1.12.0 \ No newline at end of file diff --git a/src/lane_assist/image_manipulation/image_stitch/images/corner/center.jpg b/resources/images/corner/center.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/corner/center.jpg rename to resources/images/corner/center.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/corner/left.jpg b/resources/images/corner/left.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/corner/left.jpg rename to resources/images/corner/left.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/corner/right.jpg b/resources/images/corner/right.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/corner/right.jpg rename to resources/images/corner/right.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/crossing/center.jpg b/resources/images/crossing/center.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/crossing/center.jpg rename to resources/images/crossing/center.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/crossing/left.jpg b/resources/images/crossing/left.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/crossing/left.jpg rename to resources/images/crossing/left.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/crossing/right.jpg b/resources/images/crossing/right.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/crossing/right.jpg rename to resources/images/crossing/right.jpg diff --git a/resources/images/stopline/center.jpg b/resources/images/stopline/center.jpg new file mode 100644 index 0000000..bd52da2 Binary files /dev/null and b/resources/images/stopline/center.jpg differ diff --git a/src/lane_assist/image_manipulation/image_stitch/images/straight/left.jpg b/resources/images/stopline/left.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/straight/left.jpg rename to resources/images/stopline/left.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/straight/right.jpg b/resources/images/stopline/right.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/straight/right.jpg rename to resources/images/stopline/right.jpg diff --git a/src/lane_assist/image_manipulation/image_stitch/images/straight/center.jpg b/resources/images/straight/center.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/straight/center.jpg rename to resources/images/straight/center.jpg diff --git a/resources/images/straight/left.jpg b/resources/images/straight/left.jpg new file mode 100644 index 0000000..3f201a9 Binary files /dev/null and b/resources/images/straight/left.jpg differ diff --git a/resources/images/straight/right.jpg b/resources/images/straight/right.jpg new file mode 100644 index 0000000..cc87f13 Binary files /dev/null and b/resources/images/straight/right.jpg differ diff --git a/resources/stitched_images/corner.jpg b/resources/stitched_images/corner.jpg new file mode 100644 index 0000000..aae7c7d Binary files /dev/null and b/resources/stitched_images/corner.jpg differ diff --git a/resources/stitched_images/crossing.jpg b/resources/stitched_images/crossing.jpg new file mode 100644 index 0000000..2ec76c8 Binary files /dev/null and b/resources/stitched_images/crossing.jpg differ diff --git a/resources/stitched_images/stopline.jpg b/resources/stitched_images/stopline.jpg new file mode 100644 index 0000000..57f2562 Binary files /dev/null and b/resources/stitched_images/stopline.jpg differ diff --git a/resources/stitched_images/straight.jpg b/resources/stitched_images/straight.jpg new file mode 100644 index 0000000..2c45472 Binary files /dev/null and b/resources/stitched_images/straight.jpg differ diff --git a/resources/unstitched_images/corner/center.jpg b/resources/unstitched_images/corner/center.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/corner/left.jpg b/resources/unstitched_images/corner/left.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/corner/right.jpg b/resources/unstitched_images/corner/right.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/crossing/center.jpg b/resources/unstitched_images/crossing/center.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/crossing/left.jpg b/resources/unstitched_images/crossing/left.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/crossing/right.jpg b/resources/unstitched_images/crossing/right.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/stopline/center.jpg b/resources/unstitched_images/stopline/center.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/stopline/left.jpg b/resources/unstitched_images/stopline/left.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/stopline/right.jpg b/resources/unstitched_images/stopline/right.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/straight/center.jpg b/resources/unstitched_images/straight/center.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/straight/left.jpg b/resources/unstitched_images/straight/left.jpg new file mode 100644 index 0000000..e69de29 diff --git a/resources/unstitched_images/straight/right.jpg b/resources/unstitched_images/straight/right.jpg new file mode 100644 index 0000000..e69de29 diff --git a/src/lane_assist/image_manipulation/image_stitch/images/result.jpg b/result.jpg similarity index 100% rename from src/lane_assist/image_manipulation/image_stitch/images/result.jpg rename to result.jpg diff --git a/src/common/config.py b/src/common/config.py index 42c3f9f..6351ba5 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -1,11 +1,6 @@ from common.constants import SpeedMode speed_mode: SpeedMode = SpeedMode.SLOW -speed_mode_to_speed = { - SpeedMode.SLOW: 25, - SpeedMode.MEDIUM: 50, - SpeedMode.FAST: 75, - SpeedMode.VERY_FAST: 100 -} +speed_mode_to_speed = {SpeedMode.SLOW: 25, SpeedMode.MEDIUM: 50, SpeedMode.FAST: 75, SpeedMode.VERY_FAST: 100} speed = speed_mode_to_speed[speed_mode] diff --git a/src/common/constants.py b/src/common/constants.py index 6b2f6d4..69d90bd 100644 --- a/src/common/constants.py +++ b/src/common/constants.py @@ -4,18 +4,20 @@ class CANControlIdentifier(IntEnum): """The identifiers for the CAN messages sent to the go-kart.""" - BRAKE = 0x110 + BRAKE = 0x110 STEERING = 0x220 THROTTLE = 0x330 + class CANFeedbackIdentifier(IntEnum): """The identifiers for the CAN messages received from the go-kart.""" - BRAKE = 0x710 - SPEED_SENSOR = 0x440 - STEERING_ECU = 0x720 + BRAKE = 0x710 + SPEED_SENSOR = 0x440 + STEERING_ECU = 0x720 STEERING_SENSOR = 0x1E5 - THROTTLE = 0x730 + THROTTLE = 0x730 + class CameraResolution: """The camera resolutions that the Logitech StreamCam supports.""" @@ -24,6 +26,7 @@ class CameraResolution: HD = (1280, 720) VGA = (848, 480) + class CameraFramerate: """The camera framerates that the Logitech StreamCam supports.""" @@ -34,17 +37,19 @@ class CameraFramerate: FPS_15 = 15 FPS_10 = 10 + class Gear(IntEnum): """The gear of the go-kart.""" NEUTRAL = 0 - DRIVE = 1 + DRIVE = 1 REVERSE = 2 + class SpeedMode(IntEnum): """The mode of the speed controller.""" - SLOW = 0 - MEDIUM = 1 - FAST = 2 + SLOW = 0 + MEDIUM = 1 + FAST = 2 VERY_FAST = 3 diff --git a/src/kart_control/can_controller.py b/src/kart_control/can_controller.py index b2cd100..f7fa53c 100644 --- a/src/kart_control/can_controller.py +++ b/src/kart_control/can_controller.py @@ -28,9 +28,7 @@ def __init__(self, can_bus: can.Bus) -> None: self.__listeners = {} self.__thread = threading.Thread(target=self.__listen, daemon=True) - def add_listener( - self, message_id: CANFeedbackIdentifier, listener: callable - ) -> None: + def add_listener(self, message_id: CANFeedbackIdentifier, listener: callable) -> None: """Add a listener for a message. :param message_id: The identifier of the message. diff --git a/src/kart_control/new_controller/controller.py b/src/kart_control/new_controller/controller.py index c71d468..ba27db2 100644 --- a/src/kart_control/new_controller/controller.py +++ b/src/kart_control/new_controller/controller.py @@ -105,37 +105,35 @@ def __start(self) -> None: elif event.ev_type == "Absolute": self._handle_axis_event(event) - def vibrate(self, duration: int=1000) -> None: + def vibrate(self, duration: int = 1000) -> None: """Vibrate the controller. - + Parameters ---------- :param duration int: the duration to vibrate in miliseconds. default = 1000 - + """ try: self.gamepad.set_vibration(1, 1, duration) except Exception: # noqa: BLE001 - print("Failed to vibrate") # noqa: T201 + print("Failed to vibrate") # noqa: T201 - def add_listener(self, event_type: EventType, button_or_axis: ControllerButton | ControllerAxis, - callback: callable) -> None: + def add_listener( + self, event_type: EventType, button_or_axis: ControllerButton | ControllerAxis, callback: callable + ) -> None: """Add a listener to be executed on that event. - + Parameters ---------- :param event_type EventType: the type of event to call it on. :param button_or_axis ControllerButton | ControllerAxis: the axis or button that the event needs to be for. :param callback callable: the callback to call when the event occurs - + """ if event_type not in EventType: raise ValueError(f"Invalid event type: {event_type}") - if ( - button_or_axis not in ControllerButton - and button_or_axis not in ControllerAxis - ): + if button_or_axis not in ControllerButton and button_or_axis not in ControllerAxis: raise ValueError(f"Invalid button or axis: {button_or_axis}") key = (event_type, button_or_axis) if key not in self._listeners: @@ -173,7 +171,9 @@ def _handle_axis_event(self, event: EventType) -> None: self._axes[axis] = value self._check_events(EventType.AXIS_CHANGED, axis, value) - def _check_events(self, event_type: EventType, data: ControllerButton | ControllerAxis, value: float=None) -> None: + def _check_events( + self, event_type: EventType, data: ControllerButton | ControllerAxis, value: float = None + ) -> None: key = (event_type, data) if key not in self._listeners: return @@ -186,7 +186,7 @@ def _check_events(self, event_type: EventType, data: ControllerButton | Controll for callback in self._listeners[key]: callback(event_type, data) - def _start_long_press_timer(self, button: ControllerButton, timeout: float=1.5) -> None: + def _start_long_press_timer(self, button: ControllerButton, timeout: float = 1.5) -> None: def timer_callback() -> None: if self._buttons[button]: self._check_events(EventType.LONG_PRESS, button) diff --git a/src/lane_assist/image_manipulation/__init__.py b/src/lane_assist/image_manipulation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lane_assist/image_manipulation/image_stitch/__init__.py b/src/lane_assist/image_manipulation/image_stitch/__init__.py index 6bf1aa9..aa3ada5 100644 --- a/src/lane_assist/image_manipulation/image_stitch/__init__.py +++ b/src/lane_assist/image_manipulation/image_stitch/__init__.py @@ -10,8 +10,9 @@ MAX_WIDTH = 1280 MASK_OFFSET = np.array([[1, 1], [1, -1], [-1, -1], [-1, 1]]) -PTS_ORIGINAL = np.float32([[MIN_WIDTH, MIN_HEIGHT], [MIN_WIDTH, MAX_HEIGHT], - [MAX_WIDTH, MAX_HEIGHT], [MAX_WIDTH, MIN_HEIGHT]]) +PTS_ORIGINAL = np.float32( + [[MIN_WIDTH, MIN_HEIGHT], [MIN_WIDTH, MAX_HEIGHT], [MAX_WIDTH, MAX_HEIGHT], [MAX_WIDTH, MIN_HEIGHT]] +) RATIOS_LEFT = np.float32([[0, 1.0055555], [0.36197916, 2.6185186], [1.6765625, 0.7537037], [1.5010417, 0]]) RATIOS_RIGHT = np.float32([[0.18125, 0], [0, 0.74907407], [1.28125, 2.55833333], [1.66770833, 0.99722222]]) @@ -47,12 +48,14 @@ def relative_to_absolute(x: float, y: float, width: float, height: float) -> tup return x * width, y * height -def get_ltbr(x: int, y: int, width: int, height: int) -> tuple[int, int, int, int]: +def get_ltbr(x: float, y: float, width: int, height: int) -> tuple[int, int, int, int]: """Get the left, top, right, and bottom coordinates of the image based on xywh.""" - return (int(round(x - (width / 2))), - int(round(y - (height / 2))), - int(round(x + (width / 2))), - int(round(y + (height / 2)))) + return ( + int(round(x - (width / 2))), + int(round(y - (height / 2))), + int(round(x + (width / 2))), + int(round(y + (height / 2))), + ) def warp_image(image: np.ndarray, matrix: np.ndarray, width: int, height: int) -> np.ndarray: @@ -60,7 +63,7 @@ def warp_image(image: np.ndarray, matrix: np.ndarray, width: int, height: int) - return cv2.warpPerspective(image, matrix, (width, height), flags=cv2.INTER_LINEAR) -def merge_image(base:np.ndarray, overlay:np.ndarray, x1:int, y1:int, x2:int, y2:int) -> np.ndarray: +def merge_image(base: np.ndarray, overlay: np.ndarray, x1: int, y1: int, x2: int, y2: int) -> np.ndarray: """Merge two images.""" base[y1:y2, x1:x2] = overlay return base @@ -97,11 +100,22 @@ def stitch_images(left: np.ndarray, center: np.ndarray, right: np.ndarray) -> np return merge_image(result, center, cx1, cy1, cx2, cy2) +def adjust_gamma(image: np.ndarray, gamma: float = 1.0) -> np.ndarray: + """Adjust the gamma of the image.""" + inv_gamma = 1.0 / gamma + table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("uint8") + + return cv2.LUT(image, table) + + if __name__ == "__main__": # Load images - center_img = cv2.imread("images/crossing/center.jpg") - left_img = cv2.imread("images/crossing/left.jpg") - right_img = cv2.imread("images/crossing/right.jpg") + center_img = cv2.imread("../../../../resources/images/stopline/center.jpg") + left_img = cv2.imread("../../../../resources/images/stopline/left.jpg") + right_img = cv2.imread("../../../../resources/images/stopline/right.jpg") + + left_img = adjust_gamma(left_img, 0.62) + right_img = adjust_gamma(right_img, 0.62) # Write result result_img = stitch_images(left_img, center_img, right_img) diff --git a/src/lane_assist/image_manipulation/top_down_transfrom/__init__.py b/src/lane_assist/image_manipulation/top_down_transfrom/__init__.py index e69de29..fb937fe 100644 --- a/src/lane_assist/image_manipulation/top_down_transfrom/__init__.py +++ b/src/lane_assist/image_manipulation/top_down_transfrom/__init__.py @@ -0,0 +1,42 @@ +import cv2 +import numpy as np + +from src.utils.image import cut_image + + +def topdown(image: np.ndarray) -> np.ndarray: + """Transform stitched image to top-down view.""" + if image is None: + raise ValueError("Error: Unable to load image") + + pts = np.array([[55, 900], [1841, 253], [2067, 253], [3861, 900]], dtype=np.float32) + ipm_pts = np.array([[780, 450], [800, 1100], [600, 1100], [620, 450]], dtype=np.float32) + ipm_matrix = cv2.getPerspectiveTransform(pts, ipm_pts) + ipm = cv2.warpPerspective(image, ipm_matrix, (image.shape[1], image.shape[0]), flags=cv2.INTER_LINEAR) + ipm = cut_image(ipm, 300, 450, 800, 900) + ipm = cv2.rotate(ipm, cv2.ROTATE_90_CLOCKWISE) + return cv2.rotate(ipm, cv2.ROTATE_90_CLOCKWISE) + + +if __name__ == "__main__": + # load stitched image + image = cv2.imread("result.jpg") + + # time the topdown function and print the Iterations per second + import time + + start = time.time() + for _ in range(500): + topdown(image) + end = time.time() + fps = 500 / (end - start) + + # convert to grayscale + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + + # time again + start = time.time() + for _ in range(500): + topdown(gray) + end = time.time() + fps_gray = 500 / (end - start) diff --git a/src/lane_assist/line_detection/__init__.py b/src/lane_assist/line_detection/__init__.py index e69de29..65bd2ae 100644 --- a/src/lane_assist/line_detection/__init__.py +++ b/src/lane_assist/line_detection/__init__.py @@ -0,0 +1,62 @@ +import cv2 +import numpy as np + +from lane_assist.image_manipulation.top_down_transfrom import topdown +from lane_assist.line_detection.line import Line, LineType +from lane_assist.line_detection.window import Window +from lane_assist.line_detection.window_search import window_search +from src.utils.image import list_images + + +def get_lines(image: np.ndarray) -> list[Line]: + """Get the lines in the image. + + This function will take an image and return the lines in the image. + the image shoulb be stitched and not top down + """ + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + top_down = topdown(gray) + white = cv2.inRange(top_down, 200, 255) + blurred = cv2.GaussianBlur(white, (15, 15), 0) + return window_search(blurred, 50) + + +def main() -> None: + """Example usage of the line detection. + + this function is used for testing the line detection algorithm. + this is done by taking a few images and drawing the lines on the topdown image. + """ + test_images = [ + cv2.imread("../../../tests/line_detection/images/corner.jpg"), + cv2.imread("../../../tests/line_detection/images/straight.jpg"), + cv2.imread("../../../tests/line_detection/images/crossing.jpg"), + cv2.imread("../../../tests/line_detection/images/stopline.jpg"), + ] + + + colours = { + LineType.SOLID: (255, 0, 0), # red + LineType.DASHED: (0, 255, 0), # green + LineType.STOP: (0, 0, 255), # blue + } + + final_images = [] + # convert the images, so we can find the lines + for img in test_images: + lines = get_lines(img) + td_img = topdown(img) # convert too topdown to draw the lines + + # draw the points on the topdown image + for line in lines: + for point in line.points: + colour = colours[line.line_type] + cv2.circle(td_img, (point[0], point[1]), 10, colour, -1) + + final_images.append(td_img) + + list_images(final_images, rows=4, cols=2) + + +if __name__ == "__main__": + main() diff --git a/src/lane_assist/line_detection/line.py b/src/lane_assist/line_detection/line.py new file mode 100644 index 0000000..af71b74 --- /dev/null +++ b/src/lane_assist/line_detection/line.py @@ -0,0 +1,69 @@ +from enum import IntEnum + +import numpy as np + + +class LineType(IntEnum): + """The type of line.""" + + SOLID = 1 + DASHED = 2 + STOP = 3 + + +class Line: + """A line in the image. + + This class is used to represent a line in the image. + It contains the points of the line and the type of line. + """ + + points: np.ndarray + line_type: LineType + + def __init__( + self, points: np.ndarray, window_height: int = None, line_type: LineType = None, gaps_allowed: int = 2 + ) -> None: + """Initialize the line. + + Parameters + ---------- + :param points: the points of the line + :param window_height: the height of a window. used to determine if it is a solid or dashed line + :param line_type: the type of line. this can be set if it is known, for example stop lines. + + """ + self.points = points + + if line_type is not None: + self.line_type = line_type + return + + if window_height is None: + raise ValueError("window_height or line_type must be provided") + + # check if there are regular intervals greater then the height of a window + # if so we have a dashed line. if there are no regular intervals we have a solid line + if len(points) < 2: + self.line_type = LineType.SOLID + return + + intervals = np.diff(points[:, 1]) + if len(np.where(abs(intervals) > window_height * 2)[0]) > gaps_allowed: + self.line_type = LineType.DASHED + else: + self.line_type = LineType.SOLID + + def __eq__(self, other: object) -> bool: # noqa: N807, ANN001 + """Check if the lines are equal.""" + if not isinstance(other, Line): + return False + return np.array_equal(self.points, other.points) and self.line_type == other.line_type + + def __ne__(self, other: object) -> bool: # noqa: N807, ANN001 + """Check if the lines are not equal.""" + return not self.__eq__(other) + + def __repr__(self) -> str: + """Get the string representation of the line.""" + return f"Line(points={len(self.points)}, line_type={self.line_type})" diff --git a/src/lane_assist/line_detection/window.py b/src/lane_assist/line_detection/window.py new file mode 100644 index 0000000..2eb39b4 --- /dev/null +++ b/src/lane_assist/line_detection/window.py @@ -0,0 +1,11 @@ +class Window: + """Class to represent a window in the image.""" + + def __init__(self, x: int, y: int, margin: int) -> None: + """Initialize the window.""" + self.x = x + self.y = y + self.margin = margin + + self.collided = False + self.found_in_previous = False diff --git a/src/lane_assist/line_detection/window_search.py b/src/lane_assist/line_detection/window_search.py new file mode 100644 index 0000000..f1e28c7 --- /dev/null +++ b/src/lane_assist/line_detection/window_search.py @@ -0,0 +1,254 @@ +import numpy as np + +from lane_assist.line_detection import Line, Window +from lane_assist.line_detection.line import LineType + + +LINE_WIDTH = 90 +ZEBRA_CROSSING_THRESHOLD = 20000 +FILTER_DECAY = 10 + +LINE_THRESHOLD = 5000 +LINE_DECAY = 1 + + +def get_histogram_peaks(histogram: np.ndarray, peak_minimum: int, decay: int = 1) -> list[list[int]]: + """Get the peaks in the histogram. + + get all complete peaks in the histogram. + this is done by iterating over the histogram until the peak_minimum is reached. + after this it will traverse the peak in both sides to find the edges of it. + + # TODO: improve detection of corners + # TODO: support multiple stoplines + + Parameters + ---------- + :param histogram: the histogram to extract the peaks from + :param peak_minimum: the minimum value to be considered a peak + :param decay: the amount of times the value needs to be lower to say it is the end of the peak + + """ + peaks = [] + # loop over all points in histogram + index = 0 + while index < len(histogram): + # get beginning above the threshold + if histogram[index] > peak_minimum: + # iterate backwards till the value is increasing. + # this index will be the start of the peak. + current_backward_index = index + lower_count = 0 + while current_backward_index > 0: + if ( + histogram[current_backward_index] < histogram[current_backward_index - 1] + or histogram[current_backward_index] < 50 + ): + lower_count += 1 + + if lower_count > decay: + break + + current_backward_index -= 1 + start = current_backward_index + + # iterate forwards till the value is increasing. + # this index will be the end. we will only count it the end if we find 5 decreasing values. + current_forward_index = index + lower_count = 0 + while current_forward_index < len(histogram) - 1: + if ( + histogram[current_forward_index] < histogram[current_forward_index + 1] + or histogram[current_backward_index] < 50 + ) and histogram[current_forward_index] < peak_minimum: + lower_count += 1 + + if lower_count > decay: + break + + current_forward_index += 1 + + # add the peak to the list + peaks.append([start, current_forward_index]) + + # set the index to the end of the peak + index = current_forward_index + 1 + else: + index += 1 + + return peaks + + +def merge_peaks(peaks: list, min_distance: int) -> list: + """Merge the peaks of the histogram. + + this function will merge the peaks of the histogram if they are close to each other. + + Parameters + ---------- + :param peaks: the peaks to merge + :param min_distance: the minimum distance between the peaks + + """ + merged_peaks = [] + index = 0 + while index < len(peaks) - 1: + peak = peaks[index] + other = peaks[index + 1] + index += 1 + + peak_position = peak + + if abs(peak - other) < min_distance: + peak_position = (peak + other) // 2 + index += 1 + + merged_peaks.append(peak_position) + + if index == len(peaks) - 1: + merged_peaks.append(peaks[-1]) + + return merged_peaks + + + + +def window_search(img: np.ndarray, window_count: int, pixels_per_window: int = 1, window_width: int = 60) -> list[Line]: + """Get the lines in the image using the sliding window algorithm. + + first we take a histogram of the x axis. this is done to filter pout the zebra crossings. + after that we take a histogram of the other axis. this is used for detecting the starting points of the line. + after that we detect the 2 lines closest to the center. this will be the starting point for the windows. + if we detect a couple od + + Todo: + ---- + - add support for multiple horizontal lines + - improve support for stop lines + - fix position of stop line + + Parameter + --------- + :param window_width: the width of the window to use for the window search + :param pixels_per_window: the minimum amount of pixels needed in the window to be considered part of a line + :parameter img: the image to get the lines from. + :parameter window_count: the amount of windows to check in the image + + """ + # take a histogram over the horizontal pixels. + # this is used to filter out the zebra crossing. + histogram = np.sum(img[:], axis=1) + avg = ZEBRA_CROSSING_THRESHOLD + filter_peaks = get_histogram_peaks(histogram, avg, FILTER_DECAY) + peak_widths = [peak[1] - peak[0] for peak in filter_peaks] + + stop_lines_y = [] + + # mask out these peaks if they are wider then a line + for width, peak in zip(peak_widths, filter_peaks): + if width > LINE_WIDTH: + img[peak[0] : peak[1]] = 0 + else: + stop_lines_y.append(peak[len(peak) // 2]) + + # create a histogram of the image to find the lines. + # we only use the bottom half because that should eb where the lines are + histogram = np.sum(img[img.shape[0] // 2 :], axis=0) + + # get the center of the peaks. merge if they are to close to each other. + # if the peaks are to close to each other they will instantly kill each other. + # the distance between these peaks most likely means that they are part of + # the same line which has a large a gap in it. + peak_limits = get_histogram_peaks(histogram, LINE_THRESHOLD, LINE_DECAY) + peak_centers = [peak[0] + np.argmax(histogram[peak[0] : peak[1]]) for peak in peak_limits] + merged_peaks = merge_peaks(peak_centers, window_width) + + # create the windows + window_height = img.shape[0] // window_count # get the height of the windows based on the amount we want. + windows = [Window(center, img.shape[0] - window_height, window_width // 2) for center in merged_peaks] + line_points = [[] for _ in range(len(windows))] + for _ in range(window_count): + # check which windows overlap + overlapped_windows = [] + + for i, window in enumerate(windows): + if window.collided: + continue + for k, other_window in enumerate(windows): + if window == other_window: + continue + if ( + window.x - window.margin < other_window.x + other_window.margin + and window.x + window.margin > other_window.x - other_window.margin + ): + overlapped_windows.append((i, k)) + break + + # check if any of the collided windows have found a line in the previous section + # if they have not, we can assume that the line has finished + for i, k in overlapped_windows: + # if they don't we stop the window + if not windows[i].found_in_previous and not windows[k].found_in_previous: + windows[i].collided = True + windows[k].collided = True + + if windows[i].found_in_previous: + windows[k].collided = True + if windows[k].found_in_previous: + windows[i].collided = True + + for i, window in enumerate(windows): + if window.collided: + continue + + # set the current position of the window + win_y_low = window.y - window_height + win_y_high = window.y + win_x_low = window.x - window.margin + win_x_high = window.x + window.margin + + # check how many white pixels are in the window + coords = np.argwhere(img[win_y_low:win_y_high, win_x_low:win_x_high] >= 100) + + # If you found > minpix pixels, recenter next window on the top nonzero pixels position + if len(coords) > pixels_per_window: + window.x = int(np.mean(coords[:, 1])) + win_x_low + line_points[i].append([window.x, window.y]) + window.found_in_previous = True + else: + window.found_in_previous = False + + window.y = win_y_low + + # create the lines. we first calculate how many gaps we have filtered out. + # this is used to determine if the line is solid or dashed. we also allow + # for 2 extra gaps jsut in case + filtered_count = len(filter_peaks) - len(stop_lines_y) + lines = [Line(np.array(points), window_height, gaps_allowed=filtered_count + 2) for points in line_points] + + if len(stop_lines_y) == 0: + return lines + + # get the y position of the stop lines + # get the solid lines in the image + solid_lines = [line for line in lines if line.line_type == LineType.SOLID] + + # get the closest point for each solid line at the y of the stop lines + closest_points = [] + for line in solid_lines: + closest_point = line.points[np.argmin(np.abs(line.points[:, 1] - stop_lines_y))] + closest_points.append(closest_point) + + if len(closest_points) == 0: + return lines + + # get the distance between the closest points and the stop lines + distances = closest_points[-1][0] - closest_points[0][0] + + # generate points in between them with a gap of window height + points = np.array( + [(closest_points[0][0] + i, stop_lines_y[0]) for i in range(0, distances, window_height)] + + [(closest_points[-1][0], stop_lines_y[0])] + ) + lines.append(Line(points, window_height, LineType.STOP)) + return lines diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..c220b8d --- /dev/null +++ b/src/main.py @@ -0,0 +1,19 @@ +import cv2 + +from lane_assist.image_manipulation.image_stitch import adjust_gamma, stitch_images +from lane_assist.line_detection import get_lines + +if __name__ == "__main__": + # load images + center_img = cv2.imread("../resources/images/straight/center.jpg") + left_img = cv2.imread("../resources/images/straight/left.jpg") + right_img = cv2.imread("../resources/images/straight/right.jpg") + + # adjust the gamma of the images so the bright unstitched_ are giving less false positives + left_img = adjust_gamma(left_img, 0.62) + right_img = adjust_gamma(right_img, 0.62) + stitched = stitch_images(left_img, center_img, right_img) + + # get the lines of the image + lines = get_lines(stitched) + print(lines) # noqa: T201 diff --git a/src/scripts/data_acquisition.py b/src/scripts/data_acquisition.py index ac10514..0b215b2 100644 --- a/src/scripts/data_acquisition.py +++ b/src/scripts/data_acquisition.py @@ -11,7 +11,8 @@ from src.common.constants import CameraFramerate, CameraResolution, CANFeedbackIdentifier -CAN_MSG_SENDING_SPEED = .040 # 25Hz +CAN_MSG_SENDING_SPEED = 0.040 # 25Hz + class CanListener: """A can listener that listens for specific messages and stores their latest values.""" @@ -28,12 +29,7 @@ def __init__(self, bus: can.Bus) -> None: self.bus = bus self.thread = threading.Thread(target=self._listen, args=(), daemon=True) self.running = False - self.data = { - "steering": None, - "throttle": None, - "brake": None, - "speed_sensor": None - } + self.data = {"steering": None, "throttle": None, "brake": None, "speed_sensor": None} def start_listening(self) -> None: """Start the can listener.""" @@ -50,11 +46,10 @@ def stop_listening(self) -> None: def get_new_values(self) -> any: """Get new values.""" return self.data - def _listen(self) -> None: while self.running: - message: Optional[can.Message] = self.bus.recv(.5) + message: Optional[can.Message] = self.bus.recv(0.5) message_id = CANFeedbackIdentifier(message.arbitration_id) if message else None if message_id in self._id_conversion: self.data[self._id_conversion[message_id]] = message.data @@ -125,8 +120,11 @@ def _process(self) -> None: [str(x) for x in values["throttle"] or []])}"|"{",".join([str(x) for x in values["brake"] or []])}"|"{",".join([str(x) for x in values["speed_sensor"] - or []])}"|"' + self.folder_name + f'/{timestamp}.jpg"', - file=self.file_pointer) + or []])}"|"' + + self.folder_name + + f'/{timestamp}.jpg"', + file=self.file_pointer, + ) self.queue.task_done() @@ -207,33 +205,20 @@ def initialize_can() -> Optional[can.Bus]: """Set up the can bus interface and apply filters for the messages we're interested in.""" try: bus = can.Bus(interface="socketcan", channel="can0", bitrate=500000) - bus.set_filters([ - { - "can_id": CANFeedbackIdentifier.STEERING_SENSOR, - "can_mask": 0xfff, - "extended": True - }, - { - "can_id": CANFeedbackIdentifier.THROTTLE, - "can_mask": 0xfff, - "extended": True - }, - { - "can_id": CANFeedbackIdentifier.BRAKE, - "can_mask": 0xfff, - "extended": True - }, - { - "can_id": CANFeedbackIdentifier.SPEED_SENSOR, - "can_mask": 0xfff, - "extended": True - } - ]) + bus.set_filters( + [ + {"can_id": CANFeedbackIdentifier.STEERING_SENSOR, "can_mask": 0xFFF, "extended": True}, + {"can_id": CANFeedbackIdentifier.THROTTLE, "can_mask": 0xFFF, "extended": True}, + {"can_id": CANFeedbackIdentifier.BRAKE, "can_mask": 0xFFF, "extended": True}, + {"can_id": CANFeedbackIdentifier.SPEED_SENSOR, "can_mask": 0xFFF, "extended": True} + ] + ) return bus except Exception as e: # noqa: BLE001 print(f"Error initializing CAN: {e}", file=sys.stderr) # noqa: T201 return None + def initialize_camera(device: int | str) -> Optional[cv2.VideoCapture]: """Connect a camera.""" capture = cv2.VideoCapture(device) @@ -255,7 +240,5 @@ def initialize_cameras() -> tuple[cv2.VideoCapture, cv2.VideoCapture, cv2.VideoC return initialize_camera(0), initialize_camera(2), initialize_camera(4) - if __name__ == "__main__": main() - diff --git a/src/scripts/original_controller.py b/src/scripts/original_controller.py index 95f8c24..d419642 100644 --- a/src/scripts/original_controller.py +++ b/src/scripts/original_controller.py @@ -9,7 +9,7 @@ import inputs from inputs import get_gamepad -CAN_MSG_SENDING_SPEED = .04 # 100Hz +CAN_MSG_SENDING_SPEED = 0.04 # 100Hz class XboxController: @@ -46,7 +46,7 @@ def __init__(self): def vibrate(self, duration=1000): try: - self.gamepad.set_vibration(1,1,duration) + self.gamepad.set_vibration(1, 1, duration) except Exception: print("Failed to vibrate") @@ -54,14 +54,14 @@ def read(self): def _rerange(x, in_min, in_max, out_min, out_max): return float((x - in_min) * (out_max - out_min) / (in_max - in_min) + out_min) - throttle = int(round(_rerange((self.RightTrigger * 100),0,400,0,100),0)) #reranges the values to 0-100 - brake = int(round(_rerange((self.LeftTrigger * 100),0,400,0,100),0)) #reranges the values to 0-100 + throttle = int(round(_rerange((self.RightTrigger * 100), 0, 400, 0, 100), 0)) # reranges the values to 0-100 + brake = int(round(_rerange((self.LeftTrigger * 100), 0, 400, 0, 100), 0)) # reranges the values to 0-100 steering = 0 - if self.LeftJoystickX * 100 > 10: # joystick drift compensation - steering = round(_rerange((self.LeftJoystickX * 10),1,10,0,1),5) + if self.LeftJoystickX * 100 > 10: # joystick drift compensation + steering = round(_rerange((self.LeftJoystickX * 10), 1, 10, 0, 1), 5) elif self.LeftJoystickX * 100 < -10: - steering = round(_rerange((self.LeftJoystickX * 10),-1,-10,-0,-1),5) + steering = round(_rerange((self.LeftJoystickX * 10), -1, -10, -0, -1), 5) a = self.A y = self.Y @@ -75,45 +75,45 @@ def _monitor_controller(self): try: events = get_gamepad() for event in events: - if event.code == 'ABS_Y': - self.LeftJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 - elif event.code == 'ABS_X': - self.LeftJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 - elif event.code == 'ABS_RY': - self.RightJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 - elif event.code == 'ABS_RX': - self.RightJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 - elif event.code == 'ABS_Z': - self.LeftTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 - elif event.code == 'ABS_RZ': - self.RightTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 - elif event.code == 'BTN_TL': + if event.code == "ABS_Y": + self.LeftJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 + elif event.code == "ABS_X": + self.LeftJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 + elif event.code == "ABS_RY": + self.RightJoystickY = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 + elif event.code == "ABS_RX": + self.RightJoystickX = event.state / XboxController.MAX_JOY_VAL # normalize between -1 and 1 + elif event.code == "ABS_Z": + self.LeftTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 + elif event.code == "ABS_RZ": + self.RightTrigger = event.state / XboxController.MAX_TRIG_VAL # normalize between 0 and 1 + elif event.code == "BTN_TL": self.LeftBumper = event.state - elif event.code == 'BTN_TR': + elif event.code == "BTN_TR": self.RightBumper = event.state - elif event.code == 'BTN_SOUTH': + elif event.code == "BTN_SOUTH": self.A = event.state - elif event.code == 'BTN_NORTH': + elif event.code == "BTN_NORTH": self.X = event.state - elif event.code == 'BTN_WEST': + elif event.code == "BTN_WEST": self.Y = event.state - elif event.code == 'BTN_EAST': + elif event.code == "BTN_EAST": self.B = event.state - elif event.code == 'BTN_THUMBL': + elif event.code == "BTN_THUMBL": self.LeftThumb = event.state - elif event.code == 'BTN_THUMBR': + elif event.code == "BTN_THUMBR": self.RightThumb = event.state - elif event.code == 'BTN_SELECT': + elif event.code == "BTN_SELECT": self.Back = event.state - elif event.code == 'BTN_START': + elif event.code == "BTN_START": self.Start = event.state - elif event.code == 'BTN_TRIGGER_HAPPY1': + elif event.code == "BTN_TRIGGER_HAPPY1": self.LeftDPad = event.state - elif event.code == 'BTN_TRIGGER_HAPPY2': + elif event.code == "BTN_TRIGGER_HAPPY2": self.RightDPad = event.state - elif event.code == 'BTN_TRIGGER_HAPPY3': + elif event.code == "BTN_TRIGGER_HAPPY3": self.UpDPad = event.state - elif event.code == 'BTN_TRIGGER_HAPPY4': + elif event.code == "BTN_TRIGGER_HAPPY4": self.DownDPad = event.state except Exception: pass @@ -123,12 +123,12 @@ def initialize_can(): system("ip link set can0 type can bitrate 500000") system("ip link set can0 up") - bus = can.Bus(interface='socketcan', channel='can0', bitrate=500000) + bus = can.Bus(interface="socketcan", channel="can0", bitrate=500000) return bus -if __name__ == '__main__': +if __name__ == "__main__": joy = XboxController() joy.vibrate(500) @@ -143,7 +143,7 @@ def initialize_can(): started = True prev_a = a - sleep(.040) + sleep(0.040) joy.vibrate(1000) bus = initialize_can() @@ -156,8 +156,7 @@ def initialize_can(): steering_msg = can.Message(arbitration_id=0x220, data=[0, 0, 0, 0, 0, 0, 195, 0], is_extended_id=False) steering_task = bus.send_periodic(steering_msg, CAN_MSG_SENDING_SPEED) - acc_msg = can.Message(arbitration_id=0x330, is_extended_id=False, data=[0, 0, 1, 0, - 0, 0, 0, 0]) + acc_msg = can.Message(arbitration_id=0x330, is_extended_id=False, data=[0, 0, 1, 0, 0, 0, 0, 0]) acc_task = bus.send_periodic(acc_msg, CAN_MSG_SENDING_SPEED) switch_sequence_time = 0 @@ -199,7 +198,6 @@ def initialize_can(): joy.vibrate(750) print("Controls activated") - if y == 1: direction = 0 print("Changed gear to N") @@ -221,20 +219,22 @@ def initialize_can(): bus = initialize_can() setup_completed = True elif active: - print(f"Throttle: {joy.read()[0]:<3} -- Braking: {joy.read()[1]:<3} -- Steering: {joy.read()[2]:>8.5f}", end="\n") + print( + f"Throttle: {joy.read()[0]:<3} -- Braking: {joy.read()[1]:<3} -- Steering: {joy.read()[2]:>8.5f}", + end="\n", + ) acc_msg.data = [int(t), 0, direction, 0, 0, 0, 0, 0] acc_task.modify_data(acc_msg) - steering_msg.data = (list(bytearray(struct.pack("f", float(s)))) - + [0, 0, 195, 0]) + steering_msg.data = list(bytearray(struct.pack("f", float(s)))) + [0, 0, 195, 0] steering_task.modify_data(steering_msg) brk_msg.data = [brake, 0, 0, 0, 0, 0, 0, 0] brk_task.modify_data(brk_msg) - sleep(.040) + sleep(0.040) except KeyboardInterrupt: - pass + pass acc_task.stop() steering_task.stop() brk_task.stop() diff --git a/src/utils/calculate_distance.py b/src/utils/calculate_distance.py index 1b22170..e50a8c2 100644 --- a/src/utils/calculate_distance.py +++ b/src/utils/calculate_distance.py @@ -1,10 +1,9 @@ import logging +from typing import TypeVar + import matplotlib.pyplot as plt import numpy as np - from scipy.optimize import curve_fit -from typing import TypeVar - """Calibrated parameters""" a0 = 1.015988 @@ -16,15 +15,19 @@ POPT = (float, float, float, float) """Functions""" + + def formula(x: T, a: float, b: float, c: float, d: float) -> T: """The formula to calculate the distance from the y-value.""" return a / (x - c) ** b + d + def recalibrate_parameters(x: np.ndarray, y: np.ndarray) -> POPT: """Recalibrate the parameters of the formula to fit the given data.""" fit = curve_fit(formula, x, y, p0=[a0, b0, c0, d0]) return fit[0] + def plot_data(x: np.ndarray, y: np.ndarray, popt: POPT) -> None: """Plot the given data and the fitted line.""" x_reg = np.linspace(0.4, 1, 1000) @@ -35,6 +38,7 @@ def plot_data(x: np.ndarray, y: np.ndarray, popt: POPT) -> None: plt.show() + def y_to_meters(y: int, height: int) -> float: """Convert a y-value to meters.""" rel_y = y / height @@ -43,11 +47,13 @@ def y_to_meters(y: int, height: int) -> float: return formula(rel_y, a0, b0, c0, d0) + def meters_to_y(meters: float, height: int) -> int: """Convert meters to a y-value.""" - rel_y = ((a0 / (meters - d0)) ** (1 / b0) + c0) + rel_y = (a0 / (meters - d0)) ** (1 / b0) + c0 return int(round(rel_y * height)) + if __name__ == "__main__": X = np.array([1028, 840, 733, 621, 561, 523, 499, 481, 468, 457, 449, 443]) / 1080 Y = np.array([0, 0.5, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + 0.115 diff --git a/src/utils/image.py b/src/utils/image.py new file mode 100644 index 0000000..aed6975 --- /dev/null +++ b/src/utils/image.py @@ -0,0 +1,18 @@ +import numpy as np +from matplotlib import pyplot as plt + + +def list_images(images: list[np.ndarray], cols: int = 1, rows: int = 4, cmap: str = None) -> None: + """Plot a list of images.""" + plt.figure(figsize=(10, 11)) + for i, image in enumerate(images): + plt.subplot(rows, cols, i + 1) + cmap = "gray" if len(image.shape) == 2 else cmap + plt.imshow(image, cmap=cmap, aspect="equal") + plt.tight_layout(pad=0, h_pad=0, w_pad=0) + plt.show() + + +def cut_image(image: np.ndarray, x: int, y: int, width: int, height: int) -> np.ndarray: + """Cut image to specified width and height.""" + return image[y : y + height, x : x + width] diff --git a/tests/line_detection/__init__.py b/tests/line_detection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/line_detection/images/corner.jpg b/tests/line_detection/images/corner.jpg new file mode 100644 index 0000000..aae7c7d Binary files /dev/null and b/tests/line_detection/images/corner.jpg differ diff --git a/tests/line_detection/images/crossing.jpg b/tests/line_detection/images/crossing.jpg new file mode 100644 index 0000000..2ec76c8 Binary files /dev/null and b/tests/line_detection/images/crossing.jpg differ diff --git a/tests/line_detection/images/stopline.jpg b/tests/line_detection/images/stopline.jpg new file mode 100644 index 0000000..57f2562 Binary files /dev/null and b/tests/line_detection/images/stopline.jpg differ diff --git a/tests/line_detection/images/straight.jpg b/tests/line_detection/images/straight.jpg new file mode 100644 index 0000000..2c45472 Binary files /dev/null and b/tests/line_detection/images/straight.jpg differ diff --git a/tests/line_detection/lines.py b/tests/line_detection/lines.py new file mode 100644 index 0000000..4ee2b30 --- /dev/null +++ b/tests/line_detection/lines.py @@ -0,0 +1,378 @@ +import numpy as np + +from lane_assist.line_detection import Line, LineType + +CORNER = [ + Line( + np.array( + [ + [312, 882], + [295, 864], + [281, 846], + [307, 720], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [477, 882], + [473, 864], + [467, 846], + [457, 828], + [445, 810], + [430, 792], + [413, 774], + [393, 756], + [370, 738], + [345, 720], + ] + ), + line_type=LineType.SOLID, + ), +] + +STRAIGHT = [ + Line( + np.array( + [ + [315, 882], + [315, 864], + [316, 846], + [316, 828], + [317, 810], + [317, 792], + [317, 774], + [318, 756], + [318, 738], + [318, 720], + [319, 702], + [319, 684], + [320, 666], + [320, 648], + [321, 630], + [322, 612], + [322, 594], + [323, 576], + [324, 558], + [324, 540], + [325, 522], + [326, 504], + [326, 486], + [327, 468], + [327, 450], + [328, 432], + [328, 414], + [336, 108], + [336, 90], + [336, 72], + [336, 54], + [337, 36], + [337, 18], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [468, 882], + [469, 864], + [470, 846], + [471, 828], + [472, 810], + [473, 792], + [474, 774], + [475, 756], + [476, 738], + [477, 720], + [478, 702], + [479, 684], + [480, 666], + [481, 648], + [482, 630], + [483, 612], + [484, 594], + [485, 576], + [486, 558], + [487, 540], + [488, 522], + [489, 504], + [490, 486], + [491, 468], + [492, 450], + [493, 432], + [493, 414], + [512, 108], + [513, 90], + [513, 72], + [514, 54], + [515, 36], + [516, 18], + ] + ), + line_type=LineType.SOLID, + ), +] +CROSSING = [ + Line( + np.array( + [ + [104, 558], + [101, 540], + [98, 522], + [95, 504], + [92, 486], + [89, 468], + [86, 450], + [83, 432], + [80, 414], + [73, 396], + [75, 378], + [72, 360], + [69, 342], + [66, 324], + [64, 306], + [61, 288], + [58, 270], + [56, 252], + [53, 234], + [50, 216], + [47, 198], + [45, 180], + [41, 162], + [39, 144], + [35, 126], + [32, 108], + [31, 90], + [28, 72], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [328, 864], + [327, 846], + [326, 828], + [297, 486], + [293, 468], + [291, 450], + [290, 432], + [286, 378], + [284, 360], + [282, 342], + [281, 324], + [277, 270], + [276, 252], + [274, 234], + [272, 216], + [271, 198], + [268, 144], + [266, 126], + [264, 108], + [262, 90], + [260, 72], + ] + ), + line_type=LineType.DASHED, + ), + Line( + np.array( + [ + [483, 882], + [482, 864], + [481, 846], + [480, 828], + [474, 558], + [473, 540], + [472, 522], + [472, 504], + [471, 486], + [471, 468], + [470, 450], + [470, 432], + [469, 414], + [469, 396], + [469, 378], + [468, 360], + [468, 342], + [467, 324], + [466, 306], + [465, 288], + [465, 270], + [464, 252], + [464, 234], + [463, 216], + [463, 198], + [462, 180], + [462, 162], + [461, 144], + [460, 126], + [460, 108], + [459, 90], + [458, 72], + [458, 54], + [457, 36], + [456, 18], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [655, 504], + [645, 486], + [642, 468], + [643, 450], + [644, 432], + [644, 414], + [645, 396], + [646, 378], + [645, 360], + [647, 342], + [646, 324], + [654, 306], + [655, 288], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [719, 558], + [721, 540], + [723, 522], + [715, 504], + [726, 486], + [729, 468], + [730, 450], + [732, 432], + [733, 414], + [733, 396], + [735, 360], + [733, 324], + [734, 306], + [737, 288], + [740, 270], + [738, 252], + [738, 234], + [739, 216], + [740, 198], + [741, 180], + [739, 162], + [741, 144], + [741, 126], + [743, 108], + [744, 72], + [744, 36], + ] + ), + line_type=LineType.SOLID, + ), +] + + +STOP_LINE = [ + Line( + np.array( + [ + [315, 882], + [324, 864], + [316, 846], + [316, 828], + [317, 810], + [317, 792], + [317, 774], + [318, 756], + [318, 738], + [318, 720], + [319, 702], + [319, 684], + [320, 666], + [320, 648], + [321, 630], + [322, 612], + [322, 594], + [323, 576], + [323, 558], + [324, 540], + [324, 522], + [325, 504], + [326, 486], + [326, 468], + [327, 450], + [328, 432], + [328, 414], + [336, 108], + [336, 90], + [336, 72], + [336, 54], + [336, 36], + [336, 18], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [468, 882], + [459, 864], + [470, 846], + [471, 828], + [472, 810], + [473, 792], + [474, 774], + [475, 756], + [476, 738], + [477, 720], + [478, 702], + [479, 684], + [480, 666], + [481, 648], + [482, 630], + [483, 612], + [484, 594], + [485, 576], + [486, 558], + [487, 540], + [488, 522], + [489, 504], + [490, 486], + [491, 468], + [492, 450], + [493, 432], + [494, 414], + [513, 108], + [514, 90], + [514, 72], + [515, 54], + [516, 36], + [517, 18], + ] + ), + line_type=LineType.SOLID, + ), + Line( + np.array( + [ + [315, 899], + [333, 899], + [351, 899], + [369, 899], + [387, 899], + [405, 899], + [423, 899], + [441, 899], + [459, 899], + [468, 899], + ] + ), + line_type=LineType.STOP, + ), +] diff --git a/tests/line_detection/test_line_detection.py b/tests/line_detection/test_line_detection.py new file mode 100644 index 0000000..74d38c9 --- /dev/null +++ b/tests/line_detection/test_line_detection.py @@ -0,0 +1,82 @@ +import os +import time +import unittest + +import cv2 +import numpy as np +from src.lane_assist.line_detection import Line, get_lines + +from .lines import CORNER, CROSSING, STOP_LINE, STRAIGHT + +BENCHMARK_ITERATIONS = 500 + + +class TestDetectLines(unittest.TestCase): + """Tests for the line detection methods.""" + + def test_line_detection_corner(self) -> None: + """Test the line detection on the corner image.""" + image = cv2.imread("./line_detection/images/corner.jpg") + self.assertIsNotNone(image, "Image not found") + self.__test_line_detection(image, CORNER) + + def test_line_detection_straight(self) -> None: + """Test the line detection on the straight image.""" + image = cv2.imread("./line_detection/images/straight.jpg") + self.assertIsNotNone(image, "Image not found") + self.__test_line_detection(image, STRAIGHT) + + def test_line_detection_crossing(self) -> None: + """Test the line detection on the crossing image.""" + image = cv2.imread("./line_detection/images/crossing.jpg") + self.assertIsNotNone(image, "Image not found") + self.__test_line_detection(image, CROSSING) + + def test_line_detection_stopline(self) -> None: + """Test the line detection on the stopline image.""" + image = cv2.imread("./line_detection/images/stopline.jpg") + self.assertIsNotNone(image, "Image not found") + self.__test_line_detection(image, STOP_LINE) + + def __test_line_detection(self, image: np.ndarray, expected_lines: list[Line]) -> None: + """Test the line detection on the given image.""" + lines = get_lines(image) + with self.subTest("amount of lines detected"): + self.assertEqual(len(lines), len(expected_lines), "Number of lines is not equal") + + with self.subTest("lines"): + for i, line in enumerate(lines): + with self.subTest(f"line {i}"): + self.assertEqual(line.line_type, expected_lines[i].line_type, f"Line type of line {i} is not equal") + self.assertEqual( + len(line.points), + len(expected_lines[i].points), + f"Number of points in line {i} is not equal", + ) + self.assertEqual(line, expected_lines[i], f"Line {i} is not equal to expected line") + + @unittest.skipIf(os.environ.get("SKIP_BENCHMARK", "0") == "1", "benchmarking not enabled") + def test_benchmark(self) -> None: + """Benchmark the line detection. + + to make sure it is a realistic situation all images in the test folder are used. + this benchmark will not include stitching the images. + """ + benchmark_images = [ + cv2.imread(f"./line_detection/images/{img}") for img in os.listdir("./line_detection/images") + ] + + img_count = len(benchmark_images) + + start = time.process_time() + for i in range(BENCHMARK_ITERATIONS): + get_lines(benchmark_images[i % img_count]) + stop = time.process_time() + fps = BENCHMARK_ITERATIONS / (stop - start) + per_iteration = (stop - start) / BENCHMARK_ITERATIONS + + print(f"Line detection: FPS: {fps:.4f}, per iteration: {per_iteration:.6f}") # noqa: T201 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils/test_calculate_distance.py b/tests/utils/test_calculate_distance.py index 4bbfc8d..5c21f96 100644 --- a/tests/utils/test_calculate_distance.py +++ b/tests/utils/test_calculate_distance.py @@ -17,5 +17,6 @@ def test_meters_to_y(self) -> None: self.assertEqual(meters_to_y(y_to_meters(720, 720), 720), 720) self.assertEqual(meters_to_y(y_to_meters(561, 1080), 1080), 561) + if __name__ == "__main__": unittest.main()