@@ -93,6 +93,44 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
93
93
return deep_input
94
94
95
95
96
+ class PredictionLayer (nn .Module ):
97
+ def __init__ (
98
+ self ,
99
+ task : Literal ["binary" , "multiclass" , "regression" ] = "binary" ,
100
+ use_bias : bool = True ,
101
+ ):
102
+ """
103
+ Model output layer
104
+ :param task: model's task in ["binary", "multiclass", "regression"]
105
+ :param use_bias: using bias
106
+ """
107
+ if task not in ["binary" , "multiclass" , "regression" ]:
108
+ raise ValueError ("task must be binary, multiclass or regression" )
109
+
110
+ super (PredictionLayer , self ).__init__ ()
111
+ self .use_bias = use_bias
112
+ self .task = task
113
+ if self .use_bias :
114
+ self .bias = nn .Parameter (torch .zeros ((1 ,)))
115
+
116
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
117
+ """
118
+ Forward pass
119
+ :param x: input tensors
120
+ :return: output tensors
121
+ """
122
+ inputs = x
123
+ if self .use_bias :
124
+ inputs += self .bias
125
+ if self .task == "binary" :
126
+ outputs = torch .sigmoid (inputs )
127
+ elif self .task == "multiclass" :
128
+ outputs = torch .softmax (inputs , dim = 0 )
129
+ else :
130
+ outputs = input
131
+ return outputs
132
+
133
+
96
134
if __name__ == "__main__" :
97
135
"""Module for Execution in Testing
98
136
python -m CATS.layers.core
0 commit comments