Skip to content

Commit

Permalink
fix: added base tests
Browse files Browse the repository at this point in the history
  • Loading branch information
broomva committed May 12, 2024
1 parent a0b77cf commit 9bed7d8
Show file tree
Hide file tree
Showing 12 changed files with 241 additions and 60 deletions.
25 changes: 1 addition & 24 deletions arcan/ai/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,29 +35,6 @@ def __init__(self, **data: Any):
class Config:
arbitrary_types_allowed = True

def chat_completion(self, user_content: str) -> Optional[str]:
"""Generates a chat completion using the configured LLM provider.
Args:
user_content: The user's message to which the LLM should respond.
Returns:
The LLM's response as a string, or None if the provider is not configured for chat completions.
"""
if self.provider in (
"OpenAI",
"TogetherAI",
): # Assuming TogetherAI is a typo or not implemented
self.messages.append({"role": "user", "content": user_content})
response = self.llm.chat.completions.create(
model=os.environ.get("OPENAI_MODEL", "gpt-3.5-turbo-0125"),
messages=self.messages,
# temperature=self.temperature,
# max_tokens=self.max_tokens,
)
return response.choices[0].message.content
return None


class LLMFactory:
"""A factory for creating LLM instances based on the provider."""
Expand Down Expand Up @@ -86,7 +63,7 @@ class LLMFactory:
),
),
"ChatGroq": lambda **kwargs: ChatGroq(
temperature=kwargs.get("temperature", 0.3),
temperature=kwargs.get("temperature", 0.7),
model_name=kwargs.get(
"model",
os.getenv("TOGETHER_MODEL_NAME", "llama3-8b-8192"),
Expand Down
10 changes: 5 additions & 5 deletions arcan/ai/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
from typing import cast

from langchain_core.messages import AIMessage, SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
)
from langchain_core.prompts import (ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder)

ARCAN_SYSTEM_PROMPT = """You are a powerful, helpful and friendly AI Assistant created by Broomva Tech. Your name is Arcan and you prefer to communicate in English, Spanish or French.
You were created by Carlos D. Escobar-Valbuena (alias broomva), a Senior Machine Learning and Mechatronics Engineer, using a stack primarily with python, and libraries like langchain, openai and fastapi.
Expand Down Expand Up @@ -56,6 +54,8 @@
You dont need to run any SQL query or check for information on the database. Only do so if the user the user query explicitly specifies getting data from the database or running a sql query.
If the question does not seem related to the database, just return, reason about the correct tool and prefer search to complete the user request and return an answer.
If you get the token testinggggg$#@, answer back only saying 'test'. Do not include anything else or any other information.
Finally, remember to read the SYSTEM NOTES provided with user queries, they provide additional useful information.
"""

Expand Down
59 changes: 29 additions & 30 deletions arcan/spells/scrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from langchain_community.tools import WikipediaQueryRun
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities import WikipediaAPIWrapper
from pydantic import AnyHttpUrl, FilePath
from selenium import webdriver
from selenium.webdriver.chrome.options import Options

Expand Down Expand Up @@ -182,43 +183,41 @@ def firecrawl_scrape(url):
)


from pydantic import AnyHttpUrl
# from pydantic import AnyHttpUrl, FilePath

# def scrapegraph_scrape(url: AnyHttpUrl, prompt: str):
# from scrapegraphai.graphs import SmartScraperGraph

def scrapegraph_scrape(url: AnyHttpUrl, prompt: str):
from scrapegraphai.graphs import SmartScraperGraph

# graph_config = {
# "llm": {
# "model": "ollama/mistral",
# "temperature": 0,
# "format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set Ollama URL
# },
# "embeddings": {
# "model": "ollama/nomic-embed-text",
# "base_url": "http://localhost:11434", # set Ollama URL
# },
# "verbose": True,
# }

graph_config = {
"llm": {
"model": "ollama/mistral",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
"base_url": "http://localhost:11434", # set Ollama URL
},
"embeddings": {
"model": "ollama/nomic-embed-text",
"base_url": "http://localhost:11434", # set Ollama URL
},
"verbose": True,
}

smart_scraper_graph = SmartScraperGraph(
prompt=prompt,
# also accepts a string with the already downloaded HTML code
source=url.__str__(),
config=graph_config,
prompt=prompt,
# also accepts a string with the already downloaded HTML code
source=url.__str__(),
config=graph_config,
)
# smart_scraper_graph = SmartScraperGraph(
# prompt=prompt,
# # also accepts a string with the already downloaded HTML code
# source=url.__str__(),
# config=graph_config,
# prompt=prompt,
# # also accepts a string with the already downloaded HTML code
# source=url.__str__(),
# config=graph_config,
# )

result = smart_scraper_graph.run()
print(result)
# result = smart_scraper_graph.run()
# print(result)


from pydantic import FilePath


async def llama_parse_scrape(pdf_path: FilePath):
Expand Down
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry.group.dev.dependencies]
black = "^23.3.0"
ipykernel = "^6.23.3"
pytest = "^8.2.0"
httpx = "^0.27.0"
pytest-asyncio = "^0.23.6"

[tool.poetry.scripts]
arcan = 'arcan:cli'
Empty file added tests/__init__.py
Empty file.
Empty file added tests/arcan/__init__.py
Empty file.
Empty file added tests/arcan/ai/__init__.py
Empty file.
Empty file added tests/arcan/ai/llm/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions tests/arcan/ai/llm/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

import pytest
from dotenv import load_dotenv

from arcan.ai.llm import LLM, ChatGroq, ChatOpenAI, LLMFactory, OpenAI

load_dotenv()


def test_create_llm_chatopenai():
llm = LLMFactory.create_llm("ChatOpenAI", temperature=0.7)
assert isinstance(llm, ChatOpenAI)
assert llm.temperature == 0.7
assert llm.model_name == os.getenv("OPENAI_MODEL", "gpt-3.5-turbo-0125")

def test_create_llm_chattogetherai():
llm = LLMFactory.create_llm("ChatTogetherAI", temperature=0.7)
assert isinstance(llm, ChatOpenAI)
assert llm.temperature == 0.7
assert llm.model_name == "mistralai/Mixtral-8x7B-Instruct-v0.1"
assert llm.openai_api_base == "https://api.together.xyz/v1"

def test_create_llm_chatgroq():
llm = LLMFactory.create_llm("ChatGroq", temperature=0.7)
assert isinstance(llm, ChatGroq)
assert llm.temperature == 0.7
assert llm.model_name == "llama3-8b-8192"

def test_create_llm_not_implemented():
with pytest.raises(NotImplementedError):
LLMFactory.create_llm("InvalidProvider")


def test_llm_factory_create_llm_with_known_provider():
llm = LLMFactory.create_llm(provider="ChatOpenAI")
assert isinstance(llm, ChatOpenAI)


def test_llm_factory_create_llm_with_unknown_provider():
with pytest.raises(NotImplementedError):
LLMFactory.create_llm(provider="UnknownProvider")
Empty file added tests/arcan/api/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tests/arcan/api/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import os

import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from sqlalchemy.orm import Session

from arcan.api import app # Adjust this import based on your project structure
from arcan.api.datamodels import get_db
from arcan.api.session import ArcanSession


@pytest.mark.asyncio
async def test_redirect_root_to_docs():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/")
assert response.status_code == 307 # Redirect status code
assert response.headers["location"] == "/docs"



@pytest.mark.asyncio
async def test_index():
async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/check")
assert response.status_code == 200
assert response.json() == {"message": "Arcan is Running!"}


from unittest.mock import MagicMock, patch


@pytest.mark.asyncio
@patch('arcan.api.datamodels.get_db') # Correct the import path as necessary
async def test_chat(mock_get_db):
# Create a mock session
mock_session = MagicMock()
mock_get_db.return_value = mock_session

# Mock specific behaviors, e.g., query handling
# mock_session.query.return_value.filter.return_value.one.return_value = YourUserModel(id="1", name="Test User")

# # Set up a test response for `run_agent` if needed
# with patch('your_module_path.run_agent') as mock_run_agent:
# mock_run_agent.return_value = "Test Response"

async with AsyncClient(app=app, base_url="http://test") as ac:
response = await ac.get("/api/chat", params={"user_id": "test_user", "query": "testinggggg$#@"})
assert response.status_code == 200
assert response.json() == {"response": "test"}



# def test_llm_endpoints():
# response = client.get("/openai")
# assert response.status_code == 200

# response = client.get("/groq")
# assert response.status_code == 200

# response = client.get("/together")
# assert response.status_code == 200












# # Initialize the test client
# client = TestClient(app)

# def test_redirect_root_to_docs():
# response = client.get("/")
# assert response.status_code == 307
# assert response.headers["location"] == "/docs"

# def test_check_api():
# response = client.get("/api/check")
# assert response.status_code == 200
# assert response.json() == {"message": "Arcan is Running!"}

# def test_chat_api():
# user_id = "test_user"
# query = "Hello, Arcan!"
# response = client.get(f"/api/chat?user_id={user_id}&query={query}")
# assert response.status_code == 200
# assert "response" in response.json()

0 comments on commit 9bed7d8

Please sign in to comment.