Skip to content

Commit

Permalink
Merge pull request #171 from mit-submit/development/anthropic_claude
Browse files Browse the repository at this point in the history
Development/anthropic claude
  • Loading branch information
pmlugato authored Oct 23, 2024
2 parents 148c6b5 + 6612dfc commit ccb9203
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 37 deletions.
32 changes: 30 additions & 2 deletions a2rchi/chains/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Chain for chatting with a vector database."""
from __future__ import annotations
from pydantic import BaseModel
from loguru import logger
from langchain.callbacks import FileCallbackHandler

Expand All @@ -12,7 +13,10 @@
from langchain.chains.llm import LLMChain
from langchain.schema import BaseRetriever, Document
from langchain.schema.prompt_template import BasePromptTemplate
from langchain_core.runnables import RunnableSequence, RunnablePassthrough
from typing import Any, Dict, List, Optional, Tuple
from typing import Callable
# from pydantic import model_rebuild
import os


Expand Down Expand Up @@ -46,7 +50,20 @@ class BaseSubMITChain(BaseConversationalRetrievalChain):
"""
retriever: BaseRetriever # Index to connect to
max_tokens_limit: Optional[int] = None # restrict doc length to return from store, enforced only for StuffDocumentChain
get_chat_history: Optional[function] = _get_chat_history
get_chat_history: Optional[Callable[[List[Tuple[str, str]]], str]] = _get_chat_history

# kuangfei: the rebuild logic is default or rewrite
# @classmethod
# def model_rebuild(
# cls,
# *,
# force: bool = False,
# raise_errors: bool = True,
# _parent_namespace_depth: int = 2,
# _types_namespace: dict[str, Any] | None = None,
# ) -> bool | None:
# return False


def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]:
num_docs = len(docs)
Expand Down Expand Up @@ -104,6 +121,17 @@ def from_llm(
callbacks = [handler],
verbose=verbose,
)

# llm_chain = RunnableSequence(
# {
# "sentence": RunnablePassthrough(),
# "language": RunnablePassthrough()
# }
# | _prompt
# | llm
# # | output_parser
# )

