Skip to content

Commit

Permalink
bug fix for use of all 3 bands
Browse files Browse the repository at this point in the history
  • Loading branch information
dbuscombe-usgs committed Jul 11, 2023
1 parent 58dc3ea commit 2be7b1b
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions doodleverse_utils/prediction_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ def get_image(f,N_DATA_BANDS,TARGET_SIZE,MODEL):
except:
pass

# print(f)
# print(image.shape)

image = standardize(image.numpy()).squeeze()

if MODEL=='segformer':
Expand All @@ -172,12 +169,9 @@ def est_label_multiclass(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE):
est_label = tf.squeeze(model(tf.expand_dims(image, 0)))
except:
if MODEL=='segformer':
#### FIX :3
# est_label = model(tf.expand_dims(image[:,:,0], 0)).logits
est_label = model(tf.expand_dims(image[:,:,:3], 0)).logits
else:
#### FIX :3
est_label = tf.squeeze(model(tf.expand_dims(image[:,:,0], 0)))
est_label = tf.squeeze(model(tf.expand_dims(image[:,:,:3], 0)))

if TESTTIMEAUG == True:
# return the flipped prediction
Expand Down Expand Up @@ -237,13 +231,10 @@ def est_label_binary(image,M,MODEL,TESTTIMEAUG,NCLASSES,TARGET_SIZE,w,h):

except:
if MODEL=='segformer':
#### FIX :3
# est_label = model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1).logits
est_label = model.predict(tf.expand_dims(image[:,:,:3], 0), batch_size=1).logits

else:
#### FIX :3
est_label = tf.squeeze(model.predict(tf.expand_dims(image[:,:,0], 0), batch_size=1))
est_label = tf.squeeze(model.predict(tf.expand_dims(image[:,:,:3], 0), batch_size=1))

if TESTTIMEAUG == True:
# return the flipped prediction
Expand Down

0 comments on commit 2be7b1b

Please sign in to comment.