Skip to content

Commit

Permalink
Add internet processor
Browse files Browse the repository at this point in the history
  • Loading branch information
yorevs committed Mar 10, 2024
1 parent 4126129 commit 626fba5
Show file tree
Hide file tree
Showing 16 changed files with 198 additions and 88 deletions.
16 changes: 16 additions & 0 deletions docs/devel/snippets/snippet-3-internet.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
if __name__ == '__main__':
from googleapiclient.discovery import build
import pprint

my_api_key = os.environ.get("GOOGLE_API_KEY")
my_cse_id = os.environ.get("GOOGLE_CSE_ID")

def google_search(search_term, api_key, cse_id, **kwargs):
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=search_term, cx=cse_id, **kwargs).execute()
return res['items']

results = google_search(
'stackoverflow site:en.wikipedia.org', my_api_key, my_cse_id, num=10)
for result in results:
pprint.pprint(result)
42 changes: 19 additions & 23 deletions src/main/askai/core/askai.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@
from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.audio_player import AudioPlayer
from askai.core.component.cache_service import CacheService
from askai.core.support.object_mapper import ObjectMapper
from askai.core.component.cache_service import cache
from askai.core.component.recorder import recorder
from askai.core.engine.ai_engine import AIEngine
from askai.core.model.chat_context import ChatContext
from askai.core.model.query_response import QueryResponse
from askai.core.processor.ai_processor import AIProcessor
from askai.core.processor.processor_proxy import proxy
from askai.core.support.object_mapper import object_mapper
from askai.core.support.shared_instances import shared
from askai.core.support.utilities import display_text

Expand Down Expand Up @@ -88,7 +88,7 @@ def __str__(self) -> str:
f"{'--' * 40} %EOL%"
f"Interactive: ON %EOL%"
f" Speaking: {'ON' if self.is_speak else 'OFF'}{device_info} %EOL%"
f" Caching: {'ON' if CacheService.is_cache_enabled() else 'OFF'} %EOL%"
f" Caching: {'ON' if cache.is_cache_enabled() else 'OFF'} %EOL%"
f" Tempo: {configs.tempo} %EOL%"
f"{'--' * 40} %EOL%%NC%"
)
Expand Down Expand Up @@ -186,8 +186,8 @@ def _startup(self) -> None:
splash_thread.start()
if configs.is_speak:
AudioPlayer.INSTANCE.start_delay()
CacheService.set_cache_enable(self.cache_enabled)
CacheService.read_query_history()
cache.set_cache_enable(self.cache_enabled)
cache.read_query_history()
askai_bus = AskAiEvents.get_bus(ASKAI_BUS_NAME)
askai_bus.subscribe(REPLY_EVENT, self._cb_reply_event)
self._ready = True
Expand Down Expand Up @@ -229,7 +229,7 @@ def _ask_and_reply(self, question: str) -> bool:
"""Ask the question and provide the reply.
:param question: The question to ask to the AI engine.
"""
if not (reply := CacheService.read_reply(question)):
if not (reply := cache.read_reply(question)):
log.debug('Response not found for "%s" in cache. Querying from %s.', question, self.engine.nickname())
status, response = proxy.process(question)
if status:
Expand All @@ -251,28 +251,24 @@ def _process_response(self, proxy_response: QueryResponse) -> bool:
elif proxy_response.terminating:
log.info("User wants to terminate the conversation.")
return False
elif proxy_response.require_internet:
log.info("Internet is required to fulfill the request.")
pass

if q_type := proxy_response.query_type:
processor: AIProcessor = AIProcessor.get_by_query_type(q_type)
if not processor:
if not (processor := AIProcessor.get_by_query_type(q_type)):
log.error(f"Unable to find a proper processor for query type: {q_type}")
self.reply_error(str(proxy_response))
else:
log.info("%s::Processing response for '%s'", processor, proxy_response.question)
status, output = processor.process(proxy_response)
if status and processor.next_in_chain():
mapped_response = ObjectMapper.INSTANCE.of_json(output, QueryResponse)
if isinstance(mapped_response, QueryResponse):
self._process_response(mapped_response)
else:
self.reply(str(mapped_response))
elif status:
self.reply(str(output))
return False
log.info("%s::Processing response for '%s'", processor, proxy_response.question)
status, output = processor.process(proxy_response)
if status and processor.next_in_chain():
mapped_response = object_mapper.of_json(output, QueryResponse)
if isinstance(mapped_response, QueryResponse):
self._process_response(mapped_response)
else:
self.reply_error(str(output))
self.reply(str(mapped_response))
elif status:
self.reply(str(output))
else:
self.reply_error(str(output))
else:
self.reply_error(msg.invalid_response(proxy_response))

