-
Notifications
You must be signed in to change notification settings - Fork 15
/
function.hpp
75 lines (57 loc) · 2.06 KB
/
function.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#ifndef FUNCTION_HPP
#define FUNCTION_HPP
// a function has an input and output dimension and implements fwd and bwd pass
struct Function {
// every layer has to implement forward and backward
virtual void forward(const Tensor& input, Tensor& output) = 0;
virtual void init_forward(const Tensor&, Tensor&) {};
virtual void backward(const Tensor& doutput, Tensor& dinput) = 0;
virtual void init_backward(const Tensor&, Tensor&) {};
// return the input dimensions
virtual const TensorDesc& getInputDesc() const = 0;
/// returns the output dimensions
virtual const TensorDesc& getOutputDesc() const = 0;
// Prints the input and output dimensions to the given stream
std::ostream& write_dims(std::ostream& os) const {
return os << getInputDesc() << "->" << getOutputDesc();
}
virtual std::ostream& write_name(std::ostream& os) const {
return os << "Function";
}
virtual std::ostream& write(std::ostream& os) const {
return this->write_dims(this->write_name(os) << " ");
}
};
/* a Layer is a Function for which the input and output dimensions are known
* at construction time and it buffers these
*/
struct Layer : public Function {
TensorDesc input_desc;
TensorDesc output_desc;
Layer(const Dim& input_desc, const Dim& output_desc)
: input_desc(input_desc), output_desc(output_desc) {}
Layer(const TensorDesc& input_desc, const TensorDesc& output_desc)
: input_desc(input_desc), output_desc(output_desc) {}
virtual const TensorDesc& getInputDesc() const override {
return input_desc;
}
virtual const TensorDesc& getOutputDesc() const override {
return output_desc;
}
virtual std::ostream& write_name(std::ostream& os) const override {
return os << "Layer";
}
};
std::ostream& operator<<(std::ostream& os, const Function& l) {
return l.write(os);
}
/*
struct Model {
void init_forward();
void init_backward();
void forward();
void backward();
std::string get_name() const;
};
*/
#endif // FUNCTION_HPP