Skip to content

Commit 52b87bd

Browse files
committed
[feat] PredictionLayer 클래스 구현
1 parent f92e348 commit 52b87bd

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

CATS/layers/core.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,44 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
9393
return deep_input
9494

9595

96+
class PredictionLayer(nn.Module):
97+
def __init__(
98+
self,
99+
task: Literal["binary", "multiclass", "regression"] = "binary",
100+
use_bias: bool = True,
101+
):
102+
"""
103+
Model output layer
104+
:param task: model's task in ["binary", "multiclass", "regression"]
105+
:param use_bias: using bias
106+
"""
107+
if task not in ["binary", "multiclass", "regression"]:
108+
raise ValueError("task must be binary, multiclass or regression")
109+
110+
super(PredictionLayer, self).__init__()
111+
self.use_bias = use_bias
112+
self.task = task
113+
if self.use_bias:
114+
self.bias = nn.Parameter(torch.zeros((1,)))
115+
116+
def forward(self, x: torch.Tensor) -> torch.Tensor:
117+
"""
118+
Forward pass
119+
:param x: input tensors
120+
:return: output tensors
121+
"""
122+
inputs = x
123+
if self.use_bias:
124+
inputs += self.bias
125+
if self.task == "binary":
126+
outputs = torch.sigmoid(inputs)
127+
elif self.task == "multiclass":
128+
outputs = torch.softmax(inputs, dim=0)
129+
else:
130+
outputs = input
131+
return outputs
132+
133+
96134
if __name__ == "__main__":
97135
"""Module for Execution in Testing
98136
python -m CATS.layers.core

0 commit comments

Comments
 (0)