Expand Down
2 changes: 1 addition & 1 deletion src/main/askai/core/component/cache_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def get_audio_file(cls, text: str, audio_format: str = "mp3") -> Tuple[str, bool
return audio_file_path, file_is_not_empty(audio_file_path)


assert CacheService().INSTANCE is not None
assert (cache := CacheService().INSTANCE) is not None
48 changes: 48 additions & 0 deletions src/main/askai/core/component/internet_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@project: HsPyLib-AskAI
@package: askai.utils
@file: cache_service.py
@created: Tue, 16 Jan 2024
@author: <B>H</B>ugo <B>S</B>aporetti <B>J</B>unior"
@site: https://github.com/yorevs/hspylib
@license: MIT - Please refer to <https://opensource.org/licenses/MIT>
Copyright·(c)·2024,·HSPyLib
"""
import logging as log
import os
from typing import List, Optional

from hspylib.core.metaclass.singleton import Singleton
from langchain_community.utilities import GoogleSearchAPIWrapper
from langchain_core.tools import Tool


class InternetService(metaclass=Singleton):
"""Provide a internet search service used to complete queries that require realtime data.ß"""

INSTANCE: 'InternetService' = None

ASKAI_INTERNET_DATA_KEY = "askai-internet-data"

def __init__(self):
self._search = GoogleSearchAPIWrapper()
self._tool = Tool(
name="google_search", description="Search Google for recent results.", func=self._search.run,
)

def _top_results(self, query: str, max_results: int = 5) -> List[str]:
"""TODO"""
return self._search.results(query, max_results)

def search(self, query: str) -> Optional[str]:
"""TODO"""
search_results = self._tool.run(query)
log.debug(f"Internet search returned: %s", search_results)
return os.linesep.join(search_results) if isinstance(search_results, list) else search_results


assert (internet := InternetService().INSTANCE) is not None
2 changes: 1 addition & 1 deletion src/main/askai/core/model/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_many(self, *keys: str) -> List[dict]:
for key in keys:
if (content := self.get(key)) and (token_length + len(content)) > self._token_limit:
raise TokenLengthExceeded(f"Required token length={token_length} limit={self._token_limit}")
context += content
context += content or ''
return context

def clear(self, key: str) -> int:
Expand Down
42 changes: 0 additions & 42 deletions src/main/askai/core/model/internet_research.py

This file was deleted.

15 changes: 15 additions & 0 deletions src/main/askai/core/model/search_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import json
from dataclasses import dataclass
from typing import List


@dataclass
class SearchResult:
"""Keep track of the internet search responses."""

query: str = None
urls: str | List[str] = None
results: str = None

def __str__(self):
return f"Internet search results: {json.dumps(self.__dict__, default=lambda obj: obj.__dict__)}"
16 changes: 11 additions & 5 deletions src/main/askai/core/processor/ai_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def find_query_types(cls) -> str:
proc_name = os.path.splitext(proc)[0]
proc_pkg = import_module(f"{__package__}.{proc_name}")
proc_class = getattr(proc_pkg, camelcase(proc_name, capitalized=True))
proc_inst = proc_class()
proc_inst: 'AIProcessor' = proc_class()
cls._PROCESSORS[proc_inst.processor_id()] = proc_inst
q_types.append(str(proc_inst))
if proc_inst.query_desc():
q_types.append(str(proc_inst))
return os.linesep.join(q_types)

@classmethod
Expand All @@ -67,6 +68,7 @@ def get_by_name(cls, name: str) -> Optional['AIProcessor']:
def __init__(self, template_file: str | Path, persona_file: str | Path):
self._template_file = str(template_file)
self._persona_file = str(persona_file)
self._next_in_chain = None

def __str__(self):
return f"'{self.query_type()}': {self.query_desc()}"
Expand All @@ -87,18 +89,22 @@ def processor_id(self) -> str:

def query_type(self) -> str:
"""Get the query type this processor can handle. By default, it's the name of the processor itself."""
return self.name
return self.processor_id()

def query_desc(self) -> str:
"""TODO"""
...
return ''

def template(self) -> str:
return prompt.read_prompt(self._template_file, self._persona_file)

def next_in_chain(self) -> Optional['AIProcessor']:
"""Return the next processor in the chain to call. Defaults to None."""
return None
return self._next_in_chain

def bind(self, next_in_chain: 'AIProcessor'):
"""Bind a processor to be the next in chain."""
self._next_in_chain = next_in_chain

def process(self, query_response: QueryResponse) -> Tuple[bool, Optional[str]]:
"""TODO"""
Expand Down
3 changes: 3 additions & 0 deletions src/main/askai/core/processor/command_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def query_desc(self) -> str:
"file, folder and application management, listing, device assessment or inquiries."
)

def bind(self, next_in_chain: 'AIProcessor'):
pass # Avoid re-binding the next in chain processor.

def next_in_chain(self) -> AIProcessor:
return AIProcessor.get_by_name(OutputProcessor.__name__)

Expand Down
22 changes: 14 additions & 8 deletions src/main/askai/core/processor/generic_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@

from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import CacheService
from askai.core.component.cache_service import cache
from askai.core.model.query_response import QueryResponse
from askai.core.processor.ai_processor import AIProcessor
from askai.core.processor.internet_processor import InternetProcessor
from askai.core.support.shared_instances import shared


Expand All @@ -40,23 +41,28 @@ def query_desc(self) -> str:
def process(self, query_response: QueryResponse) -> Tuple[bool, Optional[str]]:
status = False
output = None
template = PromptTemplate(
input_variables=['user'], template=self.template())
final_prompt: str = msg.translate(
template.format(user=prompt.user))
template = PromptTemplate(input_variables=['user'], template=self.template())
final_prompt: str = msg.translate(template.format(user=prompt.user))
shared.context.set("SETUP", final_prompt, 'system')
shared.context.set("QUESTION", query_response.question)
context: List[dict] = shared.context.get_many("GENERAL", "SETUP", "QUESTION")
log.info("Setup::[GENERIC] '%s' context=%s", query_response.question, context)
try:
if query_response.require_internet:
log.info("Internet is required to fulfill the request.")
i_processor = AIProcessor.get_by_name(InternetProcessor.__name__)
status, output = i_processor.process(query_response)
i_ctx = shared.context.get("INTERNET")
list(map(lambda c: context.insert(len(context) - 2, c), i_ctx))
if (response := shared.engine.ask(context, temperature=1, top_p=1)) and response.is_success:
output = response.message
CacheService.save_reply(query_response.question, query_response.question)
shared.context.push("GENERAL", output, 'assistant')
CacheService.save_reply(query_response.question, output)
CacheService.save_query_history()
cache.save_reply(query_response.question, output)
cache.save_query_history()
status = True
else:
output = msg.llm_error(response.message)
except Exception as err:
output = msg.llm_error(str(err))
finally:
return status, output
63 changes: 63 additions & 0 deletions src/main/askai/core/processor/internet_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
@project: HsPyLib-AskAI
@package: askai.core.processor
@file: generic_processor.py
@created: Fri, 23 Feb 2024
@author: <B>H</B>ugo <B>S</B>aporetti <B>J</B>unior"
@site: https://github.com/yorevs/hspylib
@license: MIT - Please refer to <https://opensource.org/licenses/MIT>
Copyright·(c)·2024,·HSPyLib
"""
import logging as log
from typing import Tuple, Optional, List

from langchain_core.prompts import PromptTemplate

from askai.core.askai_messages import msg
from askai.core.askai_prompt import prompt
from askai.core.component.cache_service import cache
from askai.core.component.internet_service import internet
from askai.core.model.query_response import QueryResponse
from askai.core.model.search_result import SearchResult
from askai.core.processor.ai_processor import AIProcessor
from askai.core.support.object_mapper import object_mapper
from askai.core.support.shared_instances import shared


class InternetProcessor(AIProcessor):
"""Process generic prompts."""

def __init__(self):
super().__init__('internet-prompt', 'internet-persona')

def process(self, query_response: QueryResponse) -> Tuple[bool, Optional[str]]:
status = False
output = None
template = PromptTemplate(input_variables=['user'], template=self.template())
final_prompt: str = msg.translate(template.format(user=prompt.user))
shared.context.set("SETUP", final_prompt, 'system')
shared.context.set("QUESTION", query_response.question)
context: List[dict] = shared.context.get_many("SETUP", "QUESTION")
log.info("Setup::[INTERNET] '%s' context=%s", query_response.question, context)
try:
if not (response := cache.read_reply(query_response.question)):
if (response := shared.engine.ask(context, temperature=0.0, top_p=0.0)) and response.is_success:
search_result: SearchResult = object_mapper.of_json(response.message, SearchResult)
if results := internet.search(search_result.query):
search_result.results = results
output = str(search_result)
shared.context.set("INTERNET", output, 'assistant')
cache.save_reply(query_response.question, output)
status = True
else:
output = msg.llm_error(response.message)
else:
log.debug('Reply found for "%s" in cache.', query_response.question)
output = response
status = True
finally:
return status, output
Loading

0 comments on commit 626fba5

Please sign in to comment.