15
15
upper_confidence_bound ,
16
16
greedy ,
17
17
)
18
- from .pool import Pool , TreePool
18
+ from .pool import Pool
19
19
from langchain .prompts .few_shot import FewShotPromptTemplate
20
20
from langchain .prompts .prompt import PromptTemplate
21
21
from langchain .vectorstores import FAISS , Chroma
@@ -53,7 +53,6 @@ def __init__(
53
53
model : str = "text-curie-001" ,
54
54
temperature : Optional [float ] = None ,
55
55
prefix : Optional [str ] = None ,
56
- inv_prefix : Optional [str ] = None ,
57
56
x_formatter : Callable [[str ], str ] = lambda x : x ,
58
57
y_formatter : Callable [[float ], str ] = lambda y : f"{ y :0.2f} " ,
59
58
y_name : str = "output" ,
@@ -94,7 +93,6 @@ def __init__(
94
93
self ._prompt_template = prompt_template
95
94
self ._suffix = suffix
96
95
self ._prefix = prefix
97
- self ._inv_prefix = inv_prefix
98
96
self ._model = model
99
97
self ._example_count = 0
100
98
self ._temperature = temperature
@@ -130,11 +128,7 @@ def _setup_inv_llm(self, model: str, temperature: Optional[float] = None):
130
128
temperature = 0.05 if temperature is None else temperature ,
131
129
)
132
130
133
- def _setup_inverse_prompt (self ,
134
- example : Dict ,
135
- prefix : Optional [str ] = None ):
136
- if prefix is None :
137
- prefix = ""
131
+ def _setup_inverse_prompt (self , example : Dict ):
138
132
prompt_template = PromptTemplate (
139
133
input_variables = ["x" , "y" , "y_name" , "x_name" ],
140
134
template = "If {y_name} is {y}, then {x_name} is @@@\n {x}###" ,
@@ -163,7 +157,6 @@ def _setup_inverse_prompt(self,
163
157
example_prompt = prompt_template ,
164
158
example_selector = example_selector ,
165
159
suffix = "If {y_name} is {y}, then {x_name} is @@@" ,
166
- prefix = prefix ,
167
160
input_variables = ["y" , "y_name" , "x_name" ],
168
161
)
169
162
@@ -270,7 +263,7 @@ def tell(self, x: str, y: float, alt_ys: Optional[List[float]] = None) -> None:
270
263
self .prompt = self ._setup_prompt (
271
264
example_dict , self ._prompt_template , self ._suffix , self ._prefix
272
265
)
273
- self .inv_prompt = self ._setup_inverse_prompt (inv_example , self . _inv_prefix )
266
+ self .inv_prompt = self ._setup_inverse_prompt (inv_example )
274
267
self .llm = self ._setup_llm (self ._model , self ._temperature )
275
268
self .inv_llm = self ._setup_inv_llm (self ._model , self ._temperature )
276
269
self ._ready = True
@@ -321,7 +314,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]
321
314
self .prompt = self ._setup_prompt (
322
315
None , self ._prompt_template , self ._suffix , self ._prefix
323
316
)
324
- self .inv_prompt = self ._setup_inverse_prompt (None , self . _inv_prefix )
317
+ self .inv_prompt = self ._setup_inverse_prompt (None )
325
318
self .llm = self ._setup_llm (self ._model )
326
319
self ._ready = True
327
320
@@ -367,7 +360,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]
367
360
368
361
def ask (
369
362
self ,
370
- possible_x : Union [Pool , List [str ], TreePool , OrderedDict [ str , Any ] ],
363
+ possible_x : Union [Pool , List [str ]],
371
364
aq_fxn : str = "upper_confidence_bound" ,
372
365
k : int = 1 ,
373
366
inv_filter : int = 16 ,
@@ -383,16 +376,14 @@ def ask(
383
376
aq_fxn: Acquisition function to use.
384
377
k: Number of x values to return.
385
378
inv_filter: Reduce pool size to this number with inverse model. If 0, not used
386
- aug_random_filter: Add this many random examples to the pool to increase diversity after reducing pool with inverse model
379
+ aug_random_filter: Add this man y random examples to the pool to increase diversity after reducing pool with inverse model
387
380
_lambda: Lambda value to use for UCB
388
381
Return:
389
382
The selected x values, their acquisition function values, and the predicted y modes.
390
383
Sorted by acquisition function value (descending)
391
384
"""
392
385
if type (possible_x ) == type ([]):
393
386
possible_x = Pool (possible_x , self .format_x )
394
- elif type (possible_x ) == type (OrderedDict ()):
395
- possible_x = TreePool (possible_x , self ._prompt_template .prompt , self .format_x ) #need to input the string for the prompt template
396
387
397
388
if aq_fxn == "probability_of_improvement" :
398
389
aq_fxn = probability_of_improvement
@@ -416,33 +407,15 @@ def ask(
416
407
else :
417
408
best = np .max (self ._ys )
418
409
419
- if isinstance (possible_x , Pool ):
420
- if inv_filter + aug_random_filter < len (possible_x ):
421
- possible_x_l = []
422
- print (inv_filter , aug_random_filter )
423
- if inv_filter :
424
- approx_x = self .inv_predict (best * np .random .normal (1.0 , 0.05 ))
425
- possible_x_l .extend (possible_x .approx_sample (approx_x , inv_filter ))
426
- if aug_random_filter :
427
- possible_x_l .extend (possible_x .sample (aug_random_filter ))
428
- else :
429
- possible_x_l = list (possible_x )
430
- elif isinstance (possible_x , TreePool ):
410
+ if inv_filter + aug_random_filter < len (possible_x ):
431
411
possible_x_l = []
432
- while len (possible_x_l ) < k :
433
- node = possible_x ._root
434
- while not node .is_leaf ():
435
- partial_possible_x = [possible_x .partial_format_prompt (child .get_branch ()) for child in node .get_children_list ()]
436
- node_retriever = dict (zip (partial_possible_x , node .get_children_list ()))
437
- selected_child = self ._ask (partial_possible_x , best , aq_fxn , 1 )
438
- selected_child = selected_child [0 ][0 ]
439
- node = node_retriever [selected_child ]
440
- selected = possible_x .format_prompt (node .get_branch ())
441
- while selected in possible_x_l :
442
- selected = possible_x .sample (1 )[0 ]
443
- possible_x_l .append (selected )
412
+ if inv_filter :
413
+ approx_x = self .inv_predict (best * np .random .normal (1.0 , 0.05 ))
414
+ possible_x_l .extend (possible_x .approx_sample (approx_x , inv_filter ))
415
+ if aug_random_filter :
416
+ possible_x_l .extend (possible_x .sample (aug_random_filter ))
444
417
else :
445
- raise ValueError ( "Unknown pool type" )
418
+ possible_x_l = list ( possible_x )
446
419
447
420
results = self ._ask (possible_x_l , best , aq_fxn , k )
448
421
if len (results [0 ]) == 0 and len (possible_x_l ) != 0 :
0 commit comments