-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
189 lines (154 loc) · 6.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import os
import certifi
os.environ['SSL_CERT_FILE'] = certifi.where()
import asyncio
import wandb
import weave
import pathlib
import time
from typing import Literal, Callable, List
import dspy
from colorama import init
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot
from sklearn.model_selection import train_test_split
from dataclasses import asdict
import uuid
# TODO: check if /models exists and if not create
pathlib.Path("models").mkdir(parents=True, exist_ok=True)
pathlib.Path("data").mkdir(parents=True, exist_ok=True)
pathlib.Path("logs").mkdir(parents=True, exist_ok=True)
from src.config import Config, load_config
from src.util import Persona
from src.agents import Agent, AdvancedAgent
from src.dataloader import DataManager
from src.evaluation import EvaluationManager
from src.eot import ExchangeOfThought
from src.util import LanguageModel, PrefixedChatAdapter
# Initialize colorama
init(autoreset=True)
# CONSTANTS
DEBUG: bool = True
SEED: int = 77777
API: Literal['lambda', 'openai'] = 'lambda'
MAX_TOKEN: int = 100
ID = uuid.uuid4().hex[:8]
lm_wrapper = LanguageModel(max_tokens=MAX_TOKEN, service=API)
custom_adapter = PrefixedChatAdapter()
dspy.configure(lm=lm_wrapper.lm, adapter=custom_adapter)
import pdb
def evaluate_with_weave(evaluation_dataset: List[dspy.Example], model: dspy.Module, scorer: Callable) -> None:
dataset = [dspy.Example(**row).toDict() for row in evaluation_dataset]
weave_eval = weave.Evaluation(
dataset=dataset,
scorers=[scorer],
evaluation_name=f"run-{ID}",
)
@weave.op()
def my_model(QuestionText, AnswerText, ConstructName, SubjectName, CorrectAnswer):
print("Evaluating ...")
model_output = model(QuestionText, AnswerText, ConstructName, SubjectName, CorrectAnswer)
return model_output
result = asyncio.run(weave_eval.evaluate(my_model))
wandb.run.config.update({'weave.run.name': ID})
wandb.log(result, step=1)
print("Weave Result: " + str(result))
def main(args: Config):
weave.init(project_name="llma-agents" if not DEBUG else "llma-agents-debug")
start = time.time()
# wandb.init(project="llma-agents" if not DEBUG else "llma-agents-debug")
examples = DataManager.get_examples(pathlib.Path("data"), debug=DEBUG)
# Split in 80% validation as this is what is suggested here https://dspy.ai/learn/optimization/overview/
train_data, val_data = train_test_split(examples, test_size=0.8, random_state=SEED)
# Collect persona prompts
# Set up Agents
# pdb.set_trace()
agent_a = Agent(name="Agent A" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
agent_b = Agent(name="Agent B" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
agent_c = Agent(name="Agent C" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
agent_d = Agent(name="Agent D" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
agent_e = Agent(name="Agent E" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
# agent_a = AdvancedAgent(name="Agent A" , persona_promt=Persona.AGENT_A_new)
# agent_b = AdvancedAgent(name="Agent B" , persona_promt=Persona.AGENT_B_new)
# agent_c = AdvancedAgent(name="Agent C" , persona_promt=Persona.AGENT_C_new)
# agent_d = AdvancedAgent(name="Agent D" , persona_promt=Persona.AGENT_D_new)
# agent_e = AdvancedAgent(name="Agent E" , persona_promt=Persona.AGENT_E_new)
if args.ExchangeOfThought.mode != "single":
predict = ExchangeOfThought(
agent_a, agent_b, agent_c, agent_d, agent_e, rounds=args.ExchangeOfThought.rounds, mode=args.ExchangeOfThought.mode)
persona_prompts = {
"Agent A Persona": agent_a.prefix_promt,
"Agent B Persona": agent_b.prefix_promt,
"Agent C Persona": agent_c.prefix_promt,
"Agent D Persona": agent_d.prefix_promt,
"Agent E Persona": agent_e.prefix_promt,
"debug": DEBUG
}
else:
predict = Agent(name="Agent A" , agent_type = args.ExchangeOfThought.baseagent, persona_promt=None)
persona_prompts = {
"Agent A Persona": None,
"debug": DEBUG
}
wandb.config.update(persona_prompts)
eval_manager = EvaluationManager(retrive_method=args.Dspy.evaluation.type)
# compile
print("Start training ...")
if args.Dspy.telepropmter.type == "BootstrapFewShot":
teleprompter = BootstrapFewShot(metric=eval_manager.metric_vector_search_weave, max_labeled_demos=3)
compiled_predictor = teleprompter.compile(predict, trainset=train_data)
compiled_predictor.save("models" / pathlib.Path(f'compiled_model-{ID}.dspy'))
print("Finished training BootstrapFewShot...")
elif args.Dspy.telepropmter.type == "MIPROv2":
teleprompter = dspy.MIPROv2(metric=eval_manager.metric_vector_search_weave, auto='medium', num_threads=6)
compiled_predictor = teleprompter.compile(predict, trainset=train_data, requires_permission_to_run=False)
compiled_predictor.save("models" / pathlib.Path(f'compiled_model-{ID}.dspy'))
print("Finished training MIPROv2...")
else:
predict.load("models" / pathlib.Path(f"compiled_model-{args.Dspy.telepropmter.type}.dspy"))
print("Finished loading")
# --- DO NOT CHANGE anything below this line ---
# evaluate_with_weave(val_data, predict, eval_manager.metric_vector_search_weave)
evaluate_with_weave(val_data, compiled_predictor, eval_manager.metric_vector_search_weave)
end = time.time()
usage = lm_wrapper.get_usage()
wandb.log({
"usage_cost_cents": usage[2],
"input_tokens": usage[0],
"output_tokens": usage[1],
"time_taken_seconds": end - start,
"dataset" : {
"train.size": len(train_data),
"val.size": len(val_data)
}
})
#wandb.save("models" / pathlib.Path(f'compiled_model-{ID}.dspy'))
print(f"Usage cost (in cents) about {usage[2]}, Input Tokens: {usage[0]}, Output Tokens {usage[1]}")
print("Time taken (in seconds)", end - start)
print("Run ID: ", ID)
wandb.finish()
weave.finish()
if __name__ == "__main__":
wandb.login(key=os.getenv("WANDB_API_KEY"))
wandb.init(project="llma-agents" if not DEBUG else "llma-agents-debug", name=f"run-{ID}")
USE_WANDB_CONFIG = True
if USE_WANDB_CONFIG:
args = load_config(dict(wandb.config))
else:
args_dict = {
"ExchangeOfThought": {
"baseagent": "reasoning",
"mode": "multi_4",
"rounds": 1,
},
"Dspy": {
"telepropmter": { # Nested TelepropmterConfig
"type": "BootstrapFewShot" #Literal['BootstrapFewShot', 'MIPROv2'] # Example integer value for TelepropmterConfig.max_labeled_demos
},
"evaluation": {
"type": "multi"
},
}
}
args = load_config(args_dict)
main(args)