doc_chain = StuffDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
Expand All @@ -114,7 +142,7 @@ def from_llm(
condense_question_chain = LLMChain(
llm=_llm, prompt=condense_question_prompt, callbacks = [handler], verbose=verbose
)

return cls(
retriever=retriever,
combine_docs_chain=doc_chain,
Expand Down
2 changes: 1 addition & 1 deletion a2rchi/chains/chain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from a2rchi.chains.base import BaseSubMITChain as BaseChain

from chromadb.config import Settings
from langchain.vectorstores import Chroma
from langchain_chroma import Chroma

import chromadb
import time
Expand Down
83 changes: 81 additions & 2 deletions a2rchi/chains/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.chat_models import ChatOpenAI
from langchain.chat_models import ChatAnthropic
from langchain_anthropic import ChatAnthropic
from langchain.llms import LlamaCpp


import requests
from typing import Optional, List



class BaseCustomLLM(LLM):
"""
Abstract class used to load a custom LLM
Expand Down Expand Up @@ -58,7 +64,6 @@ class AnthropicLLM(ChatAnthropic):
"""

model: str = "claude-3-opus-20240229"

temp: int = 1


Expand Down Expand Up @@ -234,3 +239,77 @@ def __call__(self, output_text):
report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
return "Salesforce Content Safety Flan T5 Base", is_safe, report

class BaseCustomLLM(LLM):
"""
Abstract class used to load a custom LLM
"""
n_tokens: int = 100 # this has to be here for parent LLM class

@property
def _llm_type(self) -> str:
return "custom"

@abstractmethod
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
pass


class ClaudeLLM(BaseCustomLLM):
"""
An LLM class that uses Anthropic's Claude model.
"""
#TODO: obscure api key in final production version
api_key: str = "INSERT KEY HERE!!!" # Claude API key
base_url: str = "https://api.anthropic.com/v1/messages" # Anthropic API endpoint
model_name: str = "claude-3-5-sonnet-20240620" # Specify the model version to use

verbose: bool = False

def _call(
self,
prompt: str = None,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
max_tokens: int = 1024,
) -> str:

if stop is not None:
print("WARNING : currently this model does not support stop tokens")

if self.verbose:
print(f"INFO : Starting call to Claude with prompt: {prompt}")

headers = {
"x-api-key": self.api_key, # Use the API key for the x-api-key header
"anthropic-version": "2023-06-01", # Add the required version header
"Content-Type": "application/json"
}

# Modify the payload to match the required structure
payload = {
"model": self.model_name, # You can keep this dynamic based on your code
"max_tokens": max_tokens, # Update to match the required max_tokens
"messages": [ # Use a list of messages where each message has a role and content
{"role": "user", "content": prompt} # Prompt becomes part of the message content
]
}

if self.verbose:
print("INFO: Sending request to Claude API")

# Send request to Claude API
response = requests.post(self.base_url, headers=headers, json=payload)

if response.status_code == 200:
completion = response.json()["content"][0]["text"]
if self.verbose:
print(f"INFO : received response from Claude API: {completion}")
return completion
else:
raise Exception(f"API request to Claude failed with status {response.status_code}, {response.text}")
2 changes: 1 addition & 1 deletion a2rchi/utils/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_config(self):

# change the model class parameter from a string to an actual class
MODEL_MAPPING = {
"AnthropicLLM": AnthropicLLM
"AnthropicLLM": AnthropicLLM,
"OpenAILLM": OpenAILLM,
"DumbLLM": DumbLLM,
"LlamaLLM": LlamaLLM
Expand Down
1 change: 0 additions & 1 deletion a2rchi/utils/scraper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from piazza_api import Piazza

import hashlib
import os
Expand Down
6 changes: 3 additions & 3 deletions config/prod-root-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ chains:
- empty.list
- miscellanea.list
- root-docs.list
- root-tutorial.list
# - root-tutorial.list
# - root-forum.list
base:
# roles that A2rchi knows about
Expand All @@ -53,12 +53,12 @@ chains:
AnthropicLLM:
class: AnthropicLLM
kwargs:
model_name: claude-3-opus-20240229
model: claude-3-opus-20240229
temperature: 1 # not sure if 1 is best value? for now keeping consistent with prior settings.
OpenAILLM:
class: OpenAILLM
kwargs:
model_name: gpt-4
model_name: gpt-4o-mini
temperature: 1
DumbLLM:
class: DumbLLM
Expand Down
62 changes: 35 additions & 27 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,60 @@ authors = [
{name="Matthew Russo", email="mdrusso@mit.edu"},
]
dependencies = [
"accelerate==0.23.0",
"accelerate",
"backoff==2.2.1",
"beautifulsoup4==4.12.2",
"beautifulsoup4==4.12.3",
"bitsandbytes==0.41.1",
"chromadb==0.4.12",
"chromadb==0.4.22",
"clickhouse-connect==0.6.6",
"coloredlogs==15.0.1",
"duckdb==0.8.1",
"fastapi==0.99.1",
"flask==2.3.3",
"flask-cors==4.0.0",
"fastapi==0.109.0",
"flask==3.0.3",
"flask-cors==5.0.0",
"flatbuffers==23.5.26",
"hnswlib==0.7.0",
"httptools==0.6.0",
"httptools==0.6.1",
"humanfriendly==10.0",
"langchain==0.0.268",
"langchain==0.3.0",
"loguru==0.7.2",
"lz4==4.3.2",
"mistune==3.0.1",
"mistune==3.0.2",
"monotonic==1.6",
"onnxruntime==1.15.1",
"openai==0.27.9",
"overrides==7.3.1",
"pandas==2.1.0",
"openai==1.8.0",
"overrides==7.4.0",
"peft==0.5.0",
"piazza-api==0.14.0",
"posthog==3.0.1",
"psycopg2==2.9.9",
"pulsar-client==3.2.0",
"pygments==2.16.1",
"posthog==3.3.1",
"pygments==2.18.0",
"pypdf==3.16.1",
"python-dotenv==1.0.0",
"python-redmine==2.4.0",
"regex==2023.6.3",
"requests==2.31.0",
"scipy==1.11.2",
"sentence-transformers==2.2.2",
"regex==2023.12.25",
"requests==2.32.3",
"sentence-transformers",
"sentencepiece==0.1.99",
"sympy==1.12",
"tiktoken==0.4.0",
"tokenizers==0.13.3",
"torch==2.0.1",
"transformers==4.33.1",
"uvloop==0.17.0",
"watchfiles==0.19.0",
"tiktoken==0.7.0",
"uvloop==0.19.0",
"watchfiles==0.21.0",
"zstandard==0.21.0",
"langchain-anthropic==0.2.1",
"langchain-chroma==0.1.4",
"langchain-community==0.3.0",
"langchain-core==0.3.5",
"langchain-text-splitters==0.3.0",
"pydantic==2.9.2",
"pydantic-core==2.23.4",
"pydantic-settings==2.5.2",
"chroma-hnswlib==0.7.3",
"anthropic==0.34.2",
"torch",
"scipy",
"pulsar-client",
"psycopg2-binary",
"requests",
"pyyaml"
]

[tool.setuptools]
Expand Down
1 change: 1 addition & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 2 additions & 0 deletions test/test_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ def test_chain_call_prevhistory():
result = c1([("User", question), ("A2rchi", answer), ("User", follow_up)])
c1.kill = True
assert result["answer"] is not None


37 changes: 37 additions & 0 deletions test_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from a2rchi.utils.data_manager import DataManager
from a2rchi.chains.chain import Chain

# place to fix:
# 1.update chroma in langchain by 'pip install -U langchain-chroma'
# and use 'from langchain_chroma import Chroma' to replace 'from langchain.vectorstores import Chroma'
# 2.

class DiscourseAIWrapper:
def __init__(self):
self.chain = Chain()
self.data_manager = DataManager()
self.data_manager.update_vectorstore()

def __call__(self, post):

formatted_history = []

formatted_history = [['User', post]]

# form the formatted history using the post

self.data_manager.update_vectorstore()

answer = self.chain(formatted_history)
return answer


# question = 'Hello, I am trying to make a box over a TF1 in the Legend. I have something like this: where fstat is a TF1 legendcell = new TLegend(0.67,0.53,0.98,0.7+0.015*(years.size()));fstat->SetLineColor(429);fstat->SetFillColor(5);fstat->SetFillStyle(1001); egendcell->AddEntry(fstat, "#bf{" + legend + "}" ,"fl"); legendcell->Draw(); The problem is that this also creates a filling in the TGraph. Is there any way to create a box on the TLegend without changing the drawing on the canvas? PS: I also tried to do fstat->SetFillStyle(0); after the drawing of the legend but this also removes the box from the TLegend My root version is 6.28/00 I really appreciate any help you can provide.'
question = 'Hello, I am trying to make a box over a TF1 in the Legend. I have something like this: where fstat is a TF1 legendcell = new TLegend(0.67,0.53,0.98,0.7+0.015*(years.size()));fstat->SetLineColor(429);fstat->SetFillColor(5);fstat->SetFillStyle(1001); egendcell->AddEntry(fstat, "#bf{" + legend + "}" ,"fl"); legendcell->Draw(); The problem is that this also creates a filling in the TGraph. Is there any way to create a box on the TLegend without changing the drawing on the canvas? PS: I also tried to do fstat->SetFillStyle(0); after the drawing of the legend but this also removes the box from the TLegend My root version is 6.28/00 I really appreciate any help you can provide.'


archi = DiscourseAIWrapper()
answer = archi(question)
print("\n\n\n")
print(answer)
print("\n\n\n")

0 comments on commit ccb9203

Please sign in to comment.