-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_memory.py
93 lines (75 loc) · 2.88 KB
/
test_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# test_memory.py
import torch
import argparse
from pathlib import Path
from memory_utils import load_memory_state, process_text_and_update_memory, load_model_from_checkpoint
def decode_token(token):
return str(chr(max(32, token)))
def decode_tokens(tokens):
return ''.join(list(map(decode_token, tokens)))
def test_memory_recall(
model,
context_text: str,
query_text: str,
memory_path: str = "memory_state.pt",
chunk_size: int = 512
):
"""
Test the model's ability to recall information from context
"""
print("\nProcessing context to build memory...")
print("-" * 50)
print("Context:", context_text[:100], "...")
# First pass: Process context and save memory state
_, _ = process_text_and_update_memory(
model,
context_text,
chunk_size=chunk_size,
save_memory=True,
memory_path=memory_path
)
# Clear the model's current memory state
for idx, (attn, _) in enumerate(model.layers):
if hasattr(attn, 'neural_mem'):
mem = attn.neural_mem
if mem is not None:
mem.previous_state = None
print("\nMemory state saved and cleared")
print("-" * 50)
# Load the saved memory state
load_memory_state(model, memory_path)
print("\nTesting recall with query...")
print("-" * 50)
print("Query:", query_text)
# Generate response using loaded memory
output = model(
torch.tensor([[ord(c) for c in query_text]], device='cuda' if torch.cuda.is_available() else 'cpu')
)
# Get the predicted next tokens
predictions = output[0, -1].softmax(dim=-1).topk(5)
print("\nTop 5 predicted next tokens:")
for prob, idx in zip(predictions.values, predictions.indices):
print(f"{decode_token(idx.item())}: {prob.item():.3f}")
return predictions
def main():
parser = argparse.ArgumentParser(description='Test neural memory recall')
parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
parser.add_argument('--context', type=str, required=True, help='Context text to store in memory')
parser.add_argument('--query', type=str, required=True, help='Query text to test recall')
parser.add_argument('--memory_path', type=str, default='memory_state.pt', help='Path to save/load memory state')
parser.add_argument('--chunk_size', type=int, default=512, help='Size of text chunks to process')
args = parser.parse_args()
# Load model from checkpoint
checkpoint = torch.load(args.checkpoint)
model = load_model_from_checkpoint(checkpoint) # You'll need to import this from your inference script
model.eval()
# Run memory recall test
test_memory_recall(
model,
args.context,
args.query,
args.memory_path,
args.chunk_size
)
if __name__ == '__main__':
main()