Skip to content

Commit

Permalink
Last dense layer output not hardcoded anymore
Browse files Browse the repository at this point in the history
  • Loading branch information
TimKoornstra committed Aug 16, 2023
1 parent 3d68b12 commit a785cac
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,19 @@ def main():
use_mask=args.use_mask
)
training_generator, validation_generator, test_generator, inference_generator, utilsObject, train_batches = loader.generators()

# Get the prediction model by taking the last dense layer of the full
# model
last_dense_layer = None
for layer in reversed(model.layers):
if layer.name.startswith('dense'):
last_dense_layer = layer
break

prediction_model = keras.models.Model(
model.get_layer(name="image").input, model.get_layer(
name="dense3").output
model.get_layer(name="image").input, last_dense_layer.output
)

prediction_model.summary(line_length=110)

inference_dataset = inference_generator
Expand Down

0 comments on commit a785cac

Please sign in to comment.