Skip to content

Commit 3f16cb8

Browse files
authored
[CM-10021] update comet llm to use new endpoint for logging chains and prompts (#128)
* Initial implementation of new endpoint usage * Draft changes * Add compatibility for old format of messages in offline mode * Fix some lint errors * Add senders dispatching based on backend version * Add backward compatible send_v2 methods to chain and prompt senders * Fix lint errors, remove redundant proxy payload class * Update client * Update tests * Update log_chain request function to follow new response format * Fix lint errors * Remove exception catching in offline and online message processors * Add constants.py * Fix lint errors * Update outdated tests * Bump v2 backend version
1 parent 6ce3ab4 commit 3f16cb8

23 files changed

+635
-103
lines changed

src/comet_llm/chains/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def log_chain(chain: chain.Chain) -> Optional[llm_result.LLMResult]:
105105
chain_data = chain.as_dict()
106106

107107
message = messages.ChainMessage(
108+
id=messages.generate_id(),
108109
experiment_info_=chain.experiment_info,
109110
tags=chain.tags,
110111
chain_data=chain_data,

src/comet_llm/experiment_api/comet_api_client.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,39 @@
1313
# *******************************************************
1414

1515
import functools
16+
import logging
1617
import urllib.parse
1718
import warnings
18-
from typing import IO, List, Optional
19+
from typing import IO, Any, Dict, List, Optional
1920

2021
import requests # type: ignore
2122
import urllib3.exceptions
2223

23-
from .. import config
24+
from .. import config, exceptions, semantic_version
2425
from ..types import JSONEncodable
25-
from . import request_exception_wrapper
26+
from . import error_codes_mapping, payload_constructor
2627

2728
ResponseContent = JSONEncodable
2829

30+
LOGGER = logging.getLogger(__name__)
31+
2932

3033
class CometAPIClient:
3134
def __init__(self, api_key: str, comet_url: str, session: requests.Session):
3235
self._headers = {"Authorization": api_key}
3336
self._comet_url = comet_url
3437
self._session = session
3538

39+
self.backend_version = semantic_version.SemanticVersion.parse(
40+
self.is_alive_ver()["version"]
41+
)
42+
43+
def is_alive_ver(self) -> ResponseContent:
44+
return self._request(
45+
"GET",
46+
"api/isAlive/ver",
47+
)
48+
3649
def create_experiment(
3750
self,
3851
type_: str,
@@ -122,6 +135,56 @@ def log_experiment_other(
122135
},
123136
)
124137

138+
def log_chain(
139+
self,
140+
experiment_key: str,
141+
chain_asset: Dict[str, JSONEncodable],
142+
workspace: Optional[str] = None,
143+
project: Optional[str] = None,
144+
parameters: Optional[Dict[str, JSONEncodable]] = None,
145+
metrics: Optional[Dict[str, JSONEncodable]] = None,
146+
tags: Optional[List[str]] = None,
147+
others: Optional[Dict[str, JSONEncodable]] = None,
148+
) -> ResponseContent:
149+
json = [
150+
{
151+
"experimentKey": experiment_key,
152+
"createExperimentRequest": {
153+
"workspaceName": workspace,
154+
"projectName": project,
155+
"type": "LLM",
156+
},
157+
"parameters": payload_constructor.chain_parameters_payload(parameters),
158+
"metrics": payload_constructor.chain_metrics_payload(metrics),
159+
"others": payload_constructor.chain_others_payload(others),
160+
"tags": tags,
161+
"jsonAsset": {
162+
"extension": "json",
163+
"type": "llm_data",
164+
"fileName": "comet_llm_data.json",
165+
"file": chain_asset,
166+
},
167+
}
168+
] # we make a list because endpoint is designed for batches
169+
170+
batched_response: Dict[str, Dict[str, Any]] = self._request(
171+
"POST",
172+
"api/rest/v2/write/experiment/llm",
173+
json=json,
174+
)
175+
sub_response = list(batched_response.values())[0]
176+
status = sub_response["status"]
177+
if status != 200:
178+
LOGGER.debug(
179+
"Failed to send trace: \nPayload %s, Response %s",
180+
str(json),
181+
str(batched_response),
182+
)
183+
error_code = sub_response["content"]["sdk_error_code"]
184+
raise exceptions.CometLLMException(error_codes_mapping.MESSAGES[error_code])
185+
186+
return sub_response["content"]
187+
125188
def _request(self, method: str, path: str, *args, **kwargs) -> ResponseContent: # type: ignore
126189
url = urllib.parse.urljoin(self._comet_url, path)
127190
response = self._session.request(

src/comet_llm/experiment_api/failed_response_handler.py renamed to src/comet_llm/experiment_api/error_codes_mapping.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,12 @@
1313
# *******************************************************
1414

1515
import collections
16-
import json
17-
from typing import NoReturn
1816

19-
import requests # type: ignore
17+
from .. import backend_error_codes, logging_messages
2018

21-
from .. import backend_error_codes, exceptions, logging_messages
22-
23-
_SDK_ERROR_CODES_LOGGING_MESSAGE = collections.defaultdict(
19+
MESSAGES = collections.defaultdict(
2420
lambda: logging_messages.FAILED_TO_SEND_DATA_TO_SERVER,
2521
{
2622
backend_error_codes.UNABLE_TO_LOG_TO_NON_LLM_PROJECT: logging_messages.UNABLE_TO_LOG_TO_NON_LLM_PROJECT
2723
},
2824
)
29-
30-
31-
def handle(exception: requests.RequestException) -> NoReturn:
32-
response = exception.response
33-
sdk_error_code = json.loads(response.text)["sdk_error_code"]
34-
error_message = _SDK_ERROR_CODES_LOGGING_MESSAGE[sdk_error_code]
35-
36-
raise exceptions.CometLLMException(error_message) from exception
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************
14+
15+
from typing import Dict, List, Optional
16+
17+
from ..types import JSONEncodable
18+
19+
20+
def chain_parameters_payload(
21+
parameters: Optional[Dict[str, JSONEncodable]]
22+
) -> List[Dict[str, JSONEncodable]]:
23+
return _dict_to_payload_format(parameters, "parameterName", "parameterValue")
24+
25+
26+
def chain_metrics_payload(
27+
metrics: Optional[Dict[str, JSONEncodable]]
28+
) -> List[Dict[str, JSONEncodable]]:
29+
return _dict_to_payload_format(metrics, "metricName", "metricValue")
30+
31+
32+
def chain_others_payload(
33+
others: Optional[Dict[str, JSONEncodable]]
34+
) -> List[Dict[str, JSONEncodable]]:
35+
return _dict_to_payload_format(others, "key", "value")
36+
37+
38+
def _dict_to_payload_format(
39+
source: Optional[Dict[str, JSONEncodable]], key_name: str, value_name: str
40+
) -> List[Dict[str, JSONEncodable]]:
41+
if source is None:
42+
return []
43+
44+
result = [{key_name: key, value_name: value} for key, value in source.items()]
45+
46+
return result

src/comet_llm/experiment_api/request_exception_wrapper.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
# *******************************************************
1414

1515
import functools
16+
import json
1617
import logging
1718
import urllib.parse
1819
from pprint import pformat
19-
from typing import Any, Callable, List
20+
from typing import Any, Callable, List, NoReturn
2021

2122
import requests # type: ignore
2223

2324
from .. import config, exceptions, logging_messages
24-
from . import failed_response_handler
25+
from . import error_codes_mapping
2526

2627
LOGGER = logging.getLogger(__name__)
2728

@@ -49,7 +50,7 @@ def wrapper(*args, **kwargs) -> Any: # type: ignore
4950
logging_messages.FAILED_TO_SEND_DATA_TO_SERVER
5051
) from exception
5152

52-
failed_response_handler.handle(exception)
53+
_handle_request_exception(exception)
5354

5455
return wrapper
5556

@@ -73,3 +74,11 @@ def _debug_log(exception: requests.RequestException) -> None:
7374
# Make sure we won't fail on attempt to debug.
7475
# It's mainly for tests when response object can be mocked
7576
pass
77+
78+
79+
def _handle_request_exception(exception: requests.RequestException) -> NoReturn:
80+
response = exception.response
81+
sdk_error_code = json.loads(response.text)["sdk_error_code"]
82+
error_message = error_codes_mapping.MESSAGES[sdk_error_code]
83+
84+
raise exceptions.CometLLMException(error_message) from exception

src/comet_llm/message_processing/messages.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,32 @@
1313
# *******************************************************
1414

1515
import dataclasses
16-
import inspect
16+
import uuid
1717
from typing import Any, ClassVar, Dict, List, Optional, Union
1818

1919
from comet_llm.types import JSONEncodable
2020

2121
from .. import experiment_info, logging_messages
2222

2323

24+
def generate_id() -> str:
25+
return uuid.uuid4().hex
26+
27+
2428
@dataclasses.dataclass
2529
class BaseMessage:
2630
experiment_info_: experiment_info.ExperimentInfo
31+
id: str
2732
VERSION: ClassVar[int]
2833

2934
@classmethod
3035
def from_dict(
3136
cls, d: Dict[str, Any], api_key: Optional[str] = None
3237
) -> "BaseMessage":
33-
d.pop("VERSION") #
38+
version = d.pop("VERSION")
39+
if version == 1:
40+
# Message was dumped before id was introduced. We can generate it now.
41+
d["id"] = generate_id()
3442

3543
experiment_info_dict: Dict[str, Optional[str]] = d.pop("experiment_info_")
3644
experiment_info_ = experiment_info.get(
@@ -57,7 +65,7 @@ class PromptMessage(BaseMessage):
5765
metadata: Optional[Dict[str, Union[str, bool, float, None]]]
5866
tags: Optional[List[str]]
5967

60-
VERSION: ClassVar[int] = 1
68+
VERSION: ClassVar[int] = 2
6169

6270

6371
@dataclasses.dataclass
@@ -69,4 +77,4 @@ class ChainMessage(BaseMessage):
6977
others: Dict[str, JSONEncodable]
7078
# 'other' - is a name of an attribute of experiment, logged via log_other
7179

72-
VERSION: ClassVar[int] = 1
80+
VERSION: ClassVar[int] = 2

src/comet_llm/message_processing/offline_message_processor.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,9 @@ def process(self, message: messages.BaseMessage) -> None:
4444
file_path = pathlib.Path(self._offline_directory, self._current_file_name)
4545

4646
if isinstance(message, messages.PromptMessage):
47-
try:
48-
return prompt.send(message, str(file_path))
49-
except Exception:
50-
LOGGER.error("Failed to log prompt", exc_info=True)
47+
return prompt.send(message, str(file_path))
5148
elif isinstance(message, messages.ChainMessage):
52-
try:
53-
return chain.send(message, str(file_path))
54-
except Exception:
55-
LOGGER.error("Failed to log chain", exc_info=True)
49+
return chain.send(message, str(file_path))
5650

5751
LOGGER.debug(f"Unsupported message type {message}")
5852
return None

src/comet_llm/message_processing/online_message_processor.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,9 @@ def __init__(self) -> None:
2828

2929
def process(self, message: messages.BaseMessage) -> Optional[llm_result.LLMResult]:
3030
if isinstance(message, messages.PromptMessage):
31-
try:
32-
return prompt.send(message)
33-
except Exception:
34-
LOGGER.error("Failed to log prompt", exc_info=True)
31+
return prompt.send(message)
3532
elif isinstance(message, messages.ChainMessage):
36-
try:
37-
return chain.send(message)
38-
except Exception:
39-
LOGGER.error("Failed to log chain", exc_info=True)
33+
return chain.send(message)
4034

4135
LOGGER.debug(f"Unsupported message type {message}")
4236
return None

src/comet_llm/message_processing/online_senders/chain.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,23 @@
1515
import io
1616
import json
1717

18-
from comet_llm import app, convert, experiment_api, llm_result
18+
from comet_llm import app, convert, experiment_api, llm_result, url_helpers
19+
from comet_llm.experiment_api import comet_api_client
1920

2021
from .. import messages
22+
from . import constants
2123

2224

2325
def send(message: messages.ChainMessage) -> llm_result.LLMResult:
26+
client = comet_api_client.get(message.experiment_info_.api_key)
27+
28+
if client.backend_version >= constants.V2_BACKEND_VERSION:
29+
return _send_v2(message, client)
30+
31+
return _send_v1(message)
32+
33+
34+
def _send_v1(message: messages.ChainMessage) -> llm_result.LLMResult:
2435
experiment_api_ = experiment_api.ExperimentAPI.create_new(
2536
api_key=message.experiment_info_.api_key,
2637
workspace=message.experiment_info_.workspace,
@@ -48,3 +59,24 @@ def send(message: messages.ChainMessage) -> llm_result.LLMResult:
4859
return llm_result.LLMResult(
4960
id=experiment_api_.id, project_url=experiment_api_.project_url
5061
)
62+
63+
64+
def _send_v2(
65+
message: messages.ChainMessage, client: comet_api_client.CometAPIClient
66+
) -> llm_result.LLMResult:
67+
metrics = {"chain_duration": message.duration}
68+
parameters = convert.chain_metadata_to_flat_parameters(message.metadata)
69+
70+
response = client.log_chain(
71+
experiment_key=message.id,
72+
chain_asset=message.chain_data,
73+
workspace=message.experiment_info_.workspace,
74+
project=message.experiment_info_.project_name,
75+
tags=message.tags,
76+
metrics=metrics,
77+
parameters=parameters,
78+
others=message.others,
79+
)
80+
project_url: str = url_helpers.experiment_to_project_url(response["link"])
81+
82+
return llm_result.LLMResult(id=message.id, project_url=project_url)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************
14+
15+
V2_BACKEND_VERSION = "3.25.80"

0 commit comments

Comments
 (0)