From 5d25a823f5b27450277b8b227566d48ca1dfde11 Mon Sep 17 00:00:00 2001 From: Samuel Monson Date: Wed, 16 Oct 2024 20:12:15 -0400 Subject: [PATCH] Refactor OpenAI Plugin (#57) * fix: type annotate RequestResult * feat: refactor streaming requests for OpenAI AI plugin * fix: return if the server response is blank * fix: flatten usage conditional * feat: switch to config var to determine api type * fix: finish making non-streaming api robust against bad responses * fix: make RequestResult.calculate_result type safe * fix: off by one error in output_tokens_before_timeout count * fix: check all tokens for output_tokens_before_timeout * fix: organize misc pieces * feat: clean up the Plugin ABC * feat: dont record blank tokens that come before first non-blank * fix: move deepget back to plugin This avoids a circular dependency * Revert "feat: clean up the Plugin ABC" This reverts commit 3e5ff9b21d1bfa14040a196d41dd5a4392cd239b. We can't do this without changing how the plugins handle request_func since @abstractmethods must be implemented not assigned. * fix: don't rely on usage data for token count * feat: set sampling defaults for all request types * fix: implement feedback * fix: standardize max_output accross all request types * feat: count tokens in each response * fix: log missing special responses * fix: track first token by usage * fix: omit non-tokens as they break latency tracking * fix: improve documentation and logging --- plugins/openai_plugin.py | 344 ++++++++++++++++++++++++--------------- result.py | 68 ++++---- 2 files changed, 251 insertions(+), 161 deletions(-) diff --git a/plugins/openai_plugin.py b/plugins/openai_plugin.py index 64e9b0e9..e4d9c19a 100644 --- a/plugins/openai_plugin.py +++ b/plugins/openai_plugin.py @@ -1,6 +1,7 @@ import json import logging import time +from typing import Any, Optional, Union import requests import urllib3 @@ -21,9 +22,31 @@ """ required_args = ["host", "streaming", "endpoint"] +APIS = ["legacy", "chat"] logger = logging.getLogger("user") +def deepget(obj: Union[dict, list], *path: Any, default: Any = None) -> Any: + """ + Acts like .get() but for nested objects. + + Each item in path is recusively indexed on obj. For path of length N, + obj[path[0]][path[1]]...[path[N-1]][path[N]] + + :param obj: root object to index + :param path: ordered list of keys to index recursively + :param default: the default value to return if an indexing fails + :returns: result of final index or default if Key/Index Error occurs + """ + current = obj + for pos in path: + try: + current = current[pos] + except (KeyError, IndexError): + return default + return current + + # This plugin is written primarily for testing vLLM, though it can be made # to work for other runtimes which conform to the OpenAI API, as required. class OpenAIPlugin(plugin.Plugin): @@ -48,6 +71,33 @@ def _parse_args(self, args): logger.debug("Model name: %s", self.model_name) + self.api = args.get('api') + + if not self.api: + self.api = 'chat' if "/v1/chat/completions" in self.host else 'legacy' + + if self.api not in APIS: + logger.error("Invalid api type: %s", self.api) + + # TODO Make this configurable + self.request_defaults = dict( + temperature = 0.0, + seed = 42, + ) + + def _process_resp(self, resp: bytes) -> Optional[dict]: + try: + _, found, data = resp.partition(b"data: ") + if not found: + return None + message = json.loads(data) + logger.debug("Message: %s", message) + except json.JSONDecodeError: + logger.exception("Response line could not be json decoded: %s", resp) + return None + + return message + def request_http(self, query: dict, user_id: int, test_end_time: float = 0): result = RequestResult(user_id, query.get("text"), query.get("input_tokens")) @@ -56,25 +106,23 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0): headers = {"Content-Type": "application/json"} - if "/v1/chat/completions" in self.host: - data = { - "messages": [ - {"role": "user", "content": query["text"]} - ], - "max_tokens": query["output_tokens"], - "temperature": 0.1, - } - else: - data = { - "prompt": query["text"], - "max_tokens": query["output_tokens"], - "min_tokens": query["output_tokens"], - "temperature": 1.0, - "top_p": 0.9, - "seed": 10, - } + request = { + "max_tokens": query["output_tokens"], + "min_tokens": query["output_tokens"], + } + + if self.api == 'chat': + request["messages"] = [ + { "role": "user", "content": query["text"] } + ] + else: # self.api == 'legacy' + request["prompt"] = query["text"], + if self.model_name is not None: - data["model"] = self.model_name + request["model"] = self.model_name + + # Merge request and defaults + data = self.request_defaults | request response = None try: @@ -97,21 +145,24 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0): result.end_time = time.time() + ########################################### + # DO NOT CALL time.time BEYOND THIS POINT # + ########################################### + logger.debug("Response: %s", json.dumps(response.text)) try: message = json.loads(response.text) error = message.get("error") if error is None: - if "/v1/chat/completions" in self.host: - #result.output_text = message["choices"][0]['delta']['content'] - result.output_text = message["choices"][0]['message']['content'] - else: - result.output_text = message["choices"][0]["text"] - - result.output_tokens = message["usage"]["completion_tokens"] - result.input_tokens = message["usage"]["prompt_tokens"] - result.stop_reason = message["choices"][0]["finish_reason"] + if self.api == 'chat': + result.output_text = deepget(message, "choices", 0, 'delta', 'content') + else: # self.api == 'legacy' + result.output_text = deepget(message, "choices", 0, 'text') + + result.output_tokens = deepget(message, "usage", "completion_tokens") + result.input_tokens = deepget(message, "usage", "prompt_tokens") + result.stop_reason = deepget(message, "choices", 0, "finish_reason") else: result.error_code = response.status_code result.error_text = error @@ -119,9 +170,6 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0): except json.JSONDecodeError: logger.exception("Response could not be json decoded: %s", response.text) result.error_text = f"Response could not be json decoded {response.text}" - except KeyError: - logger.exception("KeyError, unexpected response format: %s", response.text) - result.error_text = f"KeyError, unexpected response format: {response.text}" # For non-streaming requests we are keeping output_tokens_before_timeout and output_tokens same. result.output_tokens_before_timeout = result.output_tokens @@ -133,29 +181,31 @@ def request_http(self, query: dict, user_id: int, test_end_time: float = 0): def streaming_request_http(self, query: dict, user_id: int, test_end_time: float): headers = {"Content-Type": "application/json"} - data = { - "max_tokens": query["output_tokens"], - "temperature": 0.1, - "stream": True, - "stream_options": { - "include_usage": True - } + request = { + "max_tokens": query["output_tokens"], + "min_tokens": query["output_tokens"], + "stream": True, + "stream_options": { + "include_usage": True } - if "/v1/chat/completions" in self.host: - data["messages"] = [ - {"role": "user", "content": query["text"]} - ] - else: - data["prompt"] = query["text"], - data["min_tokens"] = query["output_tokens"] + } + + if self.api == 'chat': + request["messages"] = [ + { "role": "user", "content": query["text"] } + ] + else: # self.api == 'legacy' + request["prompt"] = query["text"], # some runtimes only serve one model, won't check this. if self.model_name is not None: - data["model"] = self.model_name + request["model"] = self.model_name + + # Merge request and defaults + data = self.request_defaults | request result = RequestResult(user_id, query.get("input_id")) - tokens = [] response = None result.start_time = time.time() try: @@ -163,113 +213,147 @@ def streaming_request_http(self, query: dict, user_id: int, test_end_time: float self.host, headers=headers, json=data, verify=False, stream=True ) response.raise_for_status() - except requests.exceptions.ConnectionError as err: + except ( + requests.exceptions.ConnectionError, + requests.exceptions.HTTPError + ) as err: result.end_time = time.time() result.error_text = repr(err) if response is not None: result.error_code = response.status_code logger.exception("Connection error") return result - except requests.exceptions.HTTPError as err: - result.end_time = time.time() - result.error_text = repr(err) - if response is not None: - result.error_code = response.status_code - logger.exception("HTTP error") - return result - logger.debug("Response: %s", response) - message = None + resps = [] try: for line in response.iter_lines(): - logger.debug("response line: %s", line) - _, found, data = line.partition(b"data: ") - if found and data != b"[DONE]": - try: - message = json.loads(data) - logger.debug("Message: %s", message) - if "/v1/chat/completions" in self.host and not message["choices"][0]['delta'].get('content'): - message["choices"][0]['delta']['content']="" - error = message.get("error") - if error is None: - # If stream_options.include_usage == True then the final - # message contains only token stats - if not message.get("choices") and message.get('usage'): - result.output_tokens = message["usage"]["completion_tokens"] - result.input_tokens = message["usage"]["prompt_tokens"] - # We don't want to record this message - continue - if "/v1/chat/completions" in self.host: - token = message["choices"][0]['delta']['content'] - else: - token = message["choices"][0]["text"] - logger.debug("Token: %s", token) - else: - result.error_code = response.status_code - result.error_text = error - logger.error("Error received in response message: %s", error) - break - except json.JSONDecodeError: - logger.exception("Response line could not be json decoded: %s", line) - except KeyError: - logger.exception( - "KeyError, unexpected response format in line: %s", line - ) - continue - else: - continue - - try: - # First chunk may not be a token, just a connection ack - if not result.ack_time: - result.ack_time = time.time() - - # First non empty token is the first token - if not result.first_token_time and token != "": - result.first_token_time = time.time() - - # If the current token time is outside the test duration, record the total tokens received before - # the current token. - if ( - not result.output_tokens_before_timeout - and time.time() > test_end_time - ): - result.output_tokens_before_timeout = len(tokens) - - tokens.append(token) - - # Last token comes with finish_reason set. - if message.get("choices", [])[0].get("finish_reason", None): - result.stop_reason = message["choices"][0]["finish_reason"] - - # If test duration timeout didn't happen before the last token is received, - # total tokens before the timeout will be equal to the total tokens in the response. - if not result.output_tokens_before_timeout: - result.output_tokens_before_timeout = result.output_tokens - - except KeyError: - logger.exception("KeyError, unexpected response format in line: %s", line) + recv_time = time.time() # Record time asap + # Only record lines with data + if line: + logger.debug("response line: %s", line) + resps.append(dict( + time = recv_time, + data = line + )) + # Full response received + result.end_time = time.time() except requests.exceptions.ChunkedEncodingError as err: result.end_time = time.time() result.error_text = repr(err) - result.output_text = "".join(tokens) - result.output_tokens = len(tokens) + #result.output_text = "".join([]) + result.output_tokens = len(resps) if response is not None: result.error_code = response.status_code logger.exception("ChunkedEncodingError while streaming response") return result + ########################################### + # DO NOT CALL time.time BEYOND THIS POINT # + ########################################### + + # If no data was received return early + if not resps: + result.output_tokens = 0 + result.error_code = response.status_code + return result + + # Check for end of request marker + if resps[-1]['data'] == b"data: [DONE]": + result.end_time = resps[-1]['time'] + resps.pop() # Drop the end indicator + else: + logger.warning("End of response marker missing, response may be incomplete") + + # Check for usage statistics + message = self._process_resp(resps[-1]['data']) + # If stream_options.include_usage == True then the final + # message contains only token stats + expected_output_tokens = None + if message and not message.get("choices") and message.get('usage'): + # We want to count output tokens ourselves, but we can check our work with usage data. + expected_output_tokens = deepget(message, "usage", "completion_tokens") + result.input_tokens = deepget(message, "usage", "prompt_tokens") + # We don't want to record this message + resps.pop() + else: + logger.warning("Usage statistics are missing, token count will be inaccurate") + + # Iterate through all responses + # Responses can have more than one token in certain scenarios + # such as speculative decoding, thus an item in this list + # represents one or more tokens + tokens = [] + prev_time = 0 + total_usage = 0 + for resp in resps: + message = self._process_resp(resp['data']) + if not message: + result.error_code = response.status_code + result.error_text = 'bad_response' + logger.error("Skipping a token that failed to parse, this may be bad") + continue + + if message.get('error'): + result.error_code = response.status_code + result.error_text = message['error'] + logger.error("Error received in response message: %s", result.error_text) + continue + + token = {} + + if self.api == 'chat': + token["text"] = deepget(message, "choices", 0, 'delta', 'content') + else: # self.api == 'legacy' + token["text"] = deepget(message, "choices", 0, 'text') + + # If the message has the current usage then record the number of + # tokens, otherwise assume 1 token + current_usage = deepget(message, "usage", "completion_tokens") + if current_usage != None: + token['count'] = current_usage - total_usage + else: + token['count'] = 1 + + # Omit responses that don't have + # tokens (or somehow negative tokens) + if token['count'] < 1: + logger.debug("Omiting response '%s' because it contains %d tokens", + token["text"], token['count']) + continue + + # Update the total token count + total_usage += token['count'] + + token['time'] = resp['time'] + token['lat'] = token['time'] - prev_time + prev_time = token['time'] + + # Append our vaild token + tokens.append(token) + + # First chunk may not be a token, just a connection ack + result.ack_time = resps[0]['time'] + + # First non empty token is the first token + result.first_token_time = tokens[0]['time'] + + # If the current token time is outside the test duration, record the total tokens received before + # the current token. + result.output_tokens_before_timeout = sum(t['count'] for t in tokens if t['time'] <= test_end_time) + + # Last token comes with finish_reason set. + result.stop_reason = deepget(resps[-1], "choices", 0, "finish_reason") + # Full response received, return - result.end_time = time.time() - result.output_text = "".join(tokens) + result.output_text = "".join([token['text'] for token in tokens]) if not result.input_tokens: logger.warning("Input token count not found in response, using dataset input_tokens") result.input_tokens = query.get("input_tokens") - if not result.output_tokens: - logger.warning("Output token count not found in response, length of token list") - result.output_tokens = len(tokens) + result.output_tokens = total_usage # Just reuse our count from the loop + if expected_output_tokens and result.output_tokens != expected_output_tokens: + logger.warning(f"Received {result.output_tokens} tokens but expected {expected_output_tokens} tokens") # If test duration timeout didn't happen before the last token is received, # total tokens before the timeout will be equal to the total tokens in the response. diff --git a/result.py b/result.py index eaef37ca..98882e6e 100644 --- a/result.py +++ b/result.py @@ -1,29 +1,31 @@ """Main result class.""" +from typing import Optional + class RequestResult: """Request result class.""" def __init__(self, user_id, input_id, input_tokens=None): """Init method.""" - self.user_id = user_id - self.input_id = input_id - self.input_tokens = input_tokens - self.output_text = None - self.output_tokens = None - self.output_tokens_before_timeout = None - self.start_time = None - self.ack_time = None - self.first_token_time = None - self.end_time = None - self.response_time = None - self.tt_ack = None - self.ttft = None - self.itl = None - self.tpot = None - self.stop_reason = None - self.error_code = None - self.error_text = None + self.user_id: int = user_id + self.input_id: int = input_id + self.input_tokens: Optional[int] = input_tokens + self.output_text: Optional[str] = None + self.output_tokens: Optional[int] = None + self.output_tokens_before_timeout: Optional[int] = None + self.start_time: Optional[float] = None + self.ack_time: Optional[float] = None + self.first_token_time: Optional[float] = None + self.end_time: Optional[float] = None + self.response_time: Optional[float] = None + self.tt_ack: Optional[float] = None + self.ttft: Optional[float] = None + self.itl: Optional[float] = None + self.tpot: Optional[float] = None + self.stop_reason: Optional[str] = None + self.error_code: Optional[int] = None + self.error_text: Optional[str] = None def asdict(self): """Return a dictionary.""" @@ -36,20 +38,24 @@ def calculate_results(self): """Calculate the results.""" # Only calculate results if response is error-free. if self.error_code is None and self.error_text is None: - # response_time in seconds - self.response_time = 1000 * (self.end_time - self.start_time) + if self.end_time is not None and self.start_time is not None: + # response_time in seconds + self.response_time = 1000 * (self.end_time - self.start_time) - if self.ack_time is not None: + if self.ack_time is not None and self.start_time is not None: self.tt_ack = 1000 * (self.ack_time - self.start_time) if self.first_token_time is not None: - self.ttft = 1000 * ( - self.first_token_time - self.start_time - ) # Time to first token in ms - self.itl = (1000 * (self.end_time - self.first_token_time)) / ( - self.output_tokens - 1 - ) # Inter-token latency in ms. Distinct from TPOT as it excludes the first token time. - - self.tpot = ( - self.response_time / self.output_tokens - ) # Time per output token in ms + if self.start_time is not None: + self.ttft = 1000 * ( + self.first_token_time - self.start_time + ) # Time to first token in ms + if self.end_time is not None and self.output_tokens is not None: + self.itl = (1000 * (self.end_time - self.first_token_time)) / ( + self.output_tokens - 1 + ) # Inter-token latency in ms. Distinct from TPOT as it excludes the first token time. + + if self.response_time is not None and self.output_tokens is not None: + self.tpot = ( + self.response_time / self.output_tokens + ) # Time per output token in ms