-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a15d2f
commit 7c798c8
Showing
4 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
import cv2 | ||
|
||
def buffer(pred_ske, buffer_size=1): | ||
if np.max(pred_ske)==0: | ||
print('blank img') | ||
return | ||
if np.max(pred_ske)!=1: | ||
pred_ske= pred_ske/np.max(pred_ske) | ||
|
||
for i in range(buffer_size): | ||
nonzeros_idx = np.where(pred_ske!=0) | ||
xs, ys = nonzeros_idx | ||
# west | ||
d1 = (xs-1, ys) | ||
# east | ||
d2 = (xs+1, ys) | ||
# north | ||
d3 = (xs, ys-1) | ||
# south | ||
d4 = (xs, ys+1) | ||
# northwest | ||
d5 = (xs-1, ys-1) | ||
# northeast | ||
d6 = (xs+1, ys-1) | ||
# southwest | ||
d7 = (xs-1, ys+1) | ||
# southeast | ||
d8 = (xs+1, ys+1) | ||
pred_ske[d1]=1 | ||
pred_ske[d2]=1 | ||
pred_ske[d3]=1 | ||
pred_ske[d4]=1 | ||
pred_ske[d5]=1 | ||
pred_ske[d6]=1 | ||
pred_ske[d7]=1 | ||
pred_ske[d8]=1 | ||
return pred_ske |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from skimage.morphology import skeletonize | ||
import numpy as np | ||
import cv2 | ||
import os | ||
|
||
def preprocess(img): | ||
img = (img == 255).astype(np.bool) | ||
return img | ||
|
||
def make_skeleton(root, fix_borders=True, debug=False): | ||
replicate = 5 | ||
clip = 2 | ||
rec = replicate + clip | ||
# open and skeletonize | ||
img = cv2.imread(root, cv2.IMREAD_GRAYSCALE) | ||
print(img.shape) | ||
|
||
if fix_borders: | ||
img = cv2.copyMakeBorder(img, replicate, replicate, replicate, replicate, cv2.BORDER_REPLICATE) | ||
img_copy = None | ||
if debug: | ||
if fix_borders: | ||
img_copy = np.copy(img[replicate:-replicate,replicate:-replicate]) | ||
else: | ||
img_copy = np.copy(img) | ||
img = preprocess(img) | ||
if not np.any(img): | ||
return None, None | ||
ske = skeletonize(img).astype(np.uint16) | ||
if fix_borders: | ||
ske = ske[rec:-rec, rec:-rec] | ||
ske = cv2.copyMakeBorder(ske, clip, clip, clip, clip, cv2.BORDER_CONSTANT, value=0) | ||
return img_copy, ske |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import subprocess | ||
import os, cv2 | ||
import xml.etree.ElementTree as ET | ||
from osgeo import gdal, gdalconst, ogr, osr | ||
import zipfile | ||
import numpy as np | ||
from gdalconst import * | ||
|
||
#full path to gdal executables> | ||
gdalsrsinfo = r'C:\OSGeo4W64\bin\gdalsrsinfo.exe' | ||
ogr2ogr = r'C:\OSGeo4W64\bin\ogr2ogr.exe' | ||
gdal_translate = r'C:\OSGeo4W64\bin\gdal_translate.exe' | ||
#------------------------------------------------------------------------------ | ||
|
||
def unzip(zip_path, unzip_path): | ||
with zipfile.ZipFile(zip_path ,"r") as zip_ref: | ||
zip_ref.extractall(unzip_path) | ||
temp = unzip_path.split('\\') | ||
map_name = temp[-1].replace('_tif', '.tif') | ||
return os.path.join(unzip_path,map_name) | ||
|
||
|
||
def reproject(input_vec, in_crs, out_crs, ogr2ogr): | ||
vector_proj = input_vec.replace('.shp','_proj.shp') | ||
call = ogr2ogr+' -t_srs "'+out_crs+'" -s_srs "'+in_crs+'" "'+vector_proj+'" "'+input_vec+'"' | ||
print (call) | ||
response=subprocess.check_output(call, shell=False) | ||
print(response) | ||
return vector_proj | ||
|
||
def get_bbox(xml_file): | ||
bbox = [] #west, east, north, south | ||
tree = ET.parse(xml_file) | ||
root = tree.getroot() | ||
for child in root: | ||
if child.tag == 'idinfo': | ||
for cchild in child: | ||
if cchild.tag =='spdom': | ||
for ccchild in cchild: # ccchild is bounding | ||
for coor in ccchild: | ||
# print coor.text | ||
bbox.append(float(coor.text)) | ||
# xmin,xmax,ymin,ymax = bbox | ||
return bbox | ||
|
||
# convert bbox points to goecoordinate of maps | ||
def point_convertor(x, y, input_epsg, out_srs): | ||
point = ogr.Geometry(ogr.wkbPoint) | ||
point.AddPoint(x, y) | ||
# print(point) | ||
# create coordinate transformation | ||
inSpatialRef = osr.SpatialReference() | ||
inSpatialRef.ImportFromEPSG(input_epsg) | ||
coordTransform = osr.CoordinateTransformation(inSpatialRef, out_srs) | ||
# transform point | ||
point.Transform(coordTransform) | ||
return [point.GetX(), point.GetY()] | ||
|
||
def clip(input_vec,xmin,ymin,xmax,ymax,ogr2ogr): | ||
vector_clip = input_vec.replace('.shp','_clip.shp') | ||
# call = '%s -dim 2 -clipsrc %f %f %f %f -nlt POLYGON %s %s ' % (ogr2ogr, xmin,ymin,xmax,ymax, vector_clip, input_vec) | ||
call = '%s -dim 2 -clipsrc %f %f %f %f %s %s ' % (ogr2ogr, xmin,ymin,xmax,ymax, vector_clip, input_vec) | ||
print (call) | ||
response=subprocess.check_output(call, shell=False) | ||
print(response) | ||
return vector_clip | ||
|
||
|
||
def create_filtered_shapefile(value, filter_field, in_shapefile, out_shapefile): | ||
driver = ogr.GetDriverByName("ESRI Shapefile") | ||
dataSource = driver.Open(in_shapefile, 0) | ||
input_layer = dataSource.GetLayer() | ||
query_str = '{} = {}'.format(filter_field, value) | ||
print(query_str) | ||
err = input_layer.SetAttributeFilter(query_str) | ||
print(err) | ||
# Copy Filtered Layer and Output File | ||
out_ds = driver.CreateDataSource(out_shapefile) | ||
out_layer = out_ds.CopyLayer(input_layer, str(value)) | ||
del input_layer, out_layer, out_ds | ||
return out_shapefile | ||
|
||
def vector2raster(input_vec, map_tiff_geo): | ||
# open data | ||
raster_fn = input_vec.replace('.shp', '.tif') | ||
raster = gdal.Open(map_tiff_geo) | ||
shp = ogr.Open(input_vec) | ||
lyr = shp.GetLayer() | ||
# # Get raster georeference info | ||
transform = raster.GetGeoTransform() | ||
# print(transform) | ||
# # Create memory target raster | ||
x_res, y_res = raster.RasterXSize, raster.RasterYSize | ||
target_ds = gdal.GetDriverByName('GTiff').Create(raster_fn, x_res, y_res, 1, gdal.GDT_Byte) | ||
target_ds.SetGeoTransform(transform) | ||
raster_srs = osr.SpatialReference() | ||
raster_srs.ImportFromWkt(raster.GetProjectionRef()) | ||
target_ds.SetProjection(raster_srs.ExportToWkt()) | ||
# Rasterize | ||
err = gdal.RasterizeLayer(target_ds, [1], lyr,burn_values=[255]) | ||
if err != 0: | ||
print(err) | ||
del target_ds | ||
return raster_fn | ||
|
||
def tif2png(in_tif): | ||
png_fn = in_tif.replace('.tif', '.png') | ||
call = 'gdal_translate -of PNG -ot Byte "'+in_tif+'" "'+png_fn+'"' | ||
print(call) | ||
response=subprocess.check_output(call, shell=True) | ||
print(response) | ||
xml_fn = png_fn + '.aux.xml' | ||
os.remove(xml_fn) | ||
return png_fn | ||
|
||
def dilated(png_fn): | ||
dilated_fn = png_fn.replace('.png', '_dilated.png') | ||
img = cv2.imread(png_fn,0) | ||
kernel = np.ones((2,2),np.uint8) | ||
dilation = cv2.dilate(img,kernel,iterations = 1) | ||
cv2.imwrite(dilated_fn, dilation) | ||
return dilated_fn | ||
|
||
def geo2img_coor(x, y, path): | ||
dataset = gdal.Open( path, GA_ReadOnly ) | ||
adfGeoTransform = dataset.GetGeoTransform() | ||
dfGeoX=float(x) | ||
dfGeoY =float(y) | ||
det = adfGeoTransform[1] * adfGeoTransform[5] - adfGeoTransform[2] *adfGeoTransform[4] | ||
# X = (int)(((dfGeoX - adfGeoTransform[0]) / adfGeoTransform[1])) | ||
# Y = (int)(((dfGeoY - adfGeoTransform[3]) / adfGeoTransform[5])) | ||
X = ((dfGeoX - adfGeoTransform[0]) * adfGeoTransform[5] - (dfGeoY - | ||
adfGeoTransform[3]) * adfGeoTransform[2]) / det | ||
Y = ((dfGeoY - adfGeoTransform[3]) * adfGeoTransform[1] - (dfGeoX - | ||
adfGeoTransform[0]) * adfGeoTransform[4]) / det | ||
return [int(Y),int(X)] | ||
|
||
def points_generator(raster, save_path, n_points=400): | ||
img = cv2.imread(raster, 0) | ||
index_x, index_y = np.where(img==255) | ||
print(index_x.shape) | ||
rand = np.random.randint(0, index_x.shape[0], n_points) | ||
points = np.array([index_x[rand], index_y[rand]]) | ||
points = np.swapaxes(points,0,1) | ||
print(points.shape) | ||
np.savetxt(save_path, points, fmt='%d', delimiter=',') | ||
return save_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import numpy as np | ||
import os, cv2 | ||
import copy | ||
from skeletonize import make_skeleton | ||
from buffer import buffer | ||
import utils_shp | ||
|
||
''' | ||
inputs for the evaluation are | ||
ground truth shape file | ||
segmentation results | ||
''' | ||
|
||
data_dir = 'C:\Users\weiweiduan\Documents\Map_proj_data\CA\CA_Bray_100414_2001_24000_geo_tif' | ||
location_name = 'CA_Bray_100414_2001_24000_geo' | ||
gt_name = 'CA_Bray_railroads_gt_buffer0.shp' | ||
pred_name = 'CA_Bray_100414_2001_24000_geo_pointrend_unet_pred.png' | ||
buffer_size = 2 | ||
|
||
pred_path = os.path.join(data_dir, 'res', pred_name) | ||
gt_shp_path = os.path.join(data_dir, 'ground_truth', gt_name) | ||
map_tif_path = os.path.join(data_dir, location_name+'.tif') | ||
|
||
# Step 1: rasterize the ground truth shapefile | ||
gt_tif_path = utils_shp.vector2raster(gt_path, map_tif_path) | ||
gt_png_path = utils_shp.tif2png(gt_tif_path) | ||
|
||
# Step 2: load the bounding box coordinates to remove the white borders of the map | ||
bbox_file = os.path.join(data_dir, 'bbox.txt') | ||
points = np.loadtxt(bbox_file, dtype='int32', delimiter=',') | ||
start_point, end_point = points[0], points[1] | ||
print(start_point, end_point) | ||
|
||
# Step 3: skeletonize the segmentation results | ||
img_copy, pred_ske = make_skeleton(pred_path) | ||
pred_ske = pred_ske.astype('uint') | ||
|
||
# Step 4: buffer the segmentation results and ground truth | ||
pred_buffer = buffer(pred_ske, buffer_size=buffer_size) | ||
gt_map = cv2.imread(gt_png_path,0) / 255 | ||
gt_buffer = buffer(gt_map, buffer_size=buffer_size) | ||
|
||
# Step 5 calculate correctness and completeness | ||
overlap_map = gt_buffer * pred_ske | ||
fp_map = pred_ske - overlap_map | ||
|
||
|
||
tp = np.count_nonzero(overlap_map) | ||
fp = np.count_nonzero(fp_map) | ||
|
||
correctness = tp / (tp+fp) | ||
print('correctness = ', correctness) | ||
|
||
overlap_comp_map = gt_map * pred_buffer | ||
|
||
fn_map = gt_map - overlap_comp_map | ||
|
||
tp_comp = np.count_nonzero(overlap_comp_map) | ||
fn = np.count_nonzero(fn_map) | ||
|
||
completeness = tp_comp / (tp_comp+fn) | ||
print('completeness = ', completeness) | ||
|