Skip to content

Commit

Permalink
Version 0.1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Nov 17, 2023
1 parent 61a1424 commit 587350a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 7 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to this project will be documented in this file.

## [0.1.2] 2023-11-17
### Fixed
- Task embeddings inputs `wavcaps_audioset_sl` and `wavcaps_bbc_sound_effects`.

## [0.1.1] 2023-11-09
### Added
- Unittests for hf model.
Expand Down
Binary file removed data/sample.wav
Binary file not shown.
6 changes: 4 additions & 2 deletions src/conette/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__license__ = "MIT"
__maintainer__ = "Etienne Labbé (Labbeti)"
__status__ = "Development"
__version__ = "0.1.1"
__version__ = "0.1.2"


from pathlib import Path
Expand All @@ -19,9 +19,11 @@
from conette.huggingface.config import CoNeTTEConfig # noqa: F401
from conette.huggingface.model import CoNeTTEModel # noqa: F401

DEFAULT_MODEL_NAME = "Labbeti/conette"


def conette(
pretrained_model_name_or_path: Optional[str] = "Labbeti/conette",
pretrained_model_name_or_path: Optional[str] = DEFAULT_MODEL_NAME,
config_kwds: Optional[dict[str, Any]] = None,
model_kwds: Optional[dict[str, Any]] = None,
) -> CoNeTTEModel:
Expand Down
6 changes: 3 additions & 3 deletions src/conette/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
self._register_load_state_dict_pre_hook(self._pre_hook_load_state_dict)

device = get_device(device)
self.to(device=device)
self.to(device=device) # type: ignore

if inference:
self.eval_and_detach()
Expand Down Expand Up @@ -204,8 +204,8 @@ def forward(
for i, task in enumerate(tasks):
task = task.split("_")
dataset_lst[i] = task[0]
if len(task) == 2:
source_lst[i] = task[1]
if len(task) >= 2:
source_lst[i] = "_".join(task[1:])

batch["dataset"] = dataset_lst
batch["source"] = source_lst
Expand Down
20 changes: 18 additions & 2 deletions src/conette/nn/functional/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,36 @@ def get_activation_fn(name: str) -> Callable[[Tensor], Tensor]:


def get_device(
device: Union[str, torch.device, None] = "auto"
device: Union[str, torch.device, None] = "auto",
safe_auto: bool = True,
) -> Optional[torch.device]:
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"

if device == "cuda" and safe_auto:
try:
_ = torch.empty((1,), device=device)
except RuntimeError:
device = "cpu"

if isinstance(device, str):
device = torch.device(device)
return device


def get_device_name(
device_name: Union[str, torch.device, None] = "auto"
device_name: Union[str, torch.device, None] = "auto",
safe_auto: bool = True,
) -> Optional[str]:
if device_name == "auto":
device_name = "cuda" if torch.cuda.is_available() else "cpu"

if device_name == "cuda" and safe_auto:
try:
_ = torch.empty((1,), device=device_name)
except RuntimeError:
device_name = "cpu"

if isinstance(device_name, torch.device):
device_name = f"{device_name.type}:{device_name.index}"
return device_name

0 comments on commit 587350a

Please sign in to comment.