diff --git a/sdks/python/src/opik/decorator/base_track_decorator.py b/sdks/python/src/opik/decorator/base_track_decorator.py index db25b69323..0d034a4aee 100644 --- a/sdks/python/src/opik/decorator/base_track_decorator.py +++ b/sdks/python/src/opik/decorator/base_track_decorator.py @@ -32,12 +32,10 @@ class BaseTrackDecorator(abc.ABC): All TrackDecorator instances share the same context and can be used together simultaneously. - The following methods MUST be implemented in the subclass: + The following methods must be implemented in the subclass: * _start_span_inputs_preprocessor * _end_span_inputs_preprocessor - - The following methods CAN be overriden in the subclass: - * _generators_handler + * _generators_handler (the default implementation is provided but still needs to be called via `super()`) Overriding other methods of this class is not recommended. """ @@ -487,6 +485,7 @@ def _after_call( exc_info=True, ) + @abc.abstractmethod def _generators_handler( self, output: Any, @@ -494,7 +493,15 @@ def _generators_handler( generations_aggregator: Optional[Callable[[List[Any]], str]], ) -> Optional[Union[Generator, AsyncGenerator]]: """ - Subclasses can override this method to customize generator objects handling + Subclasses must override this method to customize generator objects handling + This is the implementation for regular generators and async generators that + uses aggregator function passed to track. + + However, sometimes the function might return an instance of some specific class which + is not a python generator itself, but implements some API for iterating through data chunks. + In that case `_generators_handler` must be fully overriden in the subclass. + + This is usually the case when creating an integration with some LLM library. """ if inspect.isgenerator(output): span_to_end, trace_to_end = pop_end_candidates() @@ -536,14 +543,24 @@ def _start_span_inputs_preprocessor( args: Tuple, kwargs: Dict[str, Any], project_name: Optional[str], - ) -> arguments_helpers.StartSpanParameters: ... + ) -> arguments_helpers.StartSpanParameters: + """ + Subclasses must override this method to customize generating + span/trace parameters from the function input arguments + """ + pass @abc.abstractmethod def _end_span_inputs_preprocessor( self, output: Optional[Any], capture_output: bool, - ) -> arguments_helpers.EndSpanParameters: ... + ) -> arguments_helpers.EndSpanParameters: + """ + Subclasses must override this method to customize generating + span/trace parameters from the function return value + """ + pass def pop_end_candidates() -> Tuple[span.SpanData, Optional[trace.TraceData]]: diff --git a/sdks/python/src/opik/decorator/tracker.py b/sdks/python/src/opik/decorator/tracker.py index e20c5f8e0b..1770f89b1c 100644 --- a/sdks/python/src/opik/decorator/tracker.py +++ b/sdks/python/src/opik/decorator/tracker.py @@ -1,6 +1,6 @@ import logging -from typing import List, Any, Dict, Optional, Callable, Tuple +from typing import List, Any, Dict, Optional, Callable, Tuple, Union from ..types import SpanType @@ -59,6 +59,20 @@ def _end_span_inputs_preprocessor( return result + def _generators_handler( + self, + output: Any, + capture_output: bool, + generations_aggregator: Optional[Callable[[List[Any]], str]], + ) -> Union[ + base_track_decorator.Generator[Any, None, None], + base_track_decorator.AsyncGenerator[Any, None], + None, + ]: + return super()._generators_handler( + output, capture_output, generations_aggregator + ) + def flush_tracker(timeout: Optional[int] = None) -> None: opik_ = opik_client.get_client_cached() diff --git a/sdks/python/src/opik/integrations/openai/chat_completion_chunks_aggregator.py b/sdks/python/src/opik/integrations/openai/chat_completion_chunks_aggregator.py new file mode 100644 index 0000000000..bb0a3ad8cd --- /dev/null +++ b/sdks/python/src/opik/integrations/openai/chat_completion_chunks_aggregator.py @@ -0,0 +1,59 @@ +import logging +from typing import List, Optional +from openai.types.chat import chat_completion_chunk, chat_completion + +from opik import logging_messages + +LOGGER = logging.getLogger(__name__) + + +def aggregate( + items: List[chat_completion_chunk.ChatCompletionChunk], +) -> Optional[chat_completion.ChatCompletion]: + # TODO: check if there are scenarios when stream contains more than one choice + try: + first_chunk = items[0] + + aggregated_response = { + "choices": [{"index": 0, "message": {"role": "", "content": ""}}], + "id": first_chunk.id, + "created": first_chunk.created, + "model": first_chunk.model, + "object": "chat.completion", + "system_fingerprint": first_chunk.system_fingerprint, + } + + text_chunks: List[str] = [] + + for chunk in items: + if chunk.choices and chunk.choices[0].delta: + delta = chunk.choices[0].delta + + if ( + delta.role + and not aggregated_response["choices"][0]["message"]["role"] + ): + aggregated_response["choices"][0]["message"]["role"] = delta.role + + if delta.content: + text_chunks.append(delta.content) + + if chunk.choices and chunk.choices[0].finish_reason: + aggregated_response["choices"][0]["finish_reason"] = chunk.choices[ + 0 + ].finish_reason + + if chunk.usage: + aggregated_response["usage"] = chunk.usage.model_dump() + + aggregated_response["choices"][0]["message"]["content"] = "".join(text_chunks) + result = chat_completion.ChatCompletion(**aggregated_response) + + return result + except Exception as exception: + LOGGER.error( + logging_messages.FAILED_TO_PARSE_OPENAI_STREAM_CONTENT, + str(exception), + exc_info=True, + ) + return None diff --git a/sdks/python/src/opik/integrations/openai/chunks_aggregator.py b/sdks/python/src/opik/integrations/openai/chunks_aggregator.py deleted file mode 100644 index 0bf48bc85e..0000000000 --- a/sdks/python/src/opik/integrations/openai/chunks_aggregator.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging -import dataclasses -from typing import Any, List, Optional, Dict -from openai.types.chat import chat_completion_chunk - -from opik import logging_messages - -LOGGER = logging.getLogger(__name__) - - -@dataclasses.dataclass -class ExtractedStreamContent: - choices: List[Dict[str, Any]] = dataclasses.field(default_factory=list) - usage: Optional[Dict[str, Any]] = None - - -def aggregate( - items: List[chat_completion_chunk.ChatCompletionChunk], -) -> ExtractedStreamContent: - extracted_content = ExtractedStreamContent() - - # TODO: check if there are scenarios when stream contains more than one choice - try: - content_items = [_get_item_content(item) for item in items] - choices = [_construct_choice_dict(items[0], content_items)] - usage = items[-1].usage - if usage is not None: - usage = usage.model_dump() - - extracted_content = ExtractedStreamContent( - choices=choices, - usage=usage, - ) - except Exception as exception: - LOGGER.error( - logging_messages.FAILED_TO_PARSE_OPENAI_STREAM_CONTENT, - str(exception), - exc_info=True, - ) - - return extracted_content - - -def _get_item_content(item: chat_completion_chunk.ChatCompletionChunk) -> str: - result: Optional[str] = None - if len(item.choices) > 0: - result = item.choices[0].delta.content - - return "" if result is None else result - - -def _construct_choice_dict( - first_item: chat_completion_chunk.ChatCompletionChunk, items_content: List[str] -) -> Dict[str, Any]: - if len(first_item.choices) > 0: - role = first_item.choices[0].delta.role - else: - role = None - - choice_info = { - "message": {"content": "".join(items_content), "role": role}, - } - - return choice_info diff --git a/sdks/python/src/opik/integrations/openai/openai_decorator.py b/sdks/python/src/opik/integrations/openai/openai_decorator.py index 2afa5989b2..a065327ea6 100644 --- a/sdks/python/src/opik/integrations/openai/openai_decorator.py +++ b/sdks/python/src/opik/integrations/openai/openai_decorator.py @@ -1,17 +1,21 @@ import logging -import json from typing import List, Any, Dict, Optional, Callable, Tuple, Union + +from opik import dict_utils from opik.types import SpanType from opik.decorator import base_track_decorator, arguments_helpers -from . import stream_wrappers, chunks_aggregator +from . import stream_wrappers import openai -from openai.types.chat import chat_completion, chat_completion_message +from openai.types.chat import chat_completion LOGGER = logging.getLogger(__name__) CreateCallResult = Union[chat_completion.ChatCompletion, List[Any]] +KWARGS_KEYS_TO_LOG_AS_INPUTS = ["messages", "function_call"] +RESPONSE_KEYS_TO_LOG_AS_OUTPUT = ["choices"] + class OpenaiTrackDecorator(base_track_decorator.BaseTrackDecorator): """ @@ -38,20 +42,19 @@ def _start_span_inputs_preprocessor( assert ( kwargs is not None ), "Expected kwargs to be not None in OpenAI().chat.completion.create(**kwargs)" - kwargs_copy = kwargs.copy() - name = name if name is not None else func.__name__ + metadata = metadata if metadata is not None else {} - input = {} - input["messages"] = _parse_messages_list(kwargs_copy.pop("messages")) - if "function_call" in kwargs_copy: - input["function_call"] = kwargs_copy.pop("function_call") - - metadata = { - "created_from": "openai", - "type": "openai_chat", - **kwargs_copy, - } + input, new_metadata = dict_utils.split_dict_by_keys( + kwargs, keys=KWARGS_KEYS_TO_LOG_AS_INPUTS + ) + metadata = dict_utils.deepmerge(metadata, new_metadata) + metadata.update( + { + "created_from": "openai", + "type": "openai_chat", + } + ) tags = ["openai"] @@ -71,22 +74,18 @@ def _end_span_inputs_preprocessor( ) -> arguments_helpers.EndSpanParameters: assert isinstance( output, - (chat_completion.ChatCompletion, chunks_aggregator.ExtractedStreamContent), + chat_completion.ChatCompletion, ) - usage = None - - if isinstance(output, chat_completion.ChatCompletion): - result_dict = output.model_dump(mode="json") - choices: List[Dict[str, Any]] = result_dict.pop("choices") # type: ignore - output = {"choices": choices} - - usage = result_dict["usage"] - elif isinstance(output, chunks_aggregator.ExtractedStreamContent): - usage = output.usage - output = {"choices": output.choices} + result_dict = output.model_dump(mode="json") + output, metadata = dict_utils.split_dict_by_keys(result_dict, ["choices"]) + usage = result_dict["usage"] - result = arguments_helpers.EndSpanParameters(output=output, usage=usage) + result = arguments_helpers.EndSpanParameters( + output=output, + usage=usage, + metadata=metadata, + ) return result @@ -129,39 +128,3 @@ def _generators_handler( NOT_A_STREAM = None return NOT_A_STREAM - - -def _parse_messages_list( - messages: List[ - Union[Dict[str, Any], chat_completion_message.ChatCompletionMessage] - ], -) -> List[Dict[str, Any]]: - if _is_jsonable(messages): - return messages - - result = [] - - for message in messages: - if _is_jsonable(message): - result.append(message) - continue - - if isinstance(message, chat_completion_message.ChatCompletionMessage): - result.append(message.model_dump(mode="json")) - continue - - LOGGER.debug("Message %s is not json serializable", message) - - result.append( - str(message) - ) # TODO: replace with Opik serializer when it is implemented - - return result - - -def _is_jsonable(x: Any) -> bool: - try: - json.dumps(x) - return True - except (TypeError, OverflowError): - return False diff --git a/sdks/python/src/opik/integrations/openai/opik_tracker.py b/sdks/python/src/opik/integrations/openai/opik_tracker.py index e7f3d29d33..229ced39ea 100644 --- a/sdks/python/src/opik/integrations/openai/opik_tracker.py +++ b/sdks/python/src/opik/integrations/openai/opik_tracker.py @@ -2,7 +2,7 @@ import openai -from . import openai_decorator, chunks_aggregator +from . import chat_completion_chunks_aggregator, openai_decorator def track_openai( @@ -26,7 +26,7 @@ def track_openai( completions_create_decorator = decorator_factory.track( type="llm", name="chat_completion_create", - generations_aggregator=chunks_aggregator.aggregate, + generations_aggregator=chat_completion_chunks_aggregator.aggregate, project_name=project_name, ) openai_client.chat.completions.create = completions_create_decorator( diff --git a/sdks/python/tests/conftest.py b/sdks/python/tests/conftest.py index 5410df9ead..1fc8a96d5c 100644 --- a/sdks/python/tests/conftest.py +++ b/sdks/python/tests/conftest.py @@ -1,4 +1,5 @@ import pytest +import mock from opik import context_storage from opik.api_objects import opik_client from .testlib import backend_emulator_message_processor @@ -34,3 +35,21 @@ def fake_streamer(): yield streamer, fake_message_processor_ finally: streamer.close(timeout=5) + + +@pytest.fixture +def fake_backend(fake_streamer): + fake_message_processor_: ( + backend_emulator_message_processor.BackendEmulatorMessageProcessor + ) + streamer, fake_message_processor_ = fake_streamer + + mock_construct_online_streamer = mock.Mock() + mock_construct_online_streamer.return_value = streamer + + with mock.patch.object( + streamer_constructors, + "construct_online_streamer", + mock_construct_online_streamer, + ): + yield fake_message_processor_ diff --git a/sdks/python/tests/library_integration/anthropic/test_anthropic.py b/sdks/python/tests/library_integration/anthropic/test_anthropic.py index 5fc09ec89a..8254296dae 100644 --- a/sdks/python/tests/library_integration/anthropic/test_anthropic.py +++ b/sdks/python/tests/library_integration/anthropic/test_anthropic.py @@ -1,13 +1,10 @@ import pytest -import mock import os import asyncio import opik -from opik.message_processing import streamer_constructors from opik.integrations.anthropic import track_anthropic from opik.config import OPIK_PROJECT_DEFAULT_NAME -from ...testlib import backend_emulator_message_processor from ...testlib import ( SpanModel, TraceModel, @@ -37,336 +34,317 @@ def ensure_anthropic_configured(): ], ) def test_anthropic_messages_create__happyflow( - fake_streamer, project_name, expected_project_name + fake_backend, project_name, expected_project_name ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic( + anthropic_client=client, + project_name=project_name, ) - streamer, fake_message_processor_ = fake_streamer + messages = [{"role": "user", "content": "Tell a short fact"}] - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic( - anthropic_client=client, - project_name=project_name, - ) - messages = [{"role": "user", "content": "Tell a short fact"}] - - response = wrapped_client.messages.create( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", - ) + response = wrapped_client.messages.create( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + ) - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": response.model_dump()["content"]}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=expected_project_name, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": response.model_dump()["content"]}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=expected_project_name, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": response.model_dump()["content"]}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=expected_project_name, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": response.model_dump()["content"]}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=expected_project_name, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) def test_anthropic_messages_create__create_raises_an_error__span_and_trace_finished_gracefully( - fake_streamer, + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) + with pytest.raises(Exception): + _ = wrapped_client.messages.create( + messages=None, + model=None, + ) - with pytest.raises(Exception): - _ = wrapped_client.messages.create( - messages=None, - model=None, + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={"messages": None}, + output=None, + tags=["anthropic"], + metadata={"created_from": "anthropic", "model": None, "base_url": ANY}, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + type="llm", + name="anthropic_messages_create", + input={"messages": None}, + output=None, + tags=["anthropic"], + metadata={ + "created_from": "anthropic", + "model": None, + "base_url": ANY, + }, + usage=None, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=ANY_BUT_NONE, + spans=[], ) + ], + ) - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={"messages": None}, - output=None, - tags=["anthropic"], - metadata={"created_from": "anthropic", "model": None, "base_url": ANY}, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - type="llm", - name="anthropic_messages_create", - input={"messages": None}, - output=None, - tags=["anthropic"], - metadata={ - "created_from": "anthropic", - "model": None, - "base_url": ANY, - }, - usage=None, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=ANY_BUT_NONE, - spans=[], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) def test_anthropic_messages_create__create_call_made_in_another_tracked_function__anthropic_span_attached_to_existing_trace( - fake_streamer, + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer + messages = [ + {"role": "user", "content": "Tell a short fact"}, + ] - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): + @opik.track(project_name="anthropic-integration-test") + def f(): + client = anthropic.Anthropic() + wrapped_client = track_anthropic( + anthropic_client=client, + project_name="anthropic-integration-test", + ) messages = [ - {"role": "user", "content": "Tell a short fact"}, + { + "role": "user", + "content": "Tell a short fact", + } ] - @opik.track(project_name="anthropic-integration-test") - def f(): - client = anthropic.Anthropic() - wrapped_client = track_anthropic( - anthropic_client=client, - project_name="anthropic-integration-test", - ) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] + _ = wrapped_client.messages.create( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + ) - _ = wrapped_client.messages.create( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", + f() + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="f", + input={}, + output=None, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name="anthropic-integration-test", + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="f", + input={}, + output=None, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name="anthropic-integration-test", + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name="anthropic-integration-test", + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], ) + ], + ) - f() - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="f", - input={}, - output=None, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name="anthropic-integration-test", - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="f", - input={}, - output=None, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name="anthropic-integration-test", - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name="anthropic-integration-test", - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) def test_async_anthropic_messages_create_call_made_in_another_tracked_async_function__anthropic_span_attached_to_existing_trace( - fake_streamer, + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer + messages = [ + {"role": "user", "content": "Tell a short fact"}, + ] - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - messages = [ - {"role": "user", "content": "Tell a short fact"}, - ] + @opik.track() + async def async_f(): + client = anthropic.AsyncAnthropic() + wrapped_client = track_anthropic(client) + _ = await wrapped_client.messages.create( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + ) - @opik.track() - async def async_f(): - client = anthropic.AsyncAnthropic() - wrapped_client = track_anthropic(client) - _ = await wrapped_client.messages.create( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", + asyncio.run(async_f()) + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="async_f", + input={}, + output=None, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="async_f", + input={}, + output=None, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + project_name=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], ) + ], + ) - asyncio.run(async_f()) - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="async_f", - input={}, - output=None, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="async_f", - input={}, - output=None, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - project_name=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) def test_anthropic_messages_stream__generator_tracked_correctly( - fake_streamer, + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + message_stream_manager = wrapped_client.messages.stream( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", ) - streamer, fake_message_processor_ = fake_streamer + with message_stream_manager as stream: + for _ in stream: + pass - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] + assert len(fake_backend.trace_trees) == 1 + + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) + +def test_anthropic_messages_stream__stream_called_2_times__generator_tracked_correctly( + fake_backend, +): + def run_stream(client, messages): message_stream_manager = wrapped_client.messages.stream( model="claude-3-opus-20240229", messages=messages, @@ -377,598 +355,500 @@ def test_anthropic_messages_stream__generator_tracked_correctly( for _ in stream: pass - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) - - -def test_anthropic_messages_stream__stream_called_2_times__generator_tracked_correctly( - fake_streamer, -): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + + SHORT_FACT_MESSAGES = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + JOKE_MESSAGES = [ + { + "role": "user", + "content": "Tell a short joke", + } + ] + run_stream(wrapped_client, messages=SHORT_FACT_MESSAGES) + run_stream(wrapped_client, messages=JOKE_MESSAGES) + + opik.flush_tracker() + + EXPECTED_TRACE_TREE_WITH_SHORT_FACT = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": SHORT_FACT_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": SHORT_FACT_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - - def run_stream(client, messages): - message_stream_manager = wrapped_client.messages.stream( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", + EXPECTED_TRACE_TREE_WITH_JOKE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": JOKE_MESSAGES, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": JOKE_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], ) - with message_stream_manager as stream: - for _ in stream: - pass - - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) + ], + ) - SHORT_FACT_MESSAGES = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] - JOKE_MESSAGES = [ - { - "role": "user", - "content": "Tell a short joke", - } - ] - run_stream(wrapped_client, messages=SHORT_FACT_MESSAGES) - run_stream(wrapped_client, messages=JOKE_MESSAGES) - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE_WITH_SHORT_FACT = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": SHORT_FACT_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": SHORT_FACT_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - EXPECTED_TRACE_TREE_WITH_JOKE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": JOKE_MESSAGES, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": JOKE_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) + assert len(fake_backend.trace_trees) == 2 - assert len(fake_message_processor_.trace_trees) == 2 - - assert_equal( - EXPECTED_TRACE_TREE_WITH_SHORT_FACT, fake_message_processor_.trace_trees[0] - ) - assert_equal( - EXPECTED_TRACE_TREE_WITH_JOKE, fake_message_processor_.trace_trees[1] - ) + assert_equal(EXPECTED_TRACE_TREE_WITH_SHORT_FACT, fake_backend.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE_WITH_JOKE, fake_backend.trace_trees[1]) def test_anthropic_messages_stream__stream_called_2_times__second_stream_is_being_read_first__both_are_tracked__the_one_that_was_read_first_is_logged_first( - fake_streamer, + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + + SHORT_FACT_MESSAGES = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + JOKE_MESSAGES = [ + { + "role": "user", + "content": "Tell a short joke", + } + ] + + joke_message_stream_manager = wrapped_client.messages.stream( + model="claude-3-opus-20240229", + messages=JOKE_MESSAGES, + max_tokens=10, + system="You are a helpful assistant", + ) + fact_message_stream_manager = wrapped_client.messages.stream( + model="claude-3-opus-20240229", + messages=SHORT_FACT_MESSAGES, + max_tokens=10, + system="You are a helpful assistant", ) - streamer, fake_message_processor_ = fake_streamer + with fact_message_stream_manager as fact_stream: + for _ in fact_stream: + pass - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer + with joke_message_stream_manager as joke_stream: + for _ in joke_stream: + pass - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) + opik.flush_tracker() + + EXPECTED_TRACE_TREE_WITH_SHORT_FACT = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": SHORT_FACT_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": SHORT_FACT_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) + EXPECTED_TRACE_TREE_WITH_JOKE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": JOKE_MESSAGES, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": JOKE_MESSAGES, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) - SHORT_FACT_MESSAGES = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] - JOKE_MESSAGES = [ - { - "role": "user", - "content": "Tell a short joke", - } - ] + assert len(fake_backend.trace_trees) == 2 - joke_message_stream_manager = wrapped_client.messages.stream( - model="claude-3-opus-20240229", - messages=JOKE_MESSAGES, - max_tokens=10, - system="You are a helpful assistant", - ) - fact_message_stream_manager = wrapped_client.messages.stream( - model="claude-3-opus-20240229", - messages=SHORT_FACT_MESSAGES, - max_tokens=10, - system="You are a helpful assistant", - ) - with fact_message_stream_manager as fact_stream: - for _ in fact_stream: - pass + assert_equal(EXPECTED_TRACE_TREE_WITH_SHORT_FACT, fake_backend.trace_trees[1]) + assert_equal(EXPECTED_TRACE_TREE_WITH_JOKE, fake_backend.trace_trees[0]) - with joke_message_stream_manager as joke_stream: - for _ in joke_stream: - pass - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE_WITH_SHORT_FACT = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": SHORT_FACT_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": SHORT_FACT_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - EXPECTED_TRACE_TREE_WITH_JOKE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": JOKE_MESSAGES, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": JOKE_MESSAGES, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) +def test_anthropic_messages_stream__get_final_message_called__generator_tracked_correctly( + fake_backend, +): + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + message_stream_manager = wrapped_client.messages.stream( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + ) + with message_stream_manager as stream: + stream.get_final_message() + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) - assert len(fake_message_processor_.trace_trees) == 2 + assert len(fake_backend.trace_trees) == 1 - assert_equal( - EXPECTED_TRACE_TREE_WITH_SHORT_FACT, fake_message_processor_.trace_trees[1] - ) - assert_equal( - EXPECTED_TRACE_TREE_WITH_JOKE, fake_message_processor_.trace_trees[0] - ) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) -def test_anthropic_messages_stream__get_final_message_called__generator_tracked_correctly( - fake_streamer, +def test_anthropic_messages_stream__get_final_message_called_after_stream_iteration_loop__generator_tracked_correctly_only_once( + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + message_stream_manager = wrapped_client.messages.stream( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + ) + with message_stream_manager as stream: + for _ in stream: + pass + stream.get_final_message() + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], ) - streamer, fake_message_processor_ = fake_streamer - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer + assert len(fake_backend.trace_trees) == 1 + + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] +def test_async_anthropic_messages_stream__data_tracked_correctly( + fake_backend, +): + client = anthropic.AsyncAnthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + async def async_f(): message_stream_manager = wrapped_client.messages.stream( model="claude-3-opus-20240229", messages=messages, max_tokens=10, system="You are a helpful assistant", ) - with message_stream_manager as stream: - stream.get_final_message() - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) - + async with message_stream_manager as stream: + async for _ in stream: + pass -def test_anthropic_messages_stream__get_final_message_called_after_stream_iteration_loop__generator_tracked_correctly_only_once( - fake_streamer, -): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + asyncio.run(async_f()) + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], ) - streamer, fake_message_processor_ = fake_streamer - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer + assert len(fake_backend.trace_trees) == 1 + + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() - wrapped_client = track_anthropic(client) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] +def test_async_anthropic_messages_stream__get_final_message_called_twice__data_tracked_correctly_once( + fake_backend, +): + client = anthropic.AsyncAnthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + async def async_f(): message_stream_manager = wrapped_client.messages.stream( model="claude-3-opus-20240229", messages=messages, max_tokens=10, system="You are a helpful assistant", ) - with message_stream_manager as stream: - for _ in stream: - pass - stream.get_final_message() - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) - - -def test_async_anthropic_messages_stream__data_tracked_correctly( - fake_streamer, -): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.AsyncAnthropic() - wrapped_client = track_anthropic(client) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] - - async def async_f(): - message_stream_manager = wrapped_client.messages.stream( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", + async with message_stream_manager as stream: + await stream.get_final_message() + await stream.get_final_message() + + asyncio.run(async_f()) + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_stream", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], ) - async with message_stream_manager as stream: - async for _ in stream: - pass - - asyncio.run(async_f()) - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) + ], + ) - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) -def test_async_anthropic_messages_stream__get_final_message_called_twice__data_tracked_correctly_once( - fake_streamer, +def test_anthropic_messages_create__stream_argument_is_True__Stream_object_returned__generations_tracked_correctly( + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor + client = anthropic.Anthropic() + wrapped_client = track_anthropic(client) + messages = [ + { + "role": "user", + "content": "Tell a short fact", + } + ] + + stream = wrapped_client.messages.create( + model="claude-3-opus-20240229", + messages=messages, + max_tokens=10, + system="You are a helpful assistant", + stream=True, ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.AsyncAnthropic() - wrapped_client = track_anthropic(client) - messages = [ - { - "role": "user", - "content": "Tell a short fact", - } - ] - - async def async_f(): - message_stream_manager = wrapped_client.messages.stream( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", + for _ in stream: + pass + + opik.flush_tracker() + + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={"messages": messages, "system": "You are a helpful assistant"}, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": messages, + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], ) - async with message_stream_manager as stream: - await stream.get_final_message() - await stream.get_final_message() - - asyncio.run(async_f()) - - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_stream", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) + ], + ) - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) -def test_anthropic_messages_create__stream_argument_is_True__Stream_object_returned__generations_tracked_correctly( - fake_streamer, +def test_async_anthropic_messages_create__stream_argument_is_True__AsyncStream_object_returned__generations_tracked_correctly( + fake_backend, ): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - client = anthropic.Anthropic() + async def async_f(): + client = anthropic.AsyncAnthropic() wrapped_client = track_anthropic(client) messages = [ { @@ -977,136 +857,61 @@ def test_anthropic_messages_create__stream_argument_is_True__Stream_object_retur } ] - stream = wrapped_client.messages.create( + stream = await wrapped_client.messages.create( model="claude-3-opus-20240229", messages=messages, max_tokens=10, system="You are a helpful assistant", stream=True, ) - for _ in stream: + async for _ in stream: pass - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={"messages": messages, "system": "You are a helpful assistant"}, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": messages, - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) - ], - ) - - assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + asyncio.run(async_f()) + opik.flush_tracker() - -def test_async_anthropic_messages_create__stream_argument_is_True__AsyncStream_object_returned__generations_tracked_correctly( - fake_streamer, -): - fake_message_processor_: ( - backend_emulator_message_processor.BackendEmulatorMessageProcessor - ) - streamer, fake_message_processor_ = fake_streamer - - mock_construct_online_streamer = mock.Mock() - mock_construct_online_streamer.return_value = streamer - - with mock.patch.object( - streamer_constructors, - "construct_online_streamer", - mock_construct_online_streamer, - ): - - async def async_f(): - client = anthropic.AsyncAnthropic() - wrapped_client = track_anthropic(client) - messages = [ + EXPECTED_TRACE_TREE = TraceModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": [ { "role": "user", "content": "Tell a short fact", } - ] - - stream = await wrapped_client.messages.create( - model="claude-3-opus-20240229", - messages=messages, - max_tokens=10, - system="You are a helpful assistant", - stream=True, - ) - async for _ in stream: - pass - - asyncio.run(async_f()) - opik.flush_tracker() - mock_construct_online_streamer.assert_called_once() - - EXPECTED_TRACE_TREE = TraceModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": [ - { - "role": "user", - "content": "Tell a short fact", - } - ], - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - spans=[ - SpanModel( - id=ANY_BUT_NONE, - name="anthropic_messages_create", - input={ - "messages": [ - { - "role": "user", - "content": "Tell a short fact", - } - ], - "system": "You are a helpful assistant", - }, - output={"content": ANY_LIST}, - tags=["anthropic"], - metadata=ANY_DICT, - start_time=ANY_BUT_NONE, - end_time=ANY_BUT_NONE, - type="llm", - usage=ANY_DICT, - spans=[], - ) ], - ) + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + spans=[ + SpanModel( + id=ANY_BUT_NONE, + name="anthropic_messages_create", + input={ + "messages": [ + { + "role": "user", + "content": "Tell a short fact", + } + ], + "system": "You are a helpful assistant", + }, + output={"content": ANY_LIST}, + tags=["anthropic"], + metadata=ANY_DICT, + start_time=ANY_BUT_NONE, + end_time=ANY_BUT_NONE, + type="llm", + usage=ANY_DICT, + spans=[], + ) + ], + ) - assert len(fake_message_processor_.trace_trees) == 1 + assert len(fake_backend.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + assert_equal(EXPECTED_TRACE_TREE, fake_backend.trace_trees[0]) diff --git a/sdks/python/tests/library_integration/openai/test_openai.py b/sdks/python/tests/library_integration/openai/test_openai.py index 1be10fb69d..a345f0a221 100644 --- a/sdks/python/tests/library_integration/openai/test_openai.py +++ b/sdks/python/tests/library_integration/openai/test_openai.py @@ -13,12 +13,13 @@ SpanModel, TraceModel, ANY_BUT_NONE, + ANY_DICT, assert_equal, + assert_dict_has_keys, ) -# TODO: make sure that the output logged to Comet is exactly as from the response? -# Existing tests only check that output is logged and its structure is {choices: ANY_BUT_NONE} +# TODO: improve metadata checks @pytest.fixture(autouse=True) @@ -77,12 +78,7 @@ def test_openai_client_chat_completions_create__happyflow( input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - }, + metadata=ANY_DICT, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=expected_project_name, @@ -94,14 +90,12 @@ def test_openai_client_chat_completions_create__happyflow( input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "usage": ANY_BUT_NONE, + metadata=ANY_DICT, + usage={ + "prompt_tokens": ANY_BUT_NONE, + "completion_tokens": ANY_BUT_NONE, + "total_tokens": ANY_BUT_NONE, }, - usage=ANY_BUT_NONE, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=expected_project_name, @@ -111,8 +105,22 @@ def test_openai_client_chat_completions_create__happyflow( ) assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + + assert_equal(EXPECTED_TRACE_TREE, trace_tree) + + llm_span_metadata = trace_tree.spans[0].metadata + REQUIRED_METADATA_KEYS = [ + "usage", + "model", + "max_tokens", + "created_from", + "type", + "id", + "created", + "object", + ] + assert_dict_has_keys(llm_span_metadata, REQUIRED_METADATA_KEYS) def test_openai_client_chat_completions_create__create_raises_an_error__span_and_trace_finished_gracefully( @@ -181,7 +189,8 @@ def test_openai_client_chat_completions_create__create_raises_an_error__span_and assert len(fake_message_processor_.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + assert_equal(EXPECTED_TRACE_TREE, trace_tree) def test_openai_client_chat_completions_create__openai_call_made_in_another_tracked_function__openai_span_attached_to_existing_trace( @@ -252,14 +261,12 @@ def f(): input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "usage": ANY_BUT_NONE, + metadata=ANY_DICT, + usage={ + "prompt_tokens": ANY_BUT_NONE, + "completion_tokens": ANY_BUT_NONE, + "total_tokens": ANY_BUT_NONE, }, - usage=ANY_BUT_NONE, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=project_name, @@ -272,7 +279,22 @@ def f(): assert len(fake_message_processor_.trace_trees) == 1 - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + + assert_equal(EXPECTED_TRACE_TREE, trace_tree) + + llm_span_metadata = trace_tree.spans[0].spans[0].metadata + REQUIRED_METADATA_KEYS = [ + "usage", + "model", + "max_tokens", + "created_from", + "type", + "id", + "created", + "object", + ] + assert_dict_has_keys(llm_span_metadata, REQUIRED_METADATA_KEYS) def test_openai_client_chat_completions_create__async_openai_call_made_in_another_tracked_async_function__openai_span_attached_to_existing_trace( @@ -336,14 +358,12 @@ async def async_f(): input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "usage": ANY_BUT_NONE, + metadata=ANY_DICT, + usage={ + "prompt_tokens": ANY_BUT_NONE, + "completion_tokens": ANY_BUT_NONE, + "total_tokens": ANY_BUT_NONE, }, - usage=ANY_BUT_NONE, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=ANY_BUT_NONE, @@ -355,8 +375,22 @@ async def async_f(): ) assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + + assert_equal(EXPECTED_TRACE_TREE, trace_tree) + + llm_span_metadata = trace_tree.spans[0].spans[0].metadata + REQUIRED_METADATA_KEYS = [ + "usage", + "model", + "max_tokens", + "created_from", + "type", + "id", + "created", + "object", + ] + assert_dict_has_keys(llm_span_metadata, REQUIRED_METADATA_KEYS) def test_openai_client_chat_completions_create__stream_mode_is_on__generator_tracked_correctly( @@ -402,14 +436,7 @@ def test_openai_client_chat_completions_create__stream_mode_is_on__generator_tra input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "stream": True, - "stream_options": {"include_usage": True}, - }, + metadata=ANY_DICT, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=ANY_BUT_NONE, @@ -421,16 +448,12 @@ def test_openai_client_chat_completions_create__stream_mode_is_on__generator_tra input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "stream": True, - "stream_options": {"include_usage": True}, - "usage": ANY_BUT_NONE, + metadata=ANY_DICT, + usage={ + "prompt_tokens": ANY_BUT_NONE, + "completion_tokens": ANY_BUT_NONE, + "total_tokens": ANY_BUT_NONE, }, - usage=ANY_BUT_NONE, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=ANY_BUT_NONE, @@ -440,8 +463,22 @@ def test_openai_client_chat_completions_create__stream_mode_is_on__generator_tra ) assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + + assert_equal(EXPECTED_TRACE_TREE, trace_tree) + + llm_span_metadata = trace_tree.spans[0].metadata + REQUIRED_METADATA_KEYS = [ + "usage", + "model", + "max_tokens", + "created_from", + "type", + "id", + "created", + "object", + ] + assert_dict_has_keys(llm_span_metadata, REQUIRED_METADATA_KEYS) def test_openai_client_chat_completions_create__async_openai_call_made_in_another_tracked_async_function__streaming_mode_enabled__openai_span_attached_to_existing_trace( @@ -509,16 +546,12 @@ async def async_f(): input={"messages": messages}, output={"choices": ANY_BUT_NONE}, tags=["openai"], - metadata={ - "created_from": "openai", - "type": "openai_chat", - "model": "gpt-3.5-turbo", - "max_tokens": 10, - "stream": True, - "stream_options": {"include_usage": True}, - "usage": ANY_BUT_NONE, + metadata=ANY_DICT, + usage={ + "prompt_tokens": ANY_BUT_NONE, + "completion_tokens": ANY_BUT_NONE, + "total_tokens": ANY_BUT_NONE, }, - usage=ANY_BUT_NONE, start_time=ANY_BUT_NONE, end_time=ANY_BUT_NONE, project_name=ANY_BUT_NONE, @@ -530,5 +563,19 @@ async def async_f(): ) assert len(fake_message_processor_.trace_trees) == 1 - - assert_equal(EXPECTED_TRACE_TREE, fake_message_processor_.trace_trees[0]) + trace_tree = fake_message_processor_.trace_trees[0] + + assert_equal(EXPECTED_TRACE_TREE, trace_tree) + + llm_span_metadata = trace_tree.spans[0].spans[0].metadata + REQUIRED_METADATA_KEYS = [ + "usage", + "model", + "max_tokens", + "created_from", + "type", + "id", + "created", + "object", + ] + assert_dict_has_keys(llm_span_metadata, REQUIRED_METADATA_KEYS) diff --git a/sdks/python/tests/testlib/__init__.py b/sdks/python/tests/testlib/__init__.py index f7b3561358..9346380ce4 100644 --- a/sdks/python/tests/testlib/__init__.py +++ b/sdks/python/tests/testlib/__init__.py @@ -1,6 +1,11 @@ from .backend_emulator_message_processor import BackendEmulatorMessageProcessor from .models import SpanModel, TraceModel, FeedbackScoreModel -from .assert_helpers import assert_dicts_equal, prepare_difference_report, assert_equal +from .assert_helpers import ( + assert_dicts_equal, + prepare_difference_report, + assert_equal, + assert_dict_has_keys, +) from .any_compare_helpers import ANY_BUT_NONE, ANY_DICT, ANY_LIST, ANY from .patch_helpers import patch_environ @@ -14,6 +19,7 @@ "ANY", "assert_equal", "assert_dicts_equal", + "assert_dict_has_keys", "prepare_difference_report", "BackendEmulatorMessageProcessor", "patch_environ", diff --git a/sdks/python/tests/testlib/assert_helpers.py b/sdks/python/tests/testlib/assert_helpers.py index 0ff959bd23..3d996d41fb 100644 --- a/sdks/python/tests/testlib/assert_helpers.py +++ b/sdks/python/tests/testlib/assert_helpers.py @@ -46,7 +46,7 @@ def assert_dicts_equal( dict1: Dict[str, Any], dict2: Dict[str, Any], ignore_keys: Optional[List[str]] = None, -) -> bool: +) -> None: dict1_copy, dict2_copy = {**dict1}, {**dict2} ignore_keys = [] if ignore_keys is None else ignore_keys @@ -56,3 +56,14 @@ def assert_dicts_equal( dict2_copy.pop(key, None) assert dict1_copy == dict2_copy, prepare_difference_report(dict1_copy, dict2_copy) + + +def assert_dict_has_keys(dic: Dict[str, Any], keys: List[str]) -> None: + dict_has_keys = all(key in dic for key in keys) + + if dict_has_keys: + return + + raise AssertionError( + f"Dict does't contain all the required keys. Dict keys: {dic.keys()}, required keys: {keys}" + )