Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jse committed May 24, 2024
1 parent 389c452 commit 09e2b04
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 38 deletions.
43 changes: 20 additions & 23 deletions chainlite/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,44 +20,43 @@
)


all_llm_endpoints = None
prompt_dirs = None
prompt_log_file = None
prompt_logs = {}
prompts_to_skip_for_debugging = None
local_engine_set = None
class GlobalVars:
prompt_logs = {}
all_llm_endpoints = None
prompt_dirs = None
prompt_log_file = None
prompts_to_skip_for_debugging = None
local_engine_set = None


def load_config_from_file(config_file: str):
global all_llm_endpoints, prompt_dirs, prompt_log_file, prompts_to_skip_for_debugging, local_engine_set
# TODO raise errors if these values are not set, use pydantic v2

with open(config_file, "r") as config_file:
config = yaml.unsafe_load(config_file)

prompt_dirs = config.get("prompt_dirs", ["./"])
prompt_log_file = config.get("prompt_logging", {}).get(
# TODO raise errors if these values are not set, use pydantic v2
GlobalVars.prompt_dirs = config.get("prompt_dirs", ["./"])
GlobalVars.prompt_log_file = config.get("prompt_logging", {}).get(
"log_file", "./prompt_logs.jsonl"
)
prompts_to_skip_for_debugging = set(
GlobalVars.prompts_to_skip_for_debugging = set(
config.get("prompt_logging", {}).get("prompts_to_skip", [])
)

litellm.set_verbose = config.get("litellm_set_verbose", False)

all_llm_endpoints = config.get("llm_endpoints", [])
for a in all_llm_endpoints:
GlobalVars.all_llm_endpoints = config.get("llm_endpoints", [])
for a in GlobalVars.all_llm_endpoints:
if "api_key" in a:
a["api_key"] = os.getenv(a["api_key"])

all_llm_endpoints = [
GlobalVars.all_llm_endpoints = [
a
for a in all_llm_endpoints
for a in GlobalVars.all_llm_endpoints
if "api_key" not in a or (a["api_key"] is not None and len(a["api_key"]) > 0)
] # remove resources for which we don't have a key

# tell LiteLLM how we want to map the messages to a prompt string for these non-chat models
for endpoint in all_llm_endpoints:
for endpoint in GlobalVars.all_llm_endpoints:
if "prompt_format" in endpoint:
if endpoint["prompt_format"] == "distilled":
# {instruction}\n\n{input}\n
Expand Down Expand Up @@ -85,16 +84,14 @@ def load_config_from_file(config_file: str):
raise ValueError(
f"Unsupported prompt format: {endpoint['prompt_format']}"
)
all_configured_engines = []
local_engine_set = set()
GlobalVars.local_engine_set = set()

for endpoint in all_llm_endpoints:
for endpoint in GlobalVars.all_llm_endpoints:
for engine, model in endpoint["engine_map"].items():
all_configured_engines.append(engine)
if model.startswith("huggingface/"):
local_engine_set.add(engine)
GlobalVars.local_engine_set.add(engine)

set_custom_template_paths(prompt_dirs)
set_custom_template_paths(GlobalVars.prompt_dirs)


# this code is not safe to use with multiprocessing, only multithreading
Expand Down
40 changes: 25 additions & 15 deletions chainlite/llm_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from pprint import pprint
import random
import re
from typing import Iterator, Optional, Any
from typing import AsyncIterator, Optional, Any
from uuid import UUID

from . import llm_config
from langchain_core.output_parsers import StrOutputParser

from chainlite.llm_config import GlobalVars

from .load_prompt import load_fewshot_prompt_template
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.callbacks import AsyncCallbackHandler
Expand Down Expand Up @@ -45,12 +46,21 @@ def pprint_chain(_dict: Any) -> Any:
return _dict


def is_same_prompt(template_name_1: str, template_name_2: str) -> bool:
return os.path.basename(template_name_1) == os.path.basename(template_name_2)


def write_prompt_logs_to_file(log_file: Optional[str] = None):
if not log_file:
log_file = llm_config.prompt_log_file
log_file = GlobalVars.prompt_log_file
with open(log_file, "w") as f:
for item in llm_config.prompt_logs.values():
if item["template_name"] in llm_config.prompts_to_skip_for_debugging:
for item in GlobalVars.prompt_logs.values():
should_skip = False
for t in GlobalVars.prompts_to_skip_for_debugging:
if is_same_prompt(t, item["template_name"]):
should_skip = True
break
if should_skip:
continue
f.write(
json.dumps(
Expand Down Expand Up @@ -88,11 +98,11 @@ async def on_chat_model_start(
else "<no distillation instruction is specified for this prompt>"
)
llm_input = messages[0][-1].content
if run_id not in llm_config.prompt_logs:
llm_config.prompt_logs[run_id] = {}
llm_config.prompt_logs[run_id]["instruction"] = distillation_instruction
llm_config.prompt_logs[run_id]["input"] = llm_input
llm_config.prompt_logs[run_id]["template_name"] = metadata["template_name"]
if run_id not in GlobalVars.prompt_logs:
GlobalVars.prompt_logs[run_id] = {}
GlobalVars.prompt_logs[run_id]["instruction"] = distillation_instruction
GlobalVars.prompt_logs[run_id]["input"] = llm_input
GlobalVars.prompt_logs[run_id]["template_name"] = metadata["template_name"]

async def on_llm_end(
self,
Expand All @@ -106,13 +116,13 @@ async def on_llm_end(
"""Run when LLM ends running."""
run_id = str(run_id)
llm_output = response.generations[0][0].text
llm_config.prompt_logs[run_id]["output"] = llm_output
GlobalVars.prompt_logs[run_id]["output"] = llm_output


prompt_log_handler = PromptLogHandler()


async def strip(input_: Iterator[str]) -> Iterator[str]:
async def strip(input_: AsyncIterator[str]) -> AsyncIterator[str]:
"""
Strips whitespace from a string, but supports streaming in a LangChain chain
"""
Expand All @@ -138,7 +148,7 @@ def extract_until_last_full_sentence(text):
return ""


async def postprocess_generations(input_: Iterator[str]) -> Iterator[str]:
async def postprocess_generations(input_: AsyncIterator[str]) -> AsyncIterator[str]:
buffer = ""
yielded = False
async for chunk in input_:
Expand Down Expand Up @@ -226,7 +236,7 @@ def llm_generation_chain(
# Decide which LLM resource to send this request to.
potential_llm_resources = [
resource
for resource in llm_config.all_llm_endpoints
for resource in GlobalVars.all_llm_endpoints
if engine in resource["engine_map"]
]
if len(potential_llm_resources) == 0:
Expand All @@ -245,7 +255,7 @@ def llm_generation_chain(
model_kwargs["api_version"] = llm_resource["api_version"]

# TODO remove these if LiteLLM fixes their HuggingFace TGI interface
if engine in llm_config.local_engine_set:
if engine in GlobalVars.local_engine_set:
if temperature > 0:
model_kwargs["do_sample"] = True
else:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"langchain-core==0.1.50",
"langchain-text-splitters==0.0.1",
"langgraph==0.0.41",
"grandalf", # to visualize LangGraph graphs
"langsmith==0.1.53",
"litellm==1.35.38",
"pydantic>=2.5",
Expand Down
7 changes: 7 additions & 0 deletions tests/test_llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from chainlite import llm_generation_chain, load_config_from_file
from chainlite.llm_config import GlobalVars
from chainlite.llm_generate import write_prompt_logs_to_file
from chainlite.utils import get_logger

Expand All @@ -11,6 +12,12 @@

@pytest.mark.asyncio(scope="session")
async def test_llm_generate():
# Check that the config file has been loaded properly
assert GlobalVars.all_llm_endpoints
assert GlobalVars.prompt_dirs
assert GlobalVars.prompt_log_file
assert GlobalVars.prompts_to_skip_for_debugging
assert GlobalVars.local_engine_set

response = await llm_generation_chain(
template_file="test.prompt", # prompt path relative to one of the paths specified in `prompt_dirs`
Expand Down

0 comments on commit 09e2b04

Please sign in to comment.