diff --git a/cerebros/units/units.py b/cerebros/units/units.py index 890d500..6641834 100644 --- a/cerebros/units/units.py +++ b/cerebros/units/units.py @@ -83,8 +83,13 @@ def __init__(self, train_data_dtype=tf.float32, *args, **kwargs): - - self.input_shape = input_shape + if isinstance(input_shape, int): + self.input_shape = (input_shape,) + elif isinstance(input_shape, str): + self.input_shape = (int(input_shape),) + else + _input_shape = [int(ax) for ax in input_shape] + self.input_shape = tuple(_input_shape) self.neural_network_layer = [] self.base_models = base_models self.train_data_dtype = train_data_dtype @@ -100,7 +105,7 @@ def __init__(self, def materialize(self): - self.raw_input = tf.keras.layers.Input(int(self.input_shape), + self.raw_input = tf.keras.layers.Input(self.input_shape, name=f"{self.name}_inp", dtype=self.train_data_dtype) print(f"$$$$$$>>>>> Base model: {self.base_models[self.unit_id]}")