Skip to content

Commit

Permalink
minor bug fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
ritugala committed Mar 30, 2024
1 parent 379a503 commit acd6e7a
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 44 deletions.
49 changes: 49 additions & 0 deletions dump.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Input:
Output:
Input: Q: "strong themes of familial ties and spirituality that are powerful and moving without stooping to base melodrama "
A:, Q: "i 'd much rather watch teens poking their genitals into fruit pies "
A:, Q: "high-concept "
A:, Q: is both convincing and radiant
A:
Output: positive, negative, positive, positive
Input: Q: is n't merely offensive
A:, Q: "berling and b\u00e9art ... continue to impress , and isabelle huppert ... again shows uncanny skill in getting under the skin of her characters "
A:, Q: the film is , arguably , the most accomplished work to date from hong kong 's versatile stanley kwan .
A:, Q: "rubbo 's humorously tendentious intervention into the who-wrote-shakespeare controversy . "
A:, Q: "comes from its soul 's - eye view of how well-meaning patronizing masked a social injustice , at least as represented by this case "
A:, Q: we can tell what it is supposed to be , but ca n't really call it a work of art .
A:, Q: "as with so many merchandised-to-the-max movies of this type , more time appears to have gone into recruiting the right bands for the playlist and the costuming of the stars than into the script , whi...
A:, Q: a compelling allegory about the last days of germany 's democratic weimar republic .
A:, Q: easy feel-good sentiments
A:, Q: 're never quite sure where self-promotion ends and the truth begins
A:, Q: a little melodramatic , but with enough hope to keep you engaged .
A:, Q: rush to the theater
A:
Output: negative, positive, positive, positive, positive, negative, negative, positive, positive, negative, positive, positive
Input: Q: "lowbrow "
A:, Q: "wasting away "
A:, Q: "blade 2 is definitely a cut above the rest . "
A:, Q: "satisfyingly scarifying , fresh and old-fashioned at the same time "
A:, Q: "the buoyant energy level "
A:, Q: "we 've liked klein 's other work but rollerball left us cold . "
A:, Q: "raise audience 's spirits "
A:, Q: are many tense scenes in trapped
A:, Q: is having fun with it all
A:, Q: "farrelly brothers-style , down-and-dirty laugher "
A:, Q: "as shameful "
A:, Q: a dashing and absorbing
A:, Q: "shimmering cinematography that lends the setting the ethereal beauty of an asian landscape painting "
A:
Output: negative, negative, positive, positive, positive, negative, positive, positive, positive, positive, negative, positive, positive
Input: Q: "filter out the complexity "
A:, Q: "of classic romantic comedy to which it aspires "
A:, Q: "a hilarious ode to middle america and "
A:, Q: "sufficiently "
A:, Q: "plunging deeper "
A:, Q: "you 'd grab your kids and run and then probably call the police . "
A:, Q: "faith , love and power "
A:, Q: "there is no doubt that krawczyk deserves a huge amount of the credit for the film 's thoroughly winning tone . "
A:, Q: "the pretensions -- and disposable story -- sink the movie . "
A:, Q: "veers like a drunken driver through heavy traffic "
A:
Output: positive, positive, positive, positive, positive, negative, positive, positive, negative, negative
17 changes: 10 additions & 7 deletions prompt2model/dataset_retriever/description_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
allow_gated_datasets=False,
auto_transform_data: bool = False,
num_points_to_transform: int = 3000,
max_allowed_failed_transforms: int = 1000,
max_allowed_failed_transforms: int = None,
max_datasets_to_choose: int=3,
num_votes = 5
):
Expand Down Expand Up @@ -87,6 +87,8 @@ def __init__(
self.auto_transform_data = auto_transform_data
self.num_points_to_transform = num_points_to_transform
self.max_allowed_failed_transforms = max_allowed_failed_transforms
if max_allowed_failed_transforms is None:
self.max_allowed_failed_transforms = num_points_to_transform//3
self.max_datasets_to_choose = max_datasets_to_choose
self.num_votes = num_votes
self.initialize_search_index()
Expand Down Expand Up @@ -444,7 +446,7 @@ def make_dataset_from_samples(
return datasets.DatasetDict(dataset_dict)


def get_rerank_with_highest_votes(self, prompt, infos_dict, num_votes=3):
def get_rerank_with_highest_votes(self, prompt, infos_dict):
"""
Returns the dataset/config name with the highest number of votes based on the given prompt.
Expand All @@ -458,7 +460,7 @@ def get_rerank_with_highest_votes(self, prompt, infos_dict, num_votes=3):
"""
voting = []

for _ in range(num_votes):
for _ in range(self.num_votes):
curr_name = parse_prompt_to_fields(prompt, module_name="rerank")
if curr_name not in infos_dict:
logger.warning("LLM hallucinated dataset/config name: %s", curr_name)
Expand All @@ -473,7 +475,7 @@ def get_rerank_with_highest_votes(self, prompt, infos_dict, num_votes=3):
return chosen_one


def rerank_datasets(self, datasets_info_dict:dict, prompt_spec: PromptSpec, num_votes) -> str,str:
def rerank_datasets(self, datasets_info_dict:dict, prompt_spec: PromptSpec) -> tuple[str, str] | None:
"""Rerank datasets based on relevance to a given prompt specification.
This function takes a list of datasets and a prompt specification,
Expand All @@ -499,7 +501,7 @@ def rerank_datasets(self, datasets_info_dict:dict, prompt_spec: PromptSpec, num_
prompt_spec.instruction, prompt_spec.examples, datasets_info_dict
)

dataset_name = self.get_rerank_with_highest_votes(dataset_selection_prompt, datasets_info_dict, num_votes)
dataset_name = self.get_rerank_with_highest_votes(dataset_selection_prompt, datasets_info_dict)
if dataset_name is None: return None, None

time.sleep(10) # To avoid rate limiting
Expand All @@ -526,6 +528,7 @@ def canonicalize_dataset_automatically(
self,
top_dataset_info: dict,
prompt_spec: PromptSpec,
num_points_to_transform=0
):
"""Automatically canonicalize dataset (instead of cli).
Expand Down Expand Up @@ -589,7 +592,7 @@ def canonicalize_dataset_automatically(
)
logger.info("Unnecessary columns removed")

dataset_transformer = PromptBasedDatasetTransformer(num_points_to_transform=self.num_points_to_transform, max_allowed_failed_transforms=self.max_allowed_failed_transforms)
dataset_transformer = PromptBasedDatasetTransformer(num_points_to_transform=num_points_to_transform, max_allowed_failed_transforms=self.max_allowed_failed_transforms)
inputs, outputs = dataset_transformer.transform_data(
prompt_spec,
dataset=full_dataset["train"]
Expand Down Expand Up @@ -635,7 +638,7 @@ def get_datasets_of_required_size(self, dataset_list, prompt_spec):

dataset_info = datasets_info_dict[dataset_name]["configs"][config_name]
canonicalized_dataset = self.canonicalize_dataset_automatically(
dataset_info, prompt_spec, self.num_points_to_transform - curr_datasets_size + self.max_allowed_failed_transforms
dataset_info, prompt_spec, self.num_points_to_transform - curr_datasets_size
)
curr_datasets_size += len(canonicalized_dataset["train"]["input_col"])
inputs += canonicalized_dataset["train"]["input_col"]
Expand Down
7 changes: 4 additions & 3 deletions prompt2model/dataset_transformer/prompt_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def generate_plan(self, task_explanation, dataset, prompt_spec) -> str:
return make_single_api_request(plan_prompt, max_api_calls=100)


def generate_transform_prompts(self, task_explanation:str, prompt_spec:PromptSpec, dataset:datasets.Dataset) -> List[str]:
def generate_transform_prompts(self, task_explanation:str, dataset:datasets.Dataset, prompt_spec:PromptSpec,) -> List[str]:
transform_prompts = []
for i in range(min(self.num_points_to_transform, len(dataset))):
row = dataset[i]
Expand All @@ -78,7 +78,7 @@ async def generate_responses_async(transform_prompts):
"""
Generate responses asynchronously using the specified model.
"""
responses = await api_tools.APIAgent(model_name="azure/GPT-3-5-turbo-sweden", max_tokens=4000).generate_batch_completion(
responses = await api_tools.APIAgent(model_name="azure/GPT-3-5-turbo-chat", max_tokens=4000).generate_batch_completion(
transform_prompts,
temperature=0,
responses_per_request=1,
Expand Down Expand Up @@ -110,6 +110,7 @@ def process_responses(self, responses:list, prompt_spec) -> Tuple[List[str], Lis
- outputs: A list of transformed output strings.
"""
inputs, outputs = [], []
counter=0
for response in responses:
try:
extraction = find_and_parse_json(response, ["input", "output"], [])
Expand All @@ -128,6 +129,7 @@ def process_responses(self, responses:list, prompt_spec) -> Tuple[List[str], Lis
if counter < 2:
logger.info(f"inputs\n{str1}\n\nouputs\n{str2}")
counter += 1
counter+=1

except Exception as e:
logger.error(f"Error extracting from response: {e}")
Expand All @@ -154,7 +156,6 @@ def transform_data(self, prompt_spec, dataset: datasets.Dataset) -> datasets.Dat
curr_inputs, curr_outputs = self.process_responses(responses, prompt_spec)
inputs.extend(curr_inputs)
outputs.extend(curr_outputs)

if self.curr_failed_transforms > self.max_allowed_failed_transforms:
logger.error(f"Exceeded max allowed failed transforms: {self.curr_failed_transforms}")
break
Expand Down
3 changes: 2 additions & 1 deletion prompt2model/dataset_transformer/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ def construct_prompt_for_transform_data(
str: Prompt for dataset transformation.
"""

incontext_tasks = [VITAMINC]
# incontext_tasks = [VITAMINC]
incontext_tasks = []
incontext_examples = []

for incontext_task in incontext_tasks:
Expand Down
62 changes: 29 additions & 33 deletions prompt2model_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main():

while True:
line_print("Do you want to start from scratch? (y/n)")
answer = input()
answer = "y"
if answer.lower() == "n":
if os.path.isfile("status.yaml"):
with open("status.yaml", "r") as f:
Expand All @@ -154,17 +154,17 @@ def main():
dataset_has_been_generated = status.get("dataset_has_been_generated", False)
model_has_been_trained = status.get("model_has_been_trained", False)
if not propmt_has_been_parsed:
prompt = ""
line_print(
"Enter your task description and few-shot examples (or 'done' to finish):"
)
time.sleep(2)
while True:
line = input()
if line == "done":
break
prompt += line + "\n"
line_print("Parsing prompt...")
prompt = "sentiment detection, where sentiments can be [positive, negative, neutral]"
# line_print(
# "Enter your task description and few-shot examples (or 'done' to finish):"
# )
# time.sleep(2)
# while True:
# line = input()
# if line == "done":
# break
# prompt += line + "\n"
# line_print("Parsing prompt...")
prompt_spec = PromptBasedInstructionParser(task_type=TaskType.TEXT_GENERATION)
prompt_spec.parse_from_prompt(prompt)

Expand All @@ -185,26 +185,26 @@ def main():
line_print(
"Data transformation converts retrieved data into the desired format as per the prompt." # noqa E501
)
auto_transform_data = False
while True:
line = input()
if line.lower() == "y":
auto_transform_data = True
break
elif line.lower() == "n":
auto_transform_data = False
break
else:
line_print("Invalid input. Please enter y or n.")

retriever = DescriptionDatasetRetriever()
auto_transform_data = True
# while True:
# line = input()
# if line.lower() == "y":
# auto_transform_data = True
# break
# elif line.lower() == "n":
# auto_transform_data = False
# break
# else:
# line_print("Invalid input. Please enter y or n.")



if auto_transform_data:
while True:
line_print(
"Enter the number of data points you want to transform (the remaining data points in the dataset will be discarded):" # noqa E501
)
line = input()
line = 10
try:
num_points_to_transform = int(line)
except ValueError:
Expand All @@ -215,13 +215,9 @@ def main():
continue
status["num_transform"] = num_points_to_transform
break
retrieved_dataset_dict = retriever.retrieve_dataset_dict(
prompt_spec,
auto_transform_data=True,
num_points_to_transform=num_points_to_transform,
)
else:
retrieved_dataset_dict = retriever.retrieve_dataset_dict(prompt_spec)
retriever = DescriptionDatasetRetriever(auto_transform_data=auto_transform_data, num_points_to_transform=num_points_to_transform, num_votes=1)
retrieved_dataset_dict = retriever.retrieve_dataset_dict(prompt_spec)
breakpoint()

dataset_has_been_retrieved = True
if retrieved_dataset_dict is not None:
Expand Down

0 comments on commit acd6e7a

Please sign in to comment.