-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement automatic ridge regression #124
base: main
Are you sure you want to change the base?
Conversation
mne_connectivity/vector_ar/var.py
Outdated
auto_reg : bool, optional | ||
Whether to perform automatic regularization of X matrix using RidgeCV, | ||
by default False. If matrix is not full rank, this will be adjusted to | ||
True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. | |
True. If ``l2_reg`` is non-zero, then this must be set to `False`. |
mne_connectivity/vector_ar/var.py
Outdated
@@ -31,6 +32,10 @@ def vector_auto_regression( | |||
Autoregressive model order, by default 1. | |||
l2_reg : float, optional | |||
Ridge penalty (l2-regularization) parameter, by default 0.0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ridge penalty (l2-regularization) parameter, by default 0.0. | |
Ridge penalty (l2-regularization) parameter, by default 0.0. | |
If ``auto_reg`` is `True`, then this must be set to 0.0. |
mne_connectivity/vector_ar/var.py
Outdated
elif auto_reg and l2_reg: | ||
raise ValueError("If l2_reg is set, then auto_reg must be set to False") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit-test for this case? Lmk if you need help w/ that.
mne_connectivity/vector_ar/var.py
Outdated
@@ -31,6 +32,10 @@ def vector_auto_regression( | |||
Autoregressive model order, by default 1. | |||
l2_reg : float, optional | |||
Ridge penalty (l2-regularization) parameter, by default 0.0. | |||
auto_reg : bool, optional | |||
Whether to perform automatic regularization of X matrix using RidgeCV, | |||
by default False. If matrix is not full rank, this will be adjusted to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
by default False. If matrix is not full rank, this will be adjusted to | |
by default False. If the data matrix has condition number less than 1e6, then this will be adjusted to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking a quick look, I have another thought to make the API easier and more intuitive perhaps.
Instead of adding an auto_reg
parameter, what are your thoughts on making l2_reg = 'auto'
by default? So l2_reg
can be 'auto', a float, or a list of floats, or None.
- 'auto': l2_reg will compute the optimal using a pre-defined grid of alphas, like you have
- float: l2_reg will just be applied with this alpha value
- list of floats: same as auto, but you specify the pre-defined grid of alphas
- None, or if l2_reg is 0: no regularization
WDYT? This might make your/users life easier because you just have to worry about 1 parameter, which is documented vs figuring out how to play around with 2 parameters.
I agree that using only the l2_reg argument will make things much simpler. Do you want me to implement this and then commit the changes? |
That would be great and then I can help review the code! |
I opted to keep the underlying functions unchanged w/r/t the l2_reg parameter, but I had to overwrite the l2_reg value in line 185 to avoid problems. This isn't clean for the code but will not affect the user experience. I also thought it was helpful to include the cross-validation alpha values in the displayed model parameters, so that the user can view them if verbose is set to True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Just some minor touchups. The next steps would be to:
- add a unit-test for the new case
- add an example, or augment an existing example demonstrating this usage and when it might be necessary
I think the CI are also showing some unit tests getting errors: https://github.com/mne-tools/mne-connectivity/actions/runs/4087951173/jobs/7049413823
You can follow the CONTRIBUTING guide to install relevant test packages and docs-building packages, so you can run the relevant checks locally since they are faster. You can also check coding style by running make pep
. Lmk if you run into any issues.
Co-authored-by: Adam Li <adam2392@gmail.com>
I've never worked with unit tests before. What would be a good unit test to add for the new features and how do I add that? As for the current tests that are failing, I am getting the same problem across 10 tests: |
Re adding unit tests:
Lmk if you have any questions! |
Re the errors you see, I think you are right. The issue is that
So in summary:
Does this seem right to you? If so, perhaps you can walk through the code and see where there might be issues? Once you write the unit-tests, those will also help in debugging the behavior in the different cases. |
Thanks for the debugging help in the last comment! I was able to fix the problem where I did not account for the case when l2_reg='auto' and I'm now having some trouble making the unit-tests. I noticed how you have a correct A matrix in your
Right now, my unit test would simply iterate through a set of parameters for l2_reg -- 'auto', None, 0, 0.3, np.array(0.1,1), [0.1,1] -- to make sure that no errors are encountered. But how can I know if the result is "correct"? |
Hmm good question. Maybe you can try the following:
Lmk if steps 1-4 end up giving you at least what you need in terms of the test dataset. / if you have any coding questions.
We'll want to test this at least works on some simulated data. You also want to create a set of tests to verify that certain behavior is carried out in each case of Unrelated note: I saw you work with Sara Inati? She was a collaborator when I was in my PhD. Say hi :). |
OK, I tried my best to follow your suggestions. I only needed to set one eigenvalue to 1e-6 to obtain an ill-conditioned matrix (as long as there are 12 channels in the dataset), so I did not need to set 2/3 of the starting conditions to the same value as well. Using the ill-conditioned data matrix, I was able to test out the l2_reg parameter with and without regularization. Now the problem is that RidgeCV does not seem to outperform OLS in this sample dataset. In addition, many of the eigenvalues are dissimilar from the actual sample eigenvalues. What is your suggestion for next steps in debugging? And yes, I do work with Sara Inati! Glad to see you have collaborated with us in the past. I'm thankful for your help on getting this improvement finished. |
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
action='ignore', | ||
message="Ill-conditioned matrix" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RidgeCV tests out an array of alpha values and some of them do not regularize the matrix enough to avoid an ill-conditioned matrix error. If the user sees many of these messages pop up, they may think that something is going wrong, when in fact the expected behavior of the function is happening. RidgeCV will choose the best alpha value and that will be from an instance when this error was not thrown.
Hmm I'll list some things to try systematically. Before you do so, it might be helpful to track the performance changes relative to your current attempt. The two metrics I have in mind are i) how closely you approximate the true A (so some type of matrix norm of the difference), and ii) how closely you approximate the set of true eigenvalues. Then you can see if any of these changes are at least moving in the right direction (i.e. RidgeCV better than OLS).
Intuitively, this will allow the data to mix more. Perhaps rn the problem is so easy that OLS is just the better thing to do.
Lmk if that helps, or if you are stuck anywhere. |
Hey @witherscp just checking in on this to see if you need any assistance? If there is trouble figuring out a rigorous scenario where this is tested, we can always disable this by default and instead you can add a short example on some simulated, or real data (see existing examples), where this shows "improvement". |
Hey, sorry for the delay; I got sidetracked by other tasks that I have been working on. I will try to test out a few scenarios where RidgeCV outperforms OLS before opting to disable the default. If this proves to be too difficult, though, it will be easy to show an example where it is necessary (as in the case of common average reference). Can you provide a little more help with the example you suggested of using a non-upper triangular matrix? I am not that well-versed in linear algebra, so I am started to get confused by some of these ideas. What lines of code would create a non-upper triangular matrix with multiple small eigenvalues? |
PR Description
Closes #123
Adds auto_reg parameter to vector_auto_regression. This will regularize the input matrix using scikit-learn's RidgeCV which searches for optimal alpha value. The auto_reg parameter is automatically set to True for ill-conditioned matrices.
Merge checklist
Maintainer, please confirm the following before merging: