Skip to content

Commit

Permalink
adding prompt template-based solution for scoring candidates with API…
Browse files Browse the repository at this point in the history
…-based m odels
  • Loading branch information
dotpyu committed Dec 28, 2023
1 parent 236dac1 commit 0dec2d5
Show file tree
Hide file tree
Showing 14 changed files with 273 additions and 79 deletions.
36 changes: 34 additions & 2 deletions alfred/fm/ai21.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Optional, List
from typing import Optional, List, Union, Tuple

import requests

from .model import APIAccessFoundationModel
from .response import CompletionResponse
from .response import CompletionResponse, RankedResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,3 +113,35 @@ def _generate_batch(
)
)
return output

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Tentative solution for scoring candidates.
:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._ai21_query(
_scoring_prompt, model=self.model_string, **kwargs
),
scores={},
)
)
return output
73 changes: 53 additions & 20 deletions alfred/fm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import readline

from .model import APIAccessFoundationModel
from .response import CompletionResponse
from .response import CompletionResponse, RankedResponse
from .utils import colorize_str, type_print

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,12 +38,12 @@ class AnthropicModel(APIAccessFoundationModel):
"""

def _anthropic_query(
self,
query: Union[str, List],
temperature: float = 0.0,
max_tokens: int = 3,
model: str = "claude-instant-1",
**kwargs: Any,
self,
query: Union[str, List],
temperature: float = 0.0,
max_tokens: int = 3,
model: str = "claude-instant-1",
**kwargs: Any,
) -> str:
"""
Run a single query through the foundation model
Expand Down Expand Up @@ -85,7 +85,7 @@ def _anthropic_query(
return response["completion"]

def __init__(
self, model_string: str = "claude-instant-1", api_key: Optional[str] = None
self, model_string: str = "claude-instant-1", api_key: Optional[str] = None
):
"""
Initialize the Anthropic API wrapper.
Expand All @@ -100,7 +100,7 @@ def __init__(
:type api_key: Optional[str]
"""
assert (
model_string in ANTHROPIC_MODELS
model_string in ANTHROPIC_MODELS
), f"Model {model_string} not found. Please choose from {ANTHROPIC_MODELS}"

if "ANTHROPIC_API_KEY" in os.environ:
Expand All @@ -124,9 +124,9 @@ def __init__(
super().__init__(model_string, {"api_key": api_key})

def _generate_batch(
self,
batch_instance: List[str],
**kwargs,
self,
batch_instance: List[str],
**kwargs,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of prompts using the anthropic API.
Expand All @@ -152,6 +152,38 @@ def _generate_batch(
)
return output

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Tentative solution for scoring candidates.
:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._anthropic_query(
_scoring_prompt, model=self.model_string, **kwargs
),
scores={},
)
)
return output

def chat(self, **kwargs: Any):
"""
Launch an interactive chat session with the Anthropic API.
Expand All @@ -165,9 +197,10 @@ def _feedback(feedback: str, no_newline=False, override=False):
end="",
)
type_print(feedback)
print("",
end="\n" if not no_newline else "",
)
print(
"",
end="\n" if not no_newline else "",
)

model = kwargs.get("model", self.model_string)
c_title = colorize_str("Alfred's Anthropic Chat", "BLUE")
Expand Down Expand Up @@ -205,11 +238,11 @@ def _feedback(feedback: str, no_newline=False, override=False):
message_log.append({"role": "user", "content": query})
response = []
for resp in self._anthropic_query(
query,
chat=True,
model=model,
temperature=temperature,
max_tokens=max_tokens,
query,
chat=True,
model=model,
temperature=temperature,
max_tokens=max_tokens,
):
if resp["stop_reason"] in ["stop", "stop_sequence"]:
break
Expand Down
36 changes: 34 additions & 2 deletions alfred/fm/cohere.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Optional, List, Any
from typing import Optional, List, Any, Union, Tuple

import torch

from .model import APIAccessFoundationModel
from .response import CompletionResponse
from .response import CompletionResponse, RankedResponse

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,6 +47,38 @@ def _cohere_query(
)
return response.generations[0].text

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Tentative solution for scoring candidates.
:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._cohere_query(
_scoring_prompt, model=self.model_string, **kwargs
),
scores={},
)
)
return output

def _cohere_embedding_query(
self,
query_string: str,
Expand Down
76 changes: 49 additions & 27 deletions alfred/fm/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from .model import APIAccessFoundationModel
from .response import CompletionResponse
from .response import CompletionResponse, RankedResponse
from .utils import colorize_str, retry, type_print

logger = logging.getLogger(__name__)
Expand All @@ -23,17 +23,12 @@
"Google GenAI module not found. Please install google-generativeai to use the Google model."
)

GOOGLE_GENAI_MODELS = (
"gemini-pro",
)
GOOGLE_GENAI_MODELS = ("gemini-pro",)

GOOGLE_GENAI_VISION_MODELS = (
"gemini-pro-vision",
)
GOOGLE_GENAI_VISION_MODELS = ("gemini-pro-vision",)

GOOGLE_GENAI_EMBEDDING_MODELS = ("embedding-001",)

GOOGLE_GENAI_EMBEDDING_MODELS = (
"embedding-001",
)

class GoogleModel(APIAccessFoundationModel):
"""
Expand All @@ -45,9 +40,7 @@ class GoogleModel(APIAccessFoundationModel):
@retry(
num_retries=3,
wait_time=0.1,
exceptions=(
Exception
),
exceptions=(Exception),
)
def _google_genai_query(
self,
Expand All @@ -73,27 +66,25 @@ def _google_genai_query(
if self.model_string in GOOGLE_GENAI_VISION_MODELS:
img, prompt = query[0], query[1]
if not isinstance(img, PIL.Image.Image):
raise ValueError(f"Image type {type(img)} not supported. Please use PIL.Image!")
query = [
prompt, img
] if len(prompt) > 0 else [img]
raise ValueError(
f"Image type {type(img)} not supported. Please use PIL.Image!"
)
query = [prompt, img] if len(prompt) > 0 else [img]
response = self.model.generate_content(
query,
generation_config=genai.types.GenerationConfig(
candidate_count=1,
stop_sequences=['x'],
stop_sequences=["x"],
max_output_tokens=max_tokens,
temperature=temperature,
)
),
)
return response.text

@retry(
num_retries=3,
wait_time=0.1,
exceptions=(
Exception
),
exceptions=(Exception),
)
def _google_genai_embedding_query(
self,
Expand All @@ -114,12 +105,11 @@ def _google_genai_embedding_query(
model=f"models/{self.model_string}",
content=query_string,
task_type="retrieval_document",
title="Embedding of single string")
title="Embedding of single string",
)
)

def __init__(
self, model_string: str = "gemini-pro", api_key: Optional[str] = None
):
def __init__(self, model_string: str = "gemini-pro", api_key: Optional[str] = None):
"""
Initialize the Google API wrapper.
Expand All @@ -138,7 +128,9 @@ def __init__(
raise RuntimeError("Google GenAI requires Python 3.9+")
assert (
model_string
in GOOGLE_GENAI_MODELS + GOOGLE_GENAI_EMBEDDING_MODELS + GOOGLE_GENAI_VISION_MODELS
in GOOGLE_GENAI_MODELS
+ GOOGLE_GENAI_EMBEDDING_MODELS
+ GOOGLE_GENAI_VISION_MODELS
), (
f"Model {model_string} not found. "
f"Please choose from {GOOGLE_GENAI_MODELS + GOOGLE_GENAI_EMBEDDING_MODELS + GOOGLE_GENAI_VISION_MODELS}"
Expand Down Expand Up @@ -189,6 +181,36 @@ def _generate_batch(
)
return output

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Tentative solution for scoring candidates.
:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._google_genai_query(_scoring_prompt, **kwargs),
scores={},
)
)
return output

def _encode_batch(
self,
batch_instance: [List[str]],
Expand Down
6 changes: 4 additions & 2 deletions alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def _score_batch(
max_length=self.max_position_embeddings,
)


end_device = list(self.model.hf_device_map.values())[-1]

logger.log(logging.INFO, f"Ranking {len(candidate)} instances")
Expand All @@ -276,7 +275,10 @@ def _score_batch(
padding=True,
truncation=True,
max_length=self.max_position_embeddings,
add_special_tokens=not (isinstance(self.model, LlamaPreTrainedModel) or isinstance(self.model, MistralPreTrainedModel)),
add_special_tokens=not (
isinstance(self.model, LlamaPreTrainedModel)
or isinstance(self.model, MistralPreTrainedModel)
),
return_tensors="pt",
)
candidate_token_ids = candidate_tokens.input_ids.to(end_device)
Expand Down
Loading

0 comments on commit 0dec2d5

Please sign in to comment.