Skip to content

Commit

Permalink
Adding mend track correction for object tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
monte-flora committed Mar 1, 2022
1 parent e8b12fd commit 32cc49f
Show file tree
Hide file tree
Showing 2 changed files with 61,324 additions and 60,698 deletions.
168 changes: 147 additions & 21 deletions monte_python/object_tracking.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import scipy
from scipy import spatial
import numpy as np
import pandas as pd
import skimage.measure
from skimage.measure import regionprops
from skimage.measure import regionprops, regionprops_table
import math
import collections
from datetime import datetime
import itertools

def calc_dist(x1,x2,y1,y2):
return (x1 - x2)**2 + (y1 - y2)**2

class ObjectTracker:
"""
ObjectTracker performs simple object tracking by linking together objects from time
Expand All @@ -28,7 +32,7 @@ def __init__( self, one_to_one = False, percent_overlap=0.0):
self.one_to_one = one_to_one
self.percent_overlap = percent_overlap

def track_objects(self, objects):
def track_objects(self, objects, mend_tracks=False):
""" Tracks objects in time.
Parameters:
Expand All @@ -38,6 +42,7 @@ def track_objects(self, objects):
"""
objects_copy = np.copy(objects)

# Re-label so that objects across time each have a each label.
tracked_objects = self.get_unique_labels(objects_copy)

Expand All @@ -49,15 +54,24 @@ def track_objects(self, objects):
areas_before, areas_after = self._get_area(current_objects), self._get_area(future_objects)

# Check for mergers.
self.check_for_mergers(labels_before, labels_after, areas_before)
labels_before, labels_after = self.check_for_mergers(labels_before, labels_after, areas_before)

# Check for splits.
labels_before, labels_after = self.check_for_mergers(labels_before, labels_after, areas_after)

# Re-label an object if matches (this is where the tracking is done)
for label_i, label_f in zip(labels_before, labels_after):
tracked_objects[t+1, future_objects == label_f] = label_i

# Do a final re-label so that np.max(relabel_objects) == number of tracked objects.
relabeled_objects = self.relabel(tracked_objects)
return relabeled_objects
tracks = self.relabel(tracked_objects)

if mend_tracks:
# Check if the track is within 9 km.
tracks = self.mend_broken_tracks(tracks, dist_max=3)


return tracks

def _get_area(self, arr):
"""
Expand Down Expand Up @@ -86,15 +100,14 @@ def check_for_mergers(self, labels_before, labels_after, areas_before):
--------------------
"""
# Determine if there is a merged based on non-unqique labels.
unique_labels_before, counts_before = np.unique(labels_before, return_counts=True)
unique_labels_after, counts_after = np.unique(labels_after, return_counts=True)
if any(counts_after>1):
# Get the labels that non-unique.
merged_labels = unique_labels_after[counts_after>1]

for label in merged_labels:
# This should be 2 or more labels (which are being merged together).
potential_label_for_merged_obj = [l for i, l in enumerate(labels_before) if labels_after[i] == label]
print(f'{potential_label_for_merged_obj=}')
# Sort the potential merged object labels by area. Keep the largest object and remove the
# others.
inds = np.argsort([areas_before[label] for label in potential_label_for_merged_obj])[::-1]
Expand Down Expand Up @@ -123,14 +136,13 @@ def check_for_splits(labels_before, labels_after, areas_after):
Returns
--------------------
"""
unique_labels_after, counts_after = np.unique(labels_after, return_counts=True)
unique_labels_before, counts_before = np.unique(labels_before, return_counts=True)
if any(counts_before>1):
split_labels = unique_labels_before[counts_before>1]

for label in split_labels:
# This should be 2 or more labels (which are being merged together).
potential_label_for_split_obj = [l for i, l in enumerate(labels_after) if labels_before[i] == label]
print(f'{potential_label_for_split_obj=}')
# Sort the potential split object labels by area. Keep the largest object and remove the
# others.
inds = np.argsort([areas_after[label] for label in potential_label_for_split_obj])[::-1]
Expand All @@ -145,18 +157,20 @@ def check_for_splits(labels_before, labels_after, areas_after):


def get_unique_labels(self, objects):
"""Ensure that initially, each object has a unique label"""
cumulative_objects = np.cumsum([np.max(objects[i]) for i in range(len(objects))])
"""Ensure that initially, each object for the different times has a unique label"""
if not isinstance(objects, np.ndarray):
objects = np.array(objects)

unique_track_set = [objects[0,:,:]]
for i in range(1, len(objects)):
track = objects[i,:,:]
where_zero = track==0
unique_track = track+cumulative_objects[i-1]
unique_track[where_zero]=0
unique_track_set.append(unique_track)

return np.array(unique_track_set)
unique_track_set = np.zeros(objects.shape, dtype=np.int32)

num = 1
for i in range(len(objects)):
current_obj = objects[i,:,:]
for label in np.unique(current_obj)[1:]:
unique_track_set[i, current_obj==label] += num
num+=1

return unique_track_set

def relabel(self, objects):
"""Re-label objects"""
Expand Down Expand Up @@ -251,7 +265,119 @@ def calc_duration(self, time_range, objects):

return object_duration

def get_centroid(self, df, label):
try:
df=df.loc[df['label'] == label]
x_cent, y_cent = df['centroid-0'], df['centroid-1']
x_cent=int(x_cent)
y_cent=int(y_cent)
except:
return np.nan, np.nan

return x_cent, y_cent

def get_track_path(self, tracked_objects):
""" Create track path. """
properties = ['label', 'centroid']
object_dfs = [pd.DataFrame(regionprops_table(tracks, properties=properties))
for tracks in tracked_objects]

unique_labels = np.unique(tracked_objects)[1:]
centroid_x = {l : [] for l in unique_labels}
centroid_y = {l : [] for l in unique_labels}

for df in object_dfs:
for label in unique_labels:
x,y = self.get_centroid(df, label)
centroid_x[label].append(x)
centroid_y[label].append(y)

return centroid_x, centroid_y

def find_track_start_and_end(self, data):
"""
Based on the x-centriod or y-centroid values for a track,
determine when the time index when the track starts and stops.
"""
# If the track happens to persist for all
# time steps (i.e., no nan values), then
# the start and end indices are 0 and len(data)-1
if not np.isnan(np.sum(data)):
return 0, len(data)-1

# Does the track start at the first time step?
elif not np.isnan(data[0]):
return 0, np.where(np.isnan(data))[0][0]-1

# Does the track end at the last time step?
elif not np.isnan(data[-1]):
return np.where(np.isnan(data))[0][-1]+1, len(data)-1
# Otherwise the tracks starts and stops sometime during
# the time period.
else:
data_copy = np.copy(data)
data_copy[np.isnan(data)] = 0
diff = np.absolute(np.diff(data_copy))

# This will return intersecting values in
# value order rather than chronological order.
# Need to check if the storms are moving west,
# in case, the start and env vals are switched.
vals = np.intersect1d(data, diff)

is_decreasing = np.nanmean(np.diff(data)) < 0
start_val, end_val = vals[::-1] if is_decreasing else vals

start_ind = np.where(data==start_val)[0]
end_ind = np.where(data==end_val)[0]

return start_ind[0], end_ind[0]


def mend_broken_tracks(self, tracked_objects, dist_max=3):
"""
Mend broken tracks by project track ends forward
in time based on estimated storm motion and
search for tracks that start in that projected area.
If close enough, assume that those two tracks
should be combined. Re-label that new tracks with
the projected tracks label.
"""
new_tracks = np.copy(tracked_objects)
x_cent, y_cent = self.get_track_path(tracked_objects)

# Get the start and end
track_start_end = {label : self.find_track_start_and_end(x_cent[label]) for label in x_cent.keys()}

for label in x_cent.keys():
# Compute the project storm position based on
# the estimated storm motion. Since time is
# constant, we do not need to consider it.
dx = np.mean(np.diff(x_cent[label]))
dy = np.mean(np.diff(y_cent[label]))

# Get the start and end time index for this track.
start_ind, end_ind = track_start_end[label]

x_proj = x_cent[label][end_ind] + dx
y_proj = x_cent[label][end_ind] + dy


# Given the end index of this track, we are looking for tracks that
# started when this tracked ended or during the next time step.
other_labels = [l for l in x_cent.keys() if l != label and track_start_end[l][0] in [end_ind, end_ind+1] ]
for other_label in other_labels:
x_val = x_cent[other_label][end_ind]
x = x_val if x_val is not np.nan else x_cent[other_label][end_ind+1]

y_val = y_cent[other_label][end_ind]
y = y_val if y_val is not np.nan else y_cent[other_label][end_ind+1]

# Is there an existing tracks start point that is within some
# distance on this projected end of this track. If so,
# link them together and re-label the existing track to this label.
dist = calc_dist(x_proj, x, y_proj, y)
if dist <= dist_max:
new_tracks[tracked_objects==other_label] = label

return new_tracks
Loading

0 comments on commit 32cc49f

Please sign in to comment.