Skip to content
This repository has been archived by the owner on Apr 24, 2024. It is now read-only.

Commit

Permalink
feat: add docstring and type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
dsdanielpark committed Jan 18, 2024
1 parent 6f17cc3 commit 721c1da
Showing 1 changed file with 111 additions and 16 deletions.
127 changes: 111 additions & 16 deletions bardapi/core_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import string
import random
import base64
from typing import Optional
import logging
from re import search
from typing import Optional, Tuple, Union
from httpx import AsyncClient

try:
Expand Down Expand Up @@ -36,6 +36,18 @@
)

class BardAsync:
"""
BardAsync is a class for interacting with the Bard API asynchronously.
Attributes:
token (str): Bard API token.
client (AsyncClient): Asynchronous client for making HTTP requests.
conversation_id (str): ID to maintain the context of the conversation.
google_translator_api_key (str): API key for Google Cloud Translation.
language (str): The language in which the input text is written.
run_code (bool): A flag to determine whether to execute code snippets.
token_from_browser (bool): A flag to determine whether to get the token from the browser.
"""
def __init__(
self,
token: Optional[str] = None,
Expand All @@ -48,6 +60,23 @@ def __init__(
run_code: bool = False,
token_from_browser: bool = False,
):
"""
Initialize the BardAsync class.
Args:
token (Optional[str]): Bard API token.
timeout (int): Timeout for HTTP requests.
proxies (Optional[dict]): Proxies for HTTP requests.
client (Optional[AsyncClient]): Asynchronous client for making HTTP requests.
conversation_id (Optional[str]): ID to maintain the context of the conversation.
google_translator_api_key (Optional[str]): API key for Google Cloud Translation.
language (Optional[str]): The language in which the input text is written.
run_code (bool): A flag to determine whether to execute code snippets.
token_from_browser (bool): A flag to determine whether to get the token from the browser.
Raises:
Exception: If the token is not provided and cannot be extracted from the environment variable.
"""
self.token = self._get_token(token, token_from_browser)
if not self.token:
raise Exception("Token must be provided either directly or through _BARD_API_KEY environment variable.")
Expand All @@ -71,12 +100,21 @@ def __init__(
self.google_translator_api_key = google_translator_api_key
self.SNlM0e = None

async def async_setup(self):
async def async_setup(self) -> None:
"""
Set up the BardAsync instance asynchronously.
"""
self.SNlM0e = await self._get_snim0e()
if not self.client:
self.client = await self._get_client() # Ensure this is awaited

async def _get_snim0e(self):
async def _get_snim0e(self) -> Optional[str]:
"""
Get the SNlM0e value from the Bard website.
Returns:
Optional[str]: The SNlM0e value if found, otherwise None.
"""
if isinstance(self.SNlM0e, str):
return self.SNlM0e

Expand All @@ -102,6 +140,12 @@ async def _get_snim0e(self):
return self.SNlM0e

async def _initialize_client(self) -> AsyncClient:
"""
Initialize the AsyncClient instance.
Returns:
AsyncClient: The initialized AsyncClient instance.
"""
return AsyncClient(
http2=True,
headers=SESSION_HEADERS,
Expand All @@ -113,12 +157,14 @@ async def _initialize_client(self) -> AsyncClient:
def _get_token(self, token: str, token_from_browser: bool) -> str:
"""
Get the Bard API token either from the provided token or from the browser cookie.
Args:
token (str): Bard API token.
token_from_browser (bool): Whether to extract the token from the browser cookie.
Returns:
dict: The Bard API tokens.
str: The Bard API token.
Raises:
Exception: If the token is not provided and can't be extracted from the browser.
"""
Expand All @@ -138,13 +184,13 @@ def _get_token(self, token: str, token_from_browser: bool) -> str:

async def _get_client(self, session: Optional[AsyncClient]) -> AsyncClient:
"""
The _get_snim0e function is used to get the SNlM0e value from the Bard website.
The function uses a regular expression to search for the SNlM0e value in the response text.
If it finds it, then it returns that value.
:param self: Represent the instance of the class
:return: (`str`) The SNlM0e value
Get or initialize the AsyncClient instance.
Args:
session (Optional[AsyncClient]): Existing AsyncClient instance.
Returns:
AsyncClient: The AsyncClient instance.
"""
if session is None:
async_client = AsyncClient(
Expand All @@ -160,6 +206,15 @@ async def _get_client(self, session: Optional[AsyncClient]) -> AsyncClient:
return session

async def get_answer(self, input_text: str) -> dict:
"""
Get the answer from the Bard API for the input text.
Args:
input_text (str): Text input for which the answer is sought.
Returns:
dict: The response from the Bard API.
"""
params, data = self._prepare_request(input_text)
resp = await self.client.post(
POST_ENDPOINT,
Expand All @@ -172,7 +227,16 @@ async def get_answer(self, input_text: str) -> dict:
)
return self._process_response(resp)

def _prepare_request(self, input_text: str):
def _prepare_request(self, input_text: str) -> Tuple[dict, dict]:
"""
Prepare the request for the Bard API.
Args:
input_text (str): Text input for which the answer is sought.
Returns:
Tuple[dict, dict]: The parameters and data for the POST request.
"""
# Translate the input text if the language is not allowed and a translator is available
if self.language not in ALLOWED_LANGUAGES:
if self.google_translator_api_key:
Expand Down Expand Up @@ -204,6 +268,15 @@ def _prepare_request(self, input_text: str):
return params, data

def _process_response(self, resp) -> dict:
"""
Process the response from the Bard API.
Args:
resp: The response from the Bard API.
Returns:
dict: The processed response.
"""
if resp.status_code != 200:
logging.error(f"Response status code: {resp.status_code}")
return {"content": f"Response Error: {resp.content}."}
Expand All @@ -219,6 +292,16 @@ def _process_response(self, resp) -> dict:
return bard_answer

def _extract_answer(self, parsed_answer, resp) -> dict:
"""
Extract the answer from the parsed response.
Args:
parsed_answer: The parsed response from the Bard API.
resp: The original response from the Bard API.
Returns:
dict: The extracted answer.
"""
# Assuming 'parsed_answer' is a dictionary that contains the required information
bard_answer = {
"content": parsed_answer[4][0][1][0],
Expand All @@ -235,14 +318,26 @@ def _extract_answer(self, parsed_answer, resp) -> dict:
}
return bard_answer

def _update_state(self, bard_answer):
def _update_state(self, bard_answer) -> None:
"""
Update the state with the latest conversation ID, response ID, and choice ID.
Args:
bard_answer (dict): The bard answer containing the state information.
"""
self.conversation_id = bard_answer.get("conversation_id", "")
self.response_id = bard_answer.get("response_id", "")
choice_id = bard_answer.get("choices", [{}])[0].get("id", "")
self.choice_id = choice_id if choice_id else self.choice_id
self._reqid += 100000 # Increment _reqid for the next request

def _execute_code_if_needed(self, bard_answer):
def _execute_code_if_needed(self, bard_answer) -> None:
"""
Execute code snippets if the `run_code` flag is set and code is present in the bard answer.
Args:
bard_answer (dict): The bard answer containing the code to execute.
"""
if self.run_code and bard_answer.get("code"):
code = bard_answer["code"]
try:
Expand Down

0 comments on commit 721c1da

Please sign in to comment.