Skip to content

Commit

Permalink
#276: Fix travelling salesman solver bug
Browse files Browse the repository at this point in the history
  • Loading branch information
bgottula committed Jul 7, 2024
1 parent 02d7446 commit d855d0a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 16 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
'pandas',
'point @ https://github.com/seeing-things/point/tarball/master',
'pyftdi>=0.49', # for laser pointer control
'pytest>=8.0',
'requests',
'scipy',
'skyfield',
Expand Down
41 changes: 41 additions & 0 deletions tests/test_tsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Test travelling salesman solver module."""

from __future__ import annotations
import itertools
import numpy as np
from track import tsp


class Position(tsp.Destination):
"""A position in 2-space."""

def __init__(self, position: tuple[float, float]):
self.position = np.array(position)

def distance_to(self, other: Position) -> int:
max_error_mag = np.linalg.norm(self.position - other.position)
# Scale by 1000 to minimize precision loss when quantizing to integer.
return int(1000 * max_error_mag)


def test_solver():
"""Basic solver test.
This test case checks for regressions in the solver with a nearly trivial test. This test case
fails without the bug fix in the tsp module for this issue:
https://github.com/seeing-things/track/issues/276.
"""
# Intentionally suboptimal route with diagonal crossings between the vertices of a square.
positions = [
Position((0, 0)),
Position((0, 1)),
Position((1, 0)),
Position((1, 1)),
]

positions_sorted = tsp.solve_route(positions)

# If the solver worked properly the route should traverse the perimeter of a square rather than
# crossing over the diagonals.
dist = sum(p1.distance_to(p2) for p1, p2 in itertools.pairwise(positions_sorted))
assert dist == 4000
43 changes: 27 additions & 16 deletions track/tsp.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
"""Travelling salesman solver for optimizing route to set of multiple positions."""

from __future__ import annotations
from abc import ABC, abstractmethod
from ortools.constraint_solver import routing_enums_pb2
import logging
import typing
from ortools.constraint_solver import pywrapcp
import numpy as np


logger = logging.getLogger(__name__)


class SolverFailedError(Exception):
"""Travelling salesman solver has failed to produce a solution."""


class Destination(ABC):
"""Abstract class representing a destination along a route."""

@abstractmethod
def distance_to(self, other_destination: "Destination") -> int:
def distance_to(self, other_destination: Destination) -> int:
"""Returns the distance (cost) of travelling from this to the other destination.
Args:
Expand All @@ -28,19 +37,23 @@ def solve_route(destinations: list[Destination]) -> list[Destination]:
Takes a list of destinations and sorts them such that the total time taken to visit each one is
minimized. This is known as the travelling salesman problem.
Much of the code in this function is borrowed from the example here:
https://developers.google.com/optimization/routing/tsp
Args:
destinations: A list of places to visit. It is assumed that the first destination in the
list is the first and last destination for the trip, otherwise known as the "depot".
Returns:
The list of places ordered such that visiting them in that order minimizes the total trip
time.
Raises:
RuntimeError if the solver fails to find a solution.
"""
# pre-compute a matrix of distances between each destination
distance_matrix = np.zeros((len(destinations),) * 2)
# Much of the code in this function is borrowed from the example here:
# https://developers.google.com/optimization/routing/tsp

# Pre-compute a matrix of distances between each destination.
distance_matrix: np.ndarray[typing.Any, np.dtype[np.int64]]
distance_matrix = np.zeros((len(destinations),) * 2, dtype=int)
for idx, dest in enumerate(destinations):
for jdx in range(idx + 1, len(destinations)):
distance_matrix[idx, jdx] = dest.distance_to(destinations[jdx])
Expand All @@ -49,22 +62,20 @@ def solve_route(destinations: list[Destination]) -> list[Destination]:
manager = pywrapcp.RoutingIndexManager(len(destinations), 1, 0)
routing = pywrapcp.RoutingModel(manager)

def distance_callback(from_index, to_index):
def distance_callback(from_index: int, to_index: int) -> int:
"""Returns the distance between the two nodes."""
from_node = manager.IndexToNode(from_index)
to_node = manager.IndexToNode(to_index)
return distance_matrix[from_node][to_node]
return int(distance_matrix[from_node][to_node])

transit_callback_index = routing.RegisterTransitCallback(distance_callback)
routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index)
solution = routing.Solve()

# Setting first solution heuristic.
search_parameters = pywrapcp.DefaultRoutingSearchParameters()
search_parameters.first_solution_strategy = (
routing_enums_pb2.FirstSolutionStrategy.PATH_CHEAPEST_ARC
) # pylint: disable=no-member

solution = routing.SolveWithParameters(search_parameters)
# Status values:
# https://developers.google.com/optimization/routing/routing_options#search_status
if (status := routing.status()) != 1:
raise SolverFailedError(f'Routing solver failed with status {status}.')

index = routing.Start(0)
route = [destinations[manager.IndexToNode(index)]]
Expand Down

0 comments on commit d855d0a

Please sign in to comment.