Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Update spice (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle authored Apr 5, 2024
1 parent f60f362 commit 8a70f70
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 53 deletions.
4 changes: 3 additions & 1 deletion benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import re
from datetime import datetime
from pathlib import Path
from typing import List
from uuid import uuid4

from openai.types.chat.completion_create_params import ResponseFormat
from spice import SpiceMessage

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand Down Expand Up @@ -43,7 +45,7 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s

async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
try:
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": to_grade},
]
Expand Down
4 changes: 3 additions & 1 deletion benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import List

import tqdm
from openai import BadRequestError
from spice import SpiceMessage

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand Down Expand Up @@ -56,7 +58,7 @@ async def failure_analysis(exercise_runner, language):
test_results = exercise_runner.read_test_results()

final_message = f"All instructions:\n{instructions}\nCode to review:\n{code}\nTest" f" results:\n{test_results}"
messages = [
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": final_message},
]
Expand Down
8 changes: 2 additions & 6 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,13 @@ async def _stream_model_response(

stream.send("Streaming...\n")
async with stream.interrupt_catcher(parser.shutdown):
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response.stream()))
parsed_llm_response = await parser.stream_and_parse_llm_response(add_newline(response))

# Sampler and History require previous_file_lines
for file_edit in parsed_llm_response.file_edits:
file_edit.previous_file_lines = code_file_manager.file_lines.get(file_edit.file_path, []).copy()

# TODO: this is janky come up with better solution
# if the stream was interrupted, then the finally block in the response.stream() async generator
# will wait for an opportunity to run. This sleep call gives it that opportunity.
# the finally block runs the logging callback
await asyncio.sleep(0.01)
cost_tracker.log_api_call_stats(response.current_response())
cost_tracker.display_last_api_call()

messages.append(
Expand Down
91 changes: 59 additions & 32 deletions mentat/llm_api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
Callable,
Dict,
List,
Literal,
Optional,
cast,
overload,
)

import attr
Expand All @@ -29,7 +31,10 @@
)
from openai.types.chat.completion_create_params import ResponseFormat
from PIL import Image
from spice import APIConnectionError, Spice, SpiceEmbeddings, SpiceError, SpiceResponse, SpiceWhisper
from spice import APIConnectionError, Spice, SpiceError, SpiceMessage, SpiceResponse, StreamingSpiceResponse
from spice.errors import NoAPIKeyError
from spice.models import WHISPER_1
from spice.providers import OPEN_AI

from mentat.errors import MentatError, ReturnToUser
from mentat.session_context import SESSION_CONTEXT
Expand Down Expand Up @@ -323,61 +328,83 @@ async def initialize_client(self):
if not load_dotenv(mentat_dir_path / ".env") and not load_dotenv(ctx.cwd / ".env"):
load_dotenv()

if os.getenv("AZURE_OPENAI_KEY") is not None:
embedding_and_whisper_provider = "azure"
else:
embedding_and_whisper_provider = "openai"
self.spice = Spice()

try:
self.spice.load_provider(OPEN_AI)
except NoAPIKeyError:
from mentat.session_input import collect_user_input

ctx.stream.send(
"No OpenAI api key detected. To avoid entering your api key on startup, create a .env file in"
" ~/.mentat/.env or in your workspace root.",
style="warning",
)
ctx.stream.send("Enter your api key:", style="info")
key = (await collect_user_input(log_input=False)).data
os.environ["OPENAI_API_KEY"] = key

self.spice_client = Spice()
self.spice_embedding_client = SpiceEmbeddings(provider=embedding_and_whisper_provider)
self.spice_whisper_client = SpiceWhisper(provider=embedding_and_whisper_provider)
@overload
async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[False],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
...

@overload
async def call_llm_api(
self,
messages: List[SpiceMessage],
model: str,
stream: Literal[True],
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> StreamingSpiceResponse:
...

@api_guard
async def call_llm_api(
self,
messages: list[ChatCompletionMessageParam],
messages: List[SpiceMessage],
model: str,
stream: bool,
response_format: ResponseFormat = ResponseFormat(type="text"),
) -> SpiceResponse:
) -> SpiceResponse | StreamingSpiceResponse:
session_context = SESSION_CONTEXT.get()
config = session_context.config
cost_tracker = session_context.cost_tracker

if "claude" in config.model:
messages = normalize_messages_for_anthropic(messages)

# Confirm that model has enough tokens remaining.
tokens = prompt_tokens(messages, model)
raise_if_context_exceeds_max(tokens)

# TODO: make spice message format and use across codebase consistently
_messages = [
{"role": cast(str, message["role"]), "content": cast(str, message["content"])} for message in messages
]
if "type" in response_format and response_format["type"] == "json_object":
_response_format = {"type": "json_object"}
else:
_response_format = {"type": "text"}

with sentry_sdk.start_span(description="LLM Call") as span:
span.set_tag("model", model)

response = await self.spice_client.call_llm(
model=model,
messages=_messages,
stream=stream,
temperature=config.temperature,
response_format=_response_format,
logging_callback=cost_tracker.log_api_call_stats,
)
if not stream:
response = await self.spice.get_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)
cost_tracker.log_api_call_stats(response)
else:
response = await self.spice.stream_response(
model=model,
messages=messages,
temperature=config.temperature,
response_format=response_format, # pyright: ignore
)

return response

@api_guard
def call_embedding_api(self, input_texts: list[str], model: str = "text-embedding-ada-002") -> Embeddings:
return self.spice_embedding_client.get_embeddings(input_texts, model)
return self.spice.get_embeddings_sync(input_texts, model) # pyright: ignore

@api_guard
async def call_whisper_api(self, audio_path: Path) -> str:
return await self.spice_whisper_client.get_whisper_transcription(audio_path)
return await self.spice.get_transcription(audio_path, model=WHISPER_1)
2 changes: 1 addition & 1 deletion mentat/server/mentat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MentatServer:
def __init__(self, cwd: Path, config: Config) -> None:
self.cwd = cwd
self.stopped = Event()
self.session = Session(self.cwd, config=config, apply_edits=False)
self.session = Session(self.cwd, config=config, apply_edits=False, show_update=False)

async def _client_listener(self):
with open(3) as fd_input:
Expand Down
4 changes: 3 additions & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
config: Config = Config(),
# Set to false for clients that apply the edits themselves (like vscode)
apply_edits: bool = True,
show_update: bool = True,
):
# All errors thrown here need to be caught here
self.stopped = Event()
Expand Down Expand Up @@ -112,7 +113,8 @@ def __init__(
self.error = None

# Functions that require session_context
check_version()
if show_update:
check_version()
config.send_errors_to_stream()
for path in paths:
code_context.include(path, exclude_patterns=exclude_paths)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ selenium==4.15.2
sentry-sdk==1.34.0
sounddevice==0.4.6
soundfile==0.12.1
spiceai==0.1.8
termcolor==2.3.0
textual==0.47.1
textual-autocomplete==2.1.0b0
Expand All @@ -29,4 +30,3 @@ typing_extensions==4.8.0
tqdm==4.66.1
webdriver_manager==4.0.1
watchfiles==0.21.0
spiceai==0.1.7
28 changes: 18 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from uuid import uuid4

import pytest
from spice import SpiceResponse
from spice.spice import SpiceCallArgs

from mentat import config
from mentat.agent_handler import AgentHandler
Expand Down Expand Up @@ -96,20 +98,26 @@ def mock_call_llm_api(mocker):
completion_mock = mocker.patch.object(LlmApiHandler, "call_llm_api")

def wrap_unstreamed_string(value):
mock_spice_response = MagicMock()
mock_spice_response.text = value

return mock_spice_response
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], False), value, 1, 0, 0, True)

def wrap_streamed_strings(values):
async def _async_generator():
for value in values:
yield value
class MockStreamingSpiceResponse:
def __init__(self):
self.cur_value = 0

def __aiter__(self):
return self

async def __anext__(self):
if self.cur_value >= len(values):
raise StopAsyncIteration
self.cur_value += 1
return values[self.cur_value - 1]

mock_spice_response = MagicMock()
mock_spice_response.stream = _async_generator
mock_spice_response.text = "".join(values)
def current_response(self):
return SpiceResponse(SpiceCallArgs("gpt-3.5-turbo", [], True), "".join(values), 1, 0, 0, True)

mock_spice_response = MockStreamingSpiceResponse()
return mock_spice_response

def set_streamed_values(values):
Expand Down

0 comments on commit 8a70f70

Please sign in to comment.