Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean_std Aquisition function #8

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

VannshJani
Copy link

Implementing mean_std aquisition function using ensemble and MC strategy

@coveralls
Copy link

Pull Request Test Coverage Report for Build 6698196722

  • 0 of 13 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-1.7%) to 58.531%

Changes Missing Coverage Covered Lines Changed/Added Lines %
astra/torch/al/acquisitions/Mean_std.py 0 13 0.0%
Totals Coverage Status
Change from base Build 6692078825: -1.7%
Covered Lines: 271
Relevant Lines: 463

💛 - Coveralls

# Mean-STD acquisition function
# (n_nets/n_mc_samples, pool_dim, n_classes) logits shape
pool_num = logits.shape[1]
assert len(logits.shape) == 3, "logits shape must be 3-Dimensional"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lines goes first

# std = torch.std(logits, dim=0) # standard deviation over model parameters, shape (pool_dim, n_classes)
expectaion_of_squared = torch.mean(ab**2,dim=0)
expectation_squared = torch.mean(ab,dim=0)**2
std = torch.sqrt(expectation_of_squared - expectation_squared)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a direct method of calculating std in torch. You can use that. Also, logits should be converted to probs before using this method. What is ab?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had used the direct method first bhaiya but it does not produce the same result as calculating (E[x**2] - E[x]**2)**0.5. I manually verified this with an example.



# maximum mean standard deviation aquisition function
class Mean_std(EnsembleAcquisition,MCAcquisition):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class names follow Camel Case i.e. MeanStd

expectation_squared = torch.mean(ab,dim=0)**2
std = torch.sqrt(expectation_of_squared - expectation_squared)
scores = torch.mean(std, dim=1) # mean over classes, shape (pool_dim)
assert len(scores.shape) == 1 and scores.shape[0]==pool_num, "scores shape must be 1-Dimensional and must have length equal to that of pool dataset"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed because, it is done by developer (You). We should have asserts to prevent users from passing invalid arguments.

@VannshJani
Copy link
Author

I have updated rest of the changes bhaiya.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants