Skip to content

Commit

Permalink
[OPIK-64] SDK missing openai metadata in opik openai integration (#667)
Browse files Browse the repository at this point in the history
* Update openai integration and tests for it

* Fix lint errors

* Re-implement chat completions chunks aggregator

* Remove unused piece of code

* Make _generators_handler method abstract

* Add engineering docstrings

* Extract backend emulator setup into separate fixture, refactor first batch of library integration tests accordingly

* Update type hint to old style for old python versions

* Improve usage logging

* Add protection from errors raised during openai stream chunks aggregation
  • Loading branch information
alexkuzmik authored Nov 19, 2024
1 parent b11a2a9 commit 4a3970d
Show file tree
Hide file tree
Showing 11 changed files with 1,026 additions and 1,149 deletions.
31 changes: 24 additions & 7 deletions sdks/python/src/opik/decorator/base_track_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@ class BaseTrackDecorator(abc.ABC):
All TrackDecorator instances share the same context and can be
used together simultaneously.
The following methods MUST be implemented in the subclass:
The following methods must be implemented in the subclass:
* _start_span_inputs_preprocessor
* _end_span_inputs_preprocessor
The following methods CAN be overriden in the subclass:
* _generators_handler
* _generators_handler (the default implementation is provided but still needs to be called via `super()`)
Overriding other methods of this class is not recommended.
"""
Expand Down Expand Up @@ -487,14 +485,23 @@ def _after_call(
exc_info=True,
)

@abc.abstractmethod
def _generators_handler(
self,
output: Any,
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
) -> Optional[Union[Generator, AsyncGenerator]]:
"""
Subclasses can override this method to customize generator objects handling
Subclasses must override this method to customize generator objects handling
This is the implementation for regular generators and async generators that
uses aggregator function passed to track.
However, sometimes the function might return an instance of some specific class which
is not a python generator itself, but implements some API for iterating through data chunks.
In that case `_generators_handler` must be fully overriden in the subclass.
This is usually the case when creating an integration with some LLM library.
"""
if inspect.isgenerator(output):
span_to_end, trace_to_end = pop_end_candidates()
Expand Down Expand Up @@ -536,14 +543,24 @@ def _start_span_inputs_preprocessor(
args: Tuple,
kwargs: Dict[str, Any],
project_name: Optional[str],
) -> arguments_helpers.StartSpanParameters: ...
) -> arguments_helpers.StartSpanParameters:
"""
Subclasses must override this method to customize generating
span/trace parameters from the function input arguments
"""
pass

@abc.abstractmethod
def _end_span_inputs_preprocessor(
self,
output: Optional[Any],
capture_output: bool,
) -> arguments_helpers.EndSpanParameters: ...
) -> arguments_helpers.EndSpanParameters:
"""
Subclasses must override this method to customize generating
span/trace parameters from the function return value
"""
pass


def pop_end_candidates() -> Tuple[span.SpanData, Optional[trace.TraceData]]:
Expand Down
16 changes: 15 additions & 1 deletion sdks/python/src/opik/decorator/tracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from typing import List, Any, Dict, Optional, Callable, Tuple
from typing import List, Any, Dict, Optional, Callable, Tuple, Union

from ..types import SpanType

Expand Down Expand Up @@ -59,6 +59,20 @@ def _end_span_inputs_preprocessor(

return result

def _generators_handler(
self,
output: Any,
capture_output: bool,
generations_aggregator: Optional[Callable[[List[Any]], str]],
) -> Union[
base_track_decorator.Generator[Any, None, None],
base_track_decorator.AsyncGenerator[Any, None],
None,
]:
return super()._generators_handler(
output, capture_output, generations_aggregator
)


def flush_tracker(timeout: Optional[int] = None) -> None:
opik_ = opik_client.get_client_cached()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import logging
from typing import List, Optional
from openai.types.chat import chat_completion_chunk, chat_completion

from opik import logging_messages

LOGGER = logging.getLogger(__name__)


def aggregate(
items: List[chat_completion_chunk.ChatCompletionChunk],
) -> Optional[chat_completion.ChatCompletion]:
# TODO: check if there are scenarios when stream contains more than one choice
try:
first_chunk = items[0]

aggregated_response = {
"choices": [{"index": 0, "message": {"role": "", "content": ""}}],
"id": first_chunk.id,
"created": first_chunk.created,
"model": first_chunk.model,
"object": "chat.completion",
"system_fingerprint": first_chunk.system_fingerprint,
}

text_chunks: List[str] = []

for chunk in items:
if chunk.choices and chunk.choices[0].delta:
delta = chunk.choices[0].delta

if (
delta.role
and not aggregated_response["choices"][0]["message"]["role"]
):
aggregated_response["choices"][0]["message"]["role"] = delta.role

if delta.content:
text_chunks.append(delta.content)

if chunk.choices and chunk.choices[0].finish_reason:
aggregated_response["choices"][0]["finish_reason"] = chunk.choices[
0
].finish_reason

if chunk.usage:
aggregated_response["usage"] = chunk.usage.model_dump()

aggregated_response["choices"][0]["message"]["content"] = "".join(text_chunks)
result = chat_completion.ChatCompletion(**aggregated_response)

return result
except Exception as exception:
LOGGER.error(
logging_messages.FAILED_TO_PARSE_OPENAI_STREAM_CONTENT,
str(exception),
exc_info=True,
)
return None
64 changes: 0 additions & 64 deletions sdks/python/src/opik/integrations/openai/chunks_aggregator.py

This file was deleted.

91 changes: 27 additions & 64 deletions sdks/python/src/opik/integrations/openai/openai_decorator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import logging
import json
from typing import List, Any, Dict, Optional, Callable, Tuple, Union

from opik import dict_utils
from opik.types import SpanType
from opik.decorator import base_track_decorator, arguments_helpers
from . import stream_wrappers, chunks_aggregator
from . import stream_wrappers

import openai
from openai.types.chat import chat_completion, chat_completion_message
from openai.types.chat import chat_completion

LOGGER = logging.getLogger(__name__)

CreateCallResult = Union[chat_completion.ChatCompletion, List[Any]]

KWARGS_KEYS_TO_LOG_AS_INPUTS = ["messages", "function_call"]
RESPONSE_KEYS_TO_LOG_AS_OUTPUT = ["choices"]


class OpenaiTrackDecorator(base_track_decorator.BaseTrackDecorator):
"""
Expand All @@ -38,20 +42,19 @@ def _start_span_inputs_preprocessor(
assert (
kwargs is not None
), "Expected kwargs to be not None in OpenAI().chat.completion.create(**kwargs)"
kwargs_copy = kwargs.copy()

name = name if name is not None else func.__name__
metadata = metadata if metadata is not None else {}

input = {}
input["messages"] = _parse_messages_list(kwargs_copy.pop("messages"))
if "function_call" in kwargs_copy:
input["function_call"] = kwargs_copy.pop("function_call")

metadata = {
"created_from": "openai",
"type": "openai_chat",
**kwargs_copy,
}
input, new_metadata = dict_utils.split_dict_by_keys(
kwargs, keys=KWARGS_KEYS_TO_LOG_AS_INPUTS
)
metadata = dict_utils.deepmerge(metadata, new_metadata)
metadata.update(
{
"created_from": "openai",
"type": "openai_chat",
}
)

tags = ["openai"]

Expand All @@ -71,22 +74,18 @@ def _end_span_inputs_preprocessor(
) -> arguments_helpers.EndSpanParameters:
assert isinstance(
output,
(chat_completion.ChatCompletion, chunks_aggregator.ExtractedStreamContent),
chat_completion.ChatCompletion,
)

usage = None

if isinstance(output, chat_completion.ChatCompletion):
result_dict = output.model_dump(mode="json")
choices: List[Dict[str, Any]] = result_dict.pop("choices") # type: ignore
output = {"choices": choices}

usage = result_dict["usage"]
elif isinstance(output, chunks_aggregator.ExtractedStreamContent):
usage = output.usage
output = {"choices": output.choices}
result_dict = output.model_dump(mode="json")
output, metadata = dict_utils.split_dict_by_keys(result_dict, ["choices"])
usage = result_dict["usage"]

result = arguments_helpers.EndSpanParameters(output=output, usage=usage)
result = arguments_helpers.EndSpanParameters(
output=output,
usage=usage,
metadata=metadata,
)

return result

Expand Down Expand Up @@ -129,39 +128,3 @@ def _generators_handler(
NOT_A_STREAM = None

return NOT_A_STREAM


def _parse_messages_list(
messages: List[
Union[Dict[str, Any], chat_completion_message.ChatCompletionMessage]
],
) -> List[Dict[str, Any]]:
if _is_jsonable(messages):
return messages

result = []

for message in messages:
if _is_jsonable(message):
result.append(message)
continue

if isinstance(message, chat_completion_message.ChatCompletionMessage):
result.append(message.model_dump(mode="json"))
continue

LOGGER.debug("Message %s is not json serializable", message)

result.append(
str(message)
) # TODO: replace with Opik serializer when it is implemented

return result


def _is_jsonable(x: Any) -> bool:
try:
json.dumps(x)
return True
except (TypeError, OverflowError):
return False
4 changes: 2 additions & 2 deletions sdks/python/src/opik/integrations/openai/opik_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import openai

from . import openai_decorator, chunks_aggregator
from . import chat_completion_chunks_aggregator, openai_decorator


def track_openai(
Expand All @@ -26,7 +26,7 @@ def track_openai(
completions_create_decorator = decorator_factory.track(
type="llm",
name="chat_completion_create",
generations_aggregator=chunks_aggregator.aggregate,
generations_aggregator=chat_completion_chunks_aggregator.aggregate,
project_name=project_name,
)
openai_client.chat.completions.create = completions_create_decorator(
Expand Down
19 changes: 19 additions & 0 deletions sdks/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import mock
from opik import context_storage
from opik.api_objects import opik_client
from .testlib import backend_emulator_message_processor
Expand Down Expand Up @@ -34,3 +35,21 @@ def fake_streamer():
yield streamer, fake_message_processor_
finally:
streamer.close(timeout=5)


@pytest.fixture
def fake_backend(fake_streamer):
fake_message_processor_: (
backend_emulator_message_processor.BackendEmulatorMessageProcessor
)
streamer, fake_message_processor_ = fake_streamer

mock_construct_online_streamer = mock.Mock()
mock_construct_online_streamer.return_value = streamer

with mock.patch.object(
streamer_constructors,
"construct_online_streamer",
mock_construct_online_streamer,
):
yield fake_message_processor_
Loading

0 comments on commit 4a3970d

Please sign in to comment.