From 9f317f53248e4905e336be89780d991f39bd5171 Mon Sep 17 00:00:00 2001 From: Minsuk Chang Date: Sat, 19 Oct 2024 11:08:56 -0700 Subject: [PATCH] refactored the handler PiperOrigin-RevId: 687652961 Change-Id: Ice298ceba7e0e7474351488d40eaffc5a764737e --- .../google_cloud_custom_model.py | 136 ++++++------------ 1 file changed, 46 insertions(+), 90 deletions(-) diff --git a/concordia/language_model/google_cloud_custom_model.py b/concordia/language_model/google_cloud_custom_model.py index c5aa877c..c63731e2 100644 --- a/concordia/language_model/google_cloud_custom_model.py +++ b/concordia/language_model/google_cloud_custom_model.py @@ -12,19 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Language Model that uses Google Cloud Vertex AI API. - -Recommended model names are: - 'gemma-2-9b-it' - 'gemma-2-27b-it' -""" - -# Before running: -# 1. Create a Google Cloud project and enable the Vertex AI API. -# 2. Authenticate with Google Cloud credentials -# 3. Find the model you want to use and upload it to the project. -# 4. Deploy the model to an endpoint in the region. -# 5. Replace model_name, project, and location with your actual values. +"""Language Model that uses Google Cloud Vertex AI API.""" from collections.abc import Collection, Sequence import random @@ -51,46 +39,40 @@ class VertexAI(language_model.LanguageModel): you can find the endpoint_id, and region in the Vertex AI model registry page. the quickest way to find these info at once is to go to the Vertex AI model - registry page, and click on the model you want to use. + registry page in Google Cloud, and click on the model you want to use. Then, click on the Sample Request link, it'll open a panel on the right, click on the PYTHON tab, under instruction number 3, you'll see all three - project_id, endpoint_id and region info there. + project_id, endpoint_id and region info. """ def __init__( self, - model_name: str, # endpoint ID, all numbers + endpoint_id: str, # all numbers *, - project: str, # project ID, all numbers - location: str = "us-central1", # e.g., "us-central1" + project_id: str, # all numbers + location: str = "us-central1", measurements: measurements_lib.Measurements | None = None, channel: str = language_model.DEFAULT_STATS_CHANNEL, ): - """Initializes the instance. - - Args: - model_name: The endpoint ID of the language model to use. - project: Your Google Cloud project ID. - location: The region where the model is deployed. - measurements: The measurements object to log usage statistics to. - channel: The channel to write the statistics to. - """ - self._model_name = model_name - self._project = project + """Initializes the instance.""" + self._endpoint_id = endpoint_id + self._project_id = project_id self._location = location self._measurements = measurements self._channel = channel + # Initialize the client and endpoint name *once* + self._endpoint_name = f"projects/{self._project_id}/locations/{self._location}/endpoints/{self._endpoint_id}" + self._api_endpoint = f"{self._location}-aiplatform.googleapis.com" + self._client = aiplatform.gapic.PredictionServiceClient( + client_options={"api_endpoint": self._api_endpoint} + ) + self._parameters = { + "top_p": 0.95, + "top_k": 40, + } + @override - # sample_text: - # Uses aiplatform.gapic.PredictionServiceClient().predict to make - # prediction requests. - # Might need to adjust response[0]["content"] based on the actual structure). - # Includes top_p and top_k parameters as examples - # Add max_output_tokens which corresponds to Together's max_tokens. - # Applies terminators after generation. Because Vertex AI doesn't support - # terminators during generation as post-processing step. - # It will find the last instance of a terminator and truncate there. def sample_text( self, prompt: str, @@ -99,22 +81,12 @@ def sample_text( terminators: Collection[str] = language_model.DEFAULT_TERMINATORS, temperature: float = language_model.DEFAULT_TEMPERATURE, timeout: float = language_model.DEFAULT_TIMEOUT_SECONDS, - seed: int | None = None, # Vertex doesn't directly support seed. + seed: int | None = None, ) -> str: - - endpoint_name = ( - "projects/" - + self._project - + "locations/" - + self._location - + "/endpoints/" - + self._model_name - ) - api_endpoint = self._location + "-aiplatform.googleapis.com" - max_tokens = min(max_tokens, _DEFAULT_MAX_TOKENS) + self._parameters["temperature"] = temperature + self._parameters["max_output_tokens"] = max_tokens - result = "" for attempts in range(_MAX_ATTEMPTS): if attempts > 0: seconds_to_sleep = _SECONDS_TO_SLEEP_WHEN_RATE_LIMITED + random.uniform( @@ -127,25 +99,13 @@ def sample_text( ) time.sleep(seconds_to_sleep) - client_options = {"api_endpoint": api_endpoint} - client = aiplatform.gapic.PredictionServiceClient( - client_options=client_options - ) - try: - response = ( - client.predict( - endpoint=endpoint_name, - instances=[{"inputs": prompt}], - parameters={ - "temperature": temperature, - "max_output_tokens": max_tokens, - "top_p": 0.95, # Example: Add other parameters as needed - "top_k": 40, - }, - ) - .predictions[0] - ) + response = self._client.predict( + endpoint=self._endpoint_name, + instances=[{"inputs": prompt}], + parameters=self._parameters, + ).predictions[0] + result = response # Apply terminators @@ -154,19 +114,21 @@ def sample_text( result = result[: result.index(terminator)] + terminator break - break # Success, exit the retry loop + if self._measurements is not None: + self._measurements.publish_datum( + self._channel, + {"raw_text_length": len(result)}, + ) + + return result + except Exception as err: # pylint: disable=broad-exception-caught if attempts >= _NUM_SILENT_ATTEMPTS: print(f" Exception: {err}") - continue - if self._measurements is not None: - self._measurements.publish_datum( - self._channel, - {"raw_text_length": len(result)}, - ) - - return result + raise RuntimeError( + f"Failed to get a response after {_MAX_ATTEMPTS} attempts." + ) @override def sample_choice( @@ -174,7 +136,7 @@ def sample_choice( prompt: str, responses: Sequence[str], *, - seed: int | None = None, # Seed is not supported by Vertex AI. + seed: int | None = None, ) -> tuple[int, str, dict[str, float]]: prompt = ( prompt @@ -190,25 +152,19 @@ def sample_choice( attempts, _MAX_MULTIPLE_CHOICE_ATTEMPTS ) - sample = self.sample_text( - prompt, - temperature=temperature, - seed=seed, - ) + sample = self.sample_text(prompt, temperature=temperature, seed=seed) answer = sampling.extract_choice_response(sample) try: idx = responses.index(answer) - except ValueError: - continue - else: if self._measurements is not None: self._measurements.publish_datum( self._channel, {"choices_calls": attempts} ) - debug = {} - return idx, responses[idx], debug + return idx, responses[idx], {} + except ValueError: + pass raise language_model.InvalidResponseError(( f"Too many multiple choice attempts.\nLast attempt: {sample}, " - + f"extracted: {answer}" + f"extracted: {answer}" ))