-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mean_std Aquisition function #8
base: main
Are you sure you want to change the base?
Conversation
Pull Request Test Coverage Report for Build 6698196722
💛 - Coveralls |
# Mean-STD acquisition function | ||
# (n_nets/n_mc_samples, pool_dim, n_classes) logits shape | ||
pool_num = logits.shape[1] | ||
assert len(logits.shape) == 3, "logits shape must be 3-Dimensional" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This lines goes first
# std = torch.std(logits, dim=0) # standard deviation over model parameters, shape (pool_dim, n_classes) | ||
expectaion_of_squared = torch.mean(ab**2,dim=0) | ||
expectation_squared = torch.mean(ab,dim=0)**2 | ||
std = torch.sqrt(expectation_of_squared - expectation_squared) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a direct method of calculating std in torch. You can use that. Also, logits
should be converted to probs
before using this method. What is ab
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had used the direct method first bhaiya but it does not produce the same result as calculating (E[x**2] - E[x]**2)**0.5. I manually verified this with an example.
|
||
|
||
# maximum mean standard deviation aquisition function | ||
class Mean_std(EnsembleAcquisition,MCAcquisition): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Class names follow Camel Case i.e. MeanStd
expectation_squared = torch.mean(ab,dim=0)**2 | ||
std = torch.sqrt(expectation_of_squared - expectation_squared) | ||
scores = torch.mean(std, dim=1) # mean over classes, shape (pool_dim) | ||
assert len(scores.shape) == 1 and scores.shape[0]==pool_num, "scores shape must be 1-Dimensional and must have length equal to that of pool dataset" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed because, it is done by developer (You). We should have asserts to prevent users from passing invalid arguments.
I have updated rest of the changes bhaiya. |
Implementing mean_std aquisition function using ensemble and MC strategy