Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/aimclub/OCEANAI
Browse files Browse the repository at this point in the history
  • Loading branch information
DmitryRyumin committed Dec 12, 2023
2 parents 903c763 + d5888bc commit 7f7984e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 29 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 1 addition & 27 deletions oceanai/modules/lab/keras_vggface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import numpy as np
from keras import backend as K
from keras.utils.data_utils import get_file

V1_LABELS_PATH = "https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_labels_v1.npy"
V2_LABELS_PATH = "https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_labels_v2.npy"
Expand Down Expand Up @@ -62,29 +61,4 @@ def preprocess_input(x, data_format=None, version=1):
else:
raise NotImplementedError

return x_temp


def decode_predictions(preds, top=5):
LABELS = None
if len(preds.shape) == 2:
if preds.shape[1] == 2622:
fpath = get_file("rcmalli_vggface_labels_v1.npy", V1_LABELS_PATH, cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
elif preds.shape[1] == 8631:
fpath = get_file("rcmalli_vggface_labels_v2.npy", V2_LABELS_PATH, cache_subdir=VGGFACE_DIR)
LABELS = np.load(fpath)
else:
raise ValueError
else:
raise ValueError

results = []

for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [[str(LABELS[i].encode("utf8")), pred[i]] for i in top_indices]
result.sort(key=lambda x: x[1], reverse=True)
results.append(result)

return results
return x_temp

0 comments on commit 7f7984e

Please sign in to comment.