Skip to content

Commit

Permalink
refactored the handler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687652961
Change-Id: Ice298ceba7e0e7474351488d40eaffc5a764737e
  • Loading branch information
minsukchang authored and copybara-github committed Oct 19, 2024
1 parent e24a50d commit 9f317f5
Showing 1 changed file with 46 additions and 90 deletions.
136 changes: 46 additions & 90 deletions concordia/language_model/google_cloud_custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Language Model that uses Google Cloud Vertex AI API.
Recommended model names are:
'gemma-2-9b-it'
'gemma-2-27b-it'
"""

# Before running:
# 1. Create a Google Cloud project and enable the Vertex AI API.
# 2. Authenticate with Google Cloud credentials
# 3. Find the model you want to use and upload it to the project.
# 4. Deploy the model to an endpoint in the region.
# 5. Replace model_name, project, and location with your actual values.
"""Language Model that uses Google Cloud Vertex AI API."""

from collections.abc import Collection, Sequence
import random
Expand All @@ -51,46 +39,40 @@ class VertexAI(language_model.LanguageModel):
you can find the endpoint_id, and region in the Vertex AI model registry page.
the quickest way to find these info at once is to go to the Vertex AI model
registry page, and click on the model you want to use.
registry page in Google Cloud, and click on the model you want to use.
Then, click on the Sample Request link, it'll open a panel on the right,
click on the PYTHON tab, under instruction number 3, you'll see all three
project_id, endpoint_id and region info there.
project_id, endpoint_id and region info.
"""

def __init__(
self,
model_name: str, # endpoint ID, all numbers
endpoint_id: str, # all numbers
*,
project: str, # project ID, all numbers
location: str = "us-central1", # e.g., "us-central1"
project_id: str, # all numbers
location: str = "us-central1",
measurements: measurements_lib.Measurements | None = None,
channel: str = language_model.DEFAULT_STATS_CHANNEL,
):
"""Initializes the instance.
Args:
model_name: The endpoint ID of the language model to use.
project: Your Google Cloud project ID.
location: The region where the model is deployed.
measurements: The measurements object to log usage statistics to.
channel: The channel to write the statistics to.
"""
self._model_name = model_name
self._project = project
"""Initializes the instance."""
self._endpoint_id = endpoint_id
self._project_id = project_id
self._location = location
self._measurements = measurements
self._channel = channel

# Initialize the client and endpoint name *once*
self._endpoint_name = f"projects/{self._project_id}/locations/{self._location}/endpoints/{self._endpoint_id}"
self._api_endpoint = f"{self._location}-aiplatform.googleapis.com"
self._client = aiplatform.gapic.PredictionServiceClient(
client_options={"api_endpoint": self._api_endpoint}
)
self._parameters = {
"top_p": 0.95,
"top_k": 40,
}

@override
# sample_text:
# Uses aiplatform.gapic.PredictionServiceClient().predict to make
# prediction requests.
# Might need to adjust response[0]["content"] based on the actual structure).
# Includes top_p and top_k parameters as examples
# Add max_output_tokens which corresponds to Together's max_tokens.
# Applies terminators after generation. Because Vertex AI doesn't support
# terminators during generation as post-processing step.
# It will find the last instance of a terminator and truncate there.
def sample_text(
self,
prompt: str,
Expand All @@ -99,22 +81,12 @@ def sample_text(
terminators: Collection[str] = language_model.DEFAULT_TERMINATORS,
temperature: float = language_model.DEFAULT_TEMPERATURE,
timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS,
seed: int | None = None, # Vertex doesn't directly support seed.
seed: int | None = None,
) -> str:

endpoint_name = (
"projects/"
+ self._project
+ "locations/"
+ self._location
+ "/endpoints/"
+ self._model_name
)
api_endpoint = self._location + "-aiplatform.googleapis.com"

max_tokens = min(max_tokens, _DEFAULT_MAX_TOKENS)
self._parameters["temperature"] = temperature
self._parameters["max_output_tokens"] = max_tokens

result = ""
for attempts in range(_MAX_ATTEMPTS):
if attempts > 0:
seconds_to_sleep = _SECONDS_TO_SLEEP_WHEN_RATE_LIMITED + random.uniform(
Expand All @@ -127,25 +99,13 @@ def sample_text(
)
time.sleep(seconds_to_sleep)

client_options = {"api_endpoint": api_endpoint}
client = aiplatform.gapic.PredictionServiceClient(
client_options=client_options
)

try:
response = (
client.predict(
endpoint=endpoint_name,
instances=[{"inputs": prompt}],
parameters={
"temperature": temperature,
"max_output_tokens": max_tokens,
"top_p": 0.95, # Example: Add other parameters as needed
"top_k": 40,
},
)
.predictions[0]
)
response = self._client.predict(
endpoint=self._endpoint_name,
instances=[{"inputs": prompt}],
parameters=self._parameters,
).predictions[0]

result = response

# Apply terminators
Expand All @@ -154,27 +114,29 @@ def sample_text(
result = result[: result.index(terminator)] + terminator
break

break # Success, exit the retry loop
if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{"raw_text_length": len(result)},
)

return result

except Exception as err: # pylint: disable=broad-exception-caught
if attempts >= _NUM_SILENT_ATTEMPTS:
print(f" Exception: {err}")
continue

if self._measurements is not None:
self._measurements.publish_datum(
self._channel,
{"raw_text_length": len(result)},
)

return result
raise RuntimeError(
f"Failed to get a response after {_MAX_ATTEMPTS} attempts."
)

@override
def sample_choice(
self,
prompt: str,
responses: Sequence[str],
*,
seed: int | None = None, # Seed is not supported by Vertex AI.
seed: int | None = None,
) -> tuple[int, str, dict[str, float]]:
prompt = (
prompt
Expand All @@ -190,25 +152,19 @@ def sample_choice(
attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS
)

sample = self.sample_text(
prompt,
temperature=temperature,
seed=seed,
)
sample = self.sample_text(prompt, temperature=temperature, seed=seed)
answer = sampling.extract_choice_response(sample)
try:
idx = responses.index(answer)
except ValueError:
continue
else:
if self._measurements is not None:
self._measurements.publish_datum(
self._channel, {"choices_calls": attempts}
)
debug = {}
return idx, responses[idx], debug
return idx, responses[idx], {}
except ValueError:
pass

raise language_model.InvalidResponseError((
f"Too many multiple choice attempts.\nLast attempt: {sample}, "
+ f"extracted: {answer}"
f"extracted: {answer}"
))

0 comments on commit 9f317f5

Please sign in to comment.