From 2be7b1beb003ed8400a0a520bd9d005de35657d1 Mon Sep 17 00:00:00 2001 From: dbuscombe-usgs Date: Mon, 10 Jul 2023 17:46:08 -0700 Subject: [PATCH] bug fix for use of all 3 bands --- doodleverse_utils/prediction_imports.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/doodleverse_utils/prediction_imports.py b/doodleverse_utils/prediction_imports.py index 6452f5a..492f3e6 100755 --- a/doodleverse_utils/prediction_imports.py +++ b/doodleverse_utils/prediction_imports.py @@ -145,9 +145,6 @@ def get_image(f,N_DATA_BANDS,TARGET_SIZE,MODEL): except: pass - # print(f) - # print(image.shape) - image = standardize(image.numpy()).squeeze() if MODEL=='segformer': @@ -172,12 +169,9 @@ def est_label_multiclass(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE): est_label = tf.squeeze(model(tf.expand_dims(image, 0))) except: if MODEL=='segformer': - #### FIX :3 - # est_label = model(tf.expand_dims(image[:,:,0], 0)).logits est_label = model(tf.expand_dims(image[:,:,:3], 0)).logits else: - #### FIX :3 - est_label = tf.squeeze(model(tf.expand_dims(image[:,:,0], 0))) + est_label = tf.squeeze(model(tf.expand_dims(image[:,:,:3], 0))) if TESTTIMEAUG == True: # return the flipped prediction @@ -237,13 +231,10 @@ def est_label_binary(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE,w,h): except: if MODEL=='segformer': - #### FIX :3 - # est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).logits est_label = model.predict(tf.expand_dims(image[:,:,:3], 0), batch_size=1).logits else: - #### FIX :3 - est_label = tf.squeeze(model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1)) + est_label = tf.squeeze(model.predict(tf.expand_dims(image[:,:,:3], 0), batch_size=1)) if TESTTIMEAUG == True: # return the flipped prediction