Skip to content

Commit 88cc997

Browse files
committed
Improve initial recorder device select
1 parent 35b6b05 commit 88cc997

File tree

7 files changed

+127
-107
lines changed

7 files changed

+127
-107
lines changed

src/demo/components/vision_demo.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from askai.core.engine.ai_vision import AIVision
2-
from askai.core.support.shared_instances import shared
3-
from utils import init_context
1+
import os
2+
3+
from askai.core.features.router.tools.vision import offline_captioner
44

55
if __name__ == "__main__":
6-
init_context("vision-demo")
7-
vision: AIVision = shared.engine.vision()
6+
# init_context("vision-demo")
7+
# vision: AIVision = shared.engine.vision()
88
load_dir: str = "${HOME}/.config/hhs/askai/cache/pictures/photos"
99
image_file: str = "eu-edvaldo-suecia.jpg"
10-
result = vision.caption(image_file, load_dir)
11-
print(result)
10+
# result = vision.caption(image_file, load_dir)
11+
# print(result)
12+
result2 = offline_captioner(os.path.join(load_dir, image_file))
13+
print(result2)

src/main/askai/__classpath__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,31 @@
1212
1313
Copyright (c) 2024, HomeSetup
1414
"""
15-
16-
from askai.core.model.api_keys import ApiKeys
17-
from hspylib.core.metaclass.classpath import Classpath
18-
from hspylib.core.tools.commons import parent_path, root_dir
19-
2015
import logging as log
2116
import os
22-
import pydantic
2317
import sys
2418
import warnings
2519

26-
warnings.simplefilter(action="ignore", category=FutureWarning)
20+
import pydantic
21+
from clitt.core.term.commons import is_a_tty
22+
from hspylib.core.metaclass.classpath import Classpath
23+
from hspylib.core.tools.commons import parent_path, root_dir, is_debugging
24+
25+
from askai.core.model.api_keys import ApiKeys
26+
27+
if not is_debugging():
28+
warnings.simplefilter("ignore", category=FutureWarning)
29+
warnings.simplefilter("ignore", category=UserWarning)
30+
warnings.simplefilter("ignore", category=DeprecationWarning)
31+
32+
if not is_a_tty():
33+
log.getLogger().setLevel(log.ERROR)
2734

2835
if not os.environ.get("USER_AGENT"):
2936
# The AskAI User Agent, required by the langchain framework
3037
ASKAI_USER_AGENT: str = "AskAI-User-Agent"
3138
os.environ["USER_AGENT"] = ASKAI_USER_AGENT
3239

33-
3440
try:
3541
API_KEYS: ApiKeys = ApiKeys()
3642
except pydantic.v1.error_wrappers.ValidationError as err:
@@ -40,7 +46,7 @@
4046

4147

4248
class _Classpath(Classpath):
43-
"""TODO"""
49+
"""A class for managing classpath-related operations. Uses the Classpath metaclass."""
4450

4551
def __init__(self):
4652
super().__init__(parent_path(__file__), parent_path(root_dir()), (parent_path(__file__) / "resources"))

src/main/askai/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
1313
Copyright (c) 2024, HomeSetup
1414
"""
15-
from askai.__classpath__ import classpath
16-
from askai.core.askai import AskAi
17-
from askai.core.askai_cli import AskAiCli
18-
from askai.core.askai_configs import configs
19-
from askai.core.support.shared_instances import shared
20-
from askai.tui.askai_app import AskAiApp
21-
from clitt.core.term.commons import is_a_tty
15+
import logging as log
16+
import os
17+
import sys
18+
from textwrap import dedent
19+
from typing import Any, Optional
20+
2221
from clitt.core.tui.tui_application import TUIApplication
2322
from hspylib.core.enums.charset import Charset
2423
from hspylib.core.tools.commons import to_bool
@@ -27,15 +26,13 @@
2726
from hspylib.modules.application.argparse.parser_action import ParserAction
2827
from hspylib.modules.application.exit_status import ExitStatus
2928
from hspylib.modules.application.version import Version
30-
from textwrap import dedent
31-
from typing import Any, Optional
32-
33-
import logging as log
34-
import os
35-
import sys
3629

37-
if not is_a_tty():
38-
log.getLogger().setLevel(log.ERROR)
30+
from askai.__classpath__ import classpath
31+
from askai.core.askai import AskAi
32+
from askai.core.askai_cli import AskAiCli
33+
from askai.core.askai_configs import configs
34+
from askai.core.support.shared_instances import shared
35+
from askai.tui.askai_app import AskAiApp
3936

4037

4138
class Main(TUIApplication):

src/main/askai/core/askai_cli.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,21 @@
1212
1313
Copyright (c) 2024, HomeSetup
1414
"""
15+
import logging as log
16+
import os
17+
from functools import partial
18+
from pathlib import Path
19+
from threading import Thread
20+
from typing import List, TypeAlias
21+
22+
import nltk
23+
import pause
24+
from clitt.core.term.cursor import cursor
25+
from clitt.core.term.screen import screen
26+
from clitt.core.tui.line_input.keyboard_input import KeyboardInput
27+
from hspylib.modules.eventbus.event import Event
28+
from rich.progress import Progress
29+
1530
from askai.core.askai import AskAi
1631
from askai.core.askai_configs import configs
1732
from askai.core.askai_events import *
@@ -24,20 +39,6 @@
2439
from askai.core.support.shared_instances import shared
2540
from askai.core.support.text_formatter import text_formatter
2641
from askai.core.support.utilities import display_text
27-
from clitt.core.term.cursor import cursor
28-
from clitt.core.term.screen import screen
29-
from clitt.core.tui.line_input.keyboard_input import KeyboardInput
30-
from functools import partial
31-
from hspylib.modules.eventbus.event import Event
32-
from pathlib import Path
33-
from rich.progress import Progress
34-
from threading import Thread
35-
from typing import List, TypeAlias
36-
37-
import logging as log
38-
import nltk
39-
import os
40-
import pause
4142

4243
QueryString: TypeAlias = str | List[str] | None
4344

@@ -161,7 +162,7 @@ def _startup(self) -> None:
161162
# List of tasks for progress tracking
162163
tasks = [
163164
"Downloading nltk data",
164-
"Preloading input history",
165+
"Loading input history",
165166
"Starting scheduler",
166167
"Setting up recorder",
167168
"Starting player delay",
@@ -180,7 +181,7 @@ def _startup(self) -> None:
180181
self._progress.update(task, advance=1, description="[green]Downloading nltk data")
181182
nltk.download("averaged_perceptron_tagger", quiet=True, download_dir=CACHE_DIR)
182183
cache.cache_enable = configs.is_cache
183-
self._progress.update(task, advance=1, description="[green]Preloading input history")
184+
self._progress.update(task, advance=1, description="[green]Loading input history")
184185
KeyboardInput.preload_history(cache.load_input_history(commands()))
185186
self._progress.update(task, advance=1, description="[green]Starting scheduler")
186187
scheduler.start()

src/main/askai/core/commander/commands/tts_stt_cmd.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,20 @@
1212
1313
Copyright (c) 2024, HomeSetup
1414
"""
15+
import os
1516
from abc import ABC
17+
from pathlib import Path
18+
19+
import pause
20+
from clitt.core.tui.mselect.mselect import mselect
21+
1622
from askai.core.askai_configs import configs
1723
from askai.core.askai_settings import settings
1824
from askai.core.component.audio_player import player
19-
from askai.core.component.recorder import recorder
25+
from askai.core.component.recorder import recorder, InputDevice
2026
from askai.core.support.shared_instances import shared
2127
from askai.core.support.text_formatter import text_formatter
2228
from askai.core.support.utilities import copy_file
23-
from pathlib import Path
24-
25-
import os
2629

2730

2831
class TtsSttCmd(ABC):
@@ -101,15 +104,25 @@ def device_set(name_or_index: str | int | None = None) -> None:
101104
:param name_or_index: The name or index of the audio input device to set. If None, the default device will be
102105
used.
103106
"""
107+
device: InputDevice | None = None
104108
all_devices = recorder.devices
105-
if name_or_index.isdecimal() and 0 <= int(name_or_index) <= len(all_devices):
109+
110+
def _set_device(_device) -> bool:
111+
if recorder.set_device(_device):
112+
text_formatter.cmd_print(f"`Text-To-Speech` device changed to %GREEN%{_device}%NC%")
113+
return True
114+
text_formatter.cmd_print(f"%HOM%%ED2%Error: '{_device}' is not an Audio Input device!%NC%")
115+
all_devices.remove(_device)
116+
pause.seconds(2)
117+
return False
118+
119+
if not name_or_index:
120+
device: InputDevice = mselect(
121+
all_devices, f"{'-=' * 40}%EOL%AskAI::Select the Audio Input device%EOL%{'=-' * 40}%EOL%")
122+
elif name_or_index.isdecimal() and 0 <= int(name_or_index) <= len(all_devices):
106123
name_or_index = all_devices[int(name_or_index)][1]
107-
if device := next((dev for dev in all_devices if dev[1] == name_or_index), None):
108-
if recorder.set_device(device):
109-
text_formatter.cmd_print(f"`Text-To-Speech` device changed to %GREEN%{device[1]}%NC%")
110-
else:
111-
text_formatter.cmd_print(f"%RED%Device: '{name_or_index}' failed to initialize!%NC%")
112-
else:
124+
device = next((dev for dev in all_devices if dev[1] == name_or_index), None)
125+
if not (device and _set_device(device)):
113126
text_formatter.cmd_print(f"%RED%Invalid audio input device: '{name_or_index}'%NC%")
114127

115128
@staticmethod

src/main/askai/core/component/recorder.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,28 @@
1212
1313
Copyright (c) 2024, HomeSetup
1414
"""
15-
from askai.core.askai_configs import configs
16-
from askai.core.askai_events import events
17-
from askai.core.askai_messages import msg
18-
from askai.core.component.cache_service import REC_DIR
19-
from askai.core.component.scheduler import Scheduler
20-
from askai.core.support.utilities import display_text, seconds
21-
from askai.exception.exceptions import InvalidInputDevice, InvalidRecognitionApiError
22-
from askai.language.language import Language
23-
from clitt.core.tui.mselect.mselect import mselect
15+
import logging as log
16+
import operator
17+
import sys
18+
from pathlib import Path
19+
from typing import Callable, Optional, TypeAlias
20+
2421
from hspylib.core.enums.enumeration import Enumeration
2522
from hspylib.core.metaclass.classpath import AnyPath
2623
from hspylib.core.metaclass.singleton import Singleton
2724
from hspylib.core.preconditions import check_argument, check_state
2825
from hspylib.core.zoned_datetime import now_ms
2926
from hspylib.modules.application.exit_status import ExitStatus
30-
from pathlib import Path
3127
from speech_recognition import AudioData, Microphone, Recognizer, RequestError, UnknownValueError, WaitTimeoutError
32-
from typing import Callable, Optional, TypeAlias
3328

34-
import logging as log
35-
import operator
36-
import pause
37-
import sys
29+
from askai.core.askai_configs import configs
30+
from askai.core.askai_events import events
31+
from askai.core.askai_messages import msg
32+
from askai.core.component.cache_service import REC_DIR
33+
from askai.core.component.scheduler import Scheduler
34+
from askai.core.support.utilities import seconds, display_text
35+
from askai.exception.exceptions import InvalidInputDevice, InvalidRecognitionApiError
36+
from askai.language.language import Language
3837

3938
InputDevice: TypeAlias = tuple[int, str]
4039

@@ -240,23 +239,19 @@ def _select_device(self) -> None:
240239
available: list[str] = list(filter(lambda d: d, map(str.strip, configs.recorder_devices)))
241240
device: InputDevice | None = None
242241
devices: list[InputDevice] = list(reversed(self.devices))
243-
while not device:
242+
while devices and not device:
244243
if available:
245244
for dev in devices:
246245
if dev[1] in available and self.set_device(dev):
247246
device = dev
248247
break
249248
if not device:
250-
device: InputDevice = mselect(
251-
devices, f"{'-=' * 40}%EOL%AskAI::Select the Audio Input device%EOL%{'=-' * 40}%EOL%"
252-
)
253-
if not device:
249+
if not (device := next(devices.__iter__(), None)):
250+
display_text(f"%HOM%%ED2%Error: Unable to setup an Audio Input device!%NC%")
254251
sys.exit(ExitStatus.FAILED.val)
255-
elif not self.set_device(device):
256-
display_text(f"%HOM%%ED2%Error: '{device[1]}' is not an Audio Input device!%NC%")
257-
devices.remove(device)
258-
device = None
259-
pause.seconds(2)
252+
if device and not self.set_device(device):
253+
devices.remove(device)
254+
device = None
260255

261256

262257
assert (recorder := Recorder().INSTANCE) is not None

src/main/askai/core/features/router/tools/vision.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,41 @@
1+
import os
2+
from textwrap import indent
3+
4+
import torch
5+
from PIL import Image
6+
from hspylib.core.config.path_object import PathObject
7+
from hspylib.core.enums.enumeration import Enumeration
8+
from hspylib.core.metaclass.classpath import AnyPath
9+
from transformers import BlipForConditionalGeneration, BlipProcessor
10+
111
from askai.core.askai_events import events
212
from askai.core.askai_messages import msg
313
from askai.core.component.cache_service import PICTURE_DIR
414
from askai.core.engine.ai_vision import AIVision
515
from askai.core.features.validation.accuracy import resolve_x_refs
616
from askai.core.model.image_result import ImageResult
717
from askai.core.support.shared_instances import shared
8-
from hspylib.core.config.path_object import PathObject
9-
from hspylib.core.enums.enumeration import Enumeration
10-
from hspylib.core.metaclass.classpath import AnyPath
11-
from PIL import Image
12-
from textwrap import indent
13-
from transformers import BlipForConditionalGeneration, BlipProcessor
1418

15-
import os
16-
import torch
19+
20+
class HFModel(Enumeration):
21+
"""Available Hugging Face models"""
22+
23+
# fmt: off
24+
SF_BLIP_BASE = "Salesforce/blip-image-captioning-base"
25+
SF_BLIP_LARGE = "Salesforce/blip-image-captioning-large"
26+
# fmt: on
27+
28+
@staticmethod
29+
def default() -> "HFModel":
30+
"""Return the default HF model."""
31+
return HFModel.SF_BLIP_BASE
1732

1833

1934
def offline_captioner(path_name: AnyPath) -> str:
2035
"""This tool is used to describe an image.
2136
:param path_name: The path of the image to describe.
2237
"""
2338

24-
class HFModel(Enumeration):
25-
"""Available Hugging Face models"""
26-
27-
# fmt: off
28-
SF_BLIP_BASE = "Salesforce/blip-image-captioning-base"
29-
SF_BLIP_LARGE = "Salesforce/blip-image-captioning-large"
30-
# fmt: on
31-
32-
@staticmethod
33-
def default() -> "HFModel":
34-
"""Return the default HF model."""
35-
return HFModel.SF_BLIP_LARGE
36-
3739
caption: str = "Not available"
3840

3941
posix_path: PathObject = PathObject.of(path_name)
@@ -46,12 +48,16 @@ def default() -> "HFModel":
4648

4749
if posix_path.exists:
4850
events.reply.emit(message=msg.describe_image(str(posix_path)))
49-
hf_model: HFModel = HFModel.default()
5051
# Use GPU if it's available
5152
device = "cuda" if torch.cuda.is_available() else "cpu"
5253
image = Image.open(str(posix_path)).convert("RGB")
53-
model = BlipForConditionalGeneration.from_pretrained(hf_model.value).to(device)
54-
processor = BlipProcessor.from_pretrained(hf_model.value)
54+
model_id: str = HFModel.default().value
55+
match model_id.casefold():
56+
case model if "blip-" in model:
57+
model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
58+
processor = BlipProcessor.from_pretrained(model_id)
59+
case _:
60+
raise ValueError(f"Unsupported model: '{model_id}'")
5561
inputs = processor(images=image, return_tensors="pt").to(device)
5662
outputs = model.generate(**inputs)
5763
caption = processor.decode(outputs[0], skip_special_tokens=True)

0 commit comments

Comments
 (0)