From d2bf5c75dc0465086d46f6b2924037e57868cf4e Mon Sep 17 00:00:00 2001 From: Zheng Xu Date: Mon, 26 Aug 2019 17:27:40 -0700 Subject: [PATCH] Fix segment_eval_inference.py script. --- segment_eval_inference.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/segment_eval_inference.py b/segment_eval_inference.py index 4676ed62..c5813aa6 100644 --- a/segment_eval_inference.py +++ b/segment_eval_inference.py @@ -46,7 +46,7 @@ def labels(self): def to_file(self, file_name): """Materialize the GT mapping to file.""" - with tf.gfile.Open(file_name, "w") as fobj: + with tf.io.gfile.GFile(file_name, "w") as fobj: for k, v in self._labels.items(): seg_id, label = k line = "%s,%s,%s\n" % (seg_id, label, v) @@ -79,18 +79,18 @@ def read_labels(data_pattern, cache_path=""): tf.logging.info("Reading cached labels from %s..." % cache_path) return Labels.from_file(cache_path) tf.enable_eager_execution() - data_paths = tf.gfile.Glob(data_pattern) + data_paths = tf.io.gfile.glob(data_pattern) ds = tf.data.TFRecordDataset(data_paths, num_parallel_reads=50) context_features = { - "id": tf.FixedLenFeature([], tf.string), - "segment_labels": tf.VarLenFeature(tf.int64), - "segment_start_times": tf.VarLenFeature(tf.int64), - "segment_scores": tf.VarLenFeature(tf.float32) + "id": tf.io.FixedLenFeature([], tf.string), + "segment_labels": tf.io.VarLenFeature(tf.int64), + "segment_start_times": tf.io.VarLenFeature(tf.int64), + "segment_scores": tf.io.VarLenFeature(tf.float32) } def _parse_se_func(sequence_example): - return tf.parse_single_sequence_example(sequence_example, - context_features=context_features) + return tf.io.parse_single_sequence_example( + sequence_example, context_features=context_features) ds = ds.map(_parse_se_func) rated_labels = {} @@ -104,7 +104,8 @@ def _parse_se_func(sequence_example): segment_scores = cxt_feature_val["segment_scores"].values.numpy() for label, start_time, score in zip(segment_labels, segment_start_times, segment_scores): - rated_labels[("%s:%d" % (video_id, start_time), label)] = score + rated_labels[("%s:%d" % (video_id.decode("utf8"), start_time), + label)] = score batch_id = len(rated_labels) // batch_size if batch_id != last_batch: tf.logging.info("%d examples processed.", len(rated_labels)) @@ -129,10 +130,13 @@ def read_segment_predictions(file_path, labels, top_n=None): a segment prediction list for each classes. """ cls_preds = {} # A label_id to pred list mapping. - with tf.gfile.Open(file_path) as fobj: + with tf.io.gfile.GFile(file_path) as fobj: tf.logging.info("Reading predictions from %s..." % file_path) for line in fobj: label_id, pred_ids_val = line.split(",") + if not label_id.isdigit(): + # Skip the header line. + continue pred_ids = pred_ids_val.split(" ") if top_n: pred_ids = pred_ids[:top_n] @@ -177,9 +181,11 @@ def main(unused_argv): float(x) / len(class_preds) for x in range(len(class_preds), 0, -1) ] seg_scored_preds.append(seg_scored_pred) - num_positives.append(positive_counter[label_id]) + num_positives.append(positive_counter.get(label_id, 0)) + map_cal.accumulate(seg_scored_preds, seg_labels, num_positives) - map_at_n = np.mean(map_cal.peek_map_at_n()) + aps = map_cal.peek_map_at_n() + map_at_n = np.mean(aps) tf.logging.info("Num classes: %d | mAP@%d: %.6f" % (len(seg_preds), FLAGS.top_n, map_at_n))