Skip to content

Commit

Permalink
Add snapping function
Browse files Browse the repository at this point in the history
  • Loading branch information
dfsnow committed Nov 30, 2024
1 parent 9a7f2f9 commit dfde348
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 16 deletions.
5 changes: 5 additions & 0 deletions data/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ times:
# Valhalla. This is necessary because larger matrices will cause it to choke
max_split_size: 200

# Coordinates are snapped to the OSM street network before time calculation.
# Setting this to true will use the snapped coordinates directly in the
# matrix API calls
use_snapped: true

input:
# Distance in meters to buffer each state boundary by when clipping the
# national road network. Should be slightly higher than `destination_buffer_m`
Expand Down
15 changes: 13 additions & 2 deletions data/src/calculate_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TravelTimeCalculator,
TravelTimeConfig,
TravelTimeInputs,
snap_df_to_osm,
)
from utils.utils import format_time, get_md5_hash

Expand Down Expand Up @@ -64,6 +65,16 @@ def main() -> None:
# Initialize the default Valhalla actor bindings
actor = valhalla.Actor((Path.cwd() / "valhalla.json").as_posix())

# Use the Vahalla Locate API to append coordinates that are snapped to OSM
if config.params["times"]["use_snapped"]:
logger.info("Snapping coordinates to OSM network")
inputs.origins_chunk = snap_df_to_osm(
inputs.origins_chunk, config.args.mode, actor
)
inputs.destinations = snap_df_to_osm(
inputs.destinations, config.args.mode, actor
)

# Calculate times for each chunk and append to a list
tt_calc = TravelTimeCalculator(actor, config, inputs)
results_df = tt_calc.get_times()
Expand Down Expand Up @@ -93,8 +104,8 @@ def main() -> None:

# Create a new input class, keeping only pairs that were unroutable
inputs_sp = TravelTimeInputs(
origins=inputs.origins[
inputs.origins["id"].isin(
origins=inputs.origins_chunk[
inputs.origins_chunk["id"].isin(
missing_pairs_df.index.get_level_values("origin_id")
)
].reset_index(drop=True),
Expand Down
94 changes: 80 additions & 14 deletions data/src/utils/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,9 @@ def __init__(
def _set_chunk_attributes(self) -> None:
"""Sets the origin chunk indices given the input chunk string."""
if self.chunk:
self.chunk_start_idx, self.chunk_end_idx = map(
int, self.chunk.split("-")
)
chunk_start_idx, chunk_end_idx = self.chunk.split("-")
self.chunk_start_idx = int(chunk_start_idx)
self.chunk_end_idx = int(chunk_end_idx) + 1
self.chunk_size = self.chunk_end_idx - self.chunk_start_idx

def _set_origins_chunk(self) -> None:
Expand Down Expand Up @@ -381,16 +381,21 @@ def _calculate_times(
and distances.
"""

def col_dict(x, snapped=self.config.params["times"]["use_snapped"]):
"""Use the snapped lat/lon if set."""
col_suffix = "_snapped" if snapped else ""
return {"lat": x[f"lat{col_suffix}"], "lon": x[f"lon{col_suffix}"]}

# Get the subset of origin and destination points and convert them to
# lists then squash them into the request body
# lists, then squash them into the request body
origins_list = (
self.inputs.origins.iloc[o_start_idx:o_end_idx]
.apply(lambda row: {"lat": row["lat"], "lon": row["lon"]}, axis=1)
self.inputs.origins_chunk.iloc[o_start_idx:o_end_idx]
.apply(col_dict, axis=1)
.tolist()
)
destinations_list = (
self.inputs.destinations.iloc[d_start_idx:d_end_idx]
.apply(lambda row: {"lat": row["lat"], "lon": row["lon"]}, axis=1)
.apply(col_dict, axis=1)
.tolist()
)
request_json = json.dumps(
Expand Down Expand Up @@ -512,16 +517,20 @@ def get_times(self) -> pd.DataFrame:
and distances for all inputs.
"""
results = []
msso = self.inputs.max_split_size_origins
noc = self.inputs.n_origins_chunk
mssd = self.inputs.max_split_size_destinations
ndc = self.inputs.n_destinations_chunk
max_spl_o = self.inputs.max_split_size_origins
n_oc = self.inputs.n_origins_chunk
m_spl_d = self.inputs.max_split_size_destinations
n_dc = self.inputs.n_destinations_chunk

for o in range(0, noc, msso):
for d in range(0, ndc, mssd):
for o in range(0, n_oc, max_spl_o):
for d in range(0, n_dc, m_spl_d):
results.extend(
self._binary_search(
o, d, min(o + msso, noc), min(d + mssd, ndc), True
o,
d,
min(o + max_spl_o, n_oc),
min(d + m_spl_d, n_dc),
True,
)
)

Expand All @@ -543,3 +552,60 @@ def get_times(self) -> pd.DataFrame:
)
del results
return results_df


def snap_df_to_osm(
df: pd.DataFrame, mode: str, actor: valhalla.Actor
) -> pd.DataFrame:
"""
Snap a DataFrame of lat/lon points to the OpenStreetMap network using
the Valhalla Locate API.
Args:
df: DataFrame containing the columns 'id', 'lat', and 'lon'.
mode: Travel mode to use for snapping.
actor: Valhalla Actor object for making API requests.
"""
df_list = df.apply(
lambda x: {"lat": x["lat"], "lon": x["lon"]}, axis=1
).tolist()
request_json = json.dumps(
{
"locations": df_list,
"costing": mode,
"verbose": False,
}
)

response = actor.locate(request_json)
response_data = json.loads(response)

# Use the first element of nodes to populate the snapped lat/lon, otherwise
# fallback to the correlated lat/lon from edges
response_df = pd.DataFrame(
[
{
"lon_snapped": item["nodes"][0]["lon"]
if item["nodes"]
else (
item["edges"][0]["correlated_lon"]
if item["edges"]
else None
),
"lat_snapped": item["nodes"][0]["lat"]
if item["nodes"]
else (
item["edges"][0]["correlated_lat"]
if item["edges"]
else None
),
}
for item in response_data
]
)

df = pd.concat([df, response_df], axis=1)
df.fillna({"lon_snapped": df["lon"]}, inplace=True)
df.fillna({"lat_snapped": df["lat"]}, inplace=True)
df["is_snapped"] = df["lon"] != df["lon_snapped"]
return df

0 comments on commit dfde348

Please sign in to comment.