Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I would like to express the MIN2Net model in pytorch, is that correct? #4

Open
YuminosukeSato opened this issue Nov 22, 2022 · 0 comments

Comments

@YuminosukeSato
Copy link

YuminosukeSato commented Nov 22, 2022

Only the MI-EEG classification part was expressed in pytorch. Does it match?

class Conv2D_Norm_Constrained(nn.Conv2d):
    def __init__(self, max_norm_val, norm_dim, **kwargs):
        super().__init__(**kwargs)
        self.max_norm_val = max_norm_val
        self.norm_dim = norm_dim

    def get_constrained_weights(self, epsilon=1e-8):
        norm = self.weight.norm(2, dim=self.norm_dim, keepdim=True)
        return self.weight * (torch.clamp(norm, 0, self.max_norm_val) / (norm + epsilon))

    def forward(self, input):
        return F.conv2d(input, self.get_constrained_weights(), self.bias, self.stride, self.padding, self.dilation, self.groups)

class ConstrainedLinear(nn.Linear):
    def forward(self, input):
        return F.linear(input, self.weight.clamp(min=-1.0, max=0.5), self.bias)
class MinNet(nn.Module): # input = (1,16,125)
  def __init__(self, input_shape=(1,400,20)):
    super().__init__()
    self.D, self.T, self.C = input_shape
    self.subsampling_size = 100
    self.pool_size_1 = (1,self.T//self.subsampling_size)
    self.en_conv = nn.Sequential(
                    Conv2D_Norm_Constrained(in_channels=1, out_channels=16, kernel_size=(1, 64), padding="same", max_norm_val=2.0, norm_dim=(0, 1, 2)),
                    nn.ELU(),
                    nn.BatchNorm2d(16,eps=1e-05, momentum=0.1),
                    nn.AvgPool2d((1,self.pool_size_1)),
                    nn.Flatten(),
                    ConstrainedLinear(32000,64),
                    nn.ELU(),
                    ConstrainedLinear(64,3)
                )
  def forward(self,x):
      x = self.en_conv(x)
      return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant