diff --git a/a2rchi/chains/base.py b/a2rchi/chains/base.py index f5b67d9..768e5b7 100644 --- a/a2rchi/chains/base.py +++ b/a2rchi/chains/base.py @@ -1,5 +1,6 @@ """Chain for chatting with a vector database.""" from __future__ import annotations +from pydantic import BaseModel from loguru import logger from langchain.callbacks import FileCallbackHandler @@ -12,7 +13,10 @@ from langchain.chains.llm import LLMChain from langchain.schema import BaseRetriever, Document from langchain.schema.prompt_template import BasePromptTemplate +from langchain_core.runnables import RunnableSequence, RunnablePassthrough from typing import Any, Dict, List, Optional, Tuple +from typing import Callable +# from pydantic import model_rebuild import os @@ -46,7 +50,20 @@ class BaseSubMITChain(BaseConversationalRetrievalChain): """ retriever: BaseRetriever # Index to connect to max_tokens_limit: Optional[int] = None # restrict doc length to return from store, enforced only for StuffDocumentChain - get_chat_history: Optional[function] = _get_chat_history + get_chat_history: Optional[Callable[[List[Tuple[str, str]]], str]] = _get_chat_history + + # kuangfei: the rebuild logic is default or rewrite + # @classmethod + # def model_rebuild( + # cls, + # *, + # force: bool = False, + # raise_errors: bool = True, + # _parent_namespace_depth: int = 2, + # _types_namespace: dict[str, Any] | None = None, + # ) -> bool | None: + # return False + def _reduce_tokens_below_limit(self, docs: List[Document]) -> List[Document]: num_docs = len(docs) @@ -104,6 +121,17 @@ def from_llm( callbacks = [handler], verbose=verbose, ) + + # llm_chain = RunnableSequence( + # { + # "sentence": RunnablePassthrough(), + # "language": RunnablePassthrough() + # } + # | _prompt + # | llm + # # | output_parser + # ) + doc_chain = StuffDocumentsChain( llm_chain=llm_chain, document_variable_name=document_variable_name, @@ -114,7 +142,7 @@ def from_llm( condense_question_chain = LLMChain( llm=_llm, prompt=condense_question_prompt, callbacks = [handler], verbose=verbose ) - + return cls( retriever=retriever, combine_docs_chain=doc_chain, diff --git a/a2rchi/chains/chain.py b/a2rchi/chains/chain.py index 8ab6ee9..6103de8 100644 --- a/a2rchi/chains/chain.py +++ b/a2rchi/chains/chain.py @@ -1,7 +1,7 @@ from a2rchi.chains.base import BaseSubMITChain as BaseChain from chromadb.config import Settings -from langchain.vectorstores import Chroma +from langchain_chroma import Chroma import chromadb import time diff --git a/a2rchi/chains/models.py b/a2rchi/chains/models.py index 91fcd3d..e3773dd 100644 --- a/a2rchi/chains/models.py +++ b/a2rchi/chains/models.py @@ -7,9 +7,15 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.chat_models import ChatOpenAI -from langchain.chat_models import ChatAnthropic +from langchain_anthropic import ChatAnthropic from langchain.llms import LlamaCpp + +import requests +from typing import Optional, List + + + class BaseCustomLLM(LLM): """ Abstract class used to load a custom LLM @@ -58,7 +64,6 @@ class AnthropicLLM(ChatAnthropic): """ model: str = "claude-3-opus-20240229" - temp: int = 1 @@ -234,3 +239,77 @@ def __call__(self, output_text): report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n" report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n" return "Salesforce Content Safety Flan T5 Base", is_safe, report + +class BaseCustomLLM(LLM): + """ + Abstract class used to load a custom LLM + """ + n_tokens: int = 100 # this has to be here for parent LLM class + + @property + def _llm_type(self) -> str: + return "custom" + + @abstractmethod + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + ) -> str: + pass + + +class ClaudeLLM(BaseCustomLLM): + """ + An LLM class that uses Anthropic's Claude model. + """ + #TODO: obscure api key in final production version + api_key: str = "INSERT KEY HERE!!!" # Claude API key + base_url: str = "https://api.anthropic.com/v1/messages" # Anthropic API endpoint + model_name: str = "claude-3-5-sonnet-20240620" # Specify the model version to use + + verbose: bool = False + + def _call( + self, + prompt: str = None, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + max_tokens: int = 1024, + ) -> str: + + if stop is not None: + print("WARNING : currently this model does not support stop tokens") + + if self.verbose: + print(f"INFO : Starting call to Claude with prompt: {prompt}") + + headers = { + "x-api-key": self.api_key, # Use the API key for the x-api-key header + "anthropic-version": "2023-06-01", # Add the required version header + "Content-Type": "application/json" + } + + # Modify the payload to match the required structure + payload = { + "model": self.model_name, # You can keep this dynamic based on your code + "max_tokens": max_tokens, # Update to match the required max_tokens + "messages": [ # Use a list of messages where each message has a role and content + {"role": "user", "content": prompt} # Prompt becomes part of the message content + ] + } + + if self.verbose: + print("INFO: Sending request to Claude API") + + # Send request to Claude API + response = requests.post(self.base_url, headers=headers, json=payload) + + if response.status_code == 200: + completion = response.json()["content"][0]["text"] + if self.verbose: + print(f"INFO : received response from Claude API: {completion}") + return completion + else: + raise Exception(f"API request to Claude failed with status {response.status_code}, {response.text}") diff --git a/a2rchi/utils/config_loader.py b/a2rchi/utils/config_loader.py index 80933b3..b363a5b 100644 --- a/a2rchi/utils/config_loader.py +++ b/a2rchi/utils/config_loader.py @@ -23,7 +23,7 @@ def load_config(self): # change the model class parameter from a string to an actual class MODEL_MAPPING = { - "AnthropicLLM": AnthropicLLM + "AnthropicLLM": AnthropicLLM, "OpenAILLM": OpenAILLM, "DumbLLM": DumbLLM, "LlamaLLM": LlamaLLM diff --git a/a2rchi/utils/scraper.py b/a2rchi/utils/scraper.py index a608f3c..1e1a87d 100644 --- a/a2rchi/utils/scraper.py +++ b/a2rchi/utils/scraper.py @@ -1,4 +1,3 @@ -from piazza_api import Piazza import hashlib import os diff --git a/config/prod-root-config.yaml b/config/prod-root-config.yaml index 926b408..a30f173 100644 --- a/config/prod-root-config.yaml +++ b/config/prod-root-config.yaml @@ -28,7 +28,7 @@ chains: - empty.list - miscellanea.list - root-docs.list - - root-tutorial.list + # - root-tutorial.list # - root-forum.list base: # roles that A2rchi knows about @@ -53,12 +53,12 @@ chains: AnthropicLLM: class: AnthropicLLM kwargs: - model_name: claude-3-opus-20240229 + model: claude-3-opus-20240229 temperature: 1 # not sure if 1 is best value? for now keeping consistent with prior settings. OpenAILLM: class: OpenAILLM kwargs: - model_name: gpt-4 + model_name: gpt-4o-mini temperature: 1 DumbLLM: class: DumbLLM diff --git a/pyproject.toml b/pyproject.toml index fe8c9d9..400a7c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,52 +12,60 @@ authors = [ {name="Matthew Russo", email="mdrusso@mit.edu"}, ] dependencies = [ - "accelerate==0.23.0", + "accelerate", "backoff==2.2.1", - "beautifulsoup4==4.12.2", + "beautifulsoup4==4.12.3", "bitsandbytes==0.41.1", - "chromadb==0.4.12", + "chromadb==0.4.22", "clickhouse-connect==0.6.6", "coloredlogs==15.0.1", "duckdb==0.8.1", - "fastapi==0.99.1", - "flask==2.3.3", - "flask-cors==4.0.0", + "fastapi==0.109.0", + "flask==3.0.3", + "flask-cors==5.0.0", "flatbuffers==23.5.26", "hnswlib==0.7.0", - "httptools==0.6.0", + "httptools==0.6.1", "humanfriendly==10.0", - "langchain==0.0.268", + "langchain==0.3.0", "loguru==0.7.2", "lz4==4.3.2", - "mistune==3.0.1", + "mistune==3.0.2", "monotonic==1.6", - "onnxruntime==1.15.1", - "openai==0.27.9", - "overrides==7.3.1", - "pandas==2.1.0", + "openai==1.8.0", + "overrides==7.4.0", "peft==0.5.0", "piazza-api==0.14.0", - "posthog==3.0.1", - "psycopg2==2.9.9", - "pulsar-client==3.2.0", - "pygments==2.16.1", + "posthog==3.3.1", + "pygments==2.18.0", "pypdf==3.16.1", "python-dotenv==1.0.0", "python-redmine==2.4.0", - "regex==2023.6.3", - "requests==2.31.0", - "scipy==1.11.2", - "sentence-transformers==2.2.2", + "regex==2023.12.25", + "requests==2.32.3", + "sentence-transformers", "sentencepiece==0.1.99", "sympy==1.12", - "tiktoken==0.4.0", - "tokenizers==0.13.3", - "torch==2.0.1", - "transformers==4.33.1", - "uvloop==0.17.0", - "watchfiles==0.19.0", + "tiktoken==0.7.0", + "uvloop==0.19.0", + "watchfiles==0.21.0", "zstandard==0.21.0", + "langchain-anthropic==0.2.1", + "langchain-chroma==0.1.4", + "langchain-community==0.3.0", + "langchain-core==0.3.5", + "langchain-text-splitters==0.3.0", + "pydantic==2.9.2", + "pydantic-core==2.23.4", + "pydantic-settings==2.5.2", + "chroma-hnswlib==0.7.3", + "anthropic==0.34.2", + "torch", + "scipy", + "pulsar-client", + "psycopg2-binary", + "requests", + "pyyaml" ] [tool.setuptools] diff --git a/test/__init__.py b/test/__init__.py index e69de29..8b13789 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -0,0 +1 @@ + diff --git a/test/test_chains.py b/test/test_chains.py index ec7fd6b..cae0257 100644 --- a/test/test_chains.py +++ b/test/test_chains.py @@ -73,3 +73,5 @@ def test_chain_call_prevhistory(): result = c1([("User", question), ("A2rchi", answer), ("User", follow_up)]) c1.kill = True assert result["answer"] is not None + + diff --git a/test_langchain.py b/test_langchain.py new file mode 100644 index 0000000..b44d951 --- /dev/null +++ b/test_langchain.py @@ -0,0 +1,37 @@ +from a2rchi.utils.data_manager import DataManager +from a2rchi.chains.chain import Chain + +# place to fix: +# 1.update chroma in langchain by 'pip install -U langchain-chroma' +# and use 'from langchain_chroma import Chroma' to replace 'from langchain.vectorstores import Chroma' +# 2. + +class DiscourseAIWrapper: + def __init__(self): + self.chain = Chain() + self.data_manager = DataManager() + self.data_manager.update_vectorstore() + + def __call__(self, post): + + formatted_history = [] + + formatted_history = [['User', post]] + + # form the formatted history using the post + + self.data_manager.update_vectorstore() + + answer = self.chain(formatted_history) + return answer + + +# question = 'Hello, I am trying to make a box over a TF1 in the Legend. I have something like this: where fstat is a TF1 legendcell = new TLegend(0.67,0.53,0.98,0.7+0.015*(years.size()));fstat->SetLineColor(429);fstat->SetFillColor(5);fstat->SetFillStyle(1001); egendcell->AddEntry(fstat, "#bf{" + legend + "}" ,"fl"); legendcell->Draw(); The problem is that this also creates a filling in the TGraph. Is there any way to create a box on the TLegend without changing the drawing on the canvas? PS: I also tried to do fstat->SetFillStyle(0); after the drawing of the legend but this also removes the box from the TLegend My root version is 6.28/00 I really appreciate any help you can provide.' +question = 'Hello, I am trying to make a box over a TF1 in the Legend. I have something like this: where fstat is a TF1 legendcell = new TLegend(0.67,0.53,0.98,0.7+0.015*(years.size()));fstat->SetLineColor(429);fstat->SetFillColor(5);fstat->SetFillStyle(1001); egendcell->AddEntry(fstat, "#bf{" + legend + "}" ,"fl"); legendcell->Draw(); The problem is that this also creates a filling in the TGraph. Is there any way to create a box on the TLegend without changing the drawing on the canvas? PS: I also tried to do fstat->SetFillStyle(0); after the drawing of the legend but this also removes the box from the TLegend My root version is 6.28/00 I really appreciate any help you can provide.' + + +archi = DiscourseAIWrapper() +answer = archi(question) +print("\n\n\n") +print(answer) +print("\n\n\n")