Skip to content

Commit 5da0e02

Browse files
authored
[CM-9113] Introduce api object (#109)
* Add initial implementation of API and LLMTraceAPI * Add copyright comments * Fix lint errors * Expose new api to comet_llm namespace * Refactor if-else statements * Add API.query * Add type hint * Update log_metadata * Refactor code * Refactor some code * Add reading api key from environment * Fix lint errors
1 parent d5a6b57 commit 5da0e02

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed

src/comet_llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# *******************************************************
1414

1515
from . import app, autologgers, config, logging
16+
from .api_objects.api import API
1617
from .config import init, is_ready
1718

1819
if config.comet_disabled():
@@ -33,6 +34,7 @@
3334
"is_ready",
3435
"log_user_feedback",
3536
"flush",
37+
"API",
3638
]
3739

3840
logging.setup()

src/comet_llm/api_objects/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************

src/comet_llm/api_objects/api.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************
14+
15+
from typing import List, Optional
16+
17+
import comet_ml
18+
19+
from .. import experiment_info, logging_messages, query_dsl
20+
from . import llm_trace_api
21+
22+
# TODO: make the decision about dependencies from comet-ml. Address testing.
23+
24+
25+
class API:
26+
def __init__(self, api_key: Optional[str] = None) -> None:
27+
experiment_info_ = experiment_info.get(
28+
api_key,
29+
api_key_not_found_message=logging_messages.API_KEY_NOT_FOUND_MESSAGE
30+
% "API",
31+
)
32+
self._api = comet_ml.API(api_key=experiment_info_.api_key, cache=False)
33+
34+
def get_llm_trace_by_key(self, trace_key: str) -> llm_trace_api.LLMTraceAPI:
35+
"""
36+
Get an API Trace object by key.
37+
38+
Args:
39+
trace_key: str, key of the prompt or chain
40+
41+
Returns: An LLMTraceAPI object that can be used to get or update trace data
42+
"""
43+
matching_trace = self._api.get_experiment_by_key(trace_key)
44+
45+
if matching_trace is None:
46+
raise ValueError(
47+
f"Failed to find any matching traces with the key {trace_key}"
48+
)
49+
50+
return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace)
51+
52+
def get_llm_trace_by_name(
53+
self, workspace: str, project_name: str, trace_name: str
54+
) -> llm_trace_api.LLMTraceAPI:
55+
"""
56+
Get an API Trace object by name.
57+
58+
Args:
59+
workspace: str, name of the workspace
60+
project_name: str, name of the project
61+
trace_name: str, name of the prompt or chain
62+
63+
Returns: An LLMTraceAPI object that can be used to get or update trace data
64+
"""
65+
matching_trace = self._api.query(
66+
workspace, project_name, query_dsl.Other("Name") == trace_name
67+
)
68+
69+
if len(matching_trace) == 0:
70+
raise ValueError(
71+
f"Failed to find any matching traces with the name {trace_name} in the project {project_name}"
72+
)
73+
elif len(matching_trace) > 1:
74+
raise ValueError(
75+
f"Found multiple traces with the name {trace_name} in the project {project_name}"
76+
)
77+
78+
return llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(matching_trace[0])
79+
80+
def query(
81+
self, workspace: str, project_name: str, query: str
82+
) -> List[llm_trace_api.LLMTraceAPI]:
83+
"""
84+
Fetch LLM Trace based on a query. Currently it is only possible to use
85+
trace metadata or details fields to filter the traces.
86+
87+
Args:
88+
workspace: str, name of the workspace
89+
project_name: str, name of the project
90+
query: str, name of the prompt or chain
91+
92+
Returns: A list of LLMTrace objects
93+
94+
Notes:
95+
The `query` object takes the form of (QUERY_VARIABLE OPERATOR VALUE) with:
96+
97+
* QUERY_VARIABLE is either TraceMetadata, Duration, Timestamp.
98+
* OPERATOR is any standard mathematical operators `<=`, `>=`, `!=`, `<`, `>`.
99+
100+
It is also possible to add multiple query conditions using `&`.
101+
102+
If you are querying nested parameters, you should flatted the parameter name using the
103+
`.` operator.
104+
105+
To query the duration, you can use Duration().
106+
107+
Example:
108+
```python
109+
# Find all traces where the metadata field `token` is greater than 50
110+
api.query("workspace", "project", TraceMetadata("token") > 50)
111+
112+
# Find all traces where the duration field is between 1 second and 2 seconds
113+
api.query("workspace", "project", (Duration() > 1) & (Duration() <= 2))
114+
115+
# Find all traces based on the timestamp
116+
api.query("workspace", "project", Timestamp() > datetime(2023, 9, 10))
117+
118+
```
119+
"""
120+
matching_api_objects = self._api.query(workspace, project_name, query)
121+
122+
return [
123+
llm_trace_api.LLMTraceAPI.__api__from_api_experiment__(api_object)
124+
for api_object in matching_api_objects
125+
]
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************
14+
15+
import io
16+
import json
17+
from typing import Dict, Optional
18+
19+
import comet_ml
20+
21+
from .. import convert
22+
from ..chains import deepmerge
23+
from ..types import JSONEncodable
24+
25+
26+
class LLMTraceAPI:
27+
_api_experiment: comet_ml.APIExperiment
28+
29+
def __init__(self) -> None:
30+
raise NotImplementedError(
31+
"Please use API.get_llm_trace_by_key or API.get_llm_trace_by_name methods to get the instance"
32+
)
33+
34+
@classmethod
35+
def __api__from_api_experiment__(
36+
cls, api_experiment: comet_ml.APIExperiment
37+
) -> "LLMTraceAPI":
38+
instance = object.__new__(cls)
39+
instance._api_experiment = api_experiment
40+
41+
return instance
42+
43+
def get_name(self) -> Optional[str]:
44+
"""
45+
Get the name of the trace
46+
"""
47+
return self._api_experiment.get_name() # type: ignore
48+
49+
def get_key(self) -> str:
50+
"""
51+
Get the unique identifier for this trace
52+
"""
53+
return self._api_experiment.key # type: ignore
54+
55+
def log_user_feedback(self, score: float) -> None:
56+
"""
57+
Log user feedback
58+
59+
Args:
60+
score: float, the feedback score. Can be either 0, 0.0, 1 or 1.0
61+
"""
62+
ALLOWED_SCORES = [0.0, 1.0]
63+
if score not in ALLOWED_SCORES:
64+
raise ValueError(
65+
f"Score it not valid, should be {ALLOWED_SCORES} but got {score}"
66+
)
67+
68+
self._api_experiment.log_metric("user_feedback", score)
69+
70+
def _get_trace_data(self) -> Dict[str, JSONEncodable]:
71+
try:
72+
asset_id = next(
73+
asset
74+
for asset in self._api_experiment.get_asset_list()
75+
if asset["fileName"] == "comet_llm_data.json"
76+
)["assetId"]
77+
except Exception as exception:
78+
raise ValueError(
79+
"Failed update metadata for this trace, metadata is not available"
80+
) from exception
81+
82+
trace_data = json.loads(self._api_experiment.get_asset(asset_id))
83+
84+
return trace_data # type: ignore
85+
86+
def get_metadata(self) -> Dict[str, JSONEncodable]:
87+
"""
88+
Get trace metadata
89+
"""
90+
trace_data = self._get_trace_data()
91+
92+
return trace_data["metadata"] # type: ignore
93+
94+
def log_metadata(self, metadata: Dict[str, JSONEncodable]) -> None:
95+
"""
96+
Update the metadata field for a trace, can be used to set or update metadata fields
97+
98+
Args:
99+
metadata_dict: dict, dict in the form of {"metadata_name": value, ...}. Nested metadata is supported
100+
"""
101+
102+
trace_data = self._get_trace_data()
103+
updated_trace_metadata = deepmerge.deepmerge(
104+
trace_data.get("metadata", {}), metadata
105+
)
106+
trace_data["metadata"] = updated_trace_metadata
107+
108+
stream = io.StringIO()
109+
json.dump(trace_data, stream)
110+
stream.seek(0)
111+
self._api_experiment.log_asset(
112+
stream, overwrite=True, name="comet_llm_data.json"
113+
)
114+
115+
self._api_experiment.log_parameters(
116+
convert.chain_metadata_to_flat_parameters(metadata)
117+
)

src/comet_llm/query_dsl.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# -*- coding: utf-8 -*-
2+
# *******************************************************
3+
# ____ _ _
4+
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
5+
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
6+
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
7+
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
8+
#
9+
# Sign up for free at https://www.comet.com
10+
# Copyright (C) 2015-2023 Comet ML INC
11+
# This source code is licensed under the MIT license found in the
12+
# LICENSE file in the root directory of this package.
13+
# *******************************************************
14+
15+
from comet_ml import api
16+
17+
Duration = lambda: api.Metric("duration") # noqa: E731
18+
Timestamp = lambda: api.Metadata("start_server_timestamp") # noqa: E731
19+
TraceMetadata = api.Parameter
20+
TraceDetail = api.Metadata
21+
Other = api.Other

0 commit comments

Comments
 (0)