Skip to content

Commit

Permalink
Fix test CI (#8)
Browse files Browse the repository at this point in the history
* update ci

* fixes

* fixes tests

* update

* fixes

* fixes

* fixes
  • Loading branch information
aniketmaurya authored May 9, 2024
1 parent 0effb2c commit 0c0c02f
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 25 deletions.
31 changes: 17 additions & 14 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,46 @@ jobs:
strategy:
matrix:
os: [ ubuntu-latest, macos-latest ]
python-version: [3.8, 3.9]
python-version: ["3.8", "3.9", "3.10"]
include:
- os: ubuntu-latest
path: ~/.cache/pip
- os: macos-latest
path: ~/Library/Caches/pip

timeout-minutes: 35
env:
OS: ${{ matrix.os }}
PYTHON: '3.9'

TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html"

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
with:
fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Get pip cache dir
id: pip-cache
run: echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT

- name: Cache pip
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ${{ matrix.path }}
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-

- name: Install
- name: Install package & dependencies
run: |
python --version
pip --version
python -m pip install --upgrade pip
pip install coverage pytest
pip install .
pip install . -U -q --find-links $TORCH_URL
pip install -r tests/requirements.txt
pip list
shell: bash

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pandas
loguru
wikipedia
googlesearch-python
langchain-cohere
2 changes: 1 addition & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@

logger.remove(0)
LOGURU_LEVEL = env("LOGURU_LEVEL", str, "INFO")
logger.start(sys.stderr, level=LOGURU_LEVEL)
logger.add(sys.stderr, level=LOGURU_LEVEL)

__version__ = "0.0.1"
4 changes: 2 additions & 2 deletions src/agents/llms/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from langchain_cohere import ChatCohere


from typing import List, Optional, Any
from typing import List, Optional, Any, Dict

from langchain_core.messages import AIMessage
from loguru import logger
Expand Down Expand Up @@ -59,5 +59,5 @@ def chat_completion(
logger.debug(output)
return self._format_cohere_to_openai(output)

def run_tool(self, chat_completion: ChatCompletion) -> list[dict[str, Any]]:
def run_tool(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tool(chat_completion)
4 changes: 2 additions & 2 deletions src/agents/llms/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict
from loguru import logger

from agents.specs import ChatCompletion
Expand Down Expand Up @@ -78,5 +78,5 @@ def chat_completion(
logger.debug(output)
return ChatCompletion(**output)

def run_tool(self, chat_completion: ChatCompletion) -> list[dict[str, Any]]:
def run_tool(self, chat_completion: ChatCompletion) -> List[Dict[str, Any]]:
return self.tool_registry.call_tool(chat_completion)
10 changes: 5 additions & 5 deletions src/agents/tool_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"""

import json
from typing import Any, List
from typing import Any, List, Union, Dict

from langchain_community.tools import StructuredTool

Expand All @@ -42,8 +42,8 @@

class ToolRegistry:
def __init__(self):
self._tools: dict[str, StructuredTool] = {}
self._formatted_tools: dict[str, Any] = {}
self._tools: Dict[str, StructuredTool] = {}
self._formatted_tools: Dict[str, Any] = {}

def register_tool(self, tool: StructuredTool):
self._tools[tool.name] = tool
Expand All @@ -59,15 +59,15 @@ def pop(self, name: str) -> StructuredTool:
return self._tools.pop(name)

@property
def openai_tools(self) -> List[dict[str, Any]]:
def openai_tools(self) -> List[Dict[str, Any]]:
# [{"type": "function", "function": registry.openai_tools[0]}],
result = []
for oai_tool in self._formatted_tools.values():
result.append({"type": "function", "function": oai_tool})

return result

def call_tool(self, output: ChatCompletion | dict) -> list[dict[str, str]]:
def call_tool(self, output: Union[ChatCompletion, Dict]) -> List[Dict[str, str]]:
if isinstance(output, dict):
output = ChatCompletion(**output)

Expand Down
2 changes: 2 additions & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest
coverage
4 changes: 3 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json

from agents.tools import get_current_weather


def test_get_current_weather():
current_weather = get_current_weather("San Francisco")
current_weather = json.loads(get_current_weather("San Francisco"))
assert isinstance(current_weather, dict)
assert "FeelsLikeC" in current_weather

0 comments on commit 0c0c02f

Please sign in to comment.