-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathengram.py
More file actions
485 lines (400 loc) · 16.4 KB
/
engram.py
File metadata and controls
485 lines (400 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
"""
================================================================================
[Engram Architecture Demo Implementation]
DISCLAIMER:
1. Demo Purpose Only:
This code is a demonstration version intended to illustrate the core logic and
data flow of the Engram module.
2. Production Readiness:
This implementation requires further optimization for actual production use
(e.g., custom CUDA kernels, distributed training support).
3. Simplifications:
Standard components (Normalization, Attention, MoE) and complex Hyper-connection
mechanisms are omitted or mocked in this version to focus exclusively on the
Engram module implementation.
================================================================================
"""
"""
pip install torch numpy transformers sympy
"""
## built-in
from typing import List
from dataclasses import dataclass, field
import math
## third-party
from sympy import isprime
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from tokenizers import normalizers, Regex
@dataclass
class EngramConfig:
tokenizer_name_or_path: str = "deepseek-ai/DeepSeek-V3"
engram_vocab_size: List[int] = field(default_factory=lambda: [129280*5, 129280*5])
max_ngram_size: int = 3
n_embed_per_ngram: int = 512
n_head_per_ngram: int = 8
layer_ids: List[int] = field(default_factory=lambda: [1, 15])
pad_id: int = 2
seed: int = 0
kernel_size: int = 4
warmup_steps: int = 0
soft_constraint_steps: int = 0
@classmethod
def from_dict(cls, config_dict):
valid_keys = cls.__dataclass_fields__.keys()
filtered_dict = {k: v for k, v in config_dict.items() if k in valid_keys}
return cls(**filtered_dict)
def to_dict(self):
return {
field.name: getattr(self, field.name)
for field in self.__dataclass_fields__.values()
}
@dataclass
class BackBoneConfig:
hidden_size: int = 1024
hc_mult: int = 4
vocab_size: int = 129280
num_layers: int = 30
engram_cfg = EngramConfig()
backbone_config = BackBoneConfig()
_GLOBAL_STEP = 0
def set_global_step(step: int):
global _GLOBAL_STEP
_GLOBAL_STEP = step
class CompressedTokenizer:
def __init__(
self,
tokenizer_name_or_path,
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
SENTINEL = "\uE000"
self.normalizer = normalizers.Sequence([
normalizers.NFKC(),
normalizers.NFD(),
normalizers.StripAccents(),
normalizers.Lowercase(),
normalizers.Replace(Regex(r"[ \t\r\n]+"), " "),
normalizers.Replace(Regex(r"^ $"), SENTINEL),
normalizers.Strip(),
normalizers.Replace(SENTINEL, " "),
])
self.lookup_table, self.num_new_token = self._build_lookup_table()
def __len__(self):
return self.num_new_token
def _build_lookup_table(self):
old2new = {}
key2new = {}
new_tokens = []
vocab_size = len(self.tokenizer)
for tid in range(vocab_size):
text = self.tokenizer.decode([tid], skip_special_tokens=False)
if "�" in text:
key = self.tokenizer.convert_ids_to_tokens(tid)
else:
norm = self.normalizer.normalize_str(text)
key = norm if norm else text
nid = key2new.get(key)
if nid is None:
nid = len(new_tokens)
key2new[key] = nid
new_tokens.append(key)
old2new[tid] = nid
lookup = np.empty(vocab_size, dtype=np.int64)
for tid in range(vocab_size):
lookup[tid] = old2new[tid]
return lookup, len(new_tokens)
def _compress(self, input_ids):
arr = np.asarray(input_ids, dtype=np.int64)
pos_mask = arr >= 0
out = arr.copy()
valid_ids = arr[pos_mask]
out[pos_mask] = self.lookup_table[valid_ids]
return out
def __call__(self, input_ids):
return self._compress(input_ids)
class ShortConv(nn.Module):
def __init__(
self,
hidden_size: int,
kernel_size: int = 4,
dilation: int = 1,
norm_eps: float = 1e-5,
hc_mult: int = 4,
activation: bool = True,
):
super().__init__()
self.hc_mult = hc_mult
self.activation = activation
total_channels = hidden_size * hc_mult
self.conv = nn.Conv1d(
in_channels=total_channels,
out_channels=total_channels,
kernel_size=kernel_size,
groups=total_channels,
bias=False,
padding=(kernel_size - 1) * dilation,
dilation=dilation,
)
# Initialize weights to zero
nn.init.zeros_(self.conv.weight)
self.norms = nn.ModuleList([
nn.RMSNorm(hidden_size, eps=norm_eps)
for _ in range(hc_mult)
])
if self.activation:
self.act_fn = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Input: (B,L,HC_MULT,D)
Output: (B,L,HC_MULT,D)
"""
B, T, G, C = x.shape
assert G == self.hc_mult, f"Input groups {G} != hc_mult {self.hc_mult}"
normed_chunks = []
for i in range(G):
chunk = x[:, :, i, :]
normed_chunks.append(self.norms[i](chunk))
x_norm = torch.cat(normed_chunks, dim=-1)
x_bct = x_norm.transpose(1, 2)
y_bct = self.conv(x_bct)
y_bct = y_bct[..., :T]
if self.activation:
y_bct = self.act_fn(y_bct)
y = y_bct.transpose(1, 2).view(B, T, G, C).contiguous()
return y
def find_next_prime(start, seen_primes):
candidate = start + 1
while True:
if isprime(candidate) and candidate not in seen_primes:
return candidate
candidate += 1
class NgramHashMapping:
def __init__(
self,
engram_vocab_size,
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids,
tokenizer_name_or_path,
pad_id,
seed,
):
self.vocab_size_per_ngram = engram_vocab_size
self.max_ngram_size = max_ngram_size
self.n_embed_per_ngram = n_embed_per_ngram
self.n_head_per_ngram = n_head_per_ngram
self.pad_id = pad_id
self.layer_ids = layer_ids
self.compressed_tokenizer = CompressedTokenizer(
tokenizer_name_or_path=tokenizer_name_or_path
)
self.tokenizer_vocab_size = len(self.compressed_tokenizer)
if self.pad_id is not None:
self.pad_id = int(self.compressed_tokenizer.lookup_table[self.pad_id])
max_long = np.iinfo(np.int64).max
M_max = int(max_long // self.tokenizer_vocab_size)
half_bound = max(1, M_max // 2)
PRIME_1 = 10007
self.layer_multipliers = {}
for layer_id in self.layer_ids:
base_seed = int(seed + PRIME_1 * int(layer_id))
g = np.random.default_rng(base_seed)
r = g.integers(
low=0,
high=half_bound,
size=(self.max_ngram_size,),
dtype=np.int64
)
multipliers = r * 2 + 1
self.layer_multipliers[layer_id] = multipliers
self.vocab_size_across_layers = self.calculate_vocab_size_across_layers()
def calculate_vocab_size_across_layers(self):
seen_primes = set()
vocab_size_across_layers = {}
for layer_id in self.layer_ids:
all_ngram_vocab_sizes = []
for ngram in range(2, self.max_ngram_size + 1):
current_ngram_heads_sizes = []
vocab_size = self.vocab_size_per_ngram[ngram - 2]
num_head = self.n_head_per_ngram
current_prime_search_start = vocab_size - 1
for _ in range(num_head):
found_prime = find_next_prime(
current_prime_search_start,
seen_primes
)
seen_primes.add(found_prime)
current_ngram_heads_sizes.append(found_prime)
current_prime_search_start = found_prime
all_ngram_vocab_sizes.append(current_ngram_heads_sizes)
vocab_size_across_layers[layer_id] = all_ngram_vocab_sizes
return vocab_size_across_layers
def _get_ngram_hashes(
self,
input_ids: np.ndarray,
layer_id: int,
) -> np.ndarray:
x = np.asarray(input_ids, dtype=np.int64)
B, T = x.shape
multipliers = self.layer_multipliers[layer_id]
def shift_k(k: int) -> np.ndarray:
if k == 0: return x
shifted = np.pad(x, ((0, 0), (k, 0)),
mode='constant', constant_values=self.pad_id)[:, :T]
return shifted
base_shifts = [shift_k(k) for k in range(self.max_ngram_size)]
all_hashes = []
for n in range(2, self.max_ngram_size + 1):
n_gram_index = n - 2
tokens = base_shifts[:n]
mix = (tokens[0] * multipliers[0])
for k in range(1, n):
mix = np.bitwise_xor(mix, tokens[k] * multipliers[k])
num_heads_for_this_ngram = self.n_head_per_ngram
head_vocab_sizes = self.vocab_size_across_layers[layer_id][n_gram_index]
for j in range(num_heads_for_this_ngram):
mod = int(head_vocab_sizes[j])
head_hash = mix % mod
all_hashes.append(head_hash.astype(np.int64, copy=False))
return np.stack(all_hashes, axis=2)
def hash(self, input_ids):
input_ids = self.compressed_tokenizer(input_ids)
hash_ids_for_all_layers = {}
for layer_id in self.layer_ids:
hash_ids_for_all_layers[layer_id] = self._get_ngram_hashes(input_ids, layer_id=layer_id)
return hash_ids_for_all_layers
class MultiHeadEmbedding(nn.Module):
def __init__(self, list_of_N: List[int], D: int):
super().__init__()
self.num_heads = len(list_of_N)
self.embedding_dim = D
offsets = [0]
for n in list_of_N[:-1]:
offsets.append(offsets[-1] + n)
self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long))
total_N = sum(list_of_N)
self.embedding = nn.Embedding(num_embeddings=total_N, embedding_dim=D)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
shifted_input_ids = input_ids + self.offsets
output = self.embedding(shifted_input_ids)
return output
class Engram(nn.Module):
def __init__(
self,
layer_id,
config: EngramConfig,
backbone_hidden_size: int,
backbone_hc_mult: int,
):
super().__init__()
self.layer_id = layer_id
# Use provided config or fall back to global default
self.config = config if config is not None else engram_cfg
# Backbone params are now passed in (no global backbone_config dependency inside Engram)
self.backbone_hidden_size = int(backbone_hidden_size)
self.backbone_hc_mult = int(backbone_hc_mult)
self.hash_mapping = NgramHashMapping(
engram_vocab_size=self.config.engram_vocab_size,
max_ngram_size=self.config.max_ngram_size,
n_embed_per_ngram=self.config.n_embed_per_ngram,
n_head_per_ngram=self.config.n_head_per_ngram,
layer_ids=self.config.layer_ids,
tokenizer_name_or_path=self.config.tokenizer_name_or_path,
pad_id=self.config.pad_id,
seed=self.config.seed,
)
self.multi_head_embedding = MultiHeadEmbedding(
list_of_N=[x for y in self.hash_mapping.vocab_size_across_layers[self.layer_id] for x in y],
D=self.config.n_embed_per_ngram // self.config.n_head_per_ngram,
)
self.short_conv = ShortConv(
hidden_size=self.backbone_hidden_size,
kernel_size=self.config.kernel_size,
dilation=self.config.max_ngram_size,
hc_mult=self.backbone_hc_mult,
)
engram_hidden_size = (self.config.max_ngram_size - 1) * self.config.n_embed_per_ngram
self.value_proj = nn.Linear(engram_hidden_size, self.backbone_hidden_size)
self.key_projs = nn.ModuleList(
[nn.Linear(engram_hidden_size, self.backbone_hidden_size) for _ in range(self.backbone_hc_mult)]
)
self.norm1 = nn.ModuleList([nn.RMSNorm(self.backbone_hidden_size) for _ in range(self.backbone_hc_mult)])
self.norm2 = nn.ModuleList([nn.RMSNorm(self.backbone_hidden_size) for _ in range(self.backbone_hc_mult)])
def forward(self, hidden_states, input_ids):
"""
hidden_states: [B, L, HC_MULT, D]
input_ids: [B, L]
"""
# Move input_ids to CPU for numpy-based hashing logic
if isinstance(input_ids, torch.Tensor):
input_ids_np = input_ids.detach().cpu().numpy()
device = input_ids.device
else:
input_ids_np = input_ids
device = "cpu"
hash_input_ids = torch.from_numpy(self.hash_mapping.hash(input_ids_np)[self.layer_id]).to(device)
embeddings = self.multi_head_embedding(hash_input_ids).flatten(start_dim=-2)
gates = []
for hc_idx in range(self.backbone_hc_mult):
key = self.key_projs[hc_idx](embeddings)
normed_key = self.norm1[hc_idx](key)
query = hidden_states[:, :, hc_idx, :]
normed_query = self.norm2[hc_idx](query)
gate = (normed_key * normed_query).sum(dim=-1) / math.sqrt(self.backbone_hidden_size)
gate = gate.abs().clamp_min(1e-6).sqrt() * gate.sign()
gate = gate.sigmoid().unsqueeze(-1)
gates.append(gate)
gates = torch.stack(gates, dim=2)
# Warmup logic: if current step < warmup_steps, force gates to 1 (pass-through)
if _GLOBAL_STEP < self.config.warmup_steps:
gates = torch.ones_like(gates)
elif _GLOBAL_STEP < self.config.warmup_steps + self.config.soft_constraint_steps:
gates = gates.clamp_min(0.1)
value = gates * self.value_proj(embeddings).unsqueeze(2)
output = value + self.short_conv(value)
return output
class TransformerBlock(nn.Module):
def __init__(self, layer_id):
super().__init__()
self.attn = lambda x:x
self.moe = lambda x:x
self.engram = None
if layer_id in engram_cfg.layer_ids:
self.engram = Engram(
layer_id=layer_id,
config=engram_cfg,
backbone_hidden_size=backbone_config.hidden_size,
backbone_hc_mult=backbone_config.hc_mult,
)
def forward(self,input_ids,hidden_states):
if self.engram is not None:
hidden_states = self.engram(hidden_states=hidden_states,input_ids=input_ids) + hidden_states
hidden_states = self.attn(hidden_states) + hidden_states
hidden_states = self.moe(hidden_states) + hidden_states
return hidden_states
if __name__ == '__main__':
LLM = [
nn.Embedding(backbone_config.vocab_size,backbone_config.hidden_size),
*[TransformerBlock(layer_id=layer_id) for layer_id in range(backbone_config.num_layers)],
nn.Linear(backbone_config.hidden_size, backbone_config.vocab_size)
]
text = "Only Alexander the Great could tame the horse Bucephalus."
tokenizer = AutoTokenizer.from_pretrained(engram_cfg.tokenizer_name_or_path,trust_remote_code=True)
input_ids = tokenizer(text,return_tensors='pt').input_ids
B,L = input_ids.shape
for idx, layer in enumerate(LLM):
if idx == 0:
hidden_states = LLM[0](input_ids)
## mock hyper-connection
hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, backbone_config.hc_mult, -1)
elif idx == len(LLM)-1:
## mock hyper-connection
hidden_states = hidden_states[:,:,0,:]
output = layer(hidden_states)
else:
hidden_states = layer(input_ids=input_ids,hidden_states=hidden_states)
print("✅ Forward Complete!")
print(f"{input_ids.shape=}\n{output.shape=}")