-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
35 lines (23 loc) · 1.09 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import utils
from data import Dataset
from llms import LLM
def run():
dataset = Dataset(dataset_name='nyc', trajectory_mode='trajectory_split', historical_stays=15,
context_stays=6, save_dir='data/processed')
test_dictionary, true_locations = dataset.get_generated_datasets()
# Select the model and the prompt to run
print('Running 13 B')
llm = LLM(test_dictionary, true_locations, prompt_type='1',
model_name='llama13bchat', trajectory_mode='trajectory_split')
# OTHER EXAMPLES
#print('Running 7 B')
#llm = LLM(test_dictionary, true_locations, prompt_type='5',
# model_name='llama7bchat', trajectory_mode='trajectory_split')
#print('Running 70 B')
#llm = LLM(test_dictionary, true_locations, prompt_type='5',
# model_name='llama70bchat', trajectory_mode='trajectory_split')
#print('Running gpt 3.5 turbo')
#llm = LLM(test_dictionary, true_locations, prompt_type='5',
# model_name='gpt35turbo', trajectory_mode='trajectory_split')
if __name__ == '__main__':
run()