Skip to content

Commit 1a186c2

Browse files
authored
Removed experimental features from main (#31)
* Removed treePool * Updated version * Added tests for chat models
1 parent 44bfe35 commit 1a186c2

File tree

6 files changed

+25
-391
lines changed

6 files changed

+25
-391
lines changed

bolift/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
print("GPR Packages not installed. Do `pip install bolift[gpr]` to install them")
99
from .asktellRidgeRegression import AskTellRidgeKernelRegression
1010
from .asktellNearestNeighbor import AskTellNearestNeighbor
11-
from .pool import Pool, TreeNode, TreePool
11+
from .pool import Pool
1212
from .tool import BOLiftTool

bolift/asktell.py

+13-40
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
upper_confidence_bound,
1616
greedy,
1717
)
18-
from .pool import Pool, TreePool
18+
from .pool import Pool
1919
from langchain.prompts.few_shot import FewShotPromptTemplate
2020
from langchain.prompts.prompt import PromptTemplate
2121
from langchain.vectorstores import FAISS, Chroma
@@ -53,7 +53,6 @@ def __init__(
5353
model: str = "text-curie-001",
5454
temperature: Optional[float] = None,
5555
prefix: Optional[str] = None,
56-
inv_prefix: Optional[str] = None,
5756
x_formatter: Callable[[str], str] = lambda x: x,
5857
y_formatter: Callable[[float], str] = lambda y: f"{y:0.2f}",
5958
y_name: str = "output",
@@ -94,7 +93,6 @@ def __init__(
9493
self._prompt_template = prompt_template
9594
self._suffix = suffix
9695
self._prefix = prefix
97-
self._inv_prefix = inv_prefix
9896
self._model = model
9997
self._example_count = 0
10098
self._temperature = temperature
@@ -130,11 +128,7 @@ def _setup_inv_llm(self, model: str, temperature: Optional[float] = None):
130128
temperature=0.05 if temperature is None else temperature,
131129
)
132130

133-
def _setup_inverse_prompt(self,
134-
example: Dict,
135-
prefix: Optional[str] = None):
136-
if prefix is None:
137-
prefix = ""
131+
def _setup_inverse_prompt(self, example: Dict):
138132
prompt_template = PromptTemplate(
139133
input_variables=["x", "y", "y_name", "x_name"],
140134
template="If {y_name} is {y}, then {x_name} is @@@\n{x}###",
@@ -163,7 +157,6 @@ def _setup_inverse_prompt(self,
163157
example_prompt=prompt_template,
164158
example_selector=example_selector,
165159
suffix="If {y_name} is {y}, then {x_name} is @@@",
166-
prefix=prefix,
167160
input_variables=["y", "y_name", "x_name"],
168161
)
169162

@@ -270,7 +263,7 @@ def tell(self, x: str, y: float, alt_ys: Optional[List[float]] = None) -> None:
270263
self.prompt = self._setup_prompt(
271264
example_dict, self._prompt_template, self._suffix, self._prefix
272265
)
273-
self.inv_prompt = self._setup_inverse_prompt(inv_example, self._inv_prefix)
266+
self.inv_prompt = self._setup_inverse_prompt(inv_example)
274267
self.llm = self._setup_llm(self._model, self._temperature)
275268
self.inv_llm = self._setup_inv_llm(self._model, self._temperature)
276269
self._ready = True
@@ -321,7 +314,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]
321314
self.prompt = self._setup_prompt(
322315
None, self._prompt_template, self._suffix, self._prefix
323316
)
324-
self.inv_prompt = self._setup_inverse_prompt(None, self._inv_prefix)
317+
self.inv_prompt = self._setup_inverse_prompt(None)
325318
self.llm = self._setup_llm(self._model)
326319
self._ready = True
327320

@@ -367,7 +360,7 @@ def predict(self, x: str) -> Union[Tuple[float, float], List[Tuple[float, float]
367360

368361
def ask(
369362
self,
370-
possible_x: Union[Pool, List[str], TreePool, OrderedDict[str, Any]],
363+
possible_x: Union[Pool, List[str]],
371364
aq_fxn: str = "upper_confidence_bound",
372365
k: int = 1,
373366
inv_filter: int = 16,
@@ -383,16 +376,14 @@ def ask(
383376
aq_fxn: Acquisition function to use.
384377
k: Number of x values to return.
385378
inv_filter: Reduce pool size to this number with inverse model. If 0, not used
386-
aug_random_filter: Add this many random examples to the pool to increase diversity after reducing pool with inverse model
379+
aug_random_filter: Add this man y random examples to the pool to increase diversity after reducing pool with inverse model
387380
_lambda: Lambda value to use for UCB
388381
Return:
389382
The selected x values, their acquisition function values, and the predicted y modes.
390383
Sorted by acquisition function value (descending)
391384
"""
392385
if type(possible_x) == type([]):
393386
possible_x = Pool(possible_x, self.format_x)
394-
elif type(possible_x) == type(OrderedDict()):
395-
possible_x = TreePool(possible_x, self._prompt_template.prompt, self.format_x) #need to input the string for the prompt template
396387

397388
if aq_fxn == "probability_of_improvement":
398389
aq_fxn = probability_of_improvement
@@ -416,33 +407,15 @@ def ask(
416407
else:
417408
best = np.max(self._ys)
418409

419-
if isinstance(possible_x, Pool):
420-
if inv_filter+aug_random_filter < len(possible_x):
421-
possible_x_l = []
422-
print(inv_filter, aug_random_filter)
423-
if inv_filter:
424-
approx_x = self.inv_predict(best * np.random.normal(1.0, 0.05))
425-
possible_x_l.extend(possible_x.approx_sample(approx_x, inv_filter))
426-
if aug_random_filter:
427-
possible_x_l.extend(possible_x.sample(aug_random_filter))
428-
else:
429-
possible_x_l = list(possible_x)
430-
elif isinstance(possible_x, TreePool):
410+
if inv_filter+aug_random_filter < len(possible_x):
431411
possible_x_l = []
432-
while len(possible_x_l) < k:
433-
node = possible_x._root
434-
while not node.is_leaf():
435-
partial_possible_x = [possible_x.partial_format_prompt(child.get_branch()) for child in node.get_children_list()]
436-
node_retriever = dict(zip(partial_possible_x, node.get_children_list()))
437-
selected_child = self._ask(partial_possible_x, best, aq_fxn, 1)
438-
selected_child = selected_child[0][0]
439-
node = node_retriever[selected_child]
440-
selected = possible_x.format_prompt(node.get_branch())
441-
while selected in possible_x_l:
442-
selected = possible_x.sample(1)[0]
443-
possible_x_l.append(selected)
412+
if inv_filter:
413+
approx_x = self.inv_predict(best * np.random.normal(1.0, 0.05))
414+
possible_x_l.extend(possible_x.approx_sample(approx_x, inv_filter))
415+
if aug_random_filter:
416+
possible_x_l.extend(possible_x.sample(aug_random_filter))
444417
else:
445-
raise ValueError("Unknown pool type")
418+
possible_x_l = list(possible_x)
446419

447420
results = self._ask(possible_x_l, best, aq_fxn, k)
448421
if len(results[0]) == 0 and len(possible_x_l) != 0:

0 commit comments

Comments
 (0)