-
Notifications
You must be signed in to change notification settings - Fork 938
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Neuraxle refactor #32
base: master
Are you sure you want to change the base?
Neuraxle refactor #32
Conversation
Update code since last changes.
…MRNNTensorflow fit, and transform methods
@alexbrillant Thank you for the contribution! Let's clean this up together soon before merging. |
…PI call notebook.
…tivity-Recognition into neuraxle-refactor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks cool.
Note that I haven't yet reviewed the DeepLearningPipeline
yet so the present PR may wait. Let's finish the seq2seq's refactor first.
from neuraxle.steps.output_handlers import InputAndOutputTransformerMixin | ||
|
||
|
||
class FormatData(NonFittableMixin, InputAndOutputTransformerMixin, BaseStep): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be replaced by this?
Pipeline([
ToNumpy(),
OutputTransformerWrapper(ToNumpy())
])
expected_outputs = np.array(expected_outputs) | ||
|
||
if expected_outputs.shape != (len(data_inputs), self.n_classes): | ||
expected_outputs = np.reshape(expected_outputs, (len(data_inputs), self.n_classes)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if should not be needed. Use a OutputTransformerWrapper(OneHotEncoder())
instead.
If you also apply the previous comment, you should end up deleting this FormatData
class as things are already done in other existing classes. We should not need any reshape here whatsoever if data is fed correctly, or if the OneHotEncoder works properly.
train_and_save.py
Outdated
).set_hyperparams( | ||
HyperparameterSamples({ | ||
'n_steps': self.N_STEPS, # 128 timesteps per series | ||
'n_inputs': self.N_INPUTS, # 9 input parameters per timestep | ||
'n_hidden': self.N_HIDDEN, # Hidden layer num of features | ||
'n_classes': self.N_CLASSES, # Total classes (should go up, or should go down) | ||
'learning_rate': self.LEARNING_RATE, | ||
'lambda_loss_amount': self.LAMBDA_LOSS_AMOUNT, | ||
'batch_size': self.BATCH_SIZE | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, let's only consider n_hidden
, learning_rate
, and lambda_loss_amount
as hyperparameters per se. The others aren't planned to be changed during meta-optimization for instance).
We could let them there for now, however I would have seen them as something else. Looks like this issue perhaps: Neuraxio/Neuraxle#91
We could as well add a n_stacked
hyperparam to control how many LSTMs we stack on top of each other (optional feature, not really needed for now).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@alexbrillant Please also note this
def main(): | ||
pipeline = DeepLearningPipeline( | ||
HumanActivityRecognitionPipeline(), | ||
validation_size=0.15, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was the original project using validation data, or only train/test? I'm tempted to remove validation data here to leave the original example untouched. The simplicity was part of its success.
Note : this code uses neruaxle package from the latest commit in this pull requests : Neuraxio/Neuraxle#182
TODO :
Notebook for demonstration.
Validation Split Wrapper.