Skip to content

Commit

Permalink
Update units.py
Browse files Browse the repository at this point in the history
  • Loading branch information
david-thrower authored Apr 8, 2024
1 parent 24072bd commit 83bd06a
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions cerebros/units/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def __init__(self,
train_data_dtype=tf.float32,
*args,
**kwargs):

self.input_shape = input_shape
if isinstance(input_shape, int):
self.input_shape = (input_shape,)
elif isinstance(input_shape, str):
self.input_shape = (int(input_shape),)
else
_input_shape = [int(ax) for ax in input_shape]
self.input_shape = tuple(_input_shape)
self.neural_network_layer = []
self.base_models = base_models
self.train_data_dtype = train_data_dtype
Expand All @@ -100,7 +105,7 @@ def __init__(self,

def materialize(self):

self.raw_input = tf.keras.layers.Input(int(self.input_shape),
self.raw_input = tf.keras.layers.Input(self.input_shape,
name=f"{self.name}_inp",
dtype=self.train_data_dtype)
print(f"$$$$$$>>>>> Base model: {self.base_models[self.unit_id]}")
Expand Down

0 comments on commit 83bd06a

Please sign in to comment.