Skip to content

Commit

Permalink
Ignore points outside image boundary (#240)
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Jan 18, 2024
1 parent b1626e9 commit deb8aeb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
28 changes: 21 additions & 7 deletions samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,19 +780,22 @@ def geojson_to_coords(


def coords_to_xy(
src_fp: str, coords: list, coord_crs: str = "epsg:4326", **kwargs
src_fp: str, coords: list, coord_crs: str = "epsg:4326", return_out_of_bounds=False, **kwargs
) -> list:
"""Converts a list of coordinates to pixel coordinates, i.e., (col, row) coordinates.
Args:
src_fp: The source raster file path.
coords: A list of coordinates in the format of [[x1, y1], [x2, y2], ...]
coord_crs: The coordinate CRS of the input coordinates. Defaults to "epsg:4326".
return_out_of_bounds: Whether to return out of bounds coordinates. Defaults to False.
**kwargs: Additional keyword arguments to pass to rasterio.transform.rowcol.
Returns:
A list of pixel coordinates in the format of [[x1, y1], [x2, y2], ...]
"""
out_of_bounds = []

if isinstance(coords, np.ndarray):
coords = coords.tolist()

Expand All @@ -805,15 +808,26 @@ def coords_to_xy(
rows, cols = rasterio.transform.rowcol(src.transform, xs, ys, **kwargs)
result = [[col, row] for col, row in zip(cols, rows)]

result = [
[x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height
]
if len(result) == 0:
output = []

for i, (x, y) in enumerate(result):
if x >= 0 and y >= 0 and x < width and y < height:
output.append([x, y])
else:
out_of_bounds.append(i)

# output = [
# [x, y] for x, y in result if x >= 0 and y >= 0 and x < width and y < height
# ]
if len(output) == 0:
print("No valid pixel coordinates found.")
elif len(result) < len(coords):
elif len(output) < len(coords):
print("Some coordinates are out of the image boundary.")

return result
if return_out_of_bounds:
return output, out_of_bounds
else:
return output


def boxes_to_vector(coords, src_crs, dst_crs="EPSG:4326", output=None, **kwargs):
Expand Down
10 changes: 9 additions & 1 deletion samgeo/samgeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def predict(
return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.
"""
out_of_bounds = []

if isinstance(boxes, str):
gdf = gpd.read_file(boxes)
Expand All @@ -529,7 +530,7 @@ def predict(
point_labels = self.point_labels

if (point_crs is not None) and (point_coords is not None):
point_coords = coords_to_xy(self.source, point_coords, point_crs)
point_coords, out_of_bounds = coords_to_xy(self.source, point_coords, point_crs, return_out_of_bounds=True)

if isinstance(point_coords, list):
point_coords = np.array(point_coords)
Expand All @@ -544,6 +545,13 @@ def predict(
if len(point_labels) != len(point_coords):
if len(point_labels) == 1:
point_labels = point_labels * len(point_coords)
elif len(out_of_bounds) > 0:
print(f"Removing {len(out_of_bounds)} out-of-bound points.")
point_labels_new = []
for i, p in enumerate(point_labels):
if i not in out_of_bounds:
point_labels_new.append(p)
point_labels = point_labels_new
else:
raise ValueError(
"The length of point_labels must be equal to the length of point_coords."
Expand Down

0 comments on commit deb8aeb

Please sign in to comment.