This repository has been archived by the owner on Oct 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nn.hpp
132 lines (99 loc) · 3.27 KB
/
nn.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#pragma once
#include <functional>
#include <stddef.h>
#include <vector>
namespace nn {
struct Tensor {
size_t size;
std::vector<float> data;
std::vector<float> grad;
std::vector<size_t> shape;
Tensor(std::vector<size_t> shape, size_t size, const std::function<float()> &gen);
Tensor(std::vector<size_t> shape, size_t size, float init);
Tensor();
};
struct Module {
virtual std::vector<Tensor *> parameters();
virtual std::vector<Tensor *> _parameters();
void zero_grad();
};
class AdamW {
public:
AdamW(float lr, float beta_1, float beta_2, float eps, float weight_decay,
std::vector<Tensor *> parameters);
void update(std::vector<Tensor *> parameters, int t);
private:
float lr, beta_1, beta_2, eps, weight_decay;
std::vector<std::vector<float>> m, v;
};
class Embedding : public Module {
public:
Embedding(size_t vocab_size, size_t emb_dim, const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(std::vector<int> tokens);
void backward(std::vector<int> tokens, Tensor &out);
private:
Tensor emb;
};
class LayerNorm : public Module {
public:
LayerNorm(size_t input_dim, const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(Tensor &x);
Tensor *backward(Tensor &x, Tensor &out);
private:
Tensor w, b;
Tensor mean, rstd;
};
class FeedForwardNN : public Module {
public:
FeedForwardNN(size_t input_dim, size_t hidden_dim, size_t output_dim,
const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(Tensor &x);
Tensor *backward(Tensor &x, Tensor &out);
private:
Tensor w1, v, w2, b2;
Tensor z1, z2, h;
};
class MultiHeadAttention : public Module {
public:
MultiHeadAttention(size_t emb_dim, size_t num_heads, const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(Tensor &x);
Tensor *backward(Tensor &x, Tensor &out);
private:
const size_t num_heads;
Tensor wq, wk, wv, wo;
Tensor q, k, v, qk, attn_out;
};
class Decoder : public Module {
public:
Decoder(size_t emb_dim, size_t num_heads, size_t hidden_dim, const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(Tensor &x);
Tensor *backward(Tensor &x, Tensor &out);
private:
Module attn, ffnn, attn_ln, ffnn_ln;
};
class GPT2 : public Module {
public:
GPT2(size_t vocab_size, size_t emb_dim, size_t num_heads, size_t hidden_dim,
const std::function<float()> &gen);
std::vector<Tensor *> parameters() override;
std::vector<Tensor *> _parameters() override;
Tensor operator()(Tensor &x);
Tensor *backward(Tensor &x, Tensor &out);
private:
Embedding emb;
std::vector<Decoder> layers;
};
Tensor softmax(Tensor &x, int temp);
Tensor loss(Tensor &x, Tensor &y);
Tensor *loss_backward(Tensor &x, Tensor &y, Tensor &out);
}; // namespace